aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitignore3
-rw-r--r--LICENSE (renamed from COPYING)0
-rw-r--r--Makefile13
-rw-r--r--README.md8
-rw-r--r--conn/bind_std.go544
-rw-r--r--conn/bind_std_test.go250
-rw-r--r--conn/bind_windows.go601
-rw-r--r--conn/bindtest/bindtest.go136
-rw-r--r--conn/boundif_android.go34
-rw-r--r--conn/conn.go133
-rw-r--r--conn/conn_test.go24
-rw-r--r--conn/controlfns.go43
-rw-r--r--conn/controlfns_linux.go69
-rw-r--r--conn/controlfns_unix.go35
-rw-r--r--conn/controlfns_windows.go23
-rw-r--r--conn/default.go10
-rw-r--r--conn/errors_default.go12
-rw-r--r--conn/errors_linux.go26
-rw-r--r--conn/features_default.go15
-rw-r--r--conn/features_linux.go29
-rw-r--r--conn/gso_default.go21
-rw-r--r--conn/gso_linux.go65
-rw-r--r--conn/mark_default.go12
-rw-r--r--conn/mark_unix.go (renamed from device/mark_unix.go)16
-rw-r--r--conn/sticky_default.go42
-rw-r--r--conn/sticky_linux.go112
-rw-r--r--conn/sticky_linux_test.go266
-rw-r--r--conn/winrio/rio_windows.go254
-rw-r--r--device/allowedips.go359
-rw-r--r--device/allowedips_rand_test.go120
-rw-r--r--device/allowedips_test.go75
-rw-r--r--device/bind_test.go23
-rw-r--r--device/boundif_android.go44
-rw-r--r--device/boundif_windows.go64
-rw-r--r--device/channels.go137
-rw-r--r--device/conn.go187
-rw-r--r--device/conn_default.go178
-rw-r--r--device/conn_linux.go757
-rw-r--r--device/constants.go7
-rw-r--r--device/cookie.go7
-rw-r--r--device/cookie_test.go7
-rw-r--r--device/device.go515
-rw-r--r--device/device_test.go594
-rw-r--r--device/devicestate_string.go16
-rw-r--r--device/endpoint_test.go40
-rw-r--r--device/indextable.go7
-rw-r--r--device/ip.go2
-rw-r--r--device/kdf_test.go4
-rw-r--r--device/keypair.go9
-rw-r--r--device/logger.go67
-rw-r--r--device/mark_default.go12
-rw-r--r--device/misc.go48
-rw-r--r--device/mobilequirks.go19
-rw-r--r--device/noise-helpers.go12
-rw-r--r--device/noise-protocol.go251
-rw-r--r--device/noise-types.go30
-rw-r--r--device/noise_test.go49
-rw-r--r--device/peer.go267
-rw-r--r--device/pools.go157
-rw-r--r--device/pools_test.go139
-rw-r--r--device/queueconstants_android.go7
-rw-r--r--device/queueconstants_default.go7
-rw-r--r--device/queueconstants_ios.go23
-rw-r--r--device/queueconstants_windows.go15
-rw-r--r--device/race_disabled_test.go10
-rw-r--r--device/race_enabled_test.go10
-rw-r--r--device/receive.go635
-rw-r--r--device/send.go652
-rw-r--r--device/sticky_default.go12
-rw-r--r--device/sticky_linux.go224
-rw-r--r--device/timers.go76
-rw-r--r--device/tun.go49
-rw-r--r--device/tun_test.go56
-rw-r--r--device/uapi.go632
-rw-r--r--device/version.go3
-rw-r--r--format_test.go51
-rw-r--r--go.mod16
-rw-r--r--go.sum28
-rw-r--r--ipc/namedpipe/file.go (renamed from ipc/winpipe/file.go)138
-rw-r--r--ipc/namedpipe/namedpipe.go485
-rw-r--r--ipc/namedpipe/namedpipe_test.go674
-rw-r--r--ipc/uapi_bsd.go79
-rw-r--r--ipc/uapi_linux.go81
-rw-r--r--ipc/uapi_unix.go66
-rw-r--r--ipc/uapi_wasm.go15
-rw-r--r--ipc/uapi_windows.go14
-rw-r--r--ipc/winpipe/mksyscall.go9
-rw-r--r--ipc/winpipe/pipe.go509
-rw-r--r--ipc/winpipe/zsyscall_windows.go238
-rw-r--r--main.go81
-rw-r--r--main_windows.go31
-rw-r--r--ratelimiter/ratelimiter.go134
-rw-r--r--ratelimiter/ratelimiter_test.go90
-rw-r--r--replay/replay.go101
-rw-r--r--replay/replay_test.go44
-rw-r--r--rwcancel/fdset.go22
-rw-r--r--rwcancel/rwcancel.go60
-rw-r--r--rwcancel/rwcancel_stub.go9
-rw-r--r--rwcancel/select_default.go14
-rw-r--r--rwcancel/select_linux.go13
-rw-r--r--tai64n/tai64n.go25
-rw-r--r--tai64n/tai64n_test.go42
-rwxr-xr-xtests/netns.sh2
-rw-r--r--tun/alignment_windows_test.go67
-rw-r--r--tun/checksum.go118
-rw-r--r--tun/checksum_test.go35
-rw-r--r--tun/errors.go12
-rw-r--r--tun/netstack/examples/http_client.go54
-rw-r--r--tun/netstack/examples/http_server.go51
-rw-r--r--tun/netstack/examples/ping_client.go75
-rw-r--r--tun/netstack/tun.go1055
-rw-r--r--tun/offload_linux.go993
-rw-r--r--tun/offload_linux_test.go752
-rw-r--r--tun/operateonfd.go4
-rw-r--r--tun/tun.go42
-rw-r--r--tun/tun_darwin.go259
-rw-r--r--tun/tun_freebsd.go398
-rw-r--r--tun/tun_linux.go415
-rw-r--r--tun/tun_openbsd.go117
-rw-r--r--tun/tun_windows.go265
-rw-r--r--tun/tuntest/tuntest.go155
-rw-r--r--tun/wintun/iphlpapi/conversion_windows.go25
-rw-r--r--tun/wintun/iphlpapi/mksyscall.go8
-rw-r--r--tun/wintun/iphlpapi/zsyscall_windows.go60
-rw-r--r--tun/wintun/namespace_windows.go98
-rw-r--r--tun/wintun/namespaceapi/mksyscall.go8
-rw-r--r--tun/wintun/namespaceapi/namespaceapi_windows.go83
-rw-r--r--tun/wintun/namespaceapi/zsyscall_windows.go116
-rw-r--r--tun/wintun/nci/mksyscall.go8
-rw-r--r--tun/wintun/nci/nci_windows.go28
-rw-r--r--tun/wintun/nci/zsyscall_windows.go60
-rw-r--r--tun/wintun/registry/mksyscall.go8
-rw-r--r--tun/wintun/registry/registry_windows.go272
-rw-r--r--tun/wintun/registry/registry_windows_test.go103
-rw-r--r--tun/wintun/registry/zregistry_windows.go63
-rw-r--r--tun/wintun/ring_windows.go97
-rw-r--r--tun/wintun/setupapi/mksyscall.go8
-rw-r--r--tun/wintun/setupapi/setupapi_windows.go506
-rw-r--r--tun/wintun/setupapi/setupapi_windows_test.go488
-rw-r--r--tun/wintun/setupapi/types_windows.go568
-rw-r--r--tun/wintun/setupapi/types_windows_386.go11
-rw-r--r--tun/wintun/setupapi/types_windows_amd64.go11
-rw-r--r--tun/wintun/setupapi/zsetupapi_windows.go398
-rw-r--r--tun/wintun/setupapi/zsetupapi_windows_test.go20
-rw-r--r--tun/wintun/wintun_windows.go803
-rw-r--r--version.go3
146 files changed, 11661 insertions, 9540 deletions
diff --git a/.gitignore b/.gitignore
index 96650bb..e460293 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1 @@
wireguard-go
-vendor
-.gopath
-ireallywantobuildon_linux.go
diff --git a/COPYING b/LICENSE
index f85e365..f85e365 100644
--- a/COPYING
+++ b/LICENSE
diff --git a/Makefile b/Makefile
index 47f22d6..3f6e407 100644
--- a/Makefile
+++ b/Makefile
@@ -10,10 +10,10 @@ MAKEFLAGS += --no-print-directory
generate-version-and-build:
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
tag="$$(git describe --dirty 2>/dev/null)" && \
- ver="$$(printf 'package device\nconst WireGuardGoVersion = "%s"\n' "$${tag#v}")" && \
- [ "$$(cat device/version.go 2>/dev/null)" != "$$ver" ] && \
- echo "$$ver" > device/version.go && \
- git update-index --assume-unchanged device/version.go || true
+ ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \
+ [ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \
+ echo "$$ver" > version.go && \
+ git update-index --assume-unchanged version.go || true
@$(MAKE) wireguard-go
wireguard-go: $(wildcard *.go) $(wildcard */*.go)
@@ -22,7 +22,10 @@ wireguard-go: $(wildcard *.go) $(wildcard */*.go)
install: wireguard-go
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go"
+test:
+ go test ./...
+
clean:
rm -f wireguard-go
-.PHONY: all clean install generate-version-and-build
+.PHONY: all clean test install generate-version-and-build
diff --git a/README.md b/README.md
index d73bf59..074f7ec 100644
--- a/README.md
+++ b/README.md
@@ -18,7 +18,7 @@ To run wireguard-go without forking to the background, pass `-f` or `--foregroun
$ wireguard-go -f wg0
```
-When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/WireGuard/about/src/tools/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
+When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/wireguard-tools/about/src/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
@@ -26,7 +26,7 @@ To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
### Linux
-This will run on Linux; however **YOU SHOULD NOT RUN THIS ON LINUX**. Instead use the kernel module; see the [installation page](https://www.wireguard.com/install/) for instructions.
+This will run on Linux; however you should instead use the kernel module, which is faster and better integrated into the OS. See the [installation page](https://www.wireguard.com/install/) for instructions.
### macOS
@@ -46,7 +46,7 @@ This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapp
## Building
-This requires an installation of [go](https://golang.org) ≥ 1.12.
+This requires an installation of the latest version of [Go](https://go.dev/).
```
$ git clone https://git.zx2c4.com/wireguard-go
@@ -56,7 +56,7 @@ $ make
## License
- Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
diff --git a/conn/bind_std.go b/conn/bind_std.go
new file mode 100644
index 0000000..46df7fd
--- /dev/null
+++ b/conn/bind_std.go
@@ -0,0 +1,544 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "net/netip"
+ "runtime"
+ "strconv"
+ "sync"
+ "syscall"
+
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+)
+
+var (
+ _ Bind = (*StdNetBind)(nil)
+)
+
+// StdNetBind implements Bind for all platforms. While Windows has its own Bind
+// (see bind_windows.go), it may fall back to StdNetBind.
+// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
+// methods for sending and receiving multiple datagrams per-syscall. See the
+// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
+type StdNetBind struct {
+ mu sync.Mutex // protects all fields except as specified
+ ipv4 *net.UDPConn
+ ipv6 *net.UDPConn
+ ipv4PC *ipv4.PacketConn // will be nil on non-Linux
+ ipv6PC *ipv6.PacketConn // will be nil on non-Linux
+ ipv4TxOffload bool
+ ipv4RxOffload bool
+ ipv6TxOffload bool
+ ipv6RxOffload bool
+
+ // these two fields are not guarded by mu
+ udpAddrPool sync.Pool
+ msgsPool sync.Pool
+
+ blackhole4 bool
+ blackhole6 bool
+}
+
+func NewStdNetBind() Bind {
+ return &StdNetBind{
+ udpAddrPool: sync.Pool{
+ New: func() any {
+ return &net.UDPAddr{
+ IP: make([]byte, 16),
+ }
+ },
+ },
+
+ msgsPool: sync.Pool{
+ New: func() any {
+ // ipv6.Message and ipv4.Message are interchangeable as they are
+ // both aliases for x/net/internal/socket.Message.
+ msgs := make([]ipv6.Message, IdealBatchSize)
+ for i := range msgs {
+ msgs[i].Buffers = make(net.Buffers, 1)
+ msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
+ }
+ return &msgs
+ },
+ },
+ }
+}
+
+type StdNetEndpoint struct {
+ // AddrPort is the endpoint destination.
+ netip.AddrPort
+ // src is the current sticky source address and interface index, if
+ // supported. Typically this is a PKTINFO structure from/for control
+ // messages, see unix.PKTINFO for an example.
+ src []byte
+}
+
+var (
+ _ Bind = (*StdNetBind)(nil)
+ _ Endpoint = &StdNetEndpoint{}
+)
+
+func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
+ e, err := netip.ParseAddrPort(s)
+ if err != nil {
+ return nil, err
+ }
+ return &StdNetEndpoint{
+ AddrPort: e,
+ }, nil
+}
+
+func (e *StdNetEndpoint) ClearSrc() {
+ if e.src != nil {
+ // Truncate src, no need to reallocate.
+ e.src = e.src[:0]
+ }
+}
+
+func (e *StdNetEndpoint) DstIP() netip.Addr {
+ return e.AddrPort.Addr()
+}
+
+// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
+
+func (e *StdNetEndpoint) DstToBytes() []byte {
+ b, _ := e.AddrPort.MarshalBinary()
+ return b
+}
+
+func (e *StdNetEndpoint) DstToString() string {
+ return e.AddrPort.String()
+}
+
+func listenNet(network string, port int) (*net.UDPConn, int, error) {
+ conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
+ if err != nil {
+ return nil, 0, err
+ }
+
+ // Retrieve port.
+ laddr := conn.LocalAddr()
+ uaddr, err := net.ResolveUDPAddr(
+ laddr.Network(),
+ laddr.String(),
+ )
+ if err != nil {
+ return nil, 0, err
+ }
+ return conn.(*net.UDPConn), uaddr.Port, nil
+}
+
+func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ var err error
+ var tries int
+
+ if s.ipv4 != nil || s.ipv6 != nil {
+ return nil, 0, ErrBindAlreadyOpen
+ }
+
+ // Attempt to open ipv4 and ipv6 listeners on the same port.
+ // If uport is 0, we can retry on failure.
+again:
+ port := int(uport)
+ var v4conn, v6conn *net.UDPConn
+ var v4pc *ipv4.PacketConn
+ var v6pc *ipv6.PacketConn
+
+ v4conn, port, err = listenNet("udp4", port)
+ if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
+ return nil, 0, err
+ }
+
+ // Listen on the same port as we're using for ipv4.
+ v6conn, port, err = listenNet("udp6", port)
+ if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
+ v4conn.Close()
+ tries++
+ goto again
+ }
+ if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
+ v4conn.Close()
+ return nil, 0, err
+ }
+ var fns []ReceiveFunc
+ if v4conn != nil {
+ s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ v4pc = ipv4.NewPacketConn(v4conn)
+ s.ipv4PC = v4pc
+ }
+ fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
+ s.ipv4 = v4conn
+ }
+ if v6conn != nil {
+ s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ v6pc = ipv6.NewPacketConn(v6conn)
+ s.ipv6PC = v6pc
+ }
+ fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
+ s.ipv6 = v6conn
+ }
+ if len(fns) == 0 {
+ return nil, 0, syscall.EAFNOSUPPORT
+ }
+
+ return fns, uint16(port), nil
+}
+
+func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
+ for i := range *msgs {
+ (*msgs)[i].OOB = (*msgs)[i].OOB[:0]
+ (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
+ }
+ s.msgsPool.Put(msgs)
+}
+
+func (s *StdNetBind) getMessages() *[]ipv6.Message {
+ return s.msgsPool.Get().(*[]ipv6.Message)
+}
+
+var (
+ // If compilation fails here these are no longer the same underlying type.
+ _ ipv6.Message = ipv4.Message{}
+)
+
+type batchReader interface {
+ ReadBatch([]ipv6.Message, int) (int, error)
+}
+
+type batchWriter interface {
+ WriteBatch([]ipv6.Message, int) (int, error)
+}
+
+func (s *StdNetBind) receiveIP(
+ br batchReader,
+ conn *net.UDPConn,
+ rxOffload bool,
+ bufs [][]byte,
+ sizes []int,
+ eps []Endpoint,
+) (n int, err error) {
+ msgs := s.getMessages()
+ for i := range bufs {
+ (*msgs)[i].Buffers[0] = bufs[i]
+ (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
+ }
+ defer s.putMessages(msgs)
+ var numMsgs int
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ if rxOffload {
+ readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
+ numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
+ if err != nil {
+ return 0, err
+ }
+ numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
+ if err != nil {
+ return 0, err
+ }
+ } else {
+ numMsgs, err = br.ReadBatch(*msgs, 0)
+ if err != nil {
+ return 0, err
+ }
+ }
+ } else {
+ msg := &(*msgs)[0]
+ msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
+ if err != nil {
+ return 0, err
+ }
+ numMsgs = 1
+ }
+ for i := 0; i < numMsgs; i++ {
+ msg := &(*msgs)[i]
+ sizes[i] = msg.N
+ if sizes[i] == 0 {
+ continue
+ }
+ addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+ ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
+ getSrcFromControl(msg.OOB[:msg.NN], ep)
+ eps[i] = ep
+ }
+ return numMsgs, nil
+}
+
+func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
+ return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
+ }
+}
+
+func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
+ return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
+ }
+}
+
+// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
+// rename the IdealBatchSize constant to BatchSize.
+func (s *StdNetBind) BatchSize() int {
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ return IdealBatchSize
+ }
+ return 1
+}
+
+func (s *StdNetBind) Close() error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ var err1, err2 error
+ if s.ipv4 != nil {
+ err1 = s.ipv4.Close()
+ s.ipv4 = nil
+ s.ipv4PC = nil
+ }
+ if s.ipv6 != nil {
+ err2 = s.ipv6.Close()
+ s.ipv6 = nil
+ s.ipv6PC = nil
+ }
+ s.blackhole4 = false
+ s.blackhole6 = false
+ s.ipv4TxOffload = false
+ s.ipv4RxOffload = false
+ s.ipv6TxOffload = false
+ s.ipv6RxOffload = false
+ if err1 != nil {
+ return err1
+ }
+ return err2
+}
+
+type ErrUDPGSODisabled struct {
+ onLaddr string
+ RetryErr error
+}
+
+func (e ErrUDPGSODisabled) Error() string {
+ return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
+}
+
+func (e ErrUDPGSODisabled) Unwrap() error {
+ return e.RetryErr
+}
+
+func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
+ s.mu.Lock()
+ blackhole := s.blackhole4
+ conn := s.ipv4
+ offload := s.ipv4TxOffload
+ br := batchWriter(s.ipv4PC)
+ is6 := false
+ if endpoint.DstIP().Is6() {
+ blackhole = s.blackhole6
+ conn = s.ipv6
+ br = s.ipv6PC
+ is6 = true
+ offload = s.ipv6TxOffload
+ }
+ s.mu.Unlock()
+
+ if blackhole {
+ return nil
+ }
+ if conn == nil {
+ return syscall.EAFNOSUPPORT
+ }
+
+ msgs := s.getMessages()
+ defer s.putMessages(msgs)
+ ua := s.udpAddrPool.Get().(*net.UDPAddr)
+ defer s.udpAddrPool.Put(ua)
+ if is6 {
+ as16 := endpoint.DstIP().As16()
+ copy(ua.IP, as16[:])
+ ua.IP = ua.IP[:16]
+ } else {
+ as4 := endpoint.DstIP().As4()
+ copy(ua.IP, as4[:])
+ ua.IP = ua.IP[:4]
+ }
+ ua.Port = int(endpoint.(*StdNetEndpoint).Port())
+ var (
+ retried bool
+ err error
+ )
+retry:
+ if offload {
+ n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
+ err = s.send(conn, br, (*msgs)[:n])
+ if err != nil && offload && errShouldDisableUDPGSO(err) {
+ offload = false
+ s.mu.Lock()
+ if is6 {
+ s.ipv6TxOffload = false
+ } else {
+ s.ipv4TxOffload = false
+ }
+ s.mu.Unlock()
+ retried = true
+ goto retry
+ }
+ } else {
+ for i := range bufs {
+ (*msgs)[i].Addr = ua
+ (*msgs)[i].Buffers[0] = bufs[i]
+ setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
+ }
+ err = s.send(conn, br, (*msgs)[:len(bufs)])
+ }
+ if retried {
+ return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
+ }
+ return err
+}
+
+func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
+ var (
+ n int
+ err error
+ start int
+ )
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ for {
+ n, err = pc.WriteBatch(msgs[start:], 0)
+ if err != nil || n == len(msgs[start:]) {
+ break
+ }
+ start += n
+ }
+ } else {
+ for _, msg := range msgs {
+ _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
+ if err != nil {
+ break
+ }
+ }
+ }
+ return err
+}
+
+const (
+ // Exceeding these values results in EMSGSIZE. They account for layer3 and
+ // layer4 headers. IPv6 does not need to account for itself as the payload
+ // length field is self excluding.
+ maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
+ maxIPv6PayloadLen = 1<<16 - 1 - 8
+
+ // This is a hard limit imposed by the kernel.
+ udpSegmentMaxDatagrams = 64
+)
+
+type setGSOFunc func(control *[]byte, gsoSize uint16)
+
+func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
+ var (
+ base = -1 // index of msg we are currently coalescing into
+ gsoSize int // segmentation size of msgs[base]
+ dgramCnt int // number of dgrams coalesced into msgs[base]
+ endBatch bool // tracking flag to start a new batch on next iteration of bufs
+ )
+ maxPayloadLen := maxIPv4PayloadLen
+ if ep.DstIP().Is6() {
+ maxPayloadLen = maxIPv6PayloadLen
+ }
+ for i, buf := range bufs {
+ if i > 0 {
+ msgLen := len(buf)
+ baseLenBefore := len(msgs[base].Buffers[0])
+ freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
+ if msgLen+baseLenBefore <= maxPayloadLen &&
+ msgLen <= gsoSize &&
+ msgLen <= freeBaseCap &&
+ dgramCnt < udpSegmentMaxDatagrams &&
+ !endBatch {
+ msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
+ if i == len(bufs)-1 {
+ setGSO(&msgs[base].OOB, uint16(gsoSize))
+ }
+ dgramCnt++
+ if msgLen < gsoSize {
+ // A smaller than gsoSize packet on the tail is legal, but
+ // it must end the batch.
+ endBatch = true
+ }
+ continue
+ }
+ }
+ if dgramCnt > 1 {
+ setGSO(&msgs[base].OOB, uint16(gsoSize))
+ }
+ // Reset prior to incrementing base since we are preparing to start a
+ // new potential batch.
+ endBatch = false
+ base++
+ gsoSize = len(buf)
+ setSrcControl(&msgs[base].OOB, ep)
+ msgs[base].Buffers[0] = buf
+ msgs[base].Addr = addr
+ dgramCnt = 1
+ }
+ return base + 1
+}
+
+type getGSOFunc func(control []byte) (int, error)
+
+func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
+ for i := firstMsgAt; i < len(msgs); i++ {
+ msg := &msgs[i]
+ if msg.N == 0 {
+ return n, err
+ }
+ var (
+ gsoSize int
+ start int
+ end = msg.N
+ numToSplit = 1
+ )
+ gsoSize, err = getGSO(msg.OOB[:msg.NN])
+ if err != nil {
+ return n, err
+ }
+ if gsoSize > 0 {
+ numToSplit = (msg.N + gsoSize - 1) / gsoSize
+ end = gsoSize
+ }
+ for j := 0; j < numToSplit; j++ {
+ if n > i {
+ return n, errors.New("splitting coalesced packet resulted in overflow")
+ }
+ copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
+ msgs[n].N = copied
+ msgs[n].Addr = msg.Addr
+ start = end
+ end += gsoSize
+ if end > msg.N {
+ end = msg.N
+ }
+ n++
+ }
+ if i != n-1 {
+ // It is legal for bytes to move within msg.Buffers[0] as a result
+ // of splitting, so we only zero the source msg len when it is not
+ // the destination of the last split operation above.
+ msg.N = 0
+ }
+ }
+ return n, nil
+}
diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go
new file mode 100644
index 0000000..34a3c9a
--- /dev/null
+++ b/conn/bind_std_test.go
@@ -0,0 +1,250 @@
+package conn
+
+import (
+ "encoding/binary"
+ "net"
+ "testing"
+
+ "golang.org/x/net/ipv6"
+)
+
+func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
+ bind := NewStdNetBind().(*StdNetBind)
+ fns, _, err := bind.Open(0)
+ if err != nil {
+ t.Fatal(err)
+ }
+ bind.Close()
+ bufs := make([][]byte, 1)
+ bufs[0] = make([]byte, 1)
+ sizes := make([]int, 1)
+ eps := make([]Endpoint, 1)
+ for _, fn := range fns {
+ // The ReceiveFuncs must not access conn-related fields on StdNetBind
+ // unguarded. Close() nils the conn-related fields resulting in a panic
+ // if they violate the mutex.
+ fn(bufs, sizes, eps)
+ }
+}
+
+func mockSetGSOSize(control *[]byte, gsoSize uint16) {
+ *control = (*control)[:cap(*control)]
+ binary.LittleEndian.PutUint16(*control, gsoSize)
+}
+
+func Test_coalesceMessages(t *testing.T) {
+ cases := []struct {
+ name string
+ buffs [][]byte
+ wantLens []int
+ wantGSO []int
+ }{
+ {
+ name: "one message no coalesce",
+ buffs: [][]byte{
+ make([]byte, 1, 1),
+ },
+ wantLens: []int{1},
+ wantGSO: []int{0},
+ },
+ {
+ name: "two messages equal len coalesce",
+ buffs: [][]byte{
+ make([]byte, 1, 2),
+ make([]byte, 1, 1),
+ },
+ wantLens: []int{2},
+ wantGSO: []int{1},
+ },
+ {
+ name: "two messages unequal len coalesce",
+ buffs: [][]byte{
+ make([]byte, 2, 3),
+ make([]byte, 1, 1),
+ },
+ wantLens: []int{3},
+ wantGSO: []int{2},
+ },
+ {
+ name: "three messages second unequal len coalesce",
+ buffs: [][]byte{
+ make([]byte, 2, 3),
+ make([]byte, 1, 1),
+ make([]byte, 2, 2),
+ },
+ wantLens: []int{3, 2},
+ wantGSO: []int{2, 0},
+ },
+ {
+ name: "three messages limited cap coalesce",
+ buffs: [][]byte{
+ make([]byte, 2, 4),
+ make([]byte, 2, 2),
+ make([]byte, 2, 2),
+ },
+ wantLens: []int{4, 2},
+ wantGSO: []int{2, 0},
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ addr := &net.UDPAddr{
+ IP: net.ParseIP("127.0.0.1").To4(),
+ Port: 1,
+ }
+ msgs := make([]ipv6.Message, len(tt.buffs))
+ for i := range msgs {
+ msgs[i].Buffers = make([][]byte, 1)
+ msgs[i].OOB = make([]byte, 0, 2)
+ }
+ got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize)
+ if got != len(tt.wantLens) {
+ t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
+ }
+ for i := 0; i < got; i++ {
+ if msgs[i].Addr != addr {
+ t.Errorf("msgs[%d].Addr != passed addr", i)
+ }
+ gotLen := len(msgs[i].Buffers[0])
+ if gotLen != tt.wantLens[i] {
+ t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
+ }
+ gotGSO, err := mockGetGSOSize(msgs[i].OOB)
+ if err != nil {
+ t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
+ }
+ if gotGSO != tt.wantGSO[i] {
+ t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
+ }
+ }
+ })
+ }
+}
+
+func mockGetGSOSize(control []byte) (int, error) {
+ if len(control) < 2 {
+ return 0, nil
+ }
+ return int(binary.LittleEndian.Uint16(control)), nil
+}
+
+func Test_splitCoalescedMessages(t *testing.T) {
+ newMsg := func(n, gso int) ipv6.Message {
+ msg := ipv6.Message{
+ Buffers: [][]byte{make([]byte, 1<<16-1)},
+ N: n,
+ OOB: make([]byte, 2),
+ }
+ binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
+ if gso > 0 {
+ msg.NN = 2
+ }
+ return msg
+ }
+
+ cases := []struct {
+ name string
+ msgs []ipv6.Message
+ firstMsgAt int
+ wantNumEval int
+ wantMsgLens []int
+ wantErr bool
+ }{
+ {
+ name: "second last split last empty",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(3, 1),
+ newMsg(0, 0),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 3,
+ wantMsgLens: []int{1, 1, 1, 0},
+ wantErr: false,
+ },
+ {
+ name: "second last no split last empty",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(1, 0),
+ newMsg(0, 0),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 1,
+ wantMsgLens: []int{1, 0, 0, 0},
+ wantErr: false,
+ },
+ {
+ name: "second last no split last no split",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(1, 0),
+ newMsg(1, 0),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 2,
+ wantMsgLens: []int{1, 1, 0, 0},
+ wantErr: false,
+ },
+ {
+ name: "second last no split last split",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(1, 0),
+ newMsg(3, 1),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 4,
+ wantMsgLens: []int{1, 1, 1, 1},
+ wantErr: false,
+ },
+ {
+ name: "second last split last split",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(2, 1),
+ newMsg(2, 1),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 4,
+ wantMsgLens: []int{1, 1, 1, 1},
+ wantErr: false,
+ },
+ {
+ name: "second last no split last split overflow",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(1, 0),
+ newMsg(4, 1),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 4,
+ wantMsgLens: []int{1, 1, 1, 1},
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize)
+ if err != nil && !tt.wantErr {
+ t.Fatalf("err: %v", err)
+ }
+ if got != tt.wantNumEval {
+ t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
+ }
+ for i, msg := range tt.msgs {
+ if msg.N != tt.wantMsgLens[i] {
+ t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
+ }
+ }
+ })
+ }
+}
diff --git a/conn/bind_windows.go b/conn/bind_windows.go
new file mode 100644
index 0000000..d5095e0
--- /dev/null
+++ b/conn/bind_windows.go
@@ -0,0 +1,601 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "encoding/binary"
+ "io"
+ "net"
+ "net/netip"
+ "strconv"
+ "sync"
+ "sync/atomic"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+
+ "golang.zx2c4.com/wireguard/conn/winrio"
+)
+
+const (
+ packetsPerRing = 1024
+ bytesPerPacket = 2048 - 32
+ receiveSpins = 15
+)
+
+type ringPacket struct {
+ addr WinRingEndpoint
+ data [bytesPerPacket]byte
+}
+
+type ringBuffer struct {
+ packets uintptr
+ head, tail uint32
+ id winrio.BufferId
+ iocp windows.Handle
+ isFull bool
+ cq winrio.Cq
+ mu sync.Mutex
+ overlapped windows.Overlapped
+}
+
+func (rb *ringBuffer) Push() *ringPacket {
+ for rb.isFull {
+ panic("ring is full")
+ }
+ ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
+ rb.tail += 1
+ if rb.tail%packetsPerRing == rb.head%packetsPerRing {
+ rb.isFull = true
+ }
+ return ret
+}
+
+func (rb *ringBuffer) Return(count uint32) {
+ if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull {
+ return
+ }
+ rb.head += count
+ rb.isFull = false
+}
+
+type afWinRingBind struct {
+ sock windows.Handle
+ rx, tx ringBuffer
+ rq winrio.Rq
+ mu sync.Mutex
+ blackhole bool
+}
+
+// WinRingBind uses Windows registered I/O for fast ring buffered networking.
+type WinRingBind struct {
+ v4, v6 afWinRingBind
+ mu sync.RWMutex
+ isOpen atomic.Uint32 // 0, 1, or 2
+}
+
+func NewDefaultBind() Bind { return NewWinRingBind() }
+
+func NewWinRingBind() Bind {
+ if !winrio.Initialize() {
+ return NewStdNetBind()
+ }
+ return new(WinRingBind)
+}
+
+type WinRingEndpoint struct {
+ family uint16
+ data [30]byte
+}
+
+var (
+ _ Bind = (*WinRingBind)(nil)
+ _ Endpoint = (*WinRingEndpoint)(nil)
+)
+
+func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
+ host, port, err := net.SplitHostPort(s)
+ if err != nil {
+ return nil, err
+ }
+ host16, err := windows.UTF16PtrFromString(host)
+ if err != nil {
+ return nil, err
+ }
+ port16, err := windows.UTF16PtrFromString(port)
+ if err != nil {
+ return nil, err
+ }
+ hints := windows.AddrinfoW{
+ Flags: windows.AI_NUMERICHOST,
+ Family: windows.AF_UNSPEC,
+ Socktype: windows.SOCK_DGRAM,
+ Protocol: windows.IPPROTO_UDP,
+ }
+ var addrinfo *windows.AddrinfoW
+ err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo)
+ if err != nil {
+ return nil, err
+ }
+ defer windows.FreeAddrInfoW(addrinfo)
+ if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
+ return nil, windows.ERROR_INVALID_ADDRESS
+ }
+ var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
+ copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen))
+ return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
+}
+
+func (*WinRingEndpoint) ClearSrc() {}
+
+func (e *WinRingEndpoint) DstIP() netip.Addr {
+ switch e.family {
+ case windows.AF_INET:
+ return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
+ case windows.AF_INET6:
+ return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
+ }
+ return netip.Addr{}
+}
+
+func (e *WinRingEndpoint) SrcIP() netip.Addr {
+ return netip.Addr{} // not supported
+}
+
+func (e *WinRingEndpoint) DstToBytes() []byte {
+ switch e.family {
+ case windows.AF_INET:
+ b := make([]byte, 0, 6)
+ b = append(b, e.data[2:6]...)
+ b = append(b, e.data[1], e.data[0])
+ return b
+ case windows.AF_INET6:
+ b := make([]byte, 0, 18)
+ b = append(b, e.data[6:22]...)
+ b = append(b, e.data[1], e.data[0])
+ return b
+ }
+ return nil
+}
+
+func (e *WinRingEndpoint) DstToString() string {
+ switch e.family {
+ case windows.AF_INET:
+ return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
+ case windows.AF_INET6:
+ var zone string
+ if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
+ zone = strconv.FormatUint(uint64(scope), 10)
+ }
+ return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
+ }
+ return ""
+}
+
+func (e *WinRingEndpoint) SrcToString() string {
+ return ""
+}
+
+func (ring *ringBuffer) CloseAndZero() {
+ if ring.cq != 0 {
+ winrio.CloseCompletionQueue(ring.cq)
+ ring.cq = 0
+ }
+ if ring.iocp != 0 {
+ windows.CloseHandle(ring.iocp)
+ ring.iocp = 0
+ }
+ if ring.id != 0 {
+ winrio.DeregisterBuffer(ring.id)
+ ring.id = 0
+ }
+ if ring.packets != 0 {
+ windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
+ ring.packets = 0
+ }
+ ring.head = 0
+ ring.tail = 0
+ ring.isFull = false
+}
+
+func (bind *afWinRingBind) CloseAndZero() {
+ bind.rx.CloseAndZero()
+ bind.tx.CloseAndZero()
+ if bind.sock != 0 {
+ windows.CloseHandle(bind.sock)
+ bind.sock = 0
+ }
+ bind.blackhole = false
+}
+
+func (bind *WinRingBind) closeAndZero() {
+ bind.isOpen.Store(0)
+ bind.v4.CloseAndZero()
+ bind.v6.CloseAndZero()
+}
+
+func (ring *ringBuffer) Open() error {
+ var err error
+ packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
+ ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
+ if err != nil {
+ return err
+ }
+ ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
+ if err != nil {
+ return err
+ }
+ ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
+ if err != nil {
+ return err
+ }
+ ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) {
+ var err error
+ bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
+ if err != nil {
+ return nil, err
+ }
+ err = bind.rx.Open()
+ if err != nil {
+ return nil, err
+ }
+ err = bind.tx.Open()
+ if err != nil {
+ return nil, err
+ }
+ bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0)
+ if err != nil {
+ return nil, err
+ }
+ err = windows.Bind(bind.sock, sa)
+ if err != nil {
+ return nil, err
+ }
+ sa, err = windows.Getsockname(bind.sock)
+ if err != nil {
+ return nil, err
+ }
+ return sa, nil
+}
+
+func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) {
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+ defer func() {
+ if err != nil {
+ bind.closeAndZero()
+ }
+ }()
+ if bind.isOpen.Load() != 0 {
+ return nil, 0, ErrBindAlreadyOpen
+ }
+ var sa windows.Sockaddr
+ sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
+ if err != nil {
+ return nil, 0, err
+ }
+ sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
+ if err != nil {
+ return nil, 0, err
+ }
+ selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
+ for i := 0; i < packetsPerRing; i++ {
+ err = bind.v4.InsertReceiveRequest()
+ if err != nil {
+ return nil, 0, err
+ }
+ err = bind.v6.InsertReceiveRequest()
+ if err != nil {
+ return nil, 0, err
+ }
+ }
+ bind.isOpen.Store(1)
+ return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
+}
+
+func (bind *WinRingBind) Close() error {
+ bind.mu.RLock()
+ if bind.isOpen.Load() != 1 {
+ bind.mu.RUnlock()
+ return nil
+ }
+ bind.isOpen.Store(2)
+ windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
+ windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
+ windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
+ windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil)
+ bind.mu.RUnlock()
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+ bind.closeAndZero()
+ return nil
+}
+
+// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
+// rename the IdealBatchSize constant to BatchSize.
+func (bind *WinRingBind) BatchSize() int {
+ // TODO: implement batching in and out of the ring
+ return 1
+}
+
+func (bind *WinRingBind) SetMark(mark uint32) error {
+ return nil
+}
+
+func (bind *afWinRingBind) InsertReceiveRequest() error {
+ packet := bind.rx.Push()
+ dataBuffer := &winrio.Buffer{
+ Id: bind.rx.id,
+ Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets),
+ Length: uint32(len(packet.data)),
+ }
+ addressBuffer := &winrio.Buffer{
+ Id: bind.rx.id,
+ Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets),
+ Length: uint32(unsafe.Sizeof(packet.addr)),
+ }
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+ return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
+}
+
+//go:linkname procyield runtime.procyield
+func procyield(cycles uint32)
+
+func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
+ if isOpen.Load() != 1 {
+ return 0, nil, net.ErrClosed
+ }
+ bind.rx.mu.Lock()
+ defer bind.rx.mu.Unlock()
+
+ var err error
+ var count uint32
+ var results [1]winrio.Result
+retry:
+ count = 0
+ for tries := 0; count == 0 && tries < receiveSpins; tries++ {
+ if tries > 0 {
+ if isOpen.Load() != 1 {
+ return 0, nil, net.ErrClosed
+ }
+ procyield(1)
+ }
+ count = winrio.DequeueCompletion(bind.rx.cq, results[:])
+ }
+ if count == 0 {
+ err = winrio.Notify(bind.rx.cq)
+ if err != nil {
+ return 0, nil, err
+ }
+ var bytes uint32
+ var key uintptr
+ var overlapped *windows.Overlapped
+ err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
+ if err != nil {
+ return 0, nil, err
+ }
+ if isOpen.Load() != 1 {
+ return 0, nil, net.ErrClosed
+ }
+ count = winrio.DequeueCompletion(bind.rx.cq, results[:])
+ if count == 0 {
+ return 0, nil, io.ErrNoProgress
+ }
+ }
+ bind.rx.Return(1)
+ err = bind.InsertReceiveRequest()
+ if err != nil {
+ return 0, nil, err
+ }
+ // We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
+ // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
+ // attacker bandwidth, just like the rest of the receive path.
+ if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
+ if isOpen.Load() != 1 {
+ return 0, nil, net.ErrClosed
+ }
+ goto retry
+ }
+ if results[0].Status != 0 {
+ return 0, nil, windows.Errno(results[0].Status)
+ }
+ packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
+ ep := packet.addr
+ n := copy(buf, packet.data[:results[0].BytesTransferred])
+ return n, &ep, nil
+}
+
+func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
+ bind.mu.RLock()
+ defer bind.mu.RUnlock()
+ n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen)
+ sizes[0] = n
+ eps[0] = ep
+ return 1, err
+}
+
+func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
+ bind.mu.RLock()
+ defer bind.mu.RUnlock()
+ n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen)
+ sizes[0] = n
+ eps[0] = ep
+ return 1, err
+}
+
+func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
+ if isOpen.Load() != 1 {
+ return net.ErrClosed
+ }
+ if len(buf) > bytesPerPacket {
+ return io.ErrShortBuffer
+ }
+ bind.tx.mu.Lock()
+ defer bind.tx.mu.Unlock()
+ var results [packetsPerRing]winrio.Result
+ count := winrio.DequeueCompletion(bind.tx.cq, results[:])
+ if count == 0 && bind.tx.isFull {
+ err := winrio.Notify(bind.tx.cq)
+ if err != nil {
+ return err
+ }
+ var bytes uint32
+ var key uintptr
+ var overlapped *windows.Overlapped
+ err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
+ if err != nil {
+ return err
+ }
+ if isOpen.Load() != 1 {
+ return net.ErrClosed
+ }
+ count = winrio.DequeueCompletion(bind.tx.cq, results[:])
+ if count == 0 {
+ return io.ErrNoProgress
+ }
+ }
+ if count > 0 {
+ bind.tx.Return(count)
+ }
+ packet := bind.tx.Push()
+ packet.addr = *nend
+ copy(packet.data[:], buf)
+ dataBuffer := &winrio.Buffer{
+ Id: bind.tx.id,
+ Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets),
+ Length: uint32(len(buf)),
+ }
+ addressBuffer := &winrio.Buffer{
+ Id: bind.tx.id,
+ Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets),
+ Length: uint32(unsafe.Sizeof(packet.addr)),
+ }
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+ return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
+}
+
+func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error {
+ nend, ok := endpoint.(*WinRingEndpoint)
+ if !ok {
+ return ErrWrongEndpointType
+ }
+ bind.mu.RLock()
+ defer bind.mu.RUnlock()
+ for _, buf := range bufs {
+ switch nend.family {
+ case windows.AF_INET:
+ if bind.v4.blackhole {
+ continue
+ }
+ if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
+ return err
+ }
+ case windows.AF_INET6:
+ if bind.v6.blackhole {
+ continue
+ }
+ if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ sysconn, err := s.ipv4.SyscallConn()
+ if err != nil {
+ return err
+ }
+ err2 := sysconn.Control(func(fd uintptr) {
+ err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex)
+ })
+ if err2 != nil {
+ return err2
+ }
+ if err != nil {
+ return err
+ }
+ s.blackhole4 = blackhole
+ return nil
+}
+
+func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ sysconn, err := s.ipv6.SyscallConn()
+ if err != nil {
+ return err
+ }
+ err2 := sysconn.Control(func(fd uintptr) {
+ err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex)
+ })
+ if err2 != nil {
+ return err2
+ }
+ if err != nil {
+ return err
+ }
+ s.blackhole6 = blackhole
+ return nil
+}
+
+func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
+ bind.mu.RLock()
+ defer bind.mu.RUnlock()
+ if bind.isOpen.Load() != 1 {
+ return net.ErrClosed
+ }
+ err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
+ if err != nil {
+ return err
+ }
+ bind.v4.blackhole = blackhole
+ return nil
+}
+
+func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
+ bind.mu.RLock()
+ defer bind.mu.RUnlock()
+ if bind.isOpen.Load() != 1 {
+ return net.ErrClosed
+ }
+ err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
+ if err != nil {
+ return err
+ }
+ bind.v6.blackhole = blackhole
+ return nil
+}
+
+func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error {
+ const IP_UNICAST_IF = 31
+ /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
+ var bytes [4]byte
+ binary.BigEndian.PutUint32(bytes[:], interfaceIndex)
+ interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
+ err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex))
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error {
+ const IPV6_UNICAST_IF = 31
+ return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
+}
diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go
new file mode 100644
index 0000000..74e7add
--- /dev/null
+++ b/conn/bindtest/bindtest.go
@@ -0,0 +1,136 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package bindtest
+
+import (
+ "fmt"
+ "math/rand"
+ "net"
+ "net/netip"
+ "os"
+
+ "golang.zx2c4.com/wireguard/conn"
+)
+
+type ChannelBind struct {
+ rx4, tx4 *chan []byte
+ rx6, tx6 *chan []byte
+ closeSignal chan bool
+ source4, source6 ChannelEndpoint
+ target4, target6 ChannelEndpoint
+}
+
+type ChannelEndpoint uint16
+
+var (
+ _ conn.Bind = (*ChannelBind)(nil)
+ _ conn.Endpoint = (*ChannelEndpoint)(nil)
+)
+
+func NewChannelBinds() [2]conn.Bind {
+ arx4 := make(chan []byte, 8192)
+ brx4 := make(chan []byte, 8192)
+ arx6 := make(chan []byte, 8192)
+ brx6 := make(chan []byte, 8192)
+ var binds [2]ChannelBind
+ binds[0].rx4 = &arx4
+ binds[0].tx4 = &brx4
+ binds[1].rx4 = &brx4
+ binds[1].tx4 = &arx4
+ binds[0].rx6 = &arx6
+ binds[0].tx6 = &brx6
+ binds[1].rx6 = &brx6
+ binds[1].tx6 = &arx6
+ binds[0].target4 = ChannelEndpoint(1)
+ binds[1].target4 = ChannelEndpoint(2)
+ binds[0].target6 = ChannelEndpoint(3)
+ binds[1].target6 = ChannelEndpoint(4)
+ binds[0].source4 = binds[1].target4
+ binds[0].source6 = binds[1].target6
+ binds[1].source4 = binds[0].target4
+ binds[1].source6 = binds[0].target6
+ return [2]conn.Bind{&binds[0], &binds[1]}
+}
+
+func (c ChannelEndpoint) ClearSrc() {}
+
+func (c ChannelEndpoint) SrcToString() string { return "" }
+
+func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) }
+
+func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
+
+func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
+
+func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
+
+func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
+ c.closeSignal = make(chan bool)
+ fns = append(fns, c.makeReceiveFunc(*c.rx4))
+ fns = append(fns, c.makeReceiveFunc(*c.rx6))
+ if rand.Uint32()&1 == 0 {
+ return fns, uint16(c.source4), nil
+ } else {
+ return fns, uint16(c.source6), nil
+ }
+}
+
+func (c *ChannelBind) Close() error {
+ if c.closeSignal != nil {
+ select {
+ case <-c.closeSignal:
+ default:
+ close(c.closeSignal)
+ }
+ }
+ return nil
+}
+
+func (c *ChannelBind) BatchSize() int { return 1 }
+
+func (c *ChannelBind) SetMark(mark uint32) error { return nil }
+
+func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
+ return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
+ select {
+ case <-c.closeSignal:
+ return 0, net.ErrClosed
+ case rx := <-ch:
+ copied := copy(bufs[0], rx)
+ sizes[0] = copied
+ eps[0] = c.target6
+ return 1, nil
+ }
+ }
+}
+
+func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error {
+ for _, b := range bufs {
+ select {
+ case <-c.closeSignal:
+ return net.ErrClosed
+ default:
+ bc := make([]byte, len(b))
+ copy(bc, b)
+ if ep.(ChannelEndpoint) == c.target4 {
+ *c.tx4 <- bc
+ } else if ep.(ChannelEndpoint) == c.target6 {
+ *c.tx6 <- bc
+ } else {
+ return os.ErrInvalid
+ }
+ }
+ }
+ return nil
+}
+
+func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
+ addr, err := netip.ParseAddrPort(s)
+ if err != nil {
+ return nil, err
+ }
+ return ChannelEndpoint(addr.Port()), nil
+}
diff --git a/conn/boundif_android.go b/conn/boundif_android.go
new file mode 100644
index 0000000..dd3ca5b
--- /dev/null
+++ b/conn/boundif_android.go
@@ -0,0 +1,34 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
+ sysconn, err := s.ipv4.SyscallConn()
+ if err != nil {
+ return -1, err
+ }
+ err = sysconn.Control(func(f uintptr) {
+ fd = int(f)
+ })
+ if err != nil {
+ return -1, err
+ }
+ return
+}
+
+func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
+ sysconn, err := s.ipv6.SyscallConn()
+ if err != nil {
+ return -1, err
+ }
+ err = sysconn.Control(func(f uintptr) {
+ fd = int(f)
+ })
+ if err != nil {
+ return -1, err
+ }
+ return
+}
diff --git a/conn/conn.go b/conn/conn.go
new file mode 100644
index 0000000..a1f57d2
--- /dev/null
+++ b/conn/conn.go
@@ -0,0 +1,133 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+// Package conn implements WireGuard's network connections.
+package conn
+
+import (
+ "errors"
+ "fmt"
+ "net/netip"
+ "reflect"
+ "runtime"
+ "strings"
+)
+
+const (
+ IdealBatchSize = 128 // maximum number of packets handled per read and write
+)
+
+// A ReceiveFunc receives at least one packet from the network and writes them
+// into packets. On a successful read it returns the number of elements of
+// sizes, packets, and endpoints that should be evaluated. Some elements of
+// sizes may be zero, and callers should ignore them. Callers must pass a sizes
+// and eps slice with a length greater than or equal to the length of packets.
+// These lengths must not exceed the length of the associated Bind.BatchSize().
+type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
+
+// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
+//
+// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
+// depending on the platform-specific implementation.
+type Bind interface {
+ // Open puts the Bind into a listening state on a given port and reports the actual
+ // port that it bound to. Passing zero results in a random selection.
+ // fns is the set of functions that will be called to receive packets.
+ Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
+
+ // Close closes the Bind listener.
+ // All fns returned by Open must return net.ErrClosed after a call to Close.
+ Close() error
+
+ // SetMark sets the mark for each packet sent through this Bind.
+ // This mark is passed to the kernel as the socket option SO_MARK.
+ SetMark(mark uint32) error
+
+ // Send writes one or more packets in bufs to address ep. The length of
+ // bufs must not exceed BatchSize().
+ Send(bufs [][]byte, ep Endpoint) error
+
+ // ParseEndpoint creates a new endpoint from a string.
+ ParseEndpoint(s string) (Endpoint, error)
+
+ // BatchSize is the number of buffers expected to be passed to
+ // the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
+ BatchSize() int
+}
+
+// BindSocketToInterface is implemented by Bind objects that support being
+// tied to a single network interface. Used by wireguard-windows.
+type BindSocketToInterface interface {
+ BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error
+ BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error
+}
+
+// PeekLookAtSocketFd is implemented by Bind objects that support having their
+// file descriptor peeked at. Used by wireguard-android.
+type PeekLookAtSocketFd interface {
+ PeekLookAtSocketFd4() (fd int, err error)
+ PeekLookAtSocketFd6() (fd int, err error)
+}
+
+// An Endpoint maintains the source/destination caching for a peer.
+//
+// dst: the remote address of a peer ("endpoint" in uapi terminology)
+// src: the local address from which datagrams originate going to the peer
+type Endpoint interface {
+ ClearSrc() // clears the source address
+ SrcToString() string // returns the local source address (ip:port)
+ DstToString() string // returns the destination address (ip:port)
+ DstToBytes() []byte // used for mac2 cookie calculations
+ DstIP() netip.Addr
+ SrcIP() netip.Addr
+}
+
+var (
+ ErrBindAlreadyOpen = errors.New("bind is already open")
+ ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
+)
+
+func (fn ReceiveFunc) PrettyName() string {
+ name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
+ // 0. cheese/taco.beansIPv6.func12.func21218-fm
+ name = strings.TrimSuffix(name, "-fm")
+ // 1. cheese/taco.beansIPv6.func12.func21218
+ if idx := strings.LastIndexByte(name, '/'); idx != -1 {
+ name = name[idx+1:]
+ // 2. taco.beansIPv6.func12.func21218
+ }
+ for {
+ var idx int
+ for idx = len(name) - 1; idx >= 0; idx-- {
+ if name[idx] < '0' || name[idx] > '9' {
+ break
+ }
+ }
+ if idx == len(name)-1 {
+ break
+ }
+ const dotFunc = ".func"
+ if !strings.HasSuffix(name[:idx+1], dotFunc) {
+ break
+ }
+ name = name[:idx+1-len(dotFunc)]
+ // 3. taco.beansIPv6.func12
+ // 4. taco.beansIPv6
+ }
+ if idx := strings.LastIndexByte(name, '.'); idx != -1 {
+ name = name[idx+1:]
+ // 5. beansIPv6
+ }
+ if name == "" {
+ return fmt.Sprintf("%p", fn)
+ }
+ if strings.HasSuffix(name, "IPv4") {
+ return "v4"
+ }
+ if strings.HasSuffix(name, "IPv6") {
+ return "v6"
+ }
+ return name
+}
diff --git a/conn/conn_test.go b/conn/conn_test.go
new file mode 100644
index 0000000..c6194ee
--- /dev/null
+++ b/conn/conn_test.go
@@ -0,0 +1,24 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "testing"
+)
+
+func TestPrettyName(t *testing.T) {
+ var (
+ recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return }
+ )
+
+ const want = "TestPrettyName"
+
+ t.Run("ReceiveFunc.PrettyName", func(t *testing.T) {
+ if got := recvFunc.PrettyName(); got != want {
+ t.Errorf("PrettyName() = %v, want %v", got, want)
+ }
+ })
+}
diff --git a/conn/controlfns.go b/conn/controlfns.go
new file mode 100644
index 0000000..4f7d90f
--- /dev/null
+++ b/conn/controlfns.go
@@ -0,0 +1,43 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "net"
+ "syscall"
+)
+
+// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
+// the max supported by a default configuration of macOS. Some platforms will
+// silently clamp the value to other maximums, such as linux clamping to
+// net.core.{r,w}mem_max (see _linux.go for additional implementation that works
+// around this limitation)
+const socketBufferSize = 7 << 20
+
+// controlFn is the callback function signature from net.ListenConfig.Control.
+// It is used to apply platform specific configuration to the socket prior to
+// bind.
+type controlFn func(network, address string, c syscall.RawConn) error
+
+// controlFns is a list of functions that are called from the listen config
+// that can apply socket options.
+var controlFns = []controlFn{}
+
+// listenConfig returns a net.ListenConfig that applies the controlFns to the
+// socket prior to bind. This is used to apply socket buffer sizing and packet
+// information OOB configuration for sticky sockets.
+func listenConfig() *net.ListenConfig {
+ return &net.ListenConfig{
+ Control: func(network, address string, c syscall.RawConn) error {
+ for _, fn := range controlFns {
+ if err := fn(network, address, c); err != nil {
+ return err
+ }
+ }
+ return nil
+ },
+ }
+}
diff --git a/conn/controlfns_linux.go b/conn/controlfns_linux.go
new file mode 100644
index 0000000..f6ab1d2
--- /dev/null
+++ b/conn/controlfns_linux.go
@@ -0,0 +1,69 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "fmt"
+ "runtime"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+func init() {
+ controlFns = append(controlFns,
+
+ // Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by
+ // using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to
+ // fail silently - the result of failure is lower performance on very fast
+ // links or high latency links.
+ func(network, address string, c syscall.RawConn) error {
+ return c.Control(func(fd uintptr) {
+ // Set up to *mem_max
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
+ // Set beyond *mem_max if CAP_NET_ADMIN
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
+ })
+ },
+
+ // Enable receiving of the packet information (IP_PKTINFO for IPv4,
+ // IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
+ func(network, address string, c syscall.RawConn) error {
+ var err error
+ switch network {
+ case "udp4":
+ if runtime.GOOS != "android" {
+ c.Control(func(fd uintptr) {
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
+ })
+ }
+ case "udp6":
+ c.Control(func(fd uintptr) {
+ if runtime.GOOS != "android" {
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
+ if err != nil {
+ return
+ }
+ }
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
+ })
+ default:
+ err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
+ }
+ return err
+ },
+
+ // Attempt to enable UDP_GRO
+ func(network, address string, c syscall.RawConn) error {
+ c.Control(func(fd uintptr) {
+ _ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
+ })
+ return nil
+ },
+ )
+}
diff --git a/conn/controlfns_unix.go b/conn/controlfns_unix.go
new file mode 100644
index 0000000..91692c0
--- /dev/null
+++ b/conn/controlfns_unix.go
@@ -0,0 +1,35 @@
+//go:build !windows && !linux && !wasm
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+func init() {
+ controlFns = append(controlFns,
+ func(network, address string, c syscall.RawConn) error {
+ return c.Control(func(fd uintptr) {
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
+ })
+ },
+
+ func(network, address string, c syscall.RawConn) error {
+ var err error
+ if network == "udp6" {
+ c.Control(func(fd uintptr) {
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
+ })
+ }
+ return err
+ },
+ )
+}
diff --git a/conn/controlfns_windows.go b/conn/controlfns_windows.go
new file mode 100644
index 0000000..c3bdf7d
--- /dev/null
+++ b/conn/controlfns_windows.go
@@ -0,0 +1,23 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/windows"
+)
+
+func init() {
+ controlFns = append(controlFns,
+ func(network, address string, c syscall.RawConn) error {
+ return c.Control(func(fd uintptr) {
+ _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF, socketBufferSize)
+ _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF, socketBufferSize)
+ })
+ },
+ )
+}
diff --git a/conn/default.go b/conn/default.go
new file mode 100644
index 0000000..b6f761b
--- /dev/null
+++ b/conn/default.go
@@ -0,0 +1,10 @@
+//go:build !windows
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+func NewDefaultBind() Bind { return NewStdNetBind() }
diff --git a/conn/errors_default.go b/conn/errors_default.go
new file mode 100644
index 0000000..f1e5b90
--- /dev/null
+++ b/conn/errors_default.go
@@ -0,0 +1,12 @@
+//go:build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+func errShouldDisableUDPGSO(err error) bool {
+ return false
+}
diff --git a/conn/errors_linux.go b/conn/errors_linux.go
new file mode 100644
index 0000000..8e61000
--- /dev/null
+++ b/conn/errors_linux.go
@@ -0,0 +1,26 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "errors"
+ "os"
+
+ "golang.org/x/sys/unix"
+)
+
+func errShouldDisableUDPGSO(err error) bool {
+ var serr *os.SyscallError
+ if errors.As(err, &serr) {
+ // EIO is returned by udp_send_skb() if the device driver does not have
+ // tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
+ // See:
+ // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
+ // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
+ return serr.Err == unix.EIO
+ }
+ return false
+}
diff --git a/conn/features_default.go b/conn/features_default.go
new file mode 100644
index 0000000..d53ff5f
--- /dev/null
+++ b/conn/features_default.go
@@ -0,0 +1,15 @@
+//go:build !linux
+// +build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import "net"
+
+func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
+ return
+}
diff --git a/conn/features_linux.go b/conn/features_linux.go
new file mode 100644
index 0000000..8959d93
--- /dev/null
+++ b/conn/features_linux.go
@@ -0,0 +1,29 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "net"
+
+ "golang.org/x/sys/unix"
+)
+
+func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
+ rc, err := conn.SyscallConn()
+ if err != nil {
+ return
+ }
+ err = rc.Control(func(fd uintptr) {
+ _, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
+ txOffload = errSyscall == nil
+ opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO)
+ rxOffload = errSyscall == nil && opt == 1
+ })
+ if err != nil {
+ return false, false
+ }
+ return txOffload, rxOffload
+}
diff --git a/conn/gso_default.go b/conn/gso_default.go
new file mode 100644
index 0000000..57780db
--- /dev/null
+++ b/conn/gso_default.go
@@ -0,0 +1,21 @@
+//go:build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
+func getGSOSize(control []byte) (int, error) {
+ return 0, nil
+}
+
+// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
+func setGSOSize(control *[]byte, gsoSize uint16) {
+}
+
+// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
+// offloading control data.
+const gsoControlSize = 0
diff --git a/conn/gso_linux.go b/conn/gso_linux.go
new file mode 100644
index 0000000..8596b29
--- /dev/null
+++ b/conn/gso_linux.go
@@ -0,0 +1,65 @@
+//go:build linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "fmt"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+const (
+ sizeOfGSOData = 2
+)
+
+// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
+func getGSOSize(control []byte) (int, error) {
+ var (
+ hdr unix.Cmsghdr
+ data []byte
+ rem = control
+ err error
+ )
+
+ for len(rem) > unix.SizeofCmsghdr {
+ hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
+ if err != nil {
+ return 0, fmt.Errorf("error parsing socket control message: %w", err)
+ }
+ if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
+ var gso uint16
+ copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
+ return int(gso), nil
+ }
+ }
+ return 0, nil
+}
+
+// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
+// data in control untouched.
+func setGSOSize(control *[]byte, gsoSize uint16) {
+ existingLen := len(*control)
+ avail := cap(*control) - existingLen
+ space := unix.CmsgSpace(sizeOfGSOData)
+ if avail < space {
+ return
+ }
+ *control = (*control)[:cap(*control)]
+ gsoControl := (*control)[existingLen:]
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
+ hdr.Level = unix.SOL_UDP
+ hdr.Type = unix.UDP_SEGMENT
+ hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
+ copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
+ *control = (*control)[:existingLen+space]
+}
+
+// gsoControlSize returns the recommended buffer size for pooling UDP
+// offloading control data.
+var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)
diff --git a/conn/mark_default.go b/conn/mark_default.go
new file mode 100644
index 0000000..3102384
--- /dev/null
+++ b/conn/mark_default.go
@@ -0,0 +1,12 @@
+//go:build !linux && !openbsd && !freebsd
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+func (s *StdNetBind) SetMark(mark uint32) error {
+ return nil
+}
diff --git a/device/mark_unix.go b/conn/mark_unix.go
index 669b328..d9e46ee 100644
--- a/device/mark_unix.go
+++ b/conn/mark_unix.go
@@ -1,11 +1,11 @@
-// +build android openbsd freebsd
+//go:build linux || openbsd || freebsd
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
-package device
+package conn
import (
"runtime"
@@ -26,13 +26,13 @@ func init() {
}
}
-func (bind *nativeBind) SetMark(mark uint32) error {
+func (s *StdNetBind) SetMark(mark uint32) error {
var operr error
if fwmarkIoctl == 0 {
return nil
}
- if bind.ipv4 != nil {
- fd, err := bind.ipv4.SyscallConn()
+ if s.ipv4 != nil {
+ fd, err := s.ipv4.SyscallConn()
if err != nil {
return err
}
@@ -46,8 +46,8 @@ func (bind *nativeBind) SetMark(mark uint32) error {
return err
}
}
- if bind.ipv6 != nil {
- fd, err := bind.ipv6.SyscallConn()
+ if s.ipv6 != nil {
+ fd, err := s.ipv6.SyscallConn()
if err != nil {
return err
}
diff --git a/conn/sticky_default.go b/conn/sticky_default.go
new file mode 100644
index 0000000..0b21386
--- /dev/null
+++ b/conn/sticky_default.go
@@ -0,0 +1,42 @@
+//go:build !linux || android
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import "net/netip"
+
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+ return netip.Addr{}
+}
+
+func (e *StdNetEndpoint) SrcIfidx() int32 {
+ return 0
+}
+
+func (e *StdNetEndpoint) SrcToString() string {
+ return ""
+}
+
+// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
+// {get,set}srcControl feature set, but use alternatively named flags and need
+// ports and require testing.
+
+// getSrcFromControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
+}
+
+// setSrcControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
+}
+
+// stickyControlSize returns the recommended buffer size for pooling sticky
+// offloading control data.
+const stickyControlSize = 0
+
+const StdNetSupportsStickySockets = false
diff --git a/conn/sticky_linux.go b/conn/sticky_linux.go
new file mode 100644
index 0000000..8e206e9
--- /dev/null
+++ b/conn/sticky_linux.go
@@ -0,0 +1,112 @@
+//go:build linux && !android
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "net/netip"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+ switch len(e.src) {
+ case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+ return netip.AddrFrom4(info.Spec_dst)
+ case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+ // TODO: set zone. in order to do so we need to check if the address is
+ // link local, and if it is perform a syscall to turn the ifindex into a
+ // zone string because netip uses string zones.
+ return netip.AddrFrom16(info.Addr)
+ }
+ return netip.Addr{}
+}
+
+func (e *StdNetEndpoint) SrcIfidx() int32 {
+ switch len(e.src) {
+ case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+ return info.Ifindex
+ case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+ return int32(info.Ifindex)
+ }
+ return 0
+}
+
+func (e *StdNetEndpoint) SrcToString() string {
+ return e.SrcIP().String()
+}
+
+// getSrcFromControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
+ ep.ClearSrc()
+
+ var (
+ hdr unix.Cmsghdr
+ data []byte
+ rem []byte = control
+ err error
+ )
+
+ for len(rem) > unix.SizeofCmsghdr {
+ hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
+ if err != nil {
+ return
+ }
+
+ if hdr.Level == unix.IPPROTO_IP &&
+ hdr.Type == unix.IP_PKTINFO {
+
+ if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
+ ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+ }
+ ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
+
+ hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
+ copy(ep.src, hdrBuf)
+ copy(ep.src[unix.CmsgLen(0):], data)
+ return
+ }
+
+ if hdr.Level == unix.IPPROTO_IPV6 &&
+ hdr.Type == unix.IPV6_PKTINFO {
+
+ if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
+ ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
+ }
+
+ ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
+
+ hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
+ copy(ep.src, hdrBuf)
+ copy(ep.src[unix.CmsgLen(0):], data)
+ return
+ }
+ }
+}
+
+// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
+// and source ifindex found in ep. control's len will be set to 0 in the event
+// that ep is a default value.
+func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
+ if cap(*control) < len(ep.src) {
+ return
+ }
+ *control = (*control)[:0]
+ *control = append(*control, ep.src...)
+}
+
+// stickyControlSize returns the recommended buffer size for pooling sticky
+// offloading control data.
+var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
+
+const StdNetSupportsStickySockets = true
diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go
new file mode 100644
index 0000000..d2bd584
--- /dev/null
+++ b/conn/sticky_linux_test.go
@@ -0,0 +1,266 @@
+//go:build linux && !android
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "context"
+ "net"
+ "net/netip"
+ "runtime"
+ "testing"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) {
+ var buf []byte
+ if addr.Is4() {
+ buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+ hdr := unix.Cmsghdr{
+ Level: unix.IPPROTO_IP,
+ Type: unix.IP_PKTINFO,
+ }
+ hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
+ copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
+
+ info := unix.Inet4Pktinfo{
+ Ifindex: ifidx,
+ Spec_dst: addr.As4(),
+ }
+ copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo))
+ } else {
+ buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
+ hdr := unix.Cmsghdr{
+ Level: unix.IPPROTO_IPV6,
+ Type: unix.IPV6_PKTINFO,
+ }
+ hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
+ copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
+
+ info := unix.Inet6Pktinfo{
+ Ifindex: uint32(ifidx),
+ Addr: addr.As16(),
+ }
+ copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo))
+ }
+
+ ep.src = buf
+}
+
+func Test_setSrcControl(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ ep := &StdNetEndpoint{
+ AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
+ }
+ setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
+
+ control := make([]byte, stickyControlSize)
+
+ setSrcControl(&control, ep)
+
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ if hdr.Level != unix.IPPROTO_IP {
+ t.Errorf("unexpected level: %d", hdr.Level)
+ }
+ if hdr.Type != unix.IP_PKTINFO {
+ t.Errorf("unexpected type: %d", hdr.Type)
+ }
+ if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
+ t.Errorf("unexpected length: %d", hdr.Len)
+ }
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
+ t.Errorf("unexpected address: %v", info.Spec_dst)
+ }
+ if info.Ifindex != 5 {
+ t.Errorf("unexpected ifindex: %d", info.Ifindex)
+ }
+ })
+
+ t.Run("IPv6", func(t *testing.T) {
+ ep := &StdNetEndpoint{
+ AddrPort: netip.MustParseAddrPort("[::1]:1234"),
+ }
+ setSrc(ep, netip.MustParseAddr("::1"), 5)
+
+ control := make([]byte, stickyControlSize)
+
+ setSrcControl(&control, ep)
+
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ if hdr.Level != unix.IPPROTO_IPV6 {
+ t.Errorf("unexpected level: %d", hdr.Level)
+ }
+ if hdr.Type != unix.IPV6_PKTINFO {
+ t.Errorf("unexpected type: %d", hdr.Type)
+ }
+ if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
+ t.Errorf("unexpected length: %d", hdr.Len)
+ }
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ if info.Addr != ep.SrcIP().As16() {
+ t.Errorf("unexpected address: %v", info.Addr)
+ }
+ if info.Ifindex != 5 {
+ t.Errorf("unexpected ifindex: %d", info.Ifindex)
+ }
+ })
+
+ t.Run("ClearOnNoSrc", func(t *testing.T) {
+ control := make([]byte, stickyControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = 1
+ hdr.Type = 2
+ hdr.Len = 3
+
+ setSrcControl(&control, &StdNetEndpoint{})
+
+ if len(control) != 0 {
+ t.Errorf("unexpected control: %v", control)
+ }
+ })
+}
+
+func Test_getSrcFromControl(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ control := make([]byte, stickyControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = unix.IPPROTO_IP
+ hdr.Type = unix.IP_PKTINFO
+ hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ info.Spec_dst = [4]byte{127, 0, 0, 1}
+ info.Ifindex = 5
+
+ ep := &StdNetEndpoint{}
+ getSrcFromControl(control, ep)
+
+ if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.SrcIfidx() != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
+ }
+ })
+ t.Run("IPv6", func(t *testing.T) {
+ control := make([]byte, stickyControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = unix.IPPROTO_IPV6
+ hdr.Type = unix.IPV6_PKTINFO
+ hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
+ info.Ifindex = 5
+
+ ep := &StdNetEndpoint{}
+ getSrcFromControl(control, ep)
+
+ if ep.SrcIP() != netip.MustParseAddr("::1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.SrcIfidx() != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
+ }
+ })
+ t.Run("ClearOnEmpty", func(t *testing.T) {
+ var control []byte
+ ep := &StdNetEndpoint{}
+ setSrc(ep, netip.MustParseAddr("::1"), 5)
+
+ getSrcFromControl(control, ep)
+ if ep.SrcIP().IsValid() {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.SrcIfidx() != 0 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
+ }
+ })
+ t.Run("Multiple", func(t *testing.T) {
+ zeroControl := make([]byte, unix.CmsgSpace(0))
+ zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0]))
+ zeroHdr.SetLen(unix.CmsgLen(0))
+
+ control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = unix.IPPROTO_IP
+ hdr.Type = unix.IP_PKTINFO
+ hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ info.Spec_dst = [4]byte{127, 0, 0, 1}
+ info.Ifindex = 5
+
+ combined := make([]byte, 0)
+ combined = append(combined, zeroControl...)
+ combined = append(combined, control...)
+
+ ep := &StdNetEndpoint{}
+ getSrcFromControl(combined, ep)
+
+ if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.SrcIfidx() != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
+ }
+ })
+}
+
+func Test_listenConfig(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ sc, err := conn.(*net.UDPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if runtime.GOOS == "linux" {
+ var i int
+ sc.Control(func(fd uintptr) {
+ i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if i != 1 {
+ t.Error("IP_PKTINFO not set!")
+ }
+ } else {
+ t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
+ }
+ })
+ t.Run("IPv6", func(t *testing.T) {
+ conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ sc, err := conn.(*net.UDPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if runtime.GOOS == "linux" {
+ var i int
+ sc.Control(func(fd uintptr) {
+ i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if i != 1 {
+ t.Error("IPV6_PKTINFO not set!")
+ }
+ } else {
+ t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
+ }
+ })
+}
diff --git a/conn/winrio/rio_windows.go b/conn/winrio/rio_windows.go
new file mode 100644
index 0000000..d1037bb
--- /dev/null
+++ b/conn/winrio/rio_windows.go
@@ -0,0 +1,254 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package winrio
+
+import (
+ "log"
+ "sync"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+const (
+ MsgDontNotify = 1
+ MsgDefer = 2
+ MsgWaitAll = 4
+ MsgCommitOnly = 8
+
+ MaxCqSize = 0x8000000
+
+ invalidBufferId = 0xFFFFFFFF
+ invalidCq = 0
+ invalidRq = 0
+ corruptCq = 0xFFFFFFFF
+)
+
+var extensionFunctionTable struct {
+ cbSize uint32
+ rioReceive uintptr
+ rioReceiveEx uintptr
+ rioSend uintptr
+ rioSendEx uintptr
+ rioCloseCompletionQueue uintptr
+ rioCreateCompletionQueue uintptr
+ rioCreateRequestQueue uintptr
+ rioDequeueCompletion uintptr
+ rioDeregisterBuffer uintptr
+ rioNotify uintptr
+ rioRegisterBuffer uintptr
+ rioResizeCompletionQueue uintptr
+ rioResizeRequestQueue uintptr
+}
+
+type Cq uintptr
+
+type Rq uintptr
+
+type BufferId uintptr
+
+type Buffer struct {
+ Id BufferId
+ Offset uint32
+ Length uint32
+}
+
+type Result struct {
+ Status int32
+ BytesTransferred uint32
+ SocketContext uint64
+ RequestContext uint64
+}
+
+type notificationCompletionType uint32
+
+const (
+ eventCompletion notificationCompletionType = 1
+ iocpCompletion notificationCompletionType = 2
+)
+
+type eventNotificationCompletion struct {
+ completionType notificationCompletionType
+ event windows.Handle
+ notifyReset uint32
+}
+
+type iocpNotificationCompletion struct {
+ completionType notificationCompletionType
+ iocp windows.Handle
+ key uintptr
+ overlapped *windows.Overlapped
+}
+
+var (
+ initialized sync.Once
+ available bool
+)
+
+func Initialize() bool {
+ initialized.Do(func() {
+ var (
+ err error
+ socket windows.Handle
+ cq Cq
+ )
+ defer func() {
+ if err == nil {
+ return
+ }
+ if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 {
+ return
+ }
+ log.Printf("Registered I/O is unavailable: %v", err)
+ }()
+ socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
+ if err != nil {
+ return
+ }
+ defer windows.CloseHandle(socket)
+ WSAID_MULTIPLE_RIO := &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}}
+ const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024
+ ob := uint32(0)
+ err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,
+ (*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)),
+ (*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)),
+ &ob, nil, 0)
+ if err != nil {
+ return
+ }
+
+ // While we should be able to stop here, after getting the function pointers, some anti-virus actually causes
+ // failures in RIOCreateRequestQueue, so keep going to be certain this is supported.
+ var iocp windows.Handle
+ iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
+ if err != nil {
+ return
+ }
+ defer windows.CloseHandle(iocp)
+ var overlapped windows.Overlapped
+ cq, err = CreateIOCPCompletionQueue(2, iocp, 0, &overlapped)
+ if err != nil {
+ return
+ }
+ defer CloseCompletionQueue(cq)
+ _, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0)
+ if err != nil {
+ return
+ }
+ available = true
+ })
+ return available
+}
+
+func Socket(af, typ, proto int32) (windows.Handle, error) {
+ return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO)
+}
+
+func CloseCompletionQueue(cq Cq) {
+ _, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0)
+}
+
+func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) {
+ notificationCompletion := &eventNotificationCompletion{
+ completionType: eventCompletion,
+ event: event,
+ }
+ if notifyReset {
+ notificationCompletion.notifyReset = 1
+ }
+ ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
+ if ret == invalidCq {
+ return 0, err
+ }
+ return Cq(ret), nil
+}
+
+func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) {
+ notificationCompletion := &iocpNotificationCompletion{
+ completionType: iocpCompletion,
+ iocp: iocp,
+ key: key,
+ overlapped: overlapped,
+ }
+ ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
+ if ret == invalidCq {
+ return 0, err
+ }
+ return Cq(ret), nil
+}
+
+func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) {
+ ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0)
+ if ret == invalidCq {
+ return 0, err
+ }
+ return Cq(ret), nil
+}
+
+func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) {
+ ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0)
+ if ret == invalidRq {
+ return 0, err
+ }
+ return Rq(ret), nil
+}
+
+func DequeueCompletion(cq Cq, results []Result) uint32 {
+ var array uintptr
+ if len(results) > 0 {
+ array = uintptr(unsafe.Pointer(&results[0]))
+ }
+ ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results)))
+ if ret == corruptCq {
+ panic("cq is corrupt")
+ }
+ return uint32(ret)
+}
+
+func DeregisterBuffer(id BufferId) {
+ _, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0)
+}
+
+func RegisterBuffer(buffer []byte) (BufferId, error) {
+ var buf unsafe.Pointer
+ if len(buffer) > 0 {
+ buf = unsafe.Pointer(&buffer[0])
+ }
+ return RegisterPointer(buf, uint32(len(buffer)))
+}
+
+func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) {
+ ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0)
+ if ret == invalidBufferId {
+ return 0, err
+ }
+ return BufferId(ret), nil
+}
+
+func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
+ ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
+ if ret == 0 {
+ return err
+ }
+ return nil
+}
+
+func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
+ ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
+ if ret == 0 {
+ return err
+ }
+ return nil
+}
+
+func Notify(cq Cq) error {
+ ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0)
+ if ret != 0 {
+ return windows.Errno(ret)
+ }
+ return nil
+}
diff --git a/device/allowedips.go b/device/allowedips.go
index efc27c0..fa46f97 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -1,173 +1,201 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
+ "container/list"
+ "encoding/binary"
"errors"
"math/bits"
"net"
+ "net/netip"
"sync"
"unsafe"
)
-type trieEntry struct {
- cidr uint
- child [2]*trieEntry
- bits net.IP
- peer *Peer
-
- // index of "branching" bit
-
- bit_at_byte uint
- bit_at_shift uint
-}
-
-func isLittleEndian() bool {
- one := uint32(1)
- return *(*byte)(unsafe.Pointer(&one)) != 0
+type parentIndirection struct {
+ parentBit **trieEntry
+ parentBitType uint8
}
-func swapU32(i uint32) uint32 {
- if !isLittleEndian() {
- return i
- }
-
- return bits.ReverseBytes32(i)
-}
-
-func swapU64(i uint64) uint64 {
- if !isLittleEndian() {
- return i
- }
-
- return bits.ReverseBytes64(i)
+type trieEntry struct {
+ peer *Peer
+ child [2]*trieEntry
+ parent parentIndirection
+ cidr uint8
+ bitAtByte uint8
+ bitAtShift uint8
+ bits []byte
+ perPeerElem *list.Element
}
-func commonBits(ip1 net.IP, ip2 net.IP) uint {
+func commonBits(ip1, ip2 []byte) uint8 {
size := len(ip1)
if size == net.IPv4len {
- a := (*uint32)(unsafe.Pointer(&ip1[0]))
- b := (*uint32)(unsafe.Pointer(&ip2[0]))
- x := *a ^ *b
- return uint(bits.LeadingZeros32(swapU32(x)))
+ a := binary.BigEndian.Uint32(ip1)
+ b := binary.BigEndian.Uint32(ip2)
+ x := a ^ b
+ return uint8(bits.LeadingZeros32(x))
} else if size == net.IPv6len {
- a := (*uint64)(unsafe.Pointer(&ip1[0]))
- b := (*uint64)(unsafe.Pointer(&ip2[0]))
- x := *a ^ *b
+ a := binary.BigEndian.Uint64(ip1)
+ b := binary.BigEndian.Uint64(ip2)
+ x := a ^ b
if x != 0 {
- return uint(bits.LeadingZeros64(swapU64(x)))
+ return uint8(bits.LeadingZeros64(x))
}
- a = (*uint64)(unsafe.Pointer(&ip1[8]))
- b = (*uint64)(unsafe.Pointer(&ip2[8]))
- x = *a ^ *b
- return 64 + uint(bits.LeadingZeros64(swapU64(x)))
+ a = binary.BigEndian.Uint64(ip1[8:])
+ b = binary.BigEndian.Uint64(ip2[8:])
+ x = a ^ b
+ return 64 + uint8(bits.LeadingZeros64(x))
} else {
panic("Wrong size bit string")
}
}
-func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
- if node == nil {
- return node
- }
-
- // walk recursively
-
- node.child[0] = node.child[0].removeByPeer(p)
- node.child[1] = node.child[1].removeByPeer(p)
+func (node *trieEntry) addToPeerEntries() {
+ node.perPeerElem = node.peer.trieEntries.PushBack(node)
+}
- if node.peer != p {
- return node
+func (node *trieEntry) removeFromPeerEntries() {
+ if node.perPeerElem != nil {
+ node.peer.trieEntries.Remove(node.perPeerElem)
+ node.perPeerElem = nil
}
+}
- // remove peer & merge
+func (node *trieEntry) choose(ip []byte) byte {
+ return (ip[node.bitAtByte] >> node.bitAtShift) & 1
+}
- node.peer = nil
- if node.child[0] == nil {
- return node.child[1]
+func (node *trieEntry) maskSelf() {
+ mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
+ for i := 0; i < len(mask); i++ {
+ node.bits[i] &= mask[i]
}
- return node.child[0]
}
-func (node *trieEntry) choose(ip net.IP) byte {
- return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
+func (node *trieEntry) zeroizePointers() {
+ // Make the garbage collector's life slightly easier
+ node.peer = nil
+ node.child[0] = nil
+ node.child[1] = nil
+ node.parent.parentBit = nil
}
-func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
-
- // at leaf
-
- if node == nil {
- return &trieEntry{
- bits: ip,
- peer: peer,
- cidr: cidr,
- bit_at_byte: cidr / 8,
- bit_at_shift: 7 - (cidr % 8),
+func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
+ for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
+ parent = node
+ if parent.cidr == cidr {
+ exact = true
+ return
}
+ bit := node.choose(ip)
+ node = node.child[bit]
}
+ return
+}
- // traverse deeper
-
- common := commonBits(node.bits, ip)
- if node.cidr <= cidr && common >= node.cidr {
- if node.cidr == cidr {
- node.peer = peer
- return node
+func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
+ if *trie.parentBit == nil {
+ node := &trieEntry{
+ peer: peer,
+ parent: trie,
+ bits: ip,
+ cidr: cidr,
+ bitAtByte: cidr / 8,
+ bitAtShift: 7 - (cidr % 8),
}
- bit := node.choose(ip)
- node.child[bit] = node.child[bit].insert(ip, cidr, peer)
- return node
+ node.maskSelf()
+ node.addToPeerEntries()
+ *trie.parentBit = node
+ return
+ }
+ node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
+ if exact {
+ node.removeFromPeerEntries()
+ node.peer = peer
+ node.addToPeerEntries()
+ return
}
-
- // split node
newNode := &trieEntry{
- bits: ip,
- peer: peer,
- cidr: cidr,
- bit_at_byte: cidr / 8,
- bit_at_shift: 7 - (cidr % 8),
+ peer: peer,
+ bits: ip,
+ cidr: cidr,
+ bitAtByte: cidr / 8,
+ bitAtShift: 7 - (cidr % 8),
}
+ newNode.maskSelf()
+ newNode.addToPeerEntries()
- cidr = min(cidr, common)
-
- // check for shorter prefix
+ var down *trieEntry
+ if node == nil {
+ down = *trie.parentBit
+ } else {
+ bit := node.choose(ip)
+ down = node.child[bit]
+ if down == nil {
+ newNode.parent = parentIndirection{&node.child[bit], bit}
+ node.child[bit] = newNode
+ return
+ }
+ }
+ common := commonBits(down.bits, ip)
+ if common < cidr {
+ cidr = common
+ }
+ parent := node
if newNode.cidr == cidr {
- bit := newNode.choose(node.bits)
- newNode.child[bit] = node
- return newNode
+ bit := newNode.choose(down.bits)
+ down.parent = parentIndirection{&newNode.child[bit], bit}
+ newNode.child[bit] = down
+ if parent == nil {
+ newNode.parent = trie
+ *trie.parentBit = newNode
+ } else {
+ bit := parent.choose(newNode.bits)
+ newNode.parent = parentIndirection{&parent.child[bit], bit}
+ parent.child[bit] = newNode
+ }
+ return
}
- // create new parent for node & newNode
-
- parent := &trieEntry{
- bits: ip,
- peer: nil,
- cidr: cidr,
- bit_at_byte: cidr / 8,
- bit_at_shift: 7 - (cidr % 8),
+ node = &trieEntry{
+ bits: append([]byte{}, newNode.bits...),
+ cidr: cidr,
+ bitAtByte: cidr / 8,
+ bitAtShift: 7 - (cidr % 8),
+ }
+ node.maskSelf()
+
+ bit := node.choose(down.bits)
+ down.parent = parentIndirection{&node.child[bit], bit}
+ node.child[bit] = down
+ bit = node.choose(newNode.bits)
+ newNode.parent = parentIndirection{&node.child[bit], bit}
+ node.child[bit] = newNode
+ if parent == nil {
+ node.parent = trie
+ *trie.parentBit = node
+ } else {
+ bit := parent.choose(node.bits)
+ node.parent = parentIndirection{&parent.child[bit], bit}
+ parent.child[bit] = node
}
-
- bit := parent.choose(ip)
- parent.child[bit] = newNode
- parent.child[bit^1] = node
-
- return parent
}
-func (node *trieEntry) lookup(ip net.IP) *Peer {
+func (node *trieEntry) lookup(ip []byte) *Peer {
var found *Peer
- size := uint(len(ip))
+ size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
if node.peer != nil {
found = node.peer
}
- if node.bit_at_byte == size {
+ if node.bitAtByte == size {
break
}
bit := node.choose(ip)
@@ -176,76 +204,91 @@ func (node *trieEntry) lookup(ip net.IP) *Peer {
return found
}
-func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
- if node == nil {
- return results
- }
- if node.peer == p {
- mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
- results = append(results, net.IPNet{
- Mask: mask,
- IP: node.bits.Mask(mask),
- })
- }
- results = node.child[0].entriesForPeer(p, results)
- results = node.child[1].entriesForPeer(p, results)
- return results
-}
-
type AllowedIPs struct {
IPv4 *trieEntry
IPv6 *trieEntry
mutex sync.RWMutex
}
-func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet {
+func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
table.mutex.RLock()
defer table.mutex.RUnlock()
- allowed := make([]net.IPNet, 0, 10)
- allowed = table.IPv4.entriesForPeer(peer, allowed)
- allowed = table.IPv6.entriesForPeer(peer, allowed)
- return allowed
-}
-
-func (table *AllowedIPs) Reset() {
- table.mutex.Lock()
- defer table.mutex.Unlock()
-
- table.IPv4 = nil
- table.IPv6 = nil
+ for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
+ node := elem.Value.(*trieEntry)
+ a, _ := netip.AddrFromSlice(node.bits)
+ if !cb(netip.PrefixFrom(a, int(node.cidr))) {
+ return
+ }
+ }
}
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
- table.IPv4 = table.IPv4.removeByPeer(peer)
- table.IPv6 = table.IPv6.removeByPeer(peer)
+ var next *list.Element
+ for elem := peer.trieEntries.Front(); elem != nil; elem = next {
+ next = elem.Next()
+ node := elem.Value.(*trieEntry)
+
+ node.removeFromPeerEntries()
+ node.peer = nil
+ if node.child[0] != nil && node.child[1] != nil {
+ continue
+ }
+ bit := 0
+ if node.child[0] == nil {
+ bit = 1
+ }
+ child := node.child[bit]
+ if child != nil {
+ child.parent = node.parent
+ }
+ *node.parent.parentBit = child
+ if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
+ node.zeroizePointers()
+ continue
+ }
+ parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
+ if parent.peer != nil {
+ node.zeroizePointers()
+ continue
+ }
+ child = parent.child[node.parent.parentBitType^1]
+ if child != nil {
+ child.parent = parent.parent
+ }
+ *parent.parent.parentBit = child
+ node.zeroizePointers()
+ parent.zeroizePointers()
+ }
}
-func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
+func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
- switch len(ip) {
- case net.IPv6len:
- table.IPv6 = table.IPv6.insert(ip, cidr, peer)
- case net.IPv4len:
- table.IPv4 = table.IPv4.insert(ip, cidr, peer)
- default:
+ if prefix.Addr().Is6() {
+ ip := prefix.Addr().As16()
+ parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+ } else if prefix.Addr().Is4() {
+ ip := prefix.Addr().As4()
+ parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+ } else {
panic(errors.New("inserting unknown address type"))
}
}
-func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
+func (table *AllowedIPs) Lookup(ip []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
- return table.IPv4.lookup(address)
-}
-
-func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
- table.mutex.RLock()
- defer table.mutex.RUnlock()
- return table.IPv6.lookup(address)
+ switch len(ip) {
+ case net.IPv6len:
+ return table.IPv6.lookup(ip)
+ case net.IPv4len:
+ return table.IPv4.lookup(ip)
+ default:
+ panic(errors.New("looking up unknown address type"))
+ }
}
diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go
index 59c10f7..07065c3 100644
--- a/device/allowedips_rand_test.go
+++ b/device/allowedips_rand_test.go
@@ -1,25 +1,28 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"math/rand"
+ "net"
+ "net/netip"
"sort"
"testing"
)
const (
- NumberOfPeers = 100
- NumberOfAddresses = 250
- NumberOfTests = 10000
+ NumberOfPeers = 100
+ NumberOfPeerRemovals = 4
+ NumberOfAddresses = 250
+ NumberOfTests = 10000
)
type SlowNode struct {
peer *Peer
- cidr uint
+ cidr uint8
bits []byte
}
@@ -37,7 +40,7 @@ func (r SlowRouter) Swap(i, j int) {
r[i], r[j] = r[j], r[i]
}
-func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter {
+func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
for _, t := range r {
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
t.peer = peer
@@ -64,68 +67,75 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
return nil
}
-func TestTrieRandomIPv4(t *testing.T) {
- var trie *trieEntry
- var slow SlowRouter
- var peers []*Peer
-
- rand.Seed(1)
-
- const AddressLength = 4
-
- for n := 0; n < NumberOfPeers; n += 1 {
- peers = append(peers, &Peer{})
- }
-
- for n := 0; n < NumberOfAddresses; n += 1 {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- cidr := uint(rand.Uint32() % (AddressLength * 8))
- index := rand.Int() % NumberOfPeers
- trie = trie.insert(addr[:], cidr, peers[index])
- slow = slow.Insert(addr[:], cidr, peers[index])
- }
-
- for n := 0; n < NumberOfTests; n += 1 {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- peer1 := slow.Lookup(addr[:])
- peer2 := trie.lookup(addr[:])
- if peer1 != peer2 {
- t.Error("Trie did not match naive implementation, for:", addr)
+func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
+ n := 0
+ for _, x := range r {
+ if x.peer != peer {
+ r[n] = x
+ n++
}
}
+ return r[:n]
}
-func TestTrieRandomIPv6(t *testing.T) {
- var trie *trieEntry
- var slow SlowRouter
+func TestTrieRandom(t *testing.T) {
+ var slow4, slow6 SlowRouter
var peers []*Peer
+ var allowedIPs AllowedIPs
rand.Seed(1)
- const AddressLength = 16
-
- for n := 0; n < NumberOfPeers; n += 1 {
+ for n := 0; n < NumberOfPeers; n++ {
peers = append(peers, &Peer{})
}
- for n := 0; n < NumberOfAddresses; n += 1 {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- cidr := uint(rand.Uint32() % (AddressLength * 8))
- index := rand.Int() % NumberOfPeers
- trie = trie.insert(addr[:], cidr, peers[index])
- slow = slow.Insert(addr[:], cidr, peers[index])
+ for n := 0; n < NumberOfAddresses; n++ {
+ var addr4 [4]byte
+ rand.Read(addr4[:])
+ cidr := uint8(rand.Intn(32) + 1)
+ index := rand.Intn(NumberOfPeers)
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
+ slow4 = slow4.Insert(addr4[:], cidr, peers[index])
+
+ var addr6 [16]byte
+ rand.Read(addr6[:])
+ cidr = uint8(rand.Intn(128) + 1)
+ index = rand.Intn(NumberOfPeers)
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
+ slow6 = slow6.Insert(addr6[:], cidr, peers[index])
}
- for n := 0; n < NumberOfTests; n += 1 {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- peer1 := slow.Lookup(addr[:])
- peer2 := trie.lookup(addr[:])
- if peer1 != peer2 {
- t.Error("Trie did not match naive implementation, for:", addr)
+ var p int
+ for p = 0; ; p++ {
+ for n := 0; n < NumberOfTests; n++ {
+ var addr4 [4]byte
+ rand.Read(addr4[:])
+ peer1 := slow4.Lookup(addr4[:])
+ peer2 := allowedIPs.Lookup(addr4[:])
+ if peer1 != peer2 {
+ t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
+ }
+
+ var addr6 [16]byte
+ rand.Read(addr6[:])
+ peer1 = slow6.Lookup(addr6[:])
+ peer2 = allowedIPs.Lookup(addr6[:])
+ if peer1 != peer2 {
+ t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
+ }
+ }
+ if p >= len(peers) || p >= NumberOfPeerRemovals {
+ break
}
+ allowedIPs.RemoveByPeer(peers[p])
+ slow4 = slow4.RemoveByPeer(peers[p])
+ slow6 = slow6.RemoveByPeer(peers[p])
+ }
+ for ; p < len(peers); p++ {
+ allowedIPs.RemoveByPeer(peers[p])
+ }
+
+ if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
+ t.Error("Failed to remove all nodes from trie by peer")
}
}
diff --git a/device/allowedips_test.go b/device/allowedips_test.go
index 075ff06..cde068e 100644
--- a/device/allowedips_test.go
+++ b/device/allowedips_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -8,40 +8,17 @@ package device
import (
"math/rand"
"net"
+ "net/netip"
"testing"
)
-/* Todo: More comprehensive
- */
-
type testPairCommonBits struct {
s1 []byte
s2 []byte
- match uint
-}
-
-type testPairTrieInsert struct {
- key []byte
- cidr uint
- peer *Peer
-}
-
-type testPairTrieLookup struct {
- key []byte
- peer *Peer
-}
-
-func printTrie(t *testing.T, p *trieEntry) {
- if p == nil {
- return
- }
- t.Log(p)
- printTrie(t, p.child[0])
- printTrie(t, p.child[1])
+ match uint8
}
func TestCommonBits(t *testing.T) {
-
tests := []testPairCommonBits{
{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
@@ -62,27 +39,28 @@ func TestCommonBits(t *testing.T) {
}
}
-func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
+func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
var trie *trieEntry
var peers []*Peer
+ root := parentIndirection{&trie, 2}
rand.Seed(1)
const AddressLength = 4
- for n := 0; n < peerNumber; n += 1 {
+ for n := 0; n < peerNumber; n++ {
peers = append(peers, &Peer{})
}
- for n := 0; n < addressNumber; n += 1 {
+ for n := 0; n < addressNumber; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
- cidr := uint(rand.Uint32() % (AddressLength * 8))
+ cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % peerNumber
- trie = trie.insert(addr[:], cidr, peers[index])
+ root.insert(addr[:], cidr, peers[index])
}
- for n := 0; n < b.N; n += 1 {
+ for n := 0; n < b.N; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
trie.lookup(addr[:])
@@ -117,21 +95,21 @@ func TestTrieIPv4(t *testing.T) {
g := &Peer{}
h := &Peer{}
- var trie *trieEntry
+ var allowedIPs AllowedIPs
- insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
- trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
+ insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d byte) {
- p := trie.lookup([]byte{a, b, c, d})
+ p := allowedIPs.Lookup([]byte{a, b, c, d})
if p != peer {
t.Error("Assert EQ failed")
}
}
assertNEQ := func(peer *Peer, a, b, c, d byte) {
- p := trie.lookup([]byte{a, b, c, d})
+ p := allowedIPs.Lookup([]byte{a, b, c, d})
if p == peer {
t.Error("Assert NEQ failed")
}
@@ -173,7 +151,7 @@ func TestTrieIPv4(t *testing.T) {
assertEQ(a, 192, 0, 0, 0)
assertEQ(a, 255, 0, 0, 0)
- trie = trie.removeByPeer(a)
+ allowedIPs.RemoveByPeer(a)
assertNEQ(a, 1, 0, 0, 0)
assertNEQ(a, 64, 0, 0, 0)
@@ -181,12 +159,21 @@ func TestTrieIPv4(t *testing.T) {
assertNEQ(a, 192, 0, 0, 0)
assertNEQ(a, 255, 0, 0, 0)
- trie = nil
+ allowedIPs.RemoveByPeer(a)
+ allowedIPs.RemoveByPeer(b)
+ allowedIPs.RemoveByPeer(c)
+ allowedIPs.RemoveByPeer(d)
+ allowedIPs.RemoveByPeer(e)
+ allowedIPs.RemoveByPeer(g)
+ allowedIPs.RemoveByPeer(h)
+ if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
+ t.Error("Expected removing all the peers to empty trie, but it did not")
+ }
insert(a, 192, 168, 0, 0, 16)
insert(a, 192, 168, 0, 0, 24)
- trie = trie.removeByPeer(a)
+ allowedIPs.RemoveByPeer(a)
assertNEQ(a, 192, 168, 0, 1)
}
@@ -204,7 +191,7 @@ func TestTrieIPv6(t *testing.T) {
g := &Peer{}
h := &Peer{}
- var trie *trieEntry
+ var allowedIPs AllowedIPs
expand := func(a uint32) []byte {
var out [4]byte
@@ -215,13 +202,13 @@ func TestTrieIPv6(t *testing.T) {
return out[:]
}
- insert := func(peer *Peer, a, b, c, d uint32, cidr uint) {
+ insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
var addr []byte
addr = append(addr, expand(a)...)
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
- trie = trie.insert(addr, cidr, peer)
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d uint32) {
@@ -230,7 +217,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
- p := trie.lookup(addr)
+ p := allowedIPs.Lookup(addr)
if p != peer {
t.Error("Assert EQ failed")
}
diff --git a/device/bind_test.go b/device/bind_test.go
index 0c2e2cf..302a521 100644
--- a/device/bind_test.go
+++ b/device/bind_test.go
@@ -1,23 +1,24 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
-import "errors"
+import (
+ "errors"
+
+ "golang.zx2c4.com/wireguard/conn"
+)
type DummyDatagram struct {
msg []byte
- endpoint Endpoint
- world bool // better type
+ endpoint conn.Endpoint
}
type DummyBind struct {
in6 chan DummyDatagram
- ou6 chan DummyDatagram
in4 chan DummyDatagram
- ou4 chan DummyDatagram
closed bool
}
@@ -25,21 +26,21 @@ func (b *DummyBind) SetMark(v uint32) error {
return nil
}
-func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+func (b *DummyBind) ReceiveIPv6(buf []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in6
if !ok {
return 0, nil, errors.New("closed")
}
- copy(buff, datagram.msg)
+ copy(buf, datagram.msg)
return len(datagram.msg), datagram.endpoint, nil
}
-func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+func (b *DummyBind) ReceiveIPv4(buf []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in4
if !ok {
return 0, nil, errors.New("closed")
}
- copy(buff, datagram.msg)
+ copy(buf, datagram.msg)
return len(datagram.msg), datagram.endpoint, nil
}
@@ -50,6 +51,6 @@ func (b *DummyBind) Close() error {
return nil
}
-func (b *DummyBind) Send(buff []byte, end Endpoint) error {
+func (b *DummyBind) Send(buf []byte, end conn.Endpoint) error {
return nil
}
diff --git a/device/boundif_android.go b/device/boundif_android.go
deleted file mode 100644
index 6d0fecf..0000000
--- a/device/boundif_android.go
+++ /dev/null
@@ -1,44 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import "errors"
-
-func (device *Device) PeekLookAtSocketFd4() (fd int, err error) {
- nb, ok := device.net.bind.(*nativeBind)
- if !ok {
- return 0, errors.New("no socket exists")
- }
- sysconn, err := nb.ipv4.SyscallConn()
- if err != nil {
- return
- }
- err = sysconn.Control(func(f uintptr) {
- fd = int(f)
- })
- if err != nil {
- return
- }
- return
-}
-
-func (device *Device) PeekLookAtSocketFd6() (fd int, err error) {
- nb, ok := device.net.bind.(*nativeBind)
- if !ok {
- return 0, errors.New("no socket exists")
- }
- sysconn, err := nb.ipv6.SyscallConn()
- if err != nil {
- return
- }
- err = sysconn.Control(func(f uintptr) {
- fd = int(f)
- })
- if err != nil {
- return
- }
- return
-}
diff --git a/device/boundif_windows.go b/device/boundif_windows.go
deleted file mode 100644
index 6908415..0000000
--- a/device/boundif_windows.go
+++ /dev/null
@@ -1,64 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import (
- "encoding/binary"
- "errors"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-const (
- sockoptIP_UNICAST_IF = 31
- sockoptIPV6_UNICAST_IF = 31
-)
-
-func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
- /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
- bytes := make([]byte, 4)
- binary.BigEndian.PutUint32(bytes, interfaceIndex)
- interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
-
- if device.net.bind == nil {
- return errors.New("Bind is not yet initialized")
- }
-
- sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
- if err != nil {
- return err
- }
- err2 := sysconn.Control(func(fd uintptr) {
- err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, sockoptIP_UNICAST_IF, int(interfaceIndex))
- })
- if err2 != nil {
- return err2
- }
- if err != nil {
- return err
- }
- device.net.bind.(*nativeBind).blackhole4 = blackhole
- return nil
-}
-
-func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
- sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
- if err != nil {
- return err
- }
- err2 := sysconn.Control(func(fd uintptr) {
- err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, sockoptIPV6_UNICAST_IF, int(interfaceIndex))
- })
- if err2 != nil {
- return err2
- }
- if err != nil {
- return err
- }
- device.net.bind.(*nativeBind).blackhole6 = blackhole
- return nil
-}
diff --git a/device/channels.go b/device/channels.go
new file mode 100644
index 0000000..e526f6b
--- /dev/null
+++ b/device/channels.go
@@ -0,0 +1,137 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "runtime"
+ "sync"
+)
+
+// An outboundQueue is a channel of QueueOutboundElements awaiting encryption.
+// An outboundQueue is ref-counted using its wg field.
+// An outboundQueue created with newOutboundQueue has one reference.
+// Every additional writer must call wg.Add(1).
+// Every completed writer must call wg.Done().
+// When no further writers will be added,
+// call wg.Done to remove the initial reference.
+// When the refcount hits 0, the queue's channel is closed.
+type outboundQueue struct {
+ c chan *QueueOutboundElementsContainer
+ wg sync.WaitGroup
+}
+
+func newOutboundQueue() *outboundQueue {
+ q := &outboundQueue{
+ c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
+ }
+ q.wg.Add(1)
+ go func() {
+ q.wg.Wait()
+ close(q.c)
+ }()
+ return q
+}
+
+// A inboundQueue is similar to an outboundQueue; see those docs.
+type inboundQueue struct {
+ c chan *QueueInboundElementsContainer
+ wg sync.WaitGroup
+}
+
+func newInboundQueue() *inboundQueue {
+ q := &inboundQueue{
+ c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
+ }
+ q.wg.Add(1)
+ go func() {
+ q.wg.Wait()
+ close(q.c)
+ }()
+ return q
+}
+
+// A handshakeQueue is similar to an outboundQueue; see those docs.
+type handshakeQueue struct {
+ c chan QueueHandshakeElement
+ wg sync.WaitGroup
+}
+
+func newHandshakeQueue() *handshakeQueue {
+ q := &handshakeQueue{
+ c: make(chan QueueHandshakeElement, QueueHandshakeSize),
+ }
+ q.wg.Add(1)
+ go func() {
+ q.wg.Wait()
+ close(q.c)
+ }()
+ return q
+}
+
+type autodrainingInboundQueue struct {
+ c chan *QueueInboundElementsContainer
+}
+
+// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
+// It is useful in cases in which is it hard to manage the lifetime of the channel.
+// The returned channel must not be closed. Senders should signal shutdown using
+// some other means, such as sending a sentinel nil values.
+func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
+ q := &autodrainingInboundQueue{
+ c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
+ }
+ runtime.SetFinalizer(q, device.flushInboundQueue)
+ return q
+}
+
+func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
+ for {
+ select {
+ case elemsContainer := <-q.c:
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
+ }
+ device.PutInboundElementsContainer(elemsContainer)
+ default:
+ return
+ }
+ }
+}
+
+type autodrainingOutboundQueue struct {
+ c chan *QueueOutboundElementsContainer
+}
+
+// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
+// It is useful in cases in which is it hard to manage the lifetime of the channel.
+// The returned channel must not be closed. Senders should signal shutdown using
+// some other means, such as sending a sentinel nil values.
+// All sends to the channel must be best-effort, because there may be no receivers.
+func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
+ q := &autodrainingOutboundQueue{
+ c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
+ }
+ runtime.SetFinalizer(q, device.flushOutboundQueue)
+ return q
+}
+
+func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
+ for {
+ select {
+ case elemsContainer := <-q.c:
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
+ device.PutOutboundElementsContainer(elemsContainer)
+ default:
+ return
+ }
+ }
+}
diff --git a/device/conn.go b/device/conn.go
deleted file mode 100644
index 7b341f6..0000000
--- a/device/conn.go
+++ /dev/null
@@ -1,187 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import (
- "errors"
- "net"
- "strings"
-
- "golang.org/x/net/ipv4"
- "golang.org/x/net/ipv6"
-)
-
-const (
- ConnRoutineNumber = 2
-)
-
-/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
- */
-type Bind interface {
- SetMark(value uint32) error
- ReceiveIPv6(buff []byte) (int, Endpoint, error)
- ReceiveIPv4(buff []byte) (int, Endpoint, error)
- Send(buff []byte, end Endpoint) error
- Close() error
-}
-
-/* An Endpoint maintains the source/destination caching for a peer
- *
- * dst : the remote address of a peer ("endpoint" in uapi terminology)
- * src : the local address from which datagrams originate going to the peer
- */
-type Endpoint interface {
- ClearSrc() // clears the source address
- SrcToString() string // returns the local source address (ip:port)
- DstToString() string // returns the destination address (ip:port)
- DstToBytes() []byte // used for mac2 cookie calculations
- DstIP() net.IP
- SrcIP() net.IP
-}
-
-func parseEndpoint(s string) (*net.UDPAddr, error) {
- // ensure that the host is an IP address
-
- host, _, err := net.SplitHostPort(s)
- if err != nil {
- return nil, err
- }
- if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
- // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
- // trying to make sure with a small sanity test that this is a real IP address and
- // not something that's likely to incur DNS lookups.
- host = host[:i]
- }
- if ip := net.ParseIP(host); ip == nil {
- return nil, errors.New("Failed to parse IP address: " + host)
- }
-
- // parse address and port
-
- addr, err := net.ResolveUDPAddr("udp", s)
- if err != nil {
- return nil, err
- }
- ip4 := addr.IP.To4()
- if ip4 != nil {
- addr.IP = ip4
- }
- return addr, err
-}
-
-func unsafeCloseBind(device *Device) error {
- var err error
- netc := &device.net
- if netc.bind != nil {
- err = netc.bind.Close()
- netc.bind = nil
- }
- netc.stopping.Wait()
- return err
-}
-
-func (device *Device) BindSetMark(mark uint32) error {
-
- device.net.Lock()
- defer device.net.Unlock()
-
- // check if modified
-
- if device.net.fwmark == mark {
- return nil
- }
-
- // update fwmark on existing bind
-
- device.net.fwmark = mark
- if device.isUp.Get() && device.net.bind != nil {
- if err := device.net.bind.SetMark(mark); err != nil {
- return err
- }
- }
-
- // clear cached source addresses
-
- device.peers.RLock()
- for _, peer := range device.peers.keyMap {
- peer.Lock()
- defer peer.Unlock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
- }
- device.peers.RUnlock()
-
- return nil
-}
-
-func (device *Device) BindUpdate() error {
-
- device.net.Lock()
- defer device.net.Unlock()
-
- // close existing sockets
-
- if err := unsafeCloseBind(device); err != nil {
- return err
- }
-
- // open new sockets
-
- if device.isUp.Get() {
-
- // bind to new port
-
- var err error
- netc := &device.net
- netc.bind, netc.port, err = CreateBind(netc.port, device)
- if err != nil {
- netc.bind = nil
- netc.port = 0
- return err
- }
-
- // set fwmark
-
- if netc.fwmark != 0 {
- err = netc.bind.SetMark(netc.fwmark)
- if err != nil {
- return err
- }
- }
-
- // clear cached source addresses
-
- device.peers.RLock()
- for _, peer := range device.peers.keyMap {
- peer.Lock()
- defer peer.Unlock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
- }
- device.peers.RUnlock()
-
- // start receiving routines
-
- device.net.starting.Add(ConnRoutineNumber)
- device.net.stopping.Add(ConnRoutineNumber)
- go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
- go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
- device.net.starting.Wait()
-
- device.log.Debug.Println("UDP bind has been updated")
- }
-
- return nil
-}
-
-func (device *Device) BindClose() error {
- device.net.Lock()
- err := unsafeCloseBind(device)
- device.net.Unlock()
- return err
-}
diff --git a/device/conn_default.go b/device/conn_default.go
deleted file mode 100644
index 661f57d..0000000
--- a/device/conn_default.go
+++ /dev/null
@@ -1,178 +0,0 @@
-// +build !linux android
-
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import (
- "net"
- "os"
- "syscall"
-)
-
-/* This code is meant to be a temporary solution
- * on platforms for which the sticky socket / source caching behavior
- * has not yet been implemented.
- *
- * See conn_linux.go for an implementation on the linux platform.
- */
-
-type nativeBind struct {
- ipv4 *net.UDPConn
- ipv6 *net.UDPConn
- blackhole4 bool
- blackhole6 bool
-}
-
-type NativeEndpoint net.UDPAddr
-
-var _ Bind = (*nativeBind)(nil)
-var _ Endpoint = (*NativeEndpoint)(nil)
-
-func CreateEndpoint(s string) (Endpoint, error) {
- addr, err := parseEndpoint(s)
- return (*NativeEndpoint)(addr), err
-}
-
-func (_ *NativeEndpoint) ClearSrc() {}
-
-func (e *NativeEndpoint) DstIP() net.IP {
- return (*net.UDPAddr)(e).IP
-}
-
-func (e *NativeEndpoint) SrcIP() net.IP {
- return nil // not supported
-}
-
-func (e *NativeEndpoint) DstToBytes() []byte {
- addr := (*net.UDPAddr)(e)
- out := addr.IP.To4()
- if out == nil {
- out = addr.IP
- }
- out = append(out, byte(addr.Port&0xff))
- out = append(out, byte((addr.Port>>8)&0xff))
- return out
-}
-
-func (e *NativeEndpoint) DstToString() string {
- return (*net.UDPAddr)(e).String()
-}
-
-func (e *NativeEndpoint) SrcToString() string {
- return ""
-}
-
-func listenNet(network string, port int) (*net.UDPConn, int, error) {
-
- // listen
-
- conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
- if err != nil {
- return nil, 0, err
- }
-
- // retrieve port
-
- laddr := conn.LocalAddr()
- uaddr, err := net.ResolveUDPAddr(
- laddr.Network(),
- laddr.String(),
- )
- if err != nil {
- return nil, 0, err
- }
- return conn, uaddr.Port, nil
-}
-
-func extractErrno(err error) error {
- opErr, ok := err.(*net.OpError)
- if !ok {
- return nil
- }
- syscallErr, ok := opErr.Err.(*os.SyscallError)
- if !ok {
- return nil
- }
- return syscallErr.Err
-}
-
-func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
- var err error
- var bind nativeBind
-
- port := int(uport)
-
- bind.ipv4, port, err = listenNet("udp4", port)
- if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT {
- return nil, 0, err
- }
-
- bind.ipv6, port, err = listenNet("udp6", port)
- if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT {
- bind.ipv4.Close()
- bind.ipv4 = nil
- return nil, 0, err
- }
-
- return &bind, uint16(port), nil
-}
-
-func (bind *nativeBind) Close() error {
- var err1, err2 error
- if bind.ipv4 != nil {
- err1 = bind.ipv4.Close()
- }
- if bind.ipv6 != nil {
- err2 = bind.ipv6.Close()
- }
- if err1 != nil {
- return err1
- }
- return err2
-}
-
-func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
- if bind.ipv4 == nil {
- return 0, nil, syscall.EAFNOSUPPORT
- }
- n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
- if endpoint != nil {
- endpoint.IP = endpoint.IP.To4()
- }
- return n, (*NativeEndpoint)(endpoint), err
-}
-
-func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
- if bind.ipv6 == nil {
- return 0, nil, syscall.EAFNOSUPPORT
- }
- n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
- return n, (*NativeEndpoint)(endpoint), err
-}
-
-func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error {
- var err error
- nend := endpoint.(*NativeEndpoint)
- if nend.IP.To4() != nil {
- if bind.ipv4 == nil {
- return syscall.EAFNOSUPPORT
- }
- if bind.blackhole4 {
- return nil
- }
- _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
- } else {
- if bind.ipv6 == nil {
- return syscall.EAFNOSUPPORT
- }
- if bind.blackhole6 {
- return nil
- }
- _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
- }
- return err
-}
diff --git a/device/conn_linux.go b/device/conn_linux.go
deleted file mode 100644
index f74ad51..0000000
--- a/device/conn_linux.go
+++ /dev/null
@@ -1,757 +0,0 @@
-// +build !android
-
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
- *
- * This implements userspace semantics of "sticky sockets", modeled after
- * WireGuard's kernelspace implementation. This is more or less a straight port
- * of the sticky-sockets.c example code:
- * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
- *
- * Currently there is no way to achieve this within the net package:
- * See e.g. https://github.com/golang/go/issues/17930
- * So this code is remains platform dependent.
- */
-
-package device
-
-import (
- "errors"
- "net"
- "strconv"
- "sync"
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/unix"
- "golang.zx2c4.com/wireguard/rwcancel"
-)
-
-const (
- FD_ERR = -1
-)
-
-type IPv4Source struct {
- src [4]byte
- ifindex int32
-}
-
-type IPv6Source struct {
- src [16]byte
- //ifindex belongs in dst.ZoneId
-}
-
-type NativeEndpoint struct {
- dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
- src [unsafe.Sizeof(IPv6Source{})]byte
- isV6 bool
-}
-
-func (endpoint *NativeEndpoint) src4() *IPv4Source {
- return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
-}
-
-func (endpoint *NativeEndpoint) src6() *IPv6Source {
- return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0]))
-}
-
-func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 {
- return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
-}
-
-func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
- return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
-}
-
-type nativeBind struct {
- sock4 int
- sock6 int
- netlinkSock int
- netlinkCancel *rwcancel.RWCancel
- lastMark uint32
-}
-
-var _ Endpoint = (*NativeEndpoint)(nil)
-var _ Bind = (*nativeBind)(nil)
-
-func CreateEndpoint(s string) (Endpoint, error) {
- var end NativeEndpoint
- addr, err := parseEndpoint(s)
- if err != nil {
- return nil, err
- }
-
- ipv4 := addr.IP.To4()
- if ipv4 != nil {
- dst := end.dst4()
- end.isV6 = false
- dst.Port = addr.Port
- copy(dst.Addr[:], ipv4)
- end.ClearSrc()
- return &end, nil
- }
-
- ipv6 := addr.IP.To16()
- if ipv6 != nil {
- zone, err := zoneToUint32(addr.Zone)
- if err != nil {
- return nil, err
- }
- dst := end.dst6()
- end.isV6 = true
- dst.Port = addr.Port
- dst.ZoneId = zone
- copy(dst.Addr[:], ipv6[:])
- end.ClearSrc()
- return &end, nil
- }
-
- return nil, errors.New("Invalid IP address")
-}
-
-func createNetlinkRouteSocket() (int, error) {
- sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
- if err != nil {
- return -1, err
- }
- saddr := &unix.SockaddrNetlink{
- Family: unix.AF_NETLINK,
- Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
- }
- err = unix.Bind(sock, saddr)
- if err != nil {
- unix.Close(sock)
- return -1, err
- }
- return sock, nil
-
-}
-
-func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
- var err error
- var bind nativeBind
- var newPort uint16
-
- bind.netlinkSock, err = createNetlinkRouteSocket()
- if err != nil {
- return nil, 0, err
- }
- bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock)
- if err != nil {
- unix.Close(bind.netlinkSock)
- return nil, 0, err
- }
-
- go bind.routineRouteListener(device)
-
- // attempt ipv6 bind, update port if succesful
-
- bind.sock6, newPort, err = create6(port)
- if err != nil {
- if err != syscall.EAFNOSUPPORT {
- bind.netlinkCancel.Cancel()
- return nil, 0, err
- }
- } else {
- port = newPort
- }
-
- // attempt ipv4 bind, update port if succesful
-
- bind.sock4, newPort, err = create4(port)
- if err != nil {
- if err != syscall.EAFNOSUPPORT {
- bind.netlinkCancel.Cancel()
- unix.Close(bind.sock6)
- return nil, 0, err
- }
- } else {
- port = newPort
- }
-
- if bind.sock4 == FD_ERR && bind.sock6 == FD_ERR {
- return nil, 0, errors.New("ipv4 and ipv6 not supported")
- }
-
- return &bind, port, nil
-}
-
-func (bind *nativeBind) SetMark(value uint32) error {
- if bind.sock6 != -1 {
- err := unix.SetsockoptInt(
- bind.sock6,
- unix.SOL_SOCKET,
- unix.SO_MARK,
- int(value),
- )
-
- if err != nil {
- return err
- }
- }
-
- if bind.sock4 != -1 {
- err := unix.SetsockoptInt(
- bind.sock4,
- unix.SOL_SOCKET,
- unix.SO_MARK,
- int(value),
- )
-
- if err != nil {
- return err
- }
- }
-
- bind.lastMark = value
- return nil
-}
-
-func closeUnblock(fd int) error {
- // shutdown to unblock readers and writers
- unix.Shutdown(fd, unix.SHUT_RDWR)
- return unix.Close(fd)
-}
-
-func (bind *nativeBind) Close() error {
- var err1, err2, err3 error
- if bind.sock6 != -1 {
- err1 = closeUnblock(bind.sock6)
- }
- if bind.sock4 != -1 {
- err2 = closeUnblock(bind.sock4)
- }
- err3 = bind.netlinkCancel.Cancel()
-
- if err1 != nil {
- return err1
- }
- if err2 != nil {
- return err2
- }
- return err3
-}
-
-func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
- var end NativeEndpoint
- if bind.sock6 == -1 {
- return 0, nil, syscall.EAFNOSUPPORT
- }
- n, err := receive6(
- bind.sock6,
- buff,
- &end,
- )
- return n, &end, err
-}
-
-func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
- var end NativeEndpoint
- if bind.sock4 == -1 {
- return 0, nil, syscall.EAFNOSUPPORT
- }
- n, err := receive4(
- bind.sock4,
- buff,
- &end,
- )
- return n, &end, err
-}
-
-func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
- nend := end.(*NativeEndpoint)
- if !nend.isV6 {
- if bind.sock4 == -1 {
- return syscall.EAFNOSUPPORT
- }
- return send4(bind.sock4, nend, buff)
- } else {
- if bind.sock6 == -1 {
- return syscall.EAFNOSUPPORT
- }
- return send6(bind.sock6, nend, buff)
- }
-}
-
-func (end *NativeEndpoint) SrcIP() net.IP {
- if !end.isV6 {
- return net.IPv4(
- end.src4().src[0],
- end.src4().src[1],
- end.src4().src[2],
- end.src4().src[3],
- )
- } else {
- return end.src6().src[:]
- }
-}
-
-func (end *NativeEndpoint) DstIP() net.IP {
- if !end.isV6 {
- return net.IPv4(
- end.dst4().Addr[0],
- end.dst4().Addr[1],
- end.dst4().Addr[2],
- end.dst4().Addr[3],
- )
- } else {
- return end.dst6().Addr[:]
- }
-}
-
-func (end *NativeEndpoint) DstToBytes() []byte {
- if !end.isV6 {
- return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
- } else {
- return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
- }
-}
-
-func (end *NativeEndpoint) SrcToString() string {
- return end.SrcIP().String()
-}
-
-func (end *NativeEndpoint) DstToString() string {
- var udpAddr net.UDPAddr
- udpAddr.IP = end.DstIP()
- if !end.isV6 {
- udpAddr.Port = end.dst4().Port
- } else {
- udpAddr.Port = end.dst6().Port
- }
- return udpAddr.String()
-}
-
-func (end *NativeEndpoint) ClearDst() {
- for i := range end.dst {
- end.dst[i] = 0
- }
-}
-
-func (end *NativeEndpoint) ClearSrc() {
- for i := range end.src {
- end.src[i] = 0
- }
-}
-
-func zoneToUint32(zone string) (uint32, error) {
- if zone == "" {
- return 0, nil
- }
- if intr, err := net.InterfaceByName(zone); err == nil {
- return uint32(intr.Index), nil
- }
- n, err := strconv.ParseUint(zone, 10, 32)
- return uint32(n), err
-}
-
-func create4(port uint16) (int, uint16, error) {
-
- // create socket
-
- fd, err := unix.Socket(
- unix.AF_INET,
- unix.SOCK_DGRAM,
- 0,
- )
-
- if err != nil {
- return FD_ERR, 0, err
- }
-
- addr := unix.SockaddrInet4{
- Port: int(port),
- }
-
- // set sockopts and bind
-
- if err := func() error {
- if err := unix.SetsockoptInt(
- fd,
- unix.SOL_SOCKET,
- unix.SO_REUSEADDR,
- 1,
- ); err != nil {
- return err
- }
-
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IP,
- unix.IP_PKTINFO,
- 1,
- ); err != nil {
- return err
- }
-
- return unix.Bind(fd, &addr)
- }(); err != nil {
- unix.Close(fd)
- return FD_ERR, 0, err
- }
-
- sa, err := unix.Getsockname(fd)
- if err == nil {
- addr.Port = sa.(*unix.SockaddrInet4).Port
- }
-
- return fd, uint16(addr.Port), err
-}
-
-func create6(port uint16) (int, uint16, error) {
-
- // create socket
-
- fd, err := unix.Socket(
- unix.AF_INET6,
- unix.SOCK_DGRAM,
- 0,
- )
-
- if err != nil {
- return FD_ERR, 0, err
- }
-
- // set sockopts and bind
-
- addr := unix.SockaddrInet6{
- Port: int(port),
- }
-
- if err := func() error {
-
- if err := unix.SetsockoptInt(
- fd,
- unix.SOL_SOCKET,
- unix.SO_REUSEADDR,
- 1,
- ); err != nil {
- return err
- }
-
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IPV6,
- unix.IPV6_RECVPKTINFO,
- 1,
- ); err != nil {
- return err
- }
-
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IPV6,
- unix.IPV6_V6ONLY,
- 1,
- ); err != nil {
- return err
- }
-
- return unix.Bind(fd, &addr)
-
- }(); err != nil {
- unix.Close(fd)
- return FD_ERR, 0, err
- }
-
- sa, err := unix.Getsockname(fd)
- if err == nil {
- addr.Port = sa.(*unix.SockaddrInet6).Port
- }
-
- return fd, uint16(addr.Port), err
-}
-
-func send4(sock int, end *NativeEndpoint, buff []byte) error {
-
- // construct message header
-
- cmsg := struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet4Pktinfo
- }{
- unix.Cmsghdr{
- Level: unix.IPPROTO_IP,
- Type: unix.IP_PKTINFO,
- Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
- },
- unix.Inet4Pktinfo{
- Spec_dst: end.src4().src,
- Ifindex: end.src4().ifindex,
- },
- }
-
- _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
-
- if err == nil {
- return nil
- }
-
- // clear src and retry
-
- if err == unix.EINVAL {
- end.ClearSrc()
- cmsg.pktinfo = unix.Inet4Pktinfo{}
- _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
- }
-
- return err
-}
-
-func send6(sock int, end *NativeEndpoint, buff []byte) error {
-
- // construct message header
-
- cmsg := struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet6Pktinfo
- }{
- unix.Cmsghdr{
- Level: unix.IPPROTO_IPV6,
- Type: unix.IPV6_PKTINFO,
- Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
- },
- unix.Inet6Pktinfo{
- Addr: end.src6().src,
- Ifindex: end.dst6().ZoneId,
- },
- }
-
- if cmsg.pktinfo.Addr == [16]byte{} {
- cmsg.pktinfo.Ifindex = 0
- }
-
- _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
-
- if err == nil {
- return nil
- }
-
- // clear src and retry
-
- if err == unix.EINVAL {
- end.ClearSrc()
- cmsg.pktinfo = unix.Inet6Pktinfo{}
- _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
- }
-
- return err
-}
-
-func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
-
- // contruct message header
-
- var cmsg struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet4Pktinfo
- }
-
- size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
-
- if err != nil {
- return 0, err
- }
- end.isV6 = false
-
- if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
- *end.dst4() = *newDst4
- }
-
- // update source cache
-
- if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
- cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
- cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
- end.src4().src = cmsg.pktinfo.Spec_dst
- end.src4().ifindex = cmsg.pktinfo.Ifindex
- }
-
- return size, nil
-}
-
-func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
-
- // contruct message header
-
- var cmsg struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet6Pktinfo
- }
-
- size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
-
- if err != nil {
- return 0, err
- }
- end.isV6 = true
-
- if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
- *end.dst6() = *newDst6
- }
-
- // update source cache
-
- if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
- cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
- cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
- end.src6().src = cmsg.pktinfo.Addr
- end.dst6().ZoneId = cmsg.pktinfo.Ifindex
- }
-
- return size, nil
-}
-
-func (bind *nativeBind) routineRouteListener(device *Device) {
- type peerEndpointPtr struct {
- peer *Peer
- endpoint *Endpoint
- }
- var reqPeer map[uint32]peerEndpointPtr
- var reqPeerLock sync.Mutex
-
- defer unix.Close(bind.netlinkSock)
-
- for msg := make([]byte, 1<<16); ; {
- var err error
- var msgn int
- for {
- msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
- if err == nil || !rwcancel.RetryAfterError(err) {
- break
- }
- if !bind.netlinkCancel.ReadyRead() {
- return
- }
- }
- if err != nil {
- return
- }
-
- for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
-
- hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
-
- if uint(hdr.Len) > uint(len(remain)) {
- break
- }
-
- switch hdr.Type {
- case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
- if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
- if uint(len(remain)) < uint(hdr.Len) {
- break
- }
- if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
- attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
- for {
- if uint(len(attr)) < uint(unix.SizeofRtAttr) {
- break
- }
- attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
- if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
- break
- }
- if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
- ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
- reqPeerLock.Lock()
- if reqPeer == nil {
- reqPeerLock.Unlock()
- break
- }
- pePtr, ok := reqPeer[hdr.Seq]
- reqPeerLock.Unlock()
- if !ok {
- break
- }
- pePtr.peer.Lock()
- if &pePtr.peer.endpoint != pePtr.endpoint {
- pePtr.peer.Unlock()
- break
- }
- if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx {
- pePtr.peer.Unlock()
- break
- }
- pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc()
- pePtr.peer.Unlock()
- }
- attr = attr[attrhdr.Len:]
- }
- }
- break
- }
- reqPeerLock.Lock()
- reqPeer = make(map[uint32]peerEndpointPtr)
- reqPeerLock.Unlock()
- go func() {
- device.peers.RLock()
- i := uint32(1)
- for _, peer := range device.peers.keyMap {
- peer.RLock()
- if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil {
- peer.RUnlock()
- continue
- }
- if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 {
- peer.RUnlock()
- break
- }
- nlmsg := struct {
- hdr unix.NlMsghdr
- msg unix.RtMsg
- dsthdr unix.RtAttr
- dst [4]byte
- srchdr unix.RtAttr
- src [4]byte
- markhdr unix.RtAttr
- mark uint32
- }{
- unix.NlMsghdr{
- Type: uint16(unix.RTM_GETROUTE),
- Flags: unix.NLM_F_REQUEST,
- Seq: i,
- },
- unix.RtMsg{
- Family: unix.AF_INET,
- Dst_len: 32,
- Src_len: 32,
- },
- unix.RtAttr{
- Len: 8,
- Type: unix.RTA_DST,
- },
- peer.endpoint.(*NativeEndpoint).dst4().Addr,
- unix.RtAttr{
- Len: 8,
- Type: unix.RTA_SRC,
- },
- peer.endpoint.(*NativeEndpoint).src4().src,
- unix.RtAttr{
- Len: 8,
- Type: unix.RTA_MARK,
- },
- uint32(bind.lastMark),
- }
- nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
- reqPeerLock.Lock()
- reqPeer[i] = peerEndpointPtr{
- peer: peer,
- endpoint: &peer.endpoint,
- }
- reqPeerLock.Unlock()
- peer.RUnlock()
- i++
- _, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
- if err != nil {
- break
- }
- }
- device.peers.RUnlock()
- }()
- }
- remain = remain[hdr.Len:]
- }
- }
-}
diff --git a/device/constants.go b/device/constants.go
index e316f32..59854a1 100644
--- a/device/constants.go
+++ b/device/constants.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -12,8 +12,8 @@ import (
/* Specification constants */
const (
- RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
- RejectAfterMessages = (1 << 64) - (1 << 4) - 1
+ RekeyAfterMessages = (1 << 60)
+ RejectAfterMessages = (1 << 64) - (1 << 13) - 1
RekeyAfterTime = time.Second * 120
RekeyAttemptTime = time.Second * 90
RekeyTimeout = time.Second * 5
@@ -35,7 +35,6 @@ const (
/* Implementation constants */
const (
- UnderLoadQueueSize = QueueHandshakeSize / 8
UnderLoadAfterTime = time.Second // how long does the device remain under load after detected
MaxPeers = 1 << 16 // maximum number of configured peers
)
diff --git a/device/cookie.go b/device/cookie.go
index f134128..876f05d 100644
--- a/device/cookie.go
+++ b/device/cookie.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -83,7 +83,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
return hmac.Equal(mac1[:], msg[smac1:smac2])
}
-func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
+func (st *CookieChecker) CheckMAC2(msg, src []byte) bool {
st.RLock()
defer st.RUnlock()
@@ -119,7 +119,6 @@ func (st *CookieChecker) CreateReply(
recv uint32,
src []byte,
) (*MessageCookieReply, error) {
-
st.RLock()
// refresh cookie secret
@@ -204,7 +203,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
_, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:])
-
if err != nil {
return false
}
@@ -215,7 +213,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
}
func (st *CookieGenerator) AddMacs(msg []byte) {
-
size := len(msg)
smac2 := size - blake2s.Size128
diff --git a/device/cookie_test.go b/device/cookie_test.go
index 79a6a86..4f1e50a 100644
--- a/device/cookie_test.go
+++ b/device/cookie_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -10,7 +10,6 @@ import (
)
func TestCookieMAC1(t *testing.T) {
-
// setup generator / checker
var (
@@ -132,12 +131,12 @@ func TestCookieMAC1(t *testing.T) {
msg[5] ^= 0x20
- srcBad1 := []byte{192, 168, 13, 37, 40, 01}
+ srcBad1 := []byte{192, 168, 13, 37, 40, 1}
if checker.CheckMAC2(msg, srcBad1) {
t.Fatal("MAC2 generation/verification failed")
}
- srcBad2 := []byte{192, 168, 13, 38, 40, 01}
+ srcBad2 := []byte{192, 168, 13, 38, 40, 1}
if checker.CheckMAC2(msg, srcBad2) {
t.Fatal("MAC2 generation/verification failed")
}
diff --git a/device/device.go b/device/device.go
index 569c5a8..83c33ee 100644
--- a/device/device.go
+++ b/device/device.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -11,37 +11,40 @@ import (
"sync/atomic"
"time"
+ "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ratelimiter"
+ "golang.zx2c4.com/wireguard/rwcancel"
"golang.zx2c4.com/wireguard/tun"
)
-const (
- DeviceRoutineNumberPerCPU = 3
- DeviceRoutineNumberAdditional = 2
-)
-
type Device struct {
- isUp AtomicBool // device is (going) up
- isClosed AtomicBool // device is closed? (acting as guard)
- log *Logger
-
- // synchronized resources (locks acquired in order)
-
state struct {
- starting sync.WaitGroup
+ // state holds the device's state. It is accessed atomically.
+ // Use the device.deviceState method to read it.
+ // device.deviceState does not acquire the mutex, so it captures only a snapshot.
+ // During state transitions, the state variable is updated before the device itself.
+ // The state is thus either the current state of the device or
+ // the intended future state of the device.
+ // For example, while executing a call to Up, state will be deviceStateUp.
+ // There is no guarantee that that intended future state of the device
+ // will become the actual state; Up can fail.
+ // The device can also change state multiple times between time of check and time of use.
+ // Unsynchronized uses of state must therefore be advisory/best-effort only.
+ state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
+ // stopping blocks until all inputs to Device have been closed.
stopping sync.WaitGroup
+ // mu protects state changes.
sync.Mutex
- changing AtomicBool
- current bool
}
net struct {
- starting sync.WaitGroup
stopping sync.WaitGroup
sync.RWMutex
- bind Bind // bind interface
- port uint16 // listening port
- fwmark uint32 // mark value (0 = disabled)
+ bind conn.Bind // bind interface
+ netlinkCancel *rwcancel.RWCancel
+ port uint16 // listening port
+ fwmark uint32 // mark value (0 = disabled)
+ brokenRoaming bool
}
staticIdentity struct {
@@ -51,153 +54,176 @@ type Device struct {
}
peers struct {
- sync.RWMutex
- keyMap map[NoisePublicKey]*Peer
+ sync.RWMutex // protects keyMap
+ keyMap map[NoisePublicKey]*Peer
}
- // unprotected / "self-synchronising resources"
+ rate struct {
+ underLoadUntil atomic.Int64
+ limiter ratelimiter.Ratelimiter
+ }
allowedips AllowedIPs
indexTable IndexTable
cookieChecker CookieChecker
- rate struct {
- underLoadUntil atomic.Value
- limiter ratelimiter.Ratelimiter
- }
-
pool struct {
- messageBufferPool *sync.Pool
- messageBufferReuseChan chan *[MaxMessageSize]byte
- inboundElementPool *sync.Pool
- inboundElementReuseChan chan *QueueInboundElement
- outboundElementPool *sync.Pool
- outboundElementReuseChan chan *QueueOutboundElement
+ inboundElementsContainer *WaitPool
+ outboundElementsContainer *WaitPool
+ messageBuffers *WaitPool
+ inboundElements *WaitPool
+ outboundElements *WaitPool
}
queue struct {
- encryption chan *QueueOutboundElement
- decryption chan *QueueInboundElement
- handshake chan QueueHandshakeElement
- }
-
- signals struct {
- stop chan struct{}
+ encryption *outboundQueue
+ decryption *inboundQueue
+ handshake *handshakeQueue
}
tun struct {
device tun.Device
- mtu int32
+ mtu atomic.Int32
}
+
+ ipcMutex sync.RWMutex
+ closed chan struct{}
+ log *Logger
}
-/* Converts the peer into a "zombie", which remains in the peer map,
- * but processes no packets and does not exists in the routing table.
- *
- * Must hold device.peers.Mutex
- */
-func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
+// deviceState represents the state of a Device.
+// There are three states: down, up, closed.
+// Transitions:
+//
+// down -----+
+// ↑↓ ↓
+// up -> closed
+type deviceState uint32
- // stop routing and processing of packets
+//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
+const (
+ deviceStateDown deviceState = iota
+ deviceStateUp
+ deviceStateClosed
+)
+
+// deviceState returns device.state.state as a deviceState
+// See those docs for how to interpret this value.
+func (device *Device) deviceState() deviceState {
+ return deviceState(device.state.state.Load())
+}
+
+// isClosed reports whether the device is closed (or is closing).
+// See device.state.state comments for how to interpret this value.
+func (device *Device) isClosed() bool {
+ return device.deviceState() == deviceStateClosed
+}
+
+// isUp reports whether the device is up (or is attempting to come up).
+// See device.state.state comments for how to interpret this value.
+func (device *Device) isUp() bool {
+ return device.deviceState() == deviceStateUp
+}
+// Must hold device.peers.Lock()
+func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) {
+ // stop routing and processing of packets
device.allowedips.RemoveByPeer(peer)
peer.Stop()
// remove from peer map
-
delete(device.peers.keyMap, key)
}
-func deviceUpdateState(device *Device) {
-
- // check if state already being updated (guard)
-
- if device.state.changing.Swap(true) {
- return
- }
-
- // compare to current state of device
-
+// changeState attempts to change the device state to match want.
+func (device *Device) changeState(want deviceState) (err error) {
device.state.Lock()
-
- newIsUp := device.isUp.Get()
-
- if newIsUp == device.state.current {
- device.state.changing.Set(false)
- device.state.Unlock()
- return
+ defer device.state.Unlock()
+ old := device.deviceState()
+ if old == deviceStateClosed {
+ // once closed, always closed
+ device.log.Verbosef("Interface closed, ignored requested state %s", want)
+ return nil
}
-
- // change state of device
-
- switch newIsUp {
- case true:
- if err := device.BindUpdate(); err != nil {
- device.log.Error.Printf("Unable to update bind: %v\n", err)
- device.isUp.Set(false)
+ switch want {
+ case old:
+ return nil
+ case deviceStateUp:
+ device.state.state.Store(uint32(deviceStateUp))
+ err = device.upLocked()
+ if err == nil {
break
}
- device.peers.RLock()
- for _, peer := range device.peers.keyMap {
- peer.Start()
- if peer.persistentKeepaliveInterval > 0 {
- peer.SendKeepalive()
- }
+ fallthrough // up failed; bring the device all the way back down
+ case deviceStateDown:
+ device.state.state.Store(uint32(deviceStateDown))
+ errDown := device.downLocked()
+ if err == nil {
+ err = errDown
}
- device.peers.RUnlock()
-
- case false:
- device.BindClose()
- device.peers.RLock()
- for _, peer := range device.peers.keyMap {
- peer.Stop()
- }
- device.peers.RUnlock()
}
+ device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState())
+ return
+}
- // update state variables
-
- device.state.current = newIsUp
- device.state.changing.Set(false)
- device.state.Unlock()
+// upLocked attempts to bring the device up and reports whether it succeeded.
+// The caller must hold device.state.mu and is responsible for updating device.state.state.
+func (device *Device) upLocked() error {
+ if err := device.BindUpdate(); err != nil {
+ device.log.Errorf("Unable to update bind: %v", err)
+ return err
+ }
- // check for state change in the mean time
+ // The IPC set operation waits for peers to be created before calling Start() on them,
+ // so if there's a concurrent IPC set request happening, we should wait for it to complete.
+ device.ipcMutex.Lock()
+ defer device.ipcMutex.Unlock()
- deviceUpdateState(device)
+ device.peers.RLock()
+ for _, peer := range device.peers.keyMap {
+ peer.Start()
+ if peer.persistentKeepaliveInterval.Load() > 0 {
+ peer.SendKeepalive()
+ }
+ }
+ device.peers.RUnlock()
+ return nil
}
-func (device *Device) Up() {
-
- // closed device cannot be brought up
+// downLocked attempts to bring the device down.
+// The caller must hold device.state.mu and is responsible for updating device.state.state.
+func (device *Device) downLocked() error {
+ err := device.BindClose()
+ if err != nil {
+ device.log.Errorf("Bind close failed: %v", err)
+ }
- if device.isClosed.Get() {
- return
+ device.peers.RLock()
+ for _, peer := range device.peers.keyMap {
+ peer.Stop()
}
+ device.peers.RUnlock()
+ return err
+}
- device.isUp.Set(true)
- deviceUpdateState(device)
+func (device *Device) Up() error {
+ return device.changeState(deviceStateUp)
}
-func (device *Device) Down() {
- device.isUp.Set(false)
- deviceUpdateState(device)
+func (device *Device) Down() error {
+ return device.changeState(deviceStateDown)
}
func (device *Device) IsUnderLoad() bool {
-
// check if currently under load
-
now := time.Now()
- underLoad := len(device.queue.handshake) >= UnderLoadQueueSize
+ underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
if underLoad {
- device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime))
+ device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
return true
}
-
// check if recently under load
-
- until := device.rate.underLoadUntil.Load().(time.Time)
- return until.After(now)
+ return device.rate.underLoadUntil.Load() > now.UnixNano()
}
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
@@ -224,7 +250,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
publicKey := sk.publicKey()
for key, peer := range device.peers.keyMap {
if peer.handshake.remoteStatic.Equals(publicKey) {
- unsafeRemovePeer(device, peer, key)
+ peer.handshake.mutex.RUnlock()
+ removePeerLocked(device, peer, key)
+ peer.handshake.mutex.RLock()
}
}
@@ -236,23 +264,11 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// do static-static DH pre-computations
- rmKey := device.staticIdentity.privateKey.IsZero()
-
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
- for key, peer := range device.peers.keyMap {
+ for _, peer := range device.peers.keyMap {
handshake := &peer.handshake
-
- if rmKey {
- handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{}
- } else {
- handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
- }
-
- if isZero(handshake.precomputedStaticStatic[:]) {
- unsafeRemovePeer(device, peer, key)
- } else {
- expiredPeers = append(expiredPeers, peer)
- }
+ handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
+ expiredPeers = append(expiredPeers, peer)
}
for _, peer := range lockedPeers {
@@ -265,68 +281,63 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
return nil
}
-func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
+func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device := new(Device)
-
- device.isUp.Set(false)
- device.isClosed.Set(false)
-
+ device.state.state.Store(uint32(deviceStateDown))
+ device.closed = make(chan struct{})
device.log = logger
-
+ device.net.bind = bind
device.tun.device = tunDevice
mtu, err := device.tun.device.MTU()
if err != nil {
- logger.Error.Println("Trouble determining MTU, assuming default:", err)
+ device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
mtu = DefaultMTU
}
- device.tun.mtu = int32(mtu)
-
+ device.tun.mtu.Store(int32(mtu))
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
-
device.rate.limiter.Init()
- device.rate.underLoadUntil.Store(time.Time{})
-
device.indexTable.Init()
- device.allowedips.Reset()
device.PopulatePools()
// create queues
- device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
- device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
- device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
-
- // prepare signals
-
- device.signals.stop = make(chan struct{})
-
- // prepare net
-
- device.net.port = 0
- device.net.bind = nil
+ device.queue.handshake = newHandshakeQueue()
+ device.queue.encryption = newOutboundQueue()
+ device.queue.decryption = newInboundQueue()
// start workers
cpus := runtime.NumCPU()
- device.state.starting.Wait()
device.state.stopping.Wait()
- device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
- device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
- for i := 0; i < cpus; i += 1 {
- go device.RoutineEncryption()
- go device.RoutineDecryption()
- go device.RoutineHandshake()
+ device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake
+ for i := 0; i < cpus; i++ {
+ go device.RoutineEncryption(i + 1)
+ go device.RoutineDecryption(i + 1)
+ go device.RoutineHandshake(i + 1)
}
+ device.state.stopping.Add(1) // RoutineReadFromTUN
+ device.queue.encryption.wg.Add(1) // RoutineReadFromTUN
go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader()
- device.state.starting.Wait()
-
return device
}
+// BatchSize returns the BatchSize for the device as a whole which is the max of
+// the bind batch size and the tun batch size. The batch size reported by device
+// is the size used to construct memory pools, and is the allowed batch size for
+// the lifetime of the device.
+func (device *Device) BatchSize() int {
+ size := device.net.bind.BatchSize()
+ dSize := device.tun.device.BatchSize()
+ if size < dSize {
+ size = dSize
+ }
+ return size
+}
+
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.peers.RLock()
defer device.peers.RUnlock()
@@ -341,7 +352,7 @@ func (device *Device) RemovePeer(key NoisePublicKey) {
peer, ok := device.peers.keyMap[key]
if ok {
- unsafeRemovePeer(device, peer, key)
+ removePeerLocked(device, peer, key)
}
}
@@ -350,67 +361,50 @@ func (device *Device) RemoveAllPeers() {
defer device.peers.Unlock()
for key, peer := range device.peers.keyMap {
- unsafeRemovePeer(device, peer, key)
+ removePeerLocked(device, peer, key)
}
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
}
-func (device *Device) FlushPacketQueues() {
- for {
- select {
- case elem, ok := <-device.queue.decryption:
- if ok {
- elem.Drop()
- }
- case elem, ok := <-device.queue.encryption:
- if ok {
- elem.Drop()
- }
- case <-device.queue.handshake:
- default:
- return
- }
- }
-
-}
-
func (device *Device) Close() {
- if device.isClosed.Swap(true) {
- return
- }
-
- device.state.starting.Wait()
-
- device.log.Info.Println("Device closing")
- device.state.changing.Set(true)
device.state.Lock()
defer device.state.Unlock()
+ device.ipcMutex.Lock()
+ defer device.ipcMutex.Unlock()
+ if device.isClosed() {
+ return
+ }
+ device.state.state.Store(uint32(deviceStateClosed))
+ device.log.Verbosef("Device closing")
device.tun.device.Close()
- device.BindClose()
-
- device.isUp.Set(false)
-
- close(device.signals.stop)
+ device.downLocked()
+ // Remove peers before closing queues,
+ // because peers assume that queues are active.
device.RemoveAllPeers()
+ // We kept a reference to the encryption and decryption queues,
+ // in case we started any new peers that might write to them.
+ // No new peers are coming; we are done with these queues.
+ device.queue.encryption.wg.Done()
+ device.queue.decryption.wg.Done()
+ device.queue.handshake.wg.Done()
device.state.stopping.Wait()
- device.FlushPacketQueues()
device.rate.limiter.Close()
- device.state.changing.Set(false)
- device.log.Info.Println("Interface closed")
+ device.log.Verbosef("Device closed")
+ close(device.closed)
}
func (device *Device) Wait() chan struct{} {
- return device.signals.stop
+ return device.closed
}
func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
- if device.isClosed.Get() {
+ if !device.isUp() {
return
}
@@ -425,3 +419,118 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
}
device.peers.RUnlock()
}
+
+// closeBindLocked closes the device's net.bind.
+// The caller must hold the net mutex.
+func closeBindLocked(device *Device) error {
+ var err error
+ netc := &device.net
+ if netc.netlinkCancel != nil {
+ netc.netlinkCancel.Cancel()
+ }
+ if netc.bind != nil {
+ err = netc.bind.Close()
+ }
+ netc.stopping.Wait()
+ return err
+}
+
+func (device *Device) Bind() conn.Bind {
+ device.net.Lock()
+ defer device.net.Unlock()
+ return device.net.bind
+}
+
+func (device *Device) BindSetMark(mark uint32) error {
+ device.net.Lock()
+ defer device.net.Unlock()
+
+ // check if modified
+ if device.net.fwmark == mark {
+ return nil
+ }
+
+ // update fwmark on existing bind
+ device.net.fwmark = mark
+ if device.isUp() && device.net.bind != nil {
+ if err := device.net.bind.SetMark(mark); err != nil {
+ return err
+ }
+ }
+
+ // clear cached source addresses
+ device.peers.RLock()
+ for _, peer := range device.peers.keyMap {
+ peer.markEndpointSrcForClearing()
+ }
+ device.peers.RUnlock()
+
+ return nil
+}
+
+func (device *Device) BindUpdate() error {
+ device.net.Lock()
+ defer device.net.Unlock()
+
+ // close existing sockets
+ if err := closeBindLocked(device); err != nil {
+ return err
+ }
+
+ // open new sockets
+ if !device.isUp() {
+ return nil
+ }
+
+ // bind to new port
+ var err error
+ var recvFns []conn.ReceiveFunc
+ netc := &device.net
+
+ recvFns, netc.port, err = netc.bind.Open(netc.port)
+ if err != nil {
+ netc.port = 0
+ return err
+ }
+
+ netc.netlinkCancel, err = device.startRouteListener(netc.bind)
+ if err != nil {
+ netc.bind.Close()
+ netc.port = 0
+ return err
+ }
+
+ // set fwmark
+ if netc.fwmark != 0 {
+ err = netc.bind.SetMark(netc.fwmark)
+ if err != nil {
+ return err
+ }
+ }
+
+ // clear cached source addresses
+ device.peers.RLock()
+ for _, peer := range device.peers.keyMap {
+ peer.markEndpointSrcForClearing()
+ }
+ device.peers.RUnlock()
+
+ // start receiving routines
+ device.net.stopping.Add(len(recvFns))
+ device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
+ device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
+ batchSize := netc.bind.BatchSize()
+ for _, fn := range recvFns {
+ go device.RoutineReceiveIncoming(batchSize, fn)
+ }
+
+ device.log.Verbosef("UDP bind has been updated")
+ return nil
+}
+
+func (device *Device) BindClose() error {
+ device.net.Lock()
+ err := closeBindLocked(device)
+ device.net.Unlock()
+ return err
+}
diff --git a/device/device_test.go b/device/device_test.go
index 14cc605..fff172b 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -1,238 +1,476 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
- "bufio"
"bytes"
- "encoding/binary"
+ "encoding/hex"
+ "fmt"
"io"
- "net"
+ "math/rand"
+ "net/netip"
"os"
- "strings"
+ "runtime"
+ "runtime/pprof"
+ "sync"
+ "sync/atomic"
"testing"
"time"
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/conn/bindtest"
"golang.zx2c4.com/wireguard/tun"
+ "golang.zx2c4.com/wireguard/tun/tuntest"
)
-func TestTwoDevicePing(t *testing.T) {
- // TODO(crawshaw): pick unused ports on localhost
- cfg1 := `private_key=481eb0d8113a4a5da532d2c3e9c14b53c8454b34ab109676f6b58c2245e37b58
-listen_port=53511
-replace_peers=true
-public_key=f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725
-protocol_version=1
-replace_allowed_ips=true
-allowed_ip=1.0.0.2/32
-endpoint=127.0.0.1:53512`
- tun1 := NewChannelTUN()
- dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: "))
- dev1.Up()
- defer dev1.Close()
- if err := dev1.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg1))); err != nil {
- t.Fatal(err)
- }
-
- cfg2 := `private_key=98c7989b1661a0d64fd6af3502000f87716b7c4bbcf00d04fc6073aa7b539768
-listen_port=53512
-replace_peers=true
-public_key=49e80929259cebdda4f322d6d2b1a6fad819d603acd26fd5d845e7a123036427
-protocol_version=1
-replace_allowed_ips=true
-allowed_ip=1.0.0.1/32
-endpoint=127.0.0.1:53511`
- tun2 := NewChannelTUN()
- dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: "))
- dev2.Up()
- defer dev2.Close()
- if err := dev2.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg2))); err != nil {
- t.Fatal(err)
+// uapiCfg returns a string that contains cfg formatted use with IpcSet.
+// cfg is a series of alternating key/value strings.
+// uapiCfg exists because editors and humans like to insert
+// whitespace into configs, which can cause failures, some of which are silent.
+// For example, a leading blank newline causes the remainder
+// of the config to be silently ignored.
+func uapiCfg(cfg ...string) string {
+ if len(cfg)%2 != 0 {
+ panic("odd number of args to uapiReader")
}
-
- t.Run("ping 1.0.0.1", func(t *testing.T) {
- msg2to1 := ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
- tun2.Outbound <- msg2to1
- select {
- case msgRecv := <-tun1.Inbound:
- if !bytes.Equal(msg2to1, msgRecv) {
- t.Error("ping did not transit correctly")
- }
- case <-time.After(300 * time.Millisecond):
- t.Error("ping did not transit")
+ buf := new(bytes.Buffer)
+ for i, s := range cfg {
+ buf.WriteString(s)
+ sep := byte('\n')
+ if i%2 == 0 {
+ sep = '='
}
- })
+ buf.WriteByte(sep)
+ }
+ return buf.String()
+}
- t.Run("ping 1.0.0.2", func(t *testing.T) {
- msg1to2 := ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
- tun1.Outbound <- msg1to2
- select {
- case msgRecv := <-tun2.Inbound:
- if !bytes.Equal(msg1to2, msgRecv) {
- t.Error("return ping did not transit correctly")
- }
- case <-time.After(300 * time.Millisecond):
- t.Error("return ping did not transit")
- }
- })
+// genConfigs generates a pair of configs that connect to each other.
+// The configs use distinct, probably-usable ports.
+func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
+ var key1, key2 NoisePrivateKey
+ _, err := rand.Read(key1[:])
+ if err != nil {
+ tb.Errorf("unable to generate private key random bytes: %v", err)
+ }
+ _, err = rand.Read(key2[:])
+ if err != nil {
+ tb.Errorf("unable to generate private key random bytes: %v", err)
+ }
+ pub1, pub2 := key1.publicKey(), key2.publicKey()
+
+ cfgs[0] = uapiCfg(
+ "private_key", hex.EncodeToString(key1[:]),
+ "listen_port", "0",
+ "replace_peers", "true",
+ "public_key", hex.EncodeToString(pub2[:]),
+ "protocol_version", "1",
+ "replace_allowed_ips", "true",
+ "allowed_ip", "1.0.0.2/32",
+ )
+ endpointCfgs[0] = uapiCfg(
+ "public_key", hex.EncodeToString(pub2[:]),
+ "endpoint", "127.0.0.1:%d",
+ )
+ cfgs[1] = uapiCfg(
+ "private_key", hex.EncodeToString(key2[:]),
+ "listen_port", "0",
+ "replace_peers", "true",
+ "public_key", hex.EncodeToString(pub1[:]),
+ "protocol_version", "1",
+ "replace_allowed_ips", "true",
+ "allowed_ip", "1.0.0.1/32",
+ )
+ endpointCfgs[1] = uapiCfg(
+ "public_key", hex.EncodeToString(pub1[:]),
+ "endpoint", "127.0.0.1:%d",
+ )
+ return
+}
+
+// A testPair is a pair of testPeers.
+type testPair [2]testPeer
+
+// A testPeer is a peer used for testing.
+type testPeer struct {
+ tun *tuntest.ChannelTUN
+ dev *Device
+ ip netip.Addr
}
-func ping(dst, src net.IP) []byte {
- localPort := uint16(1337)
- seq := uint16(0)
+type SendDirection bool
- payload := make([]byte, 4)
- binary.BigEndian.PutUint16(payload[0:], localPort)
- binary.BigEndian.PutUint16(payload[2:], seq)
+const (
+ Ping SendDirection = true
+ Pong SendDirection = false
+)
- return genICMPv4(payload, dst, src)
+func (d SendDirection) String() string {
+ if d == Ping {
+ return "ping"
+ }
+ return "pong"
}
-// checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
-func checksum(buf []byte, initial uint16) uint16 {
- v := uint32(initial)
- for i := 0; i < len(buf)-1; i += 2 {
- v += uint32(binary.BigEndian.Uint16(buf[i:]))
+func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}) {
+ tb.Helper()
+ p0, p1 := pair[0], pair[1]
+ if !ping {
+ // pong is the new ping
+ p0, p1 = p1, p0
}
- if len(buf)%2 == 1 {
- v += uint32(buf[len(buf)-1]) << 8
+ msg := tuntest.Ping(p0.ip, p1.ip)
+ p1.tun.Outbound <- msg
+ timer := time.NewTimer(5 * time.Second)
+ defer timer.Stop()
+ var err error
+ select {
+ case msgRecv := <-p0.tun.Inbound:
+ if !bytes.Equal(msg, msgRecv) {
+ err = fmt.Errorf("%s did not transit correctly", ping)
+ }
+ case <-timer.C:
+ err = fmt.Errorf("%s did not transit", ping)
+ case <-done:
}
- for v > 0xffff {
- v = (v >> 16) + (v & 0xffff)
+ if err != nil {
+ // The error may have occurred because the test is done.
+ select {
+ case <-done:
+ return
+ default:
+ }
+ // Real error.
+ tb.Error(err)
}
- return ^uint16(v)
}
-func genICMPv4(payload []byte, dst, src net.IP) []byte {
- const (
- icmpv4ProtocolNumber = 1
- icmpv4Echo = 8
- icmpv4ChecksumOffset = 2
- icmpv4Size = 8
- ipv4Size = 20
- ipv4TotalLenOffset = 2
- ipv4ChecksumOffset = 10
- ttl = 65
- )
+// genTestPair creates a testPair.
+func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
+ cfg, endpointCfg := genConfigs(tb)
+ var binds [2]conn.Bind
+ if realSocket {
+ binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
+ } else {
+ binds = bindtest.NewChannelBinds()
+ }
+ // Bring up a ChannelTun for each config.
+ for i := range pair {
+ p := &pair[i]
+ p.tun = tuntest.NewChannelTUN()
+ p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
+ level := LogLevelVerbose
+ if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
+ level = LogLevelError
+ }
+ p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)))
+ if err := p.dev.IpcSet(cfg[i]); err != nil {
+ tb.Errorf("failed to configure device %d: %v", i, err)
+ p.dev.Close()
+ continue
+ }
+ if err := p.dev.Up(); err != nil {
+ tb.Errorf("failed to bring up device %d: %v", i, err)
+ p.dev.Close()
+ continue
+ }
+ endpointCfg[i^1] = fmt.Sprintf(endpointCfg[i^1], p.dev.net.port)
+ }
+ for i := range pair {
+ p := &pair[i]
+ if err := p.dev.IpcSet(endpointCfg[i]); err != nil {
+ tb.Errorf("failed to configure device endpoint %d: %v", i, err)
+ p.dev.Close()
+ continue
+ }
+ // The device is ready. Close it when the test completes.
+ tb.Cleanup(p.dev.Close)
+ }
+ return
+}
+
+func TestTwoDevicePing(t *testing.T) {
+ goroutineLeakCheck(t)
+ pair := genTestPair(t, true)
+ t.Run("ping 1.0.0.1", func(t *testing.T) {
+ pair.Send(t, Ping, nil)
+ })
+ t.Run("ping 1.0.0.2", func(t *testing.T) {
+ pair.Send(t, Pong, nil)
+ })
+}
+
+func TestUpDown(t *testing.T) {
+ goroutineLeakCheck(t)
+ const itrials = 50
+ const otrials = 10
+
+ for n := 0; n < otrials; n++ {
+ pair := genTestPair(t, false)
+ for i := range pair {
+ for k := range pair[i].dev.peers.keyMap {
+ pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
+ }
+ }
+ var wg sync.WaitGroup
+ wg.Add(len(pair))
+ for i := range pair {
+ go func(d *Device) {
+ defer wg.Done()
+ for i := 0; i < itrials; i++ {
+ if err := d.Up(); err != nil {
+ t.Errorf("failed up bring up device: %v", err)
+ }
+ time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
+ if err := d.Down(); err != nil {
+ t.Errorf("failed to bring down device: %v", err)
+ }
+ time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
+ }
+ }(pair[i].dev)
+ }
+ wg.Wait()
+ for i := range pair {
+ pair[i].dev.Up()
+ pair[i].dev.Close()
+ }
+ }
+}
- hdr := make([]byte, ipv4Size+icmpv4Size)
+// TestConcurrencySafety does other things concurrently with tunnel use.
+// It is intended to be used with the race detector to catch data races.
+func TestConcurrencySafety(t *testing.T) {
+ pair := genTestPair(t, true)
+ done := make(chan struct{})
- ip := hdr[0:ipv4Size]
- icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
+ const warmupIters = 10
+ var warmup sync.WaitGroup
+ warmup.Add(warmupIters)
+ go func() {
+ // Send data continuously back and forth until we're done.
+ // Note that we may continue to attempt to send data
+ // even after done is closed.
+ i := warmupIters
+ for ping := Ping; ; ping = !ping {
+ pair.Send(t, ping, done)
+ select {
+ case <-done:
+ return
+ default:
+ }
+ if i > 0 {
+ warmup.Done()
+ i--
+ }
+ }
+ }()
+ warmup.Wait()
- // https://tools.ietf.org/html/rfc792
- icmpv4[0] = icmpv4Echo // type
- icmpv4[1] = 0 // code
- chksum := ^checksum(icmpv4, checksum(payload, 0))
- binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
+ applyCfg := func(cfg string) {
+ err := pair[0].dev.IpcSet(cfg)
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
- // https://tools.ietf.org/html/rfc760 section 3.1
- length := uint16(len(hdr) + len(payload))
- ip[0] = (4 << 4) | (ipv4Size / 4)
- binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
- ip[8] = ttl
- ip[9] = icmpv4ProtocolNumber
- copy(ip[12:], src.To4())
- copy(ip[16:], dst.To4())
- chksum = ^checksum(ip[:], 0)
- binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
+ // Change persistent_keepalive_interval concurrently with tunnel use.
+ t.Run("persistentKeepaliveInterval", func(t *testing.T) {
+ var pub NoisePublicKey
+ for key := range pair[0].dev.peers.keyMap {
+ pub = key
+ break
+ }
+ cfg := uapiCfg(
+ "public_key", hex.EncodeToString(pub[:]),
+ "persistent_keepalive_interval", "1",
+ )
+ for i := 0; i < 1000; i++ {
+ applyCfg(cfg)
+ }
+ })
- var v []byte
- v = append(v, hdr...)
- v = append(v, payload...)
- return []byte(v)
-}
+ // Change private keys concurrently with tunnel use.
+ t.Run("privateKey", func(t *testing.T) {
+ bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777")
+ good := uapiCfg("private_key", hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:]))
+ // Set iters to a large number like 1000 to flush out data races quickly.
+ // Don't leave it large. That can cause logical races
+ // in which the handshake is interleaved with key changes
+ // such that the private key appears to be unchanging but
+ // other state gets reset, which can cause handshake failures like
+ // "Received packet with invalid mac1".
+ const iters = 1
+ for i := 0; i < iters; i++ {
+ applyCfg(bad)
+ applyCfg(good)
+ }
+ })
-// TODO(crawshaw): find a reusable home for this. package devicetest?
-type ChannelTUN struct {
- Inbound chan []byte // incoming packets, closed on TUN close
- Outbound chan []byte // outbound packets, blocks forever on TUN close
+ // Perform bind updates and keepalive sends concurrently with tunnel use.
+ t.Run("bindUpdate and keepalive", func(t *testing.T) {
+ const iters = 10
+ for i := 0; i < iters; i++ {
+ for _, peer := range pair {
+ peer.dev.BindUpdate()
+ peer.dev.SendKeepalivesToPeersWithCurrentKeypair()
+ }
+ }
+ })
- closed chan struct{}
- events chan tun.Event
- tun chTun
+ close(done)
}
-func NewChannelTUN() *ChannelTUN {
- c := &ChannelTUN{
- Inbound: make(chan []byte),
- Outbound: make(chan []byte),
- closed: make(chan struct{}),
- events: make(chan tun.Event, 1),
+func BenchmarkLatency(b *testing.B) {
+ pair := genTestPair(b, true)
+
+ // Establish a connection.
+ pair.Send(b, Ping, nil)
+ pair.Send(b, Pong, nil)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ pair.Send(b, Ping, nil)
+ pair.Send(b, Pong, nil)
}
- c.tun.c = c
- c.events <- tun.EventUp
- return c
}
-func (c *ChannelTUN) TUN() tun.Device {
- return &c.tun
-}
+func BenchmarkThroughput(b *testing.B) {
+ pair := genTestPair(b, true)
-type chTun struct {
- c *ChannelTUN
-}
+ // Establish a connection.
+ pair.Send(b, Ping, nil)
+ pair.Send(b, Pong, nil)
-func (t *chTun) File() *os.File { return nil }
+ // Measure how long it takes to receive b.N packets,
+ // starting when we receive the first packet.
+ var recv atomic.Uint64
+ var elapsed time.Duration
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ var start time.Time
+ for {
+ <-pair[0].tun.Inbound
+ new := recv.Add(1)
+ if new == 1 {
+ start = time.Now()
+ }
+ // Careful! Don't change this to else if; b.N can be equal to 1.
+ if new == uint64(b.N) {
+ elapsed = time.Since(start)
+ return
+ }
+ }
+ }()
-func (t *chTun) Read(data []byte, offset int) (int, error) {
- select {
- case <-t.c.closed:
- return 0, io.EOF // TODO(crawshaw): what is the correct error value?
- case msg := <-t.c.Outbound:
- return copy(data[offset:], msg), nil
+ // Send packets as fast as we can until we've received enough.
+ ping := tuntest.Ping(pair[0].ip, pair[1].ip)
+ pingc := pair[1].tun.Outbound
+ var sent uint64
+ for recv.Load() != uint64(b.N) {
+ sent++
+ pingc <- ping
}
+ wg.Wait()
+
+ b.ReportMetric(float64(elapsed)/float64(b.N), "ns/op")
+ b.ReportMetric(1-float64(b.N)/float64(sent), "packet-loss")
}
-// Write is called by the wireguard device to deliver a packet for routing.
-func (t *chTun) Write(data []byte, offset int) (int, error) {
- if offset == -1 {
- close(t.c.closed)
- close(t.c.events)
- return 0, io.EOF
+func BenchmarkUAPIGet(b *testing.B) {
+ pair := genTestPair(b, true)
+ pair.Send(b, Ping, nil)
+ pair.Send(b, Pong, nil)
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ pair[0].dev.IpcGetOperation(io.Discard)
}
- msg := make([]byte, len(data)-offset)
- copy(msg, data[offset:])
- select {
- case <-t.c.closed:
- return 0, io.EOF // TODO(crawshaw): what is the correct error value?
- case t.c.Inbound <- msg:
- return len(data) - offset, nil
+}
+
+func goroutineLeakCheck(t *testing.T) {
+ goroutines := func() (int, []byte) {
+ p := pprof.Lookup("goroutine")
+ b := new(bytes.Buffer)
+ p.WriteTo(b, 1)
+ return p.Count(), b.Bytes()
}
+
+ startGoroutines, startStacks := goroutines()
+ t.Cleanup(func() {
+ if t.Failed() {
+ return
+ }
+ // Give goroutines time to exit, if they need it.
+ for i := 0; i < 10000; i++ {
+ if runtime.NumGoroutine() <= startGoroutines {
+ return
+ }
+ time.Sleep(1 * time.Millisecond)
+ }
+ endGoroutines, endStacks := goroutines()
+ t.Logf("starting stacks:\n%s\n", startStacks)
+ t.Logf("ending stacks:\n%s\n", endStacks)
+ t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
+ })
}
-func (t *chTun) Flush() error { return nil }
-func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
-func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
-func (t *chTun) Events() chan tun.Event { return t.c.events }
-func (t *chTun) Close() error {
- t.Write(nil, -1)
- return nil
+type fakeBindSized struct {
+ size int
}
-func assertNil(t *testing.T, err error) {
- if err != nil {
- t.Fatal(err)
- }
+func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
+ return nil, 0, nil
}
+func (b *fakeBindSized) Close() error { return nil }
+func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
+func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
+func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
+func (b *fakeBindSized) BatchSize() int { return b.size }
-func assertEqual(t *testing.T, a, b []byte) {
- if !bytes.Equal(a, b) {
- t.Fatal(a, "!=", b)
- }
+type fakeTUNDeviceSized struct {
+ size int
}
-func randDevice(t *testing.T) *Device {
- sk, err := newPrivateKey()
- if err != nil {
- t.Fatal(err)
+func (t *fakeTUNDeviceSized) File() *os.File { return nil }
+func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
+ return 0, nil
+}
+func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
+func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
+func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
+func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
+func (t *fakeTUNDeviceSized) Close() error { return nil }
+func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
+
+func TestBatchSize(t *testing.T) {
+ d := Device{}
+
+ d.net.bind = &fakeBindSized{1}
+ d.tun.device = &fakeTUNDeviceSized{1}
+ if want, got := 1, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{1}
+ d.tun.device = &fakeTUNDeviceSized{128}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{128}
+ d.tun.device = &fakeTUNDeviceSized{1}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{128}
+ d.tun.device = &fakeTUNDeviceSized{128}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
}
- tun := newDummyTUN("dummy")
- logger := NewLogger(LogLevelError, "")
- device := NewDevice(tun, logger)
- device.SetPrivateKey(sk)
- return device
}
diff --git a/device/devicestate_string.go b/device/devicestate_string.go
new file mode 100644
index 0000000..6577dd4
--- /dev/null
+++ b/device/devicestate_string.go
@@ -0,0 +1,16 @@
+// Code generated by "stringer -type deviceState -trimprefix=deviceState"; DO NOT EDIT.
+
+package device
+
+import "strconv"
+
+const _deviceState_name = "DownUpClosed"
+
+var _deviceState_index = [...]uint8{0, 4, 6, 12}
+
+func (i deviceState) String() string {
+ if i >= deviceState(len(_deviceState_index)-1) {
+ return "deviceState(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _deviceState_name[_deviceState_index[i]:_deviceState_index[i+1]]
+}
diff --git a/device/endpoint_test.go b/device/endpoint_test.go
index 1896790..93a4998 100644
--- a/device/endpoint_test.go
+++ b/device/endpoint_test.go
@@ -1,53 +1,49 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"math/rand"
- "net"
+ "net/netip"
)
type DummyEndpoint struct {
- src [16]byte
- dst [16]byte
+ src, dst netip.Addr
}
func CreateDummyEndpoint() (*DummyEndpoint, error) {
- var end DummyEndpoint
- if _, err := rand.Read(end.src[:]); err != nil {
+ var src, dst [16]byte
+ if _, err := rand.Read(src[:]); err != nil {
return nil, err
}
- _, err := rand.Read(end.dst[:])
- return &end, err
+ _, err := rand.Read(dst[:])
+ return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err
}
func (e *DummyEndpoint) ClearSrc() {}
func (e *DummyEndpoint) SrcToString() string {
- var addr net.UDPAddr
- addr.IP = e.SrcIP()
- addr.Port = 1000
- return addr.String()
+ return netip.AddrPortFrom(e.SrcIP(), 1000).String()
}
func (e *DummyEndpoint) DstToString() string {
- var addr net.UDPAddr
- addr.IP = e.DstIP()
- addr.Port = 1000
- return addr.String()
+ return netip.AddrPortFrom(e.DstIP(), 1000).String()
}
-func (e *DummyEndpoint) SrcToBytes() []byte {
- return e.src[:]
+func (e *DummyEndpoint) DstToBytes() []byte {
+ out := e.DstIP().AsSlice()
+ out = append(out, byte(1000&0xff))
+ out = append(out, byte((1000>>8)&0xff))
+ return out
}
-func (e *DummyEndpoint) DstIP() net.IP {
- return e.dst[:]
+func (e *DummyEndpoint) DstIP() netip.Addr {
+ return e.dst
}
-func (e *DummyEndpoint) SrcIP() net.IP {
- return e.src[:]
+func (e *DummyEndpoint) SrcIP() netip.Addr {
+ return e.src
}
diff --git a/device/indextable.go b/device/indextable.go
index 4cba970..00ade7d 100644
--- a/device/indextable.go
+++ b/device/indextable.go
@@ -1,14 +1,14 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"crypto/rand"
+ "encoding/binary"
"sync"
- "unsafe"
)
type IndexTableEntry struct {
@@ -25,7 +25,8 @@ type IndexTable struct {
func randUint32() (uint32, error) {
var integer [4]byte
_, err := rand.Read(integer[:])
- return *(*uint32)(unsafe.Pointer(&integer[0])), err
+ // Arbitrary endianness; both are intrinsified by the Go compiler.
+ return binary.LittleEndian.Uint32(integer[:]), err
}
func (table *IndexTable) Init() {
diff --git a/device/ip.go b/device/ip.go
index 9d4fb74..eaf2363 100644
--- a/device/ip.go
+++ b/device/ip.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/kdf_test.go b/device/kdf_test.go
index cb8dbab..f9c76d6 100644
--- a/device/kdf_test.go
+++ b/device/kdf_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -20,7 +20,7 @@ type KDFTest struct {
t2 string
}
-func assertEquals(t *testing.T, a string, b string) {
+func assertEquals(t *testing.T, a, b string) {
if a != b {
t.Fatal("expected", a, "=", b)
}
diff --git a/device/keypair.go b/device/keypair.go
index 9c78fa9..e3540d7 100644
--- a/device/keypair.go
+++ b/device/keypair.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -8,6 +8,7 @@ package device
import (
"crypto/cipher"
"sync"
+ "sync/atomic"
"time"
"golang.zx2c4.com/wireguard/replay"
@@ -21,10 +22,10 @@ import (
*/
type Keypair struct {
- sendNonce uint64
+ sendNonce atomic.Uint64
send cipher.AEAD
receive cipher.AEAD
- replayFilter replay.ReplayFilter
+ replayFilter replay.Filter
isInitiator bool
created time.Time
localIndex uint32
@@ -35,7 +36,7 @@ type Keypairs struct {
sync.RWMutex
current *Keypair
previous *Keypair
- next *Keypair
+ next atomic.Pointer[Keypair]
}
func (kp *Keypairs) Current() *Keypair {
diff --git a/device/logger.go b/device/logger.go
index 7c8b704..22b0df0 100644
--- a/device/logger.go
+++ b/device/logger.go
@@ -1,59 +1,48 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
- "io"
- "io/ioutil"
"log"
"os"
)
+// A Logger provides logging for a Device.
+// The functions are Printf-style functions.
+// They must be safe for concurrent use.
+// They do not require a trailing newline in the format.
+// If nil, that level of logging will be silent.
+type Logger struct {
+ Verbosef func(format string, args ...any)
+ Errorf func(format string, args ...any)
+}
+
+// Log levels for use with NewLogger.
const (
LogLevelSilent = iota
LogLevelError
- LogLevelInfo
- LogLevelDebug
+ LogLevelVerbose
)
-type Logger struct {
- Debug *log.Logger
- Info *log.Logger
- Error *log.Logger
-}
+// Function for use in Logger for discarding logged lines.
+func DiscardLogf(format string, args ...any) {}
+// NewLogger constructs a Logger that writes to stdout.
+// It logs at the specified log level and above.
+// It decorates log lines with the log level, date, time, and prepend.
func NewLogger(level int, prepend string) *Logger {
- output := os.Stdout
- logger := new(Logger)
-
- logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) {
- if level >= LogLevelDebug {
- return output, output, output
- }
- if level >= LogLevelInfo {
- return output, output, ioutil.Discard
- }
- if level >= LogLevelError {
- return output, ioutil.Discard, ioutil.Discard
- }
- return ioutil.Discard, ioutil.Discard, ioutil.Discard
- }()
-
- logger.Debug = log.New(logDebug,
- "DEBUG: "+prepend,
- log.Ldate|log.Ltime,
- )
-
- logger.Info = log.New(logInfo,
- "INFO: "+prepend,
- log.Ldate|log.Ltime,
- )
- logger.Error = log.New(logErr,
- "ERROR: "+prepend,
- log.Ldate|log.Ltime,
- )
+ logger := &Logger{DiscardLogf, DiscardLogf}
+ logf := func(prefix string) func(string, ...any) {
+ return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf
+ }
+ if level >= LogLevelVerbose {
+ logger.Verbosef = logf("DEBUG")
+ }
+ if level >= LogLevelError {
+ logger.Errorf = logf("ERROR")
+ }
return logger
}
diff --git a/device/mark_default.go b/device/mark_default.go
deleted file mode 100644
index 7de2524..0000000
--- a/device/mark_default.go
+++ /dev/null
@@ -1,12 +0,0 @@
-// +build !linux,!openbsd,!freebsd
-
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-func (bind *nativeBind) SetMark(mark uint32) error {
- return nil
-}
diff --git a/device/misc.go b/device/misc.go
deleted file mode 100644
index a38d1c1..0000000
--- a/device/misc.go
+++ /dev/null
@@ -1,48 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import (
- "sync/atomic"
-)
-
-/* Atomic Boolean */
-
-const (
- AtomicFalse = int32(iota)
- AtomicTrue
-)
-
-type AtomicBool struct {
- int32
-}
-
-func (a *AtomicBool) Get() bool {
- return atomic.LoadInt32(&a.int32) == AtomicTrue
-}
-
-func (a *AtomicBool) Swap(val bool) bool {
- flag := AtomicFalse
- if val {
- flag = AtomicTrue
- }
- return atomic.SwapInt32(&a.int32, flag) == AtomicTrue
-}
-
-func (a *AtomicBool) Set(val bool) {
- flag := AtomicFalse
- if val {
- flag = AtomicTrue
- }
- atomic.StoreInt32(&a.int32, flag)
-}
-
-func min(a, b uint) uint {
- if a > b {
- return b
- }
- return a
-}
diff --git a/device/mobilequirks.go b/device/mobilequirks.go
new file mode 100644
index 0000000..0a0080e
--- /dev/null
+++ b/device/mobilequirks.go
@@ -0,0 +1,19 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+// DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created,
+// though it will try to deal with it, and race maybe, if called after.
+func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
+ device.net.brokenRoaming = true
+ device.peers.RLock()
+ for _, peer := range device.peers.keyMap {
+ peer.endpoint.Lock()
+ peer.endpoint.disableRoaming = peer.endpoint.val != nil
+ peer.endpoint.Unlock()
+ }
+ device.peers.RUnlock()
+}
diff --git a/device/noise-helpers.go b/device/noise-helpers.go
index f5e4b4b..c2f356b 100644
--- a/device/noise-helpers.go
+++ b/device/noise-helpers.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -9,6 +9,7 @@ import (
"crypto/hmac"
"crypto/rand"
"crypto/subtle"
+ "errors"
"hash"
"golang.org/x/crypto/blake2s"
@@ -94,9 +95,14 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
return
}
-func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
+var errInvalidPublicKey = errors.New("invalid public key")
+
+func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
apk := (*[NoisePublicKeySize]byte)(&pk)
ask := (*[NoisePrivateKeySize]byte)(sk)
curve25519.ScalarMult(&ss, ask, apk)
- return ss
+ if isZero(ss[:]) {
+ return ss, errInvalidPublicKey
+ }
+ return ss, nil
}
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index 88c6aae..e8f6145 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -1,29 +1,50 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"errors"
+ "fmt"
"sync"
"time"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305"
+
"golang.zx2c4.com/wireguard/tai64n"
)
+type handshakeState int
+
const (
- HandshakeZeroed = iota
- HandshakeInitiationCreated
- HandshakeInitiationConsumed
- HandshakeResponseCreated
- HandshakeResponseConsumed
+ handshakeZeroed = handshakeState(iota)
+ handshakeInitiationCreated
+ handshakeInitiationConsumed
+ handshakeResponseCreated
+ handshakeResponseConsumed
)
+func (hs handshakeState) String() string {
+ switch hs {
+ case handshakeZeroed:
+ return "handshakeZeroed"
+ case handshakeInitiationCreated:
+ return "handshakeInitiationCreated"
+ case handshakeInitiationConsumed:
+ return "handshakeInitiationConsumed"
+ case handshakeResponseCreated:
+ return "handshakeResponseCreated"
+ case handshakeResponseConsumed:
+ return "handshakeResponseConsumed"
+ default:
+ return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
+ }
+}
+
const (
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
@@ -39,13 +60,13 @@ const (
)
const (
- MessageInitiationSize = 148 // size of handshake initation message
+ MessageInitiationSize = 148 // size of handshake initiation message
MessageResponseSize = 92 // size of response message
MessageCookieReplySize = 64 // size of cookie reply message
- MessageTransportHeaderSize = 16 // size of data preceeding content in transport message
+ MessageTransportHeaderSize = 16 // size of data preceding content in transport message
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
MessageKeepaliveSize = MessageTransportSize // size of keepalive
- MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message
+ MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message
)
const (
@@ -95,11 +116,11 @@ type MessageCookieReply struct {
}
type Handshake struct {
- state int
+ state handshakeState
mutex sync.RWMutex
hash [blake2s.Size]byte // hash value
chainKey [blake2s.Size]byte // chain key
- presharedKey NoiseSymmetricKey // psk
+ presharedKey NoisePresharedKey // psk
localEphemeral NoisePrivateKey // ephemeral secret key
localIndex uint32 // used to clear hash-table
remoteIndex uint32 // index for sending
@@ -117,11 +138,11 @@ var (
ZeroNonce [chacha20poly1305.NonceSize]byte
)
-func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) {
+func mixKey(dst, c *[blake2s.Size]byte, data []byte) {
KDF1(dst, c[:], data)
}
-func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) {
+func mixHash(dst, h *[blake2s.Size]byte, data []byte) {
hash, _ := blake2s.New256(nil)
hash.Write(h[:])
hash.Write(data)
@@ -135,7 +156,7 @@ func (h *Handshake) Clear() {
setZero(h.chainKey[:])
setZero(h.hash[:])
h.localIndex = 0
- h.state = HandshakeZeroed
+ h.state = handshakeZeroed
}
func (h *Handshake) mixHash(data []byte) {
@@ -154,7 +175,6 @@ func init() {
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
-
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -162,12 +182,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
- if isZero(handshake.precomputedStaticStatic[:]) {
- return nil, errors.New("static shared secret is zero")
- }
-
// create ephemeral key
-
var err error
handshake.hash = InitialHash
handshake.chainKey = InitialChainKey
@@ -176,59 +191,56 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
return nil, err
}
- // assign index
-
- device.indexTable.Delete(handshake.localIndex)
- handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
-
- if err != nil {
- return nil, err
- }
-
handshake.mixHash(handshake.remoteStatic[:])
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
- Sender: handshake.localIndex,
}
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
// encrypt static key
-
- func() {
- var key [chacha20poly1305.KeySize]byte
- ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
- KDF2(
- &handshake.chainKey,
- &key,
- handshake.chainKey[:],
- ss[:],
- )
- aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
- }()
+ ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ if err != nil {
+ return nil, err
+ }
+ var key [chacha20poly1305.KeySize]byte
+ KDF2(
+ &handshake.chainKey,
+ &key,
+ handshake.chainKey[:],
+ ss[:],
+ )
+ aead, _ := chacha20poly1305.New(key[:])
+ aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
handshake.mixHash(msg.Static[:])
// encrypt timestamp
-
+ if isZero(handshake.precomputedStaticStatic[:]) {
+ return nil, errInvalidPublicKey
+ }
+ KDF2(
+ &handshake.chainKey,
+ &key,
+ handshake.chainKey[:],
+ handshake.precomputedStaticStatic[:],
+ )
timestamp := tai64n.Now()
- func() {
- var key [chacha20poly1305.KeySize]byte
- KDF2(
- &handshake.chainKey,
- &key,
- handshake.chainKey[:],
- handshake.precomputedStaticStatic[:],
- )
- aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
- }()
+ aead, _ = chacha20poly1305.New(key[:])
+ aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
+
+ // assign index
+ device.indexTable.Delete(handshake.localIndex)
+ msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
+ if err != nil {
+ return nil, err
+ }
+ handshake.localIndex = msg.Sender
handshake.mixHash(msg.Timestamp[:])
- handshake.state = HandshakeInitiationCreated
+ handshake.state = handshakeInitiationCreated
return &msg, nil
}
@@ -250,16 +262,15 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key
-
- var err error
var peerPK NoisePublicKey
- func() {
- var key [chacha20poly1305.KeySize]byte
- ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
- KDF2(&chainKey, &key, chainKey[:], ss[:])
- aead, _ := chacha20poly1305.New(key[:])
- _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
- }()
+ var key [chacha20poly1305.KeySize]byte
+ ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ if err != nil {
+ return nil
+ }
+ KDF2(&chainKey, &key, chainKey[:], ss[:])
+ aead, _ := chacha20poly1305.New(key[:])
+ _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
if err != nil {
return nil
}
@@ -268,28 +279,29 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
// lookup peer
peer := device.LookupPeer(peerPK)
- if peer == nil {
+ if peer == nil || !peer.isRunning.Load() {
return nil
}
handshake := &peer.handshake
- if isZero(handshake.precomputedStaticStatic[:]) {
- return nil
- }
// verify identity
var timestamp tai64n.Timestamp
- var key [chacha20poly1305.KeySize]byte
handshake.mutex.RLock()
+
+ if isZero(handshake.precomputedStaticStatic[:]) {
+ handshake.mutex.RUnlock()
+ return nil
+ }
KDF2(
&chainKey,
&key,
chainKey[:],
handshake.precomputedStaticStatic[:],
)
- aead, _ := chacha20poly1305.New(key[:])
+ aead, _ = chacha20poly1305.New(key[:])
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
if err != nil {
handshake.mutex.RUnlock()
@@ -299,11 +311,15 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
// protect against replay & flood
- var ok bool
- ok = timestamp.After(handshake.lastTimestamp)
- ok = ok && time.Since(handshake.lastInitiationConsumption) > HandshakeInitationRate
+ replay := !timestamp.After(handshake.lastTimestamp)
+ flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate
handshake.mutex.RUnlock()
- if !ok {
+ if replay {
+ device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp)
+ return nil
+ }
+ if flood {
+ device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer)
return nil
}
@@ -322,7 +338,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
if now.After(handshake.lastInitiationConsumption) {
handshake.lastInitiationConsumption = now
}
- handshake.state = HandshakeInitiationConsumed
+ handshake.state = handshakeInitiationConsumed
handshake.mutex.Unlock()
@@ -337,7 +353,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
- if handshake.state != HandshakeInitiationConsumed {
+ if handshake.state != handshakeInitiationConsumed {
return nil, errors.New("handshake initiation must be consumed first")
}
@@ -365,12 +381,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(msg.Ephemeral[:])
handshake.mixKey(msg.Ephemeral[:])
- func() {
- ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
- handshake.mixKey(ss[:])
- ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
- handshake.mixKey(ss[:])
- }()
+ ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
+ if err != nil {
+ return nil, err
+ }
+ handshake.mixKey(ss[:])
+ ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ if err != nil {
+ return nil, err
+ }
+ handshake.mixKey(ss[:])
// add preshared key
@@ -387,13 +407,11 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(tau[:])
- func() {
- aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
- handshake.mixHash(msg.Empty[:])
- }()
+ aead, _ := chacha20poly1305.New(key[:])
+ aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
+ handshake.mixHash(msg.Empty[:])
- handshake.state = HandshakeResponseCreated
+ handshake.state = handshakeResponseCreated
return &msg, nil
}
@@ -417,13 +435,12 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
)
ok := func() bool {
-
// lock handshake state
handshake.mutex.RLock()
defer handshake.mutex.RUnlock()
- if handshake.state != HandshakeInitiationCreated {
+ if handshake.state != handshakeInitiationCreated {
return false
}
@@ -437,17 +454,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
- func() {
- ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
- mixKey(&chainKey, &chainKey, ss[:])
- setZero(ss[:])
- }()
+ ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
+ if err != nil {
+ return false
+ }
+ mixKey(&chainKey, &chainKey, ss[:])
+ setZero(ss[:])
- func() {
- ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
- mixKey(&chainKey, &chainKey, ss[:])
- setZero(ss[:])
- }()
+ ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ if err != nil {
+ return false
+ }
+ mixKey(&chainKey, &chainKey, ss[:])
+ setZero(ss[:])
// add preshared key (psk)
@@ -465,7 +484,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
// authenticate transcript
aead, _ := chacha20poly1305.New(key[:])
- _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
+ _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil {
return false
}
@@ -484,7 +503,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
handshake.hash = hash
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
- handshake.state = HandshakeResponseConsumed
+ handshake.state = handshakeResponseConsumed
handshake.mutex.Unlock()
@@ -509,7 +528,7 @@ func (peer *Peer) BeginSymmetricSession() error {
var sendKey [chacha20poly1305.KeySize]byte
var recvKey [chacha20poly1305.KeySize]byte
- if handshake.state == HandshakeResponseConsumed {
+ if handshake.state == handshakeResponseConsumed {
KDF2(
&sendKey,
&recvKey,
@@ -517,7 +536,7 @@ func (peer *Peer) BeginSymmetricSession() error {
nil,
)
isInitiator = true
- } else if handshake.state == HandshakeResponseCreated {
+ } else if handshake.state == handshakeResponseCreated {
KDF2(
&recvKey,
&sendKey,
@@ -526,7 +545,7 @@ func (peer *Peer) BeginSymmetricSession() error {
)
isInitiator = false
} else {
- return errors.New("invalid state for keypair derivation")
+ return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
}
// zero handshake
@@ -534,7 +553,7 @@ func (peer *Peer) BeginSymmetricSession() error {
setZero(handshake.chainKey[:])
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
setZero(handshake.localEphemeral[:])
- peer.handshake.state = HandshakeZeroed
+ peer.handshake.state = handshakeZeroed
// create AEAD instances
@@ -546,8 +565,7 @@ func (peer *Peer) BeginSymmetricSession() error {
setZero(recvKey[:])
keypair.created = time.Now()
- keypair.sendNonce = 0
- keypair.replayFilter.Init()
+ keypair.replayFilter.Reset()
keypair.isInitiator = isInitiator
keypair.localIndex = peer.handshake.localIndex
keypair.remoteIndex = peer.handshake.remoteIndex
@@ -564,12 +582,12 @@ func (peer *Peer) BeginSymmetricSession() error {
defer keypairs.Unlock()
previous := keypairs.previous
- next := keypairs.next
+ next := keypairs.next.Load()
current := keypairs.current
if isInitiator {
if next != nil {
- keypairs.next = nil
+ keypairs.next.Store(nil)
keypairs.previous = next
device.DeleteKeypair(current)
} else {
@@ -578,7 +596,7 @@ func (peer *Peer) BeginSymmetricSession() error {
device.DeleteKeypair(previous)
keypairs.current = keypair
} else {
- keypairs.next = keypair
+ keypairs.next.Store(keypair)
device.DeleteKeypair(next)
keypairs.previous = nil
device.DeleteKeypair(previous)
@@ -589,18 +607,19 @@ func (peer *Peer) BeginSymmetricSession() error {
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs
- if keypairs.next != receivedKeypair {
+
+ if keypairs.next.Load() != receivedKeypair {
return false
}
keypairs.Lock()
defer keypairs.Unlock()
- if keypairs.next != receivedKeypair {
+ if keypairs.next.Load() != receivedKeypair {
return false
}
old := keypairs.previous
keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old)
- keypairs.current = keypairs.next
- keypairs.next = nil
+ keypairs.current = keypairs.next.Load()
+ keypairs.next.Store(nil)
return true
}
diff --git a/device/noise-types.go b/device/noise-types.go
index 6b1f16f..e850359 100644
--- a/device/noise-types.go
+++ b/device/noise-types.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -9,19 +9,18 @@ import (
"crypto/subtle"
"encoding/hex"
"errors"
-
- "golang.org/x/crypto/chacha20poly1305"
)
const (
- NoisePublicKeySize = 32
- NoisePrivateKeySize = 32
+ NoisePublicKeySize = 32
+ NoisePrivateKeySize = 32
+ NoisePresharedKeySize = 32
)
type (
NoisePublicKey [NoisePublicKeySize]byte
NoisePrivateKey [NoisePrivateKeySize]byte
- NoiseSymmetricKey [chacha20poly1305.KeySize]byte
+ NoisePresharedKey [NoisePresharedKeySize]byte
NoiseNonce uint64 // padded to 12-bytes
)
@@ -52,18 +51,19 @@ func (key *NoisePrivateKey) FromHex(src string) (err error) {
return
}
-func (key NoisePrivateKey) ToHex() string {
- return hex.EncodeToString(key[:])
+func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) {
+ err = loadExactHex(key[:], src)
+ if key.IsZero() {
+ return
+ }
+ key.clamp()
+ return
}
func (key *NoisePublicKey) FromHex(src string) error {
return loadExactHex(key[:], src)
}
-func (key NoisePublicKey) ToHex() string {
- return hex.EncodeToString(key[:])
-}
-
func (key NoisePublicKey) IsZero() bool {
var zero NoisePublicKey
return key.Equals(zero)
@@ -73,10 +73,6 @@ func (key NoisePublicKey) Equals(tar NoisePublicKey) bool {
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
}
-func (key *NoiseSymmetricKey) FromHex(src string) error {
+func (key *NoisePresharedKey) FromHex(src string) error {
return loadExactHex(key[:], src)
}
-
-func (key NoiseSymmetricKey) ToHex() string {
- return hex.EncodeToString(key[:])
-}
diff --git a/device/noise_test.go b/device/noise_test.go
index 6ba3f2e..2dd5324 100644
--- a/device/noise_test.go
+++ b/device/noise_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -9,6 +9,9 @@ import (
"bytes"
"encoding/binary"
"testing"
+
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/tun/tuntest"
)
func TestCurveWrappers(t *testing.T) {
@@ -21,14 +24,38 @@ func TestCurveWrappers(t *testing.T) {
pk1 := sk1.publicKey()
pk2 := sk2.publicKey()
- ss1 := sk1.sharedSecret(pk2)
- ss2 := sk2.sharedSecret(pk1)
+ ss1, err1 := sk1.sharedSecret(pk2)
+ ss2, err2 := sk2.sharedSecret(pk1)
- if ss1 != ss2 {
+ if ss1 != ss2 || err1 != nil || err2 != nil {
t.Fatal("Failed to compute shared secet")
}
}
+func randDevice(t *testing.T) *Device {
+ sk, err := newPrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+ tun := tuntest.NewChannelTUN()
+ logger := NewLogger(LogLevelError, "")
+ device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger)
+ device.SetPrivateKey(sk)
+ return device
+}
+
+func assertNil(t *testing.T, err error) {
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func assertEqual(t *testing.T, a, b []byte) {
+ if !bytes.Equal(a, b) {
+ t.Fatal(a, "!=", b)
+ }
+}
+
func TestNoiseHandshake(t *testing.T) {
dev1 := randDevice(t)
dev2 := randDevice(t)
@@ -36,8 +63,16 @@ func TestNoiseHandshake(t *testing.T) {
defer dev1.Close()
defer dev2.Close()
- peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
- peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
+ peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
+ if err != nil {
+ t.Fatal(err)
+ }
+ peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
+ if err != nil {
+ t.Fatal(err)
+ }
+ peer1.Start()
+ peer2.Start()
assertEqual(
t,
@@ -113,7 +148,7 @@ func TestNoiseHandshake(t *testing.T) {
t.Fatal("failed to derive keypair for peer 2", err)
}
- key1 := peer1.keypairs.next
+ key1 := peer1.keypairs.next.Load()
key2 := peer2.keypairs.current
// encrypting / decryption test
diff --git a/device/peer.go b/device/peer.go
index 91d975a..47a2f14 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -1,37 +1,35 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
- "encoding/base64"
+ "container/list"
"errors"
- "fmt"
"sync"
"sync/atomic"
"time"
-)
-const (
- PeerRoutineNumber = 3
+ "golang.zx2c4.com/wireguard/conn"
)
type Peer struct {
- isRunning AtomicBool
- sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
- keypairs Keypairs
- handshake Handshake
- device *Device
- endpoint Endpoint
- persistentKeepaliveInterval uint16
-
- // This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly
- stats struct {
- txBytes uint64 // bytes send to peer (endpoint)
- rxBytes uint64 // bytes received from peer
- lastHandshakeNano int64 // nano seconds since epoch
+ isRunning atomic.Bool
+ keypairs Keypairs
+ handshake Handshake
+ device *Device
+ stopping sync.WaitGroup // routines pending stop
+ txBytes atomic.Uint64 // bytes send to peer (endpoint)
+ rxBytes atomic.Uint64 // bytes received from peer
+ lastHandshakeNano atomic.Int64 // nano seconds since epoch
+
+ endpoint struct {
+ sync.Mutex
+ val conn.Endpoint
+ clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
+ disableRoaming bool
}
timers struct {
@@ -40,40 +38,32 @@ type Peer struct {
newHandshake *Timer
zeroKeyMaterial *Timer
persistentKeepalive *Timer
- handshakeAttempts uint32
- needAnotherKeepalive AtomicBool
- sentLastMinuteHandshake AtomicBool
+ handshakeAttempts atomic.Uint32
+ needAnotherKeepalive atomic.Bool
+ sentLastMinuteHandshake atomic.Bool
}
- signals struct {
- newKeypairArrived chan struct{}
- flushNonceQueue chan struct{}
+ state struct {
+ sync.Mutex // protects against concurrent Start/Stop
}
queue struct {
- nonce chan *QueueOutboundElement // nonce / pre-handshake queue
- outbound chan *QueueOutboundElement // sequential ordering of work
- inbound chan *QueueInboundElement // sequential ordering of work
- packetInNonceQueueIsAwaitingKey AtomicBool
- }
-
- routines struct {
- sync.Mutex // held when stopping / starting routines
- starting sync.WaitGroup // routines pending start
- stopping sync.WaitGroup // routines pending stop
- stop chan struct{} // size 0, stop all go routines in peer
+ staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
+ outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
+ inbound *autodrainingInboundQueue // sequential ordering of tun writing
}
- cookieGenerator CookieGenerator
+ cookieGenerator CookieGenerator
+ trieEntries list.List
+ persistentKeepaliveInterval atomic.Uint32
}
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
- if device.isClosed.Get() {
+ if device.isClosed() {
return nil, errors.New("device closed")
}
// lock resources
-
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -81,136 +71,144 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
defer device.peers.Unlock()
// check if over limit
-
if len(device.peers.keyMap) >= MaxPeers {
return nil, errors.New("too many peers")
}
// create peer
-
peer := new(Peer)
- peer.Lock()
- defer peer.Unlock()
peer.cookieGenerator.Init(pk)
peer.device = device
- peer.isRunning.Set(false)
+ peer.queue.outbound = newAutodrainingOutboundQueue(device)
+ peer.queue.inbound = newAutodrainingInboundQueue(device)
+ peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
// map public key
-
_, ok := device.peers.keyMap[pk]
if ok {
return nil, errors.New("adding existing peer")
}
// pre-compute DH
-
handshake := &peer.handshake
handshake.mutex.Lock()
- handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
- ssIsZero := isZero(handshake.precomputedStaticStatic[:])
+ handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
handshake.remoteStatic = pk
handshake.mutex.Unlock()
// reset endpoint
+ peer.endpoint.Lock()
+ peer.endpoint.val = nil
+ peer.endpoint.disableRoaming = false
+ peer.endpoint.clearSrcOnTx = false
+ peer.endpoint.Unlock()
- peer.endpoint = nil
-
- // conditionally add
-
- if !ssIsZero {
- device.peers.keyMap[pk] = peer
- } else {
- return nil, nil
- }
-
- // start peer
+ // init timers
+ peer.timersInit()
- if peer.device.isUp.Get() {
- peer.Start()
- }
+ // add
+ device.peers.keyMap[pk] = peer
return peer, nil
}
-func (peer *Peer) SendBuffer(buffer []byte) error {
+func (peer *Peer) SendBuffers(buffers [][]byte) error {
peer.device.net.RLock()
defer peer.device.net.RUnlock()
- if peer.device.net.bind == nil {
- return errors.New("no bind")
+ if peer.device.isClosed() {
+ return nil
}
- peer.RLock()
- defer peer.RUnlock()
-
- if peer.endpoint == nil {
+ peer.endpoint.Lock()
+ endpoint := peer.endpoint.val
+ if endpoint == nil {
+ peer.endpoint.Unlock()
return errors.New("no known endpoint for peer")
}
+ if peer.endpoint.clearSrcOnTx {
+ endpoint.ClearSrc()
+ peer.endpoint.clearSrcOnTx = false
+ }
+ peer.endpoint.Unlock()
- err := peer.device.net.bind.Send(buffer, peer.endpoint)
+ err := peer.device.net.bind.Send(buffers, endpoint)
if err == nil {
- atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer)))
+ var totalLen uint64
+ for _, b := range buffers {
+ totalLen += uint64(len(b))
+ }
+ peer.txBytes.Add(totalLen)
}
return err
}
func (peer *Peer) String() string {
- base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
- abbreviatedKey := "invalid"
- if len(base64Key) == 44 {
- abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43]
+ // The awful goo that follows is identical to:
+ //
+ // base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
+ // abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43]
+ // return fmt.Sprintf("peer(%s)", abbreviatedKey)
+ //
+ // except that it is considerably more efficient.
+ src := peer.handshake.remoteStatic
+ b64 := func(input byte) byte {
+ return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3)
}
- return fmt.Sprintf("peer(%s)", abbreviatedKey)
+ b := []byte("peer(____…____)")
+ const first = len("peer(")
+ const second = len("peer(____…")
+ b[first+0] = b64((src[0] >> 2) & 63)
+ b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63)
+ b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63)
+ b[first+3] = b64(src[2] & 63)
+ b[second+0] = b64(src[29] & 63)
+ b[second+1] = b64((src[30] >> 2) & 63)
+ b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63)
+ b[second+3] = b64((src[31] << 2) & 63)
+ return string(b)
}
func (peer *Peer) Start() {
-
// should never start a peer on a closed device
-
- if peer.device.isClosed.Get() {
+ if peer.device.isClosed() {
return
}
// prevent simultaneous start/stop operations
+ peer.state.Lock()
+ defer peer.state.Unlock()
- peer.routines.Lock()
- defer peer.routines.Unlock()
-
- if peer.isRunning.Get() {
+ if peer.isRunning.Load() {
return
}
device := peer.device
- device.log.Debug.Println(peer, "- Starting...")
+ device.log.Verbosef("%v - Starting", peer)
// reset routine state
+ peer.stopping.Wait()
+ peer.stopping.Add(2)
- peer.routines.starting.Wait()
- peer.routines.stopping.Wait()
- peer.routines.stop = make(chan struct{})
- peer.routines.starting.Add(PeerRoutineNumber)
- peer.routines.stopping.Add(PeerRoutineNumber)
-
- // prepare queues
+ peer.handshake.mutex.Lock()
+ peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
+ peer.handshake.mutex.Unlock()
- peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
- peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
- peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
+ peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes
- peer.timersInit()
- peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
- peer.signals.newKeypairArrived = make(chan struct{}, 1)
- peer.signals.flushNonceQueue = make(chan struct{}, 1)
+ peer.timersStart()
- // wait for routines to start
+ device.flushInboundQueue(peer.queue.inbound)
+ device.flushOutboundQueue(peer.queue.outbound)
- go peer.RoutineNonce()
- go peer.RoutineSequentialSender()
- go peer.RoutineSequentialReceiver()
+ // Use the device batch size, not the bind batch size, as the device size is
+ // the size of the batch pools.
+ batchSize := peer.device.BatchSize()
+ go peer.RoutineSequentialSender(batchSize)
+ go peer.RoutineSequentialReceiver(batchSize)
- peer.routines.starting.Wait()
- peer.isRunning.Set(true)
+ peer.isRunning.Store(true)
}
func (peer *Peer) ZeroAndFlushAll() {
@@ -222,10 +220,10 @@ func (peer *Peer) ZeroAndFlushAll() {
keypairs.Lock()
device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current)
- device.DeleteKeypair(keypairs.next)
+ device.DeleteKeypair(keypairs.next.Load())
keypairs.previous = nil
keypairs.current = nil
- keypairs.next = nil
+ keypairs.next.Store(nil)
keypairs.Unlock()
// clear handshake state
@@ -236,7 +234,7 @@ func (peer *Peer) ZeroAndFlushAll() {
handshake.Clear()
handshake.mutex.Unlock()
- peer.FlushNonceQueue()
+ peer.FlushStagedPackets()
}
func (peer *Peer) ExpireCurrentKeypairs() {
@@ -244,58 +242,55 @@ func (peer *Peer) ExpireCurrentKeypairs() {
handshake.mutex.Lock()
peer.device.indexTable.Delete(handshake.localIndex)
handshake.Clear()
- handshake.mutex.Unlock()
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
+ handshake.mutex.Unlock()
keypairs := &peer.keypairs
keypairs.Lock()
if keypairs.current != nil {
- keypairs.current.sendNonce = RejectAfterMessages
+ keypairs.current.sendNonce.Store(RejectAfterMessages)
}
- if keypairs.next != nil {
- keypairs.next.sendNonce = RejectAfterMessages
+ if next := keypairs.next.Load(); next != nil {
+ next.sendNonce.Store(RejectAfterMessages)
}
keypairs.Unlock()
}
func (peer *Peer) Stop() {
-
- // prevent simultaneous start/stop operations
+ peer.state.Lock()
+ defer peer.state.Unlock()
if !peer.isRunning.Swap(false) {
return
}
- peer.routines.starting.Wait()
-
- peer.routines.Lock()
- defer peer.routines.Unlock()
-
- peer.device.log.Debug.Println(peer, "- Stopping...")
+ peer.device.log.Verbosef("%v - Stopping", peer)
peer.timersStop()
-
- // stop & wait for ongoing peer routines
-
- close(peer.routines.stop)
- peer.routines.stopping.Wait()
-
- // close queues
-
- close(peer.queue.nonce)
- close(peer.queue.outbound)
- close(peer.queue.inbound)
+ // Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit.
+ peer.queue.inbound.c <- nil
+ peer.queue.outbound.c <- nil
+ peer.stopping.Wait()
+ peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us
peer.ZeroAndFlushAll()
}
-var RoamingDisabled bool
+func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ if peer.endpoint.disableRoaming {
+ return
+ }
+ peer.endpoint.clearSrcOnTx = false
+ peer.endpoint.val = endpoint
+}
-func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) {
- if RoamingDisabled {
+func (peer *Peer) markEndpointSrcForClearing() {
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ if peer.endpoint.val == nil {
return
}
- peer.Lock()
- peer.endpoint = endpoint
- peer.Unlock()
+ peer.endpoint.clearSrcOnTx = true
}
diff --git a/device/pools.go b/device/pools.go
index 98f4ef1..94f3dc7 100644
--- a/device/pools.go
+++ b/device/pools.go
@@ -1,89 +1,120 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
-import "sync"
+import (
+ "sync"
+ "sync/atomic"
+)
-func (device *Device) PopulatePools() {
- if PreallocatedBuffersPerPool == 0 {
- device.pool.messageBufferPool = &sync.Pool{
- New: func() interface{} {
- return new([MaxMessageSize]byte)
- },
- }
- device.pool.inboundElementPool = &sync.Pool{
- New: func() interface{} {
- return new(QueueInboundElement)
- },
- }
- device.pool.outboundElementPool = &sync.Pool{
- New: func() interface{} {
- return new(QueueOutboundElement)
- },
- }
- } else {
- device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, PreallocatedBuffersPerPool)
- for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
- device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte)
- }
- device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, PreallocatedBuffersPerPool)
- for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
- device.pool.inboundElementReuseChan <- new(QueueInboundElement)
- }
- device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, PreallocatedBuffersPerPool)
- for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
- device.pool.outboundElementReuseChan <- new(QueueOutboundElement)
+type WaitPool struct {
+ pool sync.Pool
+ cond sync.Cond
+ lock sync.Mutex
+ count atomic.Uint32
+ max uint32
+}
+
+func NewWaitPool(max uint32, new func() any) *WaitPool {
+ p := &WaitPool{pool: sync.Pool{New: new}, max: max}
+ p.cond = sync.Cond{L: &p.lock}
+ return p
+}
+
+func (p *WaitPool) Get() any {
+ if p.max != 0 {
+ p.lock.Lock()
+ for p.count.Load() >= p.max {
+ p.cond.Wait()
}
+ p.count.Add(1)
+ p.lock.Unlock()
}
+ return p.pool.Get()
}
-func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
- if PreallocatedBuffersPerPool == 0 {
- return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte)
- } else {
- return <-device.pool.messageBufferReuseChan
+func (p *WaitPool) Put(x any) {
+ p.pool.Put(x)
+ if p.max == 0 {
+ return
}
+ p.count.Add(^uint32(0))
+ p.cond.Signal()
}
-func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
- if PreallocatedBuffersPerPool == 0 {
- device.pool.messageBufferPool.Put(msg)
- } else {
- device.pool.messageBufferReuseChan <- msg
- }
+func (device *Device) PopulatePools() {
+ device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ s := make([]*QueueInboundElement, 0, device.BatchSize())
+ return &QueueInboundElementsContainer{elems: s}
+ })
+ device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ s := make([]*QueueOutboundElement, 0, device.BatchSize())
+ return &QueueOutboundElementsContainer{elems: s}
+ })
+ device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ return new([MaxMessageSize]byte)
+ })
+ device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ return new(QueueInboundElement)
+ })
+ device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ return new(QueueOutboundElement)
+ })
}
-func (device *Device) GetInboundElement() *QueueInboundElement {
- if PreallocatedBuffersPerPool == 0 {
- return device.pool.inboundElementPool.Get().(*QueueInboundElement)
- } else {
- return <-device.pool.inboundElementReuseChan
+func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
+ c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
+ c.Mutex = sync.Mutex{}
+ return c
+}
+
+func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
+ for i := range c.elems {
+ c.elems[i] = nil
}
+ c.elems = c.elems[:0]
+ device.pool.inboundElementsContainer.Put(c)
+}
+
+func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
+ c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
+ c.Mutex = sync.Mutex{}
+ return c
}
-func (device *Device) PutInboundElement(msg *QueueInboundElement) {
- if PreallocatedBuffersPerPool == 0 {
- device.pool.inboundElementPool.Put(msg)
- } else {
- device.pool.inboundElementReuseChan <- msg
+func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
+ for i := range c.elems {
+ c.elems[i] = nil
}
+ c.elems = c.elems[:0]
+ device.pool.outboundElementsContainer.Put(c)
+}
+
+func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
+ return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
+}
+
+func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
+ device.pool.messageBuffers.Put(msg)
+}
+
+func (device *Device) GetInboundElement() *QueueInboundElement {
+ return device.pool.inboundElements.Get().(*QueueInboundElement)
+}
+
+func (device *Device) PutInboundElement(elem *QueueInboundElement) {
+ elem.clearPointers()
+ device.pool.inboundElements.Put(elem)
}
func (device *Device) GetOutboundElement() *QueueOutboundElement {
- if PreallocatedBuffersPerPool == 0 {
- return device.pool.outboundElementPool.Get().(*QueueOutboundElement)
- } else {
- return <-device.pool.outboundElementReuseChan
- }
+ return device.pool.outboundElements.Get().(*QueueOutboundElement)
}
-func (device *Device) PutOutboundElement(msg *QueueOutboundElement) {
- if PreallocatedBuffersPerPool == 0 {
- device.pool.outboundElementPool.Put(msg)
- } else {
- device.pool.outboundElementReuseChan <- msg
- }
+func (device *Device) PutOutboundElement(elem *QueueOutboundElement) {
+ elem.clearPointers()
+ device.pool.outboundElements.Put(elem)
}
diff --git a/device/pools_test.go b/device/pools_test.go
new file mode 100644
index 0000000..82d7493
--- /dev/null
+++ b/device/pools_test.go
@@ -0,0 +1,139 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "math/rand"
+ "runtime"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+func TestWaitPool(t *testing.T) {
+ t.Skip("Currently disabled")
+ var wg sync.WaitGroup
+ var trials atomic.Int32
+ startTrials := int32(100000)
+ if raceEnabled {
+ // This test can be very slow with -race.
+ startTrials /= 10
+ }
+ trials.Store(startTrials)
+ workers := runtime.NumCPU() + 2
+ if workers-4 <= 0 {
+ t.Skip("Not enough cores")
+ }
+ p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
+ wg.Add(workers)
+ var max atomic.Uint32
+ updateMax := func() {
+ count := p.count.Load()
+ if count > p.max {
+ t.Errorf("count (%d) > max (%d)", count, p.max)
+ }
+ for {
+ old := max.Load()
+ if count <= old {
+ break
+ }
+ if max.CompareAndSwap(old, count) {
+ break
+ }
+ }
+ }
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ for trials.Add(-1) > 0 {
+ updateMax()
+ x := p.Get()
+ updateMax()
+ time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
+ updateMax()
+ p.Put(x)
+ updateMax()
+ }
+ }()
+ }
+ wg.Wait()
+ if max.Load() != p.max {
+ t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
+ }
+}
+
+func BenchmarkWaitPool(b *testing.B) {
+ var wg sync.WaitGroup
+ var trials atomic.Int32
+ trials.Store(int32(b.N))
+ workers := runtime.NumCPU() + 2
+ if workers-4 <= 0 {
+ b.Skip("Not enough cores")
+ }
+ p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
+ wg.Add(workers)
+ b.ResetTimer()
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ for trials.Add(-1) > 0 {
+ x := p.Get()
+ time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
+ p.Put(x)
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+func BenchmarkWaitPoolEmpty(b *testing.B) {
+ var wg sync.WaitGroup
+ var trials atomic.Int32
+ trials.Store(int32(b.N))
+ workers := runtime.NumCPU() + 2
+ if workers-4 <= 0 {
+ b.Skip("Not enough cores")
+ }
+ p := NewWaitPool(0, func() any { return make([]byte, 16) })
+ wg.Add(workers)
+ b.ResetTimer()
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ for trials.Add(-1) > 0 {
+ x := p.Get()
+ time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
+ p.Put(x)
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+func BenchmarkSyncPool(b *testing.B) {
+ var wg sync.WaitGroup
+ var trials atomic.Int32
+ trials.Store(int32(b.N))
+ workers := runtime.NumCPU() + 2
+ if workers-4 <= 0 {
+ b.Skip("Not enough cores")
+ }
+ p := sync.Pool{New: func() any { return make([]byte, 16) }}
+ wg.Add(workers)
+ b.ResetTimer()
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ for trials.Add(-1) > 0 {
+ x := p.Get()
+ time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
+ p.Put(x)
+ }
+ }()
+ }
+ wg.Wait()
+}
diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go
index f5c042d..25f700a 100644
--- a/device/queueconstants_android.go
+++ b/device/queueconstants_android.go
@@ -1,16 +1,19 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
+import "golang.zx2c4.com/wireguard/conn"
+
/* Reduce memory consumption for Android */
const (
+ QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
- MaxSegmentSize = 2200
+ MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
PreallocatedBuffersPerPool = 4096
)
diff --git a/device/queueconstants_default.go b/device/queueconstants_default.go
index cf86ba1..ea763d0 100644
--- a/device/queueconstants_default.go
+++ b/device/queueconstants_default.go
@@ -1,13 +1,16 @@
-// +build !android,!ios
+//go:build !android && !ios && !windows
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
+import "golang.zx2c4.com/wireguard/conn"
+
const (
+ QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
diff --git a/device/queueconstants_ios.go b/device/queueconstants_ios.go
index 589b0aa..acd3cec 100644
--- a/device/queueconstants_ios.go
+++ b/device/queueconstants_ios.go
@@ -1,18 +1,21 @@
-// +build ios
+//go:build ios
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
-/* Fit within memory limits for iOS's Network Extension API, which has stricter requirements */
-
-const (
- QueueOutboundSize = 1024
- QueueInboundSize = 1024
- QueueHandshakeSize = 1024
- MaxSegmentSize = 1700
- PreallocatedBuffersPerPool = 1024
+// Fit within memory limits for iOS's Network Extension API, which has stricter requirements.
+// These are vars instead of consts, because heavier network extensions might want to reduce
+// them further.
+var (
+ QueueStagedSize = 128
+ QueueOutboundSize = 1024
+ QueueInboundSize = 1024
+ QueueHandshakeSize = 1024
+ PreallocatedBuffersPerPool uint32 = 1024
)
+
+const MaxSegmentSize = 1700
diff --git a/device/queueconstants_windows.go b/device/queueconstants_windows.go
new file mode 100644
index 0000000..1eee32b
--- /dev/null
+++ b/device/queueconstants_windows.go
@@ -0,0 +1,15 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+const (
+ QueueStagedSize = 128
+ QueueOutboundSize = 1024
+ QueueInboundSize = 1024
+ QueueHandshakeSize = 1024
+ MaxSegmentSize = 2048 - 32 // largest possible UDP datagram
+ PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth
+)
diff --git a/device/race_disabled_test.go b/device/race_disabled_test.go
new file mode 100644
index 0000000..bb5c450
--- /dev/null
+++ b/device/race_disabled_test.go
@@ -0,0 +1,10 @@
+//go:build !race
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+const raceEnabled = false
diff --git a/device/race_enabled_test.go b/device/race_enabled_test.go
new file mode 100644
index 0000000..4e9daea
--- /dev/null
+++ b/device/race_enabled_test.go
@@ -0,0 +1,10 @@
+//go:build race
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+const raceEnabled = true
diff --git a/device/receive.go b/device/receive.go
index 7d0693e..1ab3e29 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -8,66 +8,46 @@ package device
import (
"bytes"
"encoding/binary"
+ "errors"
"net"
- "strconv"
"sync"
- "sync/atomic"
"time"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
+ "golang.zx2c4.com/wireguard/conn"
)
type QueueHandshakeElement struct {
msgType uint32
packet []byte
- endpoint Endpoint
+ endpoint conn.Endpoint
buffer *[MaxMessageSize]byte
}
type QueueInboundElement struct {
- dropped int32
- sync.Mutex
buffer *[MaxMessageSize]byte
packet []byte
counter uint64
keypair *Keypair
- endpoint Endpoint
-}
-
-func (elem *QueueInboundElement) Drop() {
- atomic.StoreInt32(&elem.dropped, AtomicTrue)
+ endpoint conn.Endpoint
}
-func (elem *QueueInboundElement) IsDropped() bool {
- return atomic.LoadInt32(&elem.dropped) == AtomicTrue
-}
-
-func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueInboundElement, decryptionQueue chan *QueueInboundElement, element *QueueInboundElement) bool {
- select {
- case inboundQueue <- element:
- select {
- case decryptionQueue <- element:
- return true
- default:
- element.Drop()
- element.Unlock()
- return false
- }
- default:
- device.PutInboundElement(element)
- return false
- }
+type QueueInboundElementsContainer struct {
+ sync.Mutex
+ elems []*QueueInboundElement
}
-func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, element QueueHandshakeElement) bool {
- select {
- case queue <- element:
- return true
- default:
- return false
- }
+// clearPointers clears elem fields that contain pointers.
+// This makes the garbage collector's life easier and
+// avoids accidentally keeping other objects around unnecessarily.
+// It also reduces the possible collateral damage from use-after-free bugs.
+func (elem *QueueInboundElement) clearPointers() {
+ elem.buffer = nil
+ elem.packet = nil
+ elem.keypair = nil
+ elem.endpoint = nil
}
/* Called when a new authenticated message has been received
@@ -75,12 +55,12 @@ func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, elem
* NOTE: Not thread safe, but called by sequential receiver!
*/
func (peer *Peer) keepKeyFreshReceiving() {
- if peer.timers.sentLastMinuteHandshake.Get() {
+ if peer.timers.sentLastMinuteHandshake.Load() {
return
}
keypair := peer.keypairs.Current()
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
- peer.timers.sentLastMinuteHandshake.Set(true)
+ peer.timers.sentLastMinuteHandshake.Store(true)
peer.SendHandshakeInitiation(false)
}
}
@@ -90,188 +70,189 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately)
*/
-func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
-
- logDebug := device.log.Debug
+func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
+ recvName := recv.PrettyName()
defer func() {
- logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped")
+ device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
+ device.queue.decryption.wg.Done()
+ device.queue.handshake.wg.Done()
device.net.stopping.Done()
}()
- logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - started")
- device.net.starting.Done()
+ device.log.Verbosef("Routine: receive incoming %s - started", recvName)
// receive datagrams until conn is closed
- buffer := device.GetMessageBuffer()
-
var (
- err error
- size int
- endpoint Endpoint
+ bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
+ bufs = make([][]byte, maxBatchSize)
+ err error
+ sizes = make([]int, maxBatchSize)
+ count int
+ endpoints = make([]conn.Endpoint, maxBatchSize)
+ deathSpiral int
+ elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
)
- for {
-
- // read next datagram
+ for i := range bufsArrs {
+ bufsArrs[i] = device.GetMessageBuffer()
+ bufs[i] = bufsArrs[i][:]
+ }
- switch IP {
- case ipv4.Version:
- size, endpoint, err = bind.ReceiveIPv4(buffer[:])
- case ipv6.Version:
- size, endpoint, err = bind.ReceiveIPv6(buffer[:])
- default:
- panic("invalid IP version")
+ defer func() {
+ for i := 0; i < maxBatchSize; i++ {
+ if bufsArrs[i] != nil {
+ device.PutMessageBuffer(bufsArrs[i])
+ }
}
+ }()
+ for {
+ count, err = recv(bufs, sizes, endpoints)
if err != nil {
- device.PutMessageBuffer(buffer)
+ if errors.Is(err, net.ErrClosed) {
+ return
+ }
+ device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
+ if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
+ return
+ }
+ if deathSpiral < 10 {
+ deathSpiral++
+ time.Sleep(time.Second / 3)
+ continue
+ }
return
}
+ deathSpiral = 0
- if size < MinMessageSize {
- continue
- }
-
- // check size of packet
+ // handle each packet in the batch
+ for i, size := range sizes[:count] {
+ if size < MinMessageSize {
+ continue
+ }
- packet := buffer[:size]
- msgType := binary.LittleEndian.Uint32(packet[:4])
+ // check size of packet
- var okay bool
+ packet := bufsArrs[i][:size]
+ msgType := binary.LittleEndian.Uint32(packet[:4])
- switch msgType {
+ switch msgType {
- // check if transport
+ // check if transport
- case MessageTransportType:
+ case MessageTransportType:
- // check size
+ // check size
- if len(packet) < MessageTransportSize {
- continue
- }
+ if len(packet) < MessageTransportSize {
+ continue
+ }
- // lookup key pair
+ // lookup key pair
- receiver := binary.LittleEndian.Uint32(
- packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
- )
- value := device.indexTable.Lookup(receiver)
- keypair := value.keypair
- if keypair == nil {
- continue
- }
+ receiver := binary.LittleEndian.Uint32(
+ packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
+ )
+ value := device.indexTable.Lookup(receiver)
+ keypair := value.keypair
+ if keypair == nil {
+ continue
+ }
- // check keypair expiry
+ // check keypair expiry
- if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
- continue
- }
+ if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
+ continue
+ }
- // create work element
- peer := value.peer
- elem := device.GetInboundElement()
- elem.packet = packet
- elem.buffer = buffer
- elem.keypair = keypair
- elem.dropped = AtomicFalse
- elem.endpoint = endpoint
- elem.counter = 0
- elem.Mutex = sync.Mutex{}
- elem.Lock()
-
- // add to decryption queues
-
- if peer.isRunning.Get() {
- if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) {
- buffer = device.GetMessageBuffer()
+ // create work element
+ peer := value.peer
+ elem := device.GetInboundElement()
+ elem.packet = packet
+ elem.buffer = bufsArrs[i]
+ elem.keypair = keypair
+ elem.endpoint = endpoints[i]
+ elem.counter = 0
+
+ elemsForPeer, ok := elemsByPeer[peer]
+ if !ok {
+ elemsForPeer = device.GetInboundElementsContainer()
+ elemsForPeer.Lock()
+ elemsByPeer[peer] = elemsForPeer
}
- }
+ elemsForPeer.elems = append(elemsForPeer.elems, elem)
+ bufsArrs[i] = device.GetMessageBuffer()
+ bufs[i] = bufsArrs[i][:]
+ continue
- continue
+ // otherwise it is a fixed size & handshake related packet
- // otherwise it is a fixed size & handshake related packet
+ case MessageInitiationType:
+ if len(packet) != MessageInitiationSize {
+ continue
+ }
- case MessageInitiationType:
- okay = len(packet) == MessageInitiationSize
+ case MessageResponseType:
+ if len(packet) != MessageResponseSize {
+ continue
+ }
- case MessageResponseType:
- okay = len(packet) == MessageResponseSize
+ case MessageCookieReplyType:
+ if len(packet) != MessageCookieReplySize {
+ continue
+ }
- case MessageCookieReplyType:
- okay = len(packet) == MessageCookieReplySize
+ default:
+ device.log.Verbosef("Received message with unknown type")
+ continue
+ }
- default:
- logDebug.Println("Received message with unknown type")
+ select {
+ case device.queue.handshake.c <- QueueHandshakeElement{
+ msgType: msgType,
+ buffer: bufsArrs[i],
+ packet: packet,
+ endpoint: endpoints[i],
+ }:
+ bufsArrs[i] = device.GetMessageBuffer()
+ bufs[i] = bufsArrs[i][:]
+ default:
+ }
}
-
- if okay {
- if (device.addToHandshakeQueue(
- device.queue.handshake,
- QueueHandshakeElement{
- msgType: msgType,
- buffer: buffer,
- packet: packet,
- endpoint: endpoint,
- },
- )) {
- buffer = device.GetMessageBuffer()
+ for peer, elemsContainer := range elemsByPeer {
+ if peer.isRunning.Load() {
+ peer.queue.inbound.c <- elemsContainer
+ device.queue.decryption.c <- elemsContainer
+ } else {
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
+ }
+ device.PutInboundElementsContainer(elemsContainer)
}
+ delete(elemsByPeer, peer)
}
}
}
-func (device *Device) RoutineDecryption() {
-
+func (device *Device) RoutineDecryption(id int) {
var nonce [chacha20poly1305.NonceSize]byte
- logDebug := device.log.Debug
- defer func() {
- logDebug.Println("Routine: decryption worker - stopped")
- device.state.stopping.Done()
- }()
- logDebug.Println("Routine: decryption worker - started")
- device.state.starting.Done()
-
- for {
- select {
- case <-device.signals.stop:
- return
-
- case elem, ok := <-device.queue.decryption:
-
- if !ok {
- return
- }
-
- // check if dropped
-
- if elem.IsDropped() {
- continue
- }
+ defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
+ device.log.Verbosef("Routine: decryption worker %d - started", id)
+ for elemsContainer := range device.queue.decryption.c {
+ for _, elem := range elemsContainer.elems {
// split message into fields
-
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
content := elem.packet[MessageTransportOffsetContent:]
- // expand nonce
-
- nonce[0x4] = counter[0x0]
- nonce[0x5] = counter[0x1]
- nonce[0x6] = counter[0x2]
- nonce[0x7] = counter[0x3]
-
- nonce[0x8] = counter[0x4]
- nonce[0x9] = counter[0x5]
- nonce[0xa] = counter[0x6]
- nonce[0xb] = counter[0x7]
-
// decrypt and release to consumer
-
var err error
elem.counter = binary.LittleEndian.Uint64(counter)
+ // copy counter to nonce
+ binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
elem.packet, err = elem.keypair.receive.Open(
content[:0],
nonce[:],
@@ -279,51 +260,23 @@ func (device *Device) RoutineDecryption() {
nil,
)
if err != nil {
- elem.Drop()
- device.PutMessageBuffer(elem.buffer)
+ elem.packet = nil
}
- elem.Unlock()
}
+ elemsContainer.Unlock()
}
}
/* Handles incoming packets related to handshake
*/
-func (device *Device) RoutineHandshake() {
-
- logInfo := device.log.Info
- logError := device.log.Error
- logDebug := device.log.Debug
-
- var elem QueueHandshakeElement
- var ok bool
-
+func (device *Device) RoutineHandshake(id int) {
defer func() {
- logDebug.Println("Routine: handshake worker - stopped")
- device.state.stopping.Done()
- if elem.buffer != nil {
- device.PutMessageBuffer(elem.buffer)
- }
+ device.log.Verbosef("Routine: handshake worker %d - stopped", id)
+ device.queue.encryption.wg.Done()
}()
+ device.log.Verbosef("Routine: handshake worker %d - started", id)
- logDebug.Println("Routine: handshake worker - started")
- device.state.starting.Done()
-
- for {
- if elem.buffer != nil {
- device.PutMessageBuffer(elem.buffer)
- elem.buffer = nil
- }
-
- select {
- case elem, ok = <-device.queue.handshake:
- case <-device.signals.stop:
- return
- }
-
- if !ok {
- return
- }
+ for elem := range device.queue.handshake.c {
// handle cookie fields and ratelimiting
@@ -337,8 +290,8 @@ func (device *Device) RoutineHandshake() {
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &reply)
if err != nil {
- logDebug.Println("Failed to decode cookie reply")
- return
+ device.log.Verbosef("Failed to decode cookie reply")
+ goto skip
}
// lookup peer from index
@@ -346,27 +299,27 @@ func (device *Device) RoutineHandshake() {
entry := device.indexTable.Lookup(reply.Receiver)
if entry.peer == nil {
- continue
+ goto skip
}
// consume reply
- if peer := entry.peer; peer.isRunning.Get() {
- logDebug.Println("Receiving cookie response from ", elem.endpoint.DstToString())
+ if peer := entry.peer; peer.isRunning.Load() {
+ device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
if !peer.cookieGenerator.ConsumeReply(&reply) {
- logDebug.Println("Could not decrypt invalid cookie response")
+ device.log.Verbosef("Could not decrypt invalid cookie response")
}
}
- continue
+ goto skip
case MessageInitiationType, MessageResponseType:
// check mac fields and maybe ratelimit
if !device.cookieChecker.CheckMAC1(elem.packet) {
- logDebug.Println("Received packet with invalid mac1")
- continue
+ device.log.Verbosef("Received packet with invalid mac1")
+ goto skip
}
// endpoints destination address is the source of the datagram
@@ -377,19 +330,19 @@ func (device *Device) RoutineHandshake() {
if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
device.SendHandshakeCookie(&elem)
- continue
+ goto skip
}
// check ratelimiter
if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
- continue
+ goto skip
}
}
default:
- logError.Println("Invalid packet ended up in the handshake queue")
- continue
+ device.log.Errorf("Invalid packet ended up in the handshake queue")
+ goto skip
}
// handle handshake initiation/response content
@@ -403,19 +356,16 @@ func (device *Device) RoutineHandshake() {
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil {
- logError.Println("Failed to decode initiation message")
- continue
+ device.log.Errorf("Failed to decode initiation message")
+ goto skip
}
// consume initiation
peer := device.ConsumeMessageInitiation(&msg)
if peer == nil {
- logInfo.Println(
- "Received invalid initiation message from",
- elem.endpoint.DstToString(),
- )
- continue
+ device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
+ goto skip
}
// update timers
@@ -426,8 +376,8 @@ func (device *Device) RoutineHandshake() {
// update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
- logDebug.Println(peer, "- Received handshake initiation")
- atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
+ device.log.Verbosef("%v - Received handshake initiation", peer)
+ peer.rxBytes.Add(uint64(len(elem.packet)))
peer.SendHandshakeResponse()
@@ -439,26 +389,23 @@ func (device *Device) RoutineHandshake() {
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil {
- logError.Println("Failed to decode response message")
- continue
+ device.log.Errorf("Failed to decode response message")
+ goto skip
}
// consume response
peer := device.ConsumeMessageResponse(&msg)
if peer == nil {
- logInfo.Println(
- "Received invalid response message from",
- elem.endpoint.DstToString(),
- )
- continue
+ device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
+ goto skip
}
// update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
- logDebug.Println(peer, "- Received handshake response")
- atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
+ device.log.Verbosef("%v - Received handshake response", peer)
+ peer.rxBytes.Add(uint64(len(elem.packet)))
// update timers
@@ -470,178 +417,124 @@ func (device *Device) RoutineHandshake() {
err = peer.BeginSymmetricSession()
if err != nil {
- logError.Println(peer, "- Failed to derive keypair:", err)
- continue
+ device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
+ goto skip
}
peer.timersSessionDerived()
peer.timersHandshakeComplete()
peer.SendKeepalive()
- select {
- case peer.signals.newKeypairArrived <- struct{}{}:
- default:
- }
}
+ skip:
+ device.PutMessageBuffer(elem.buffer)
}
}
-func (peer *Peer) RoutineSequentialReceiver() {
-
+func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
device := peer.device
- logInfo := device.log.Info
- logError := device.log.Error
- logDebug := device.log.Debug
-
- var elem *QueueInboundElement
-
defer func() {
- logDebug.Println(peer, "- Routine: sequential receiver - stopped")
- peer.routines.stopping.Done()
- if elem != nil {
- if !elem.IsDropped() {
- device.PutMessageBuffer(elem.buffer)
- }
- device.PutInboundElement(elem)
- }
+ device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
+ peer.stopping.Done()
}()
+ device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
- logDebug.Println(peer, "- Routine: sequential receiver - started")
-
- peer.routines.starting.Done()
-
- for {
- if elem != nil {
- if !elem.IsDropped() {
- device.PutMessageBuffer(elem.buffer)
- }
- device.PutInboundElement(elem)
- elem = nil
- }
+ bufs := make([][]byte, 0, maxBatchSize)
- var elemOk bool
- select {
- case <-peer.routines.stop:
+ for elemsContainer := range peer.queue.inbound.c {
+ if elemsContainer == nil {
return
- case elem, elemOk = <-peer.queue.inbound:
- if !elemOk {
- return
- }
- }
-
- // wait for decryption
-
- elem.Lock()
-
- if elem.IsDropped() {
- continue
- }
-
- // check for replay
-
- if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
- continue
- }
-
- // update endpoint
- peer.SetEndpointFromPacket(elem.endpoint)
-
- // check if using new keypair
- if peer.ReceivedWithKeypair(elem.keypair) {
- peer.timersHandshakeComplete()
- select {
- case peer.signals.newKeypairArrived <- struct{}{}:
- default:
- }
- }
-
- peer.keepKeyFreshReceiving()
- peer.timersAnyAuthenticatedPacketTraversal()
- peer.timersAnyAuthenticatedPacketReceived()
- atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize))
-
- // check for keepalive
-
- if len(elem.packet) == 0 {
- logDebug.Println(peer, "- Receiving keepalive packet")
- continue
}
- peer.timersDataReceived()
-
- // verify source and strip padding
-
- switch elem.packet[0] >> 4 {
- case ipv4.Version:
-
- // strip padding
-
- if len(elem.packet) < ipv4.HeaderLen {
+ elemsContainer.Lock()
+ validTailPacket := -1
+ dataPacketReceived := false
+ rxBytesLen := uint64(0)
+ for i, elem := range elemsContainer.elems {
+ if elem.packet == nil {
+ // decryption failed
continue
}
- field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
- length := binary.BigEndian.Uint16(field)
- if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
+ if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
continue
}
- elem.packet = elem.packet[:length]
-
- // verify IPv4 source
-
- src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
- if device.allowedips.LookupIPv4(src) != peer {
- logInfo.Println(
- "IPv4 packet with disallowed source address from",
- peer,
- )
- continue
+ validTailPacket = i
+ if peer.ReceivedWithKeypair(elem.keypair) {
+ peer.SetEndpointFromPacket(elem.endpoint)
+ peer.timersHandshakeComplete()
+ peer.SendStagedPackets()
}
+ rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
- case ipv6.Version:
-
- // strip padding
-
- if len(elem.packet) < ipv6.HeaderLen {
+ if len(elem.packet) == 0 {
+ device.log.Verbosef("%v - Receiving keepalive packet", peer)
continue
}
+ dataPacketReceived = true
- field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
- length := binary.BigEndian.Uint16(field)
- length += ipv6.HeaderLen
- if int(length) > len(elem.packet) {
- continue
- }
-
- elem.packet = elem.packet[:length]
+ switch elem.packet[0] >> 4 {
+ case 4:
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
+ field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
+ length := binary.BigEndian.Uint16(field)
+ if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
+ continue
+ }
+ elem.packet = elem.packet[:length]
+ src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
+ if device.allowedips.Lookup(src) != peer {
+ device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
+ continue
+ }
- // verify IPv6 source
+ case 6:
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
+ field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
+ length := binary.BigEndian.Uint16(field)
+ length += ipv6.HeaderLen
+ if int(length) > len(elem.packet) {
+ continue
+ }
+ elem.packet = elem.packet[:length]
+ src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
+ if device.allowedips.Lookup(src) != peer {
+ device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
+ continue
+ }
- src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
- if device.allowedips.LookupIPv6(src) != peer {
- logInfo.Println(
- "IPv6 packet with disallowed source address from",
- peer,
- )
+ default:
+ device.log.Verbosef("Packet with invalid IP version from %v", peer)
continue
}
- default:
- logInfo.Println("Packet with invalid IP version from", peer)
- continue
+ bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
}
- // write to tun device
-
- offset := MessageTransportOffsetContent
- _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset)
- if len(peer.queue.inbound) == 0 {
- err = device.tun.device.Flush()
- if err != nil {
- peer.device.log.Error.Printf("Unable to flush packets: %v", err)
+ peer.rxBytes.Add(rxBytesLen)
+ if validTailPacket >= 0 {
+ peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
+ peer.keepKeyFreshReceiving()
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersAnyAuthenticatedPacketReceived()
+ }
+ if dataPacketReceived {
+ peer.timersDataReceived()
+ }
+ if len(bufs) > 0 {
+ _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
+ if err != nil && !device.isClosed() {
+ device.log.Errorf("Failed to write packets to TUN device: %v", err)
}
}
- if err != nil && !device.isClosed.Get() {
- logError.Println("Failed to write packet to TUN device:", err)
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
}
+ bufs = bufs[:0]
+ device.PutInboundElementsContainer(elemsContainer)
}
}
diff --git a/device/send.go b/device/send.go
index 72633be..769720a 100644
--- a/device/send.go
+++ b/device/send.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -8,14 +8,17 @@ package device
import (
"bytes"
"encoding/binary"
+ "errors"
"net"
+ "os"
"sync"
- "sync/atomic"
"time"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/tun"
)
/* Outbound flow
@@ -43,8 +46,6 @@ import (
*/
type QueueOutboundElement struct {
- dropped int32
- sync.Mutex
buffer *[MaxMessageSize]byte // slice holding the packet data
packet []byte // slice of "buffer" (always!)
nonce uint64 // nonce for encryption
@@ -52,80 +53,52 @@ type QueueOutboundElement struct {
peer *Peer // related peer
}
+type QueueOutboundElementsContainer struct {
+ sync.Mutex
+ elems []*QueueOutboundElement
+}
+
func (device *Device) NewOutboundElement() *QueueOutboundElement {
elem := device.GetOutboundElement()
- elem.dropped = AtomicFalse
elem.buffer = device.GetMessageBuffer()
- elem.Mutex = sync.Mutex{}
elem.nonce = 0
- elem.keypair = nil
- elem.peer = nil
+ // keypair and peer were cleared (if necessary) by clearPointers.
return elem
}
-func (elem *QueueOutboundElement) Drop() {
- atomic.StoreInt32(&elem.dropped, AtomicTrue)
-}
-
-func (elem *QueueOutboundElement) IsDropped() bool {
- return atomic.LoadInt32(&elem.dropped) == AtomicTrue
-}
-
-func addToNonceQueue(queue chan *QueueOutboundElement, element *QueueOutboundElement, device *Device) {
- for {
- select {
- case queue <- element:
- return
- default:
- select {
- case old := <-queue:
- device.PutMessageBuffer(old.buffer)
- device.PutOutboundElement(old)
- default:
- }
- }
- }
+// clearPointers clears elem fields that contain pointers.
+// This makes the garbage collector's life easier and
+// avoids accidentally keeping other objects around unnecessarily.
+// It also reduces the possible collateral damage from use-after-free bugs.
+func (elem *QueueOutboundElement) clearPointers() {
+ elem.buffer = nil
+ elem.packet = nil
+ elem.keypair = nil
+ elem.peer = nil
}
-func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement, encryptionQueue chan *QueueOutboundElement, element *QueueOutboundElement) {
- select {
- case outboundQueue <- element:
+/* Queues a keepalive if no packets are queued for peer
+ */
+func (peer *Peer) SendKeepalive() {
+ if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
+ elem := peer.device.NewOutboundElement()
+ elemsContainer := peer.device.GetOutboundElementsContainer()
+ elemsContainer.elems = append(elemsContainer.elems, elem)
select {
- case encryptionQueue <- element:
- return
+ case peer.queue.staged <- elemsContainer:
+ peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
default:
- element.Drop()
- element.peer.device.PutMessageBuffer(element.buffer)
- element.Unlock()
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
+ peer.device.PutOutboundElementsContainer(elemsContainer)
}
- default:
- element.peer.device.PutMessageBuffer(element.buffer)
- element.peer.device.PutOutboundElement(element)
- }
-}
-
-/* Queues a keepalive if no packets are queued for peer
- */
-func (peer *Peer) SendKeepalive() bool {
- if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey.Get() || !peer.isRunning.Get() {
- return false
- }
- elem := peer.device.NewOutboundElement()
- elem.packet = nil
- select {
- case peer.queue.nonce <- elem:
- peer.device.log.Debug.Println(peer, "- Sending keepalive packet")
- return true
- default:
- peer.device.PutMessageBuffer(elem.buffer)
- peer.device.PutOutboundElement(elem)
- return false
}
+ peer.SendStagedPackets()
}
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if !isRetry {
- atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
+ peer.timers.handshakeAttempts.Store(0)
}
peer.handshake.mutex.RLock()
@@ -143,16 +116,16 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.handshake.lastSentHandshake = time.Now()
peer.handshake.mutex.Unlock()
- peer.device.log.Debug.Println(peer, "- Sending handshake initiation")
+ peer.device.log.Verbosef("%v - Sending handshake initiation", peer)
msg, err := peer.device.CreateMessageInitiation(peer)
if err != nil {
- peer.device.log.Error.Println(peer, "- Failed to create initiation message:", err)
+ peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
return err
}
- var buff [MessageInitiationSize]byte
- writer := bytes.NewBuffer(buff[:0])
+ var buf [MessageInitiationSize]byte
+ writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
@@ -160,9 +133,9 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
- err = peer.SendBuffer(packet)
+ err = peer.SendBuffers([][]byte{packet})
if err != nil {
- peer.device.log.Error.Println(peer, "- Failed to send handshake initiation", err)
+ peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
}
peer.timersHandshakeInitiated()
@@ -174,23 +147,23 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.handshake.lastSentHandshake = time.Now()
peer.handshake.mutex.Unlock()
- peer.device.log.Debug.Println(peer, "- Sending handshake response")
+ peer.device.log.Verbosef("%v - Sending handshake response", peer)
response, err := peer.device.CreateMessageResponse(peer)
if err != nil {
- peer.device.log.Error.Println(peer, "- Failed to create response message:", err)
+ peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
return err
}
- var buff [MessageResponseSize]byte
- writer := bytes.NewBuffer(buff[:0])
+ var buf [MessageResponseSize]byte
+ writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, response)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
err = peer.BeginSymmetricSession()
if err != nil {
- peer.device.log.Error.Println(peer, "- Failed to derive keypair:", err)
+ peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
return err
}
@@ -198,28 +171,29 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
- err = peer.SendBuffer(packet)
+ // TODO: allocation could be avoided
+ err = peer.SendBuffers([][]byte{packet})
if err != nil {
- peer.device.log.Error.Println(peer, "- Failed to send handshake response", err)
+ peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
}
return err
}
func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
-
- device.log.Debug.Println("Sending cookie response for denied handshake message for", initiatingElem.endpoint.DstToString())
+ device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
if err != nil {
- device.log.Error.Println("Failed to create cookie reply:", err)
+ device.log.Errorf("Failed to create cookie reply: %v", err)
return err
}
- var buff [MessageCookieReplySize]byte
- writer := bytes.NewBuffer(buff[:0])
+ var buf [MessageCookieReplySize]byte
+ writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, reply)
- device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
+ // TODO: allocation could be avoided
+ device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
return nil
}
@@ -228,280 +202,255 @@ func (peer *Peer) keepKeyFreshSending() {
if keypair == nil {
return
}
- nonce := atomic.LoadUint64(&keypair.sendNonce)
+ nonce := keypair.sendNonce.Load()
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
peer.SendHandshakeInitiation(false)
}
}
-/* Reads packets from the TUN and inserts
- * into nonce queue for peer
- *
- * Obs. Single instance per TUN device
- */
func (device *Device) RoutineReadFromTUN() {
-
- logDebug := device.log.Debug
- logError := device.log.Error
-
defer func() {
- logDebug.Println("Routine: TUN reader - stopped")
+ device.log.Verbosef("Routine: TUN reader - stopped")
device.state.stopping.Done()
+ device.queue.encryption.wg.Done()
}()
- logDebug.Println("Routine: TUN reader - started")
- device.state.starting.Done()
-
- var elem *QueueOutboundElement
+ device.log.Verbosef("Routine: TUN reader - started")
+
+ var (
+ batchSize = device.BatchSize()
+ readErr error
+ elems = make([]*QueueOutboundElement, batchSize)
+ bufs = make([][]byte, batchSize)
+ elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
+ count = 0
+ sizes = make([]int, batchSize)
+ offset = MessageTransportHeaderSize
+ )
+
+ for i := range elems {
+ elems[i] = device.NewOutboundElement()
+ bufs[i] = elems[i].buffer[:]
+ }
- for {
- if elem != nil {
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ defer func() {
+ for _, elem := range elems {
+ if elem != nil {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
}
- elem = device.NewOutboundElement()
-
- // read packet
-
- offset := MessageTransportHeaderSize
- size, err := device.tun.device.Read(elem.buffer[:], offset)
+ }()
- if err != nil {
- if !device.isClosed.Get() {
- logError.Println("Failed to read packet from TUN device:", err)
- device.Close()
+ for {
+ // read packets
+ count, readErr = device.tun.device.Read(bufs, sizes, offset)
+ for i := 0; i < count; i++ {
+ if sizes[i] < 1 {
+ continue
}
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
- return
- }
- if size == 0 || size > MaxContentSize {
- continue
- }
+ elem := elems[i]
+ elem.packet = bufs[i][offset : offset+sizes[i]]
- elem.packet = elem.buffer[offset : offset+size]
+ // lookup peer
+ var peer *Peer
+ switch elem.packet[0] >> 4 {
+ case 4:
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
+ dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
+ peer = device.allowedips.Lookup(dst)
- // lookup peer
+ case 6:
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
+ dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
+ peer = device.allowedips.Lookup(dst)
- var peer *Peer
- switch elem.packet[0] >> 4 {
- case ipv4.Version:
- if len(elem.packet) < ipv4.HeaderLen {
- continue
+ default:
+ device.log.Verbosef("Received packet with unknown IP version")
}
- dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
- peer = device.allowedips.LookupIPv4(dst)
- case ipv6.Version:
- if len(elem.packet) < ipv6.HeaderLen {
+ if peer == nil {
continue
}
- dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
- peer = device.allowedips.LookupIPv6(dst)
-
- default:
- logDebug.Println("Received packet with unknown IP version")
+ elemsForPeer, ok := elemsByPeer[peer]
+ if !ok {
+ elemsForPeer = device.GetOutboundElementsContainer()
+ elemsByPeer[peer] = elemsForPeer
+ }
+ elemsForPeer.elems = append(elemsForPeer.elems, elem)
+ elems[i] = device.NewOutboundElement()
+ bufs[i] = elems[i].buffer[:]
}
- if peer == nil {
- continue
+ for peer, elemsForPeer := range elemsByPeer {
+ if peer.isRunning.Load() {
+ peer.StagePackets(elemsForPeer)
+ peer.SendStagedPackets()
+ } else {
+ for _, elem := range elemsForPeer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
+ device.PutOutboundElementsContainer(elemsForPeer)
+ }
+ delete(elemsByPeer, peer)
}
- // insert into nonce/pre-handshake queue
-
- if peer.isRunning.Get() {
- if peer.queue.packetInNonceQueueIsAwaitingKey.Get() {
- peer.SendHandshakeInitiation(false)
+ if readErr != nil {
+ if errors.Is(readErr, tun.ErrTooManySegments) {
+ // TODO: record stat for this
+ // This will happen if MSS is surprisingly small (< 576)
+ // coincident with reasonably high throughput.
+ device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
+ continue
}
- addToNonceQueue(peer.queue.nonce, elem, device)
- elem = nil
+ if !device.isClosed() {
+ if !errors.Is(readErr, os.ErrClosed) {
+ device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
+ }
+ go device.Close()
+ }
+ return
}
}
}
-func (peer *Peer) FlushNonceQueue() {
- select {
- case peer.signals.flushNonceQueue <- struct{}{}:
- default:
- }
-}
-
-/* Queues packets when there is no handshake.
- * Then assigns nonces to packets sequentially
- * and creates "work" structs for workers
- *
- * Obs. A single instance per peer
- */
-func (peer *Peer) RoutineNonce() {
- var keypair *Keypair
-
- device := peer.device
- logDebug := device.log.Debug
-
- flush := func() {
- for {
- select {
- case elem := <-peer.queue.nonce:
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
- default:
- return
+func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
+ for {
+ select {
+ case peer.queue.staged <- elems:
+ return
+ default:
+ }
+ select {
+ case tooOld := <-peer.queue.staged:
+ for _, elem := range tooOld.elems {
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
}
+ peer.device.PutOutboundElementsContainer(tooOld)
+ default:
}
}
+}
- defer func() {
- flush()
- logDebug.Println(peer, "- Routine: nonce worker - stopped")
- peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
- peer.routines.stopping.Done()
- }()
+func (peer *Peer) SendStagedPackets() {
+top:
+ if len(peer.queue.staged) == 0 || !peer.device.isUp() {
+ return
+ }
- peer.routines.starting.Done()
- logDebug.Println(peer, "- Routine: nonce worker - started")
+ keypair := peer.keypairs.Current()
+ if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
+ peer.SendHandshakeInitiation(false)
+ return
+ }
for {
- NextPacket:
- peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
-
+ var elemsContainerOOO *QueueOutboundElementsContainer
select {
- case <-peer.routines.stop:
- return
-
- case <-peer.signals.flushNonceQueue:
- flush()
- goto NextPacket
-
- case elem, ok := <-peer.queue.nonce:
-
- if !ok {
- return
- }
-
- // make sure to always pick the newest key
-
- for {
-
- // check validity of newest key pair
-
- keypair = peer.keypairs.Current()
- if keypair != nil && keypair.sendNonce < RejectAfterMessages {
- if time.Since(keypair.created) < RejectAfterTime {
- break
+ case elemsContainer := <-peer.queue.staged:
+ i := 0
+ for _, elem := range elemsContainer.elems {
+ elem.peer = peer
+ elem.nonce = keypair.sendNonce.Add(1) - 1
+ if elem.nonce >= RejectAfterMessages {
+ keypair.sendNonce.Store(RejectAfterMessages)
+ if elemsContainerOOO == nil {
+ elemsContainerOOO = peer.device.GetOutboundElementsContainer()
}
- }
- peer.queue.packetInNonceQueueIsAwaitingKey.Set(true)
-
- // no suitable key pair, request for new handshake
-
- select {
- case <-peer.signals.newKeypairArrived:
- default:
+ elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
+ continue
+ } else {
+ elemsContainer.elems[i] = elem
+ i++
}
- peer.SendHandshakeInitiation(false)
-
- // wait for key to be established
-
- logDebug.Println(peer, "- Awaiting keypair")
+ elem.keypair = keypair
+ }
+ elemsContainer.Lock()
+ elemsContainer.elems = elemsContainer.elems[:i]
- select {
- case <-peer.signals.newKeypairArrived:
- logDebug.Println(peer, "- Obtained awaited keypair")
+ if elemsContainerOOO != nil {
+ peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
+ }
- case <-peer.signals.flushNonceQueue:
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
- flush()
- goto NextPacket
+ if len(elemsContainer.elems) == 0 {
+ peer.device.PutOutboundElementsContainer(elemsContainer)
+ goto top
+ }
- case <-peer.routines.stop:
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
- return
+ // add to parallel and sequential queue
+ if peer.isRunning.Load() {
+ peer.queue.outbound.c <- elemsContainer
+ peer.device.queue.encryption.c <- elemsContainer
+ } else {
+ for _, elem := range elemsContainer.elems {
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
}
+ peer.device.PutOutboundElementsContainer(elemsContainer)
}
- peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
-
- // populate work element
- elem.peer = peer
- elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1
-
- // double check in case of race condition added by future code
-
- if elem.nonce >= RejectAfterMessages {
- atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages)
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
- goto NextPacket
+ if elemsContainerOOO != nil {
+ goto top
}
+ default:
+ return
+ }
+ }
+}
- elem.keypair = keypair
- elem.dropped = AtomicFalse
- elem.Lock()
-
- // add to parallel and sequential queue
- addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption, elem)
+func (peer *Peer) FlushStagedPackets() {
+ for {
+ select {
+ case elemsContainer := <-peer.queue.staged:
+ for _, elem := range elemsContainer.elems {
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
+ }
+ peer.device.PutOutboundElementsContainer(elemsContainer)
+ default:
+ return
}
}
}
+func calculatePaddingSize(packetSize, mtu int) int {
+ lastUnit := packetSize
+ if mtu == 0 {
+ return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit
+ }
+ if lastUnit > mtu {
+ lastUnit %= mtu
+ }
+ paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1))
+ if paddedSize > mtu {
+ paddedSize = mtu
+ }
+ return paddedSize - lastUnit
+}
+
/* Encrypts the elements in the queue
* and marks them for sequential consumption (by releasing the mutex)
*
* Obs. One instance per core
*/
-func (device *Device) RoutineEncryption() {
-
+func (device *Device) RoutineEncryption(id int) {
+ var paddingZeros [PaddingMultiple]byte
var nonce [chacha20poly1305.NonceSize]byte
- logDebug := device.log.Debug
-
- defer func() {
- for {
- select {
- case elem, ok := <-device.queue.encryption:
- if ok && !elem.IsDropped() {
- elem.Drop()
- device.PutMessageBuffer(elem.buffer)
- elem.Unlock()
- }
- default:
- goto out
- }
- }
- out:
- logDebug.Println("Routine: encryption worker - stopped")
- device.state.stopping.Done()
- }()
-
- logDebug.Println("Routine: encryption worker - started")
- device.state.starting.Done()
-
- for {
-
- // fetch next element
-
- select {
- case <-device.signals.stop:
- return
-
- case elem, ok := <-device.queue.encryption:
-
- if !ok {
- return
- }
-
- // check if dropped
-
- if elem.IsDropped() {
- continue
- }
+ defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
+ device.log.Verbosef("Routine: encryption worker %d - started", id)
+ for elemsContainer := range device.queue.encryption.c {
+ for _, elem := range elemsContainer.elems {
// populate header fields
-
header := elem.buffer[:MessageTransportHeaderSize]
fieldType := header[0:4]
@@ -513,16 +462,8 @@ func (device *Device) RoutineEncryption() {
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to multiple of 16
-
- mtu := int(atomic.LoadInt32(&device.tun.mtu))
- lastUnit := len(elem.packet) % mtu
- paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)
- if paddedSize > mtu {
- paddedSize = mtu
- }
- for i := len(elem.packet); i < paddedSize; i++ {
- elem.packet = append(elem.packet, 0)
- }
+ paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
+ elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
// encrypt content and release to consumer
@@ -533,82 +474,73 @@ func (device *Device) RoutineEncryption() {
elem.packet,
nil,
)
- elem.Unlock()
}
+ elemsContainer.Unlock()
}
}
-/* Sequentially reads packets from queue and sends to endpoint
- *
- * Obs. Single instance per peer.
- * The routine terminates then the outbound queue is closed.
- */
-func (peer *Peer) RoutineSequentialSender() {
-
+func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
device := peer.device
-
- logDebug := device.log.Debug
- logError := device.log.Error
-
defer func() {
- for {
- select {
- case elem, ok := <-peer.queue.outbound:
- if ok {
- if !elem.IsDropped() {
- device.PutMessageBuffer(elem.buffer)
- elem.Drop()
- }
- device.PutOutboundElement(elem)
- }
- default:
- goto out
- }
- }
- out:
- logDebug.Println(peer, "- Routine: sequential sender - stopped")
- peer.routines.stopping.Done()
+ defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
+ peer.stopping.Done()
}()
+ device.log.Verbosef("%v - Routine: sequential sender - started", peer)
- logDebug.Println(peer, "- Routine: sequential sender - started")
-
- peer.routines.starting.Done()
-
- for {
- select {
+ bufs := make([][]byte, 0, maxBatchSize)
- case <-peer.routines.stop:
+ for elemsContainer := range peer.queue.outbound.c {
+ bufs = bufs[:0]
+ if elemsContainer == nil {
return
-
- case elem, ok := <-peer.queue.outbound:
-
- if !ok {
- return
- }
-
- elem.Lock()
- if elem.IsDropped() {
+ }
+ if !peer.isRunning.Load() {
+ // peer has been stopped; return re-usable elems to the shared pool.
+ // This is an optimization only. It is possible for the peer to be stopped
+ // immediately after this check, in which case, elem will get processed.
+ // The timers and SendBuffers code are resilient to a few stragglers.
+ // TODO: rework peer shutdown order to ensure
+ // that we never accidentally keep timers alive longer than necessary.
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
- continue
}
-
- peer.timersAnyAuthenticatedPacketTraversal()
- peer.timersAnyAuthenticatedPacketSent()
-
- // send message and return buffer to pool
-
- err := peer.SendBuffer(elem.packet)
+ continue
+ }
+ dataSent := false
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
if len(elem.packet) != MessageKeepaliveSize {
- peer.timersDataSent()
+ dataSent = true
}
+ bufs = append(bufs, elem.packet)
+ }
+
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersAnyAuthenticatedPacketSent()
+
+ err := peer.SendBuffers(bufs)
+ if dataSent {
+ peer.timersDataSent()
+ }
+ for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
- if err != nil {
- logError.Println(peer, "- Failed to send data packet", err)
- continue
+ }
+ device.PutOutboundElementsContainer(elemsContainer)
+ if err != nil {
+ var errGSO conn.ErrUDPGSODisabled
+ if errors.As(err, &errGSO) {
+ device.log.Verbosef(err.Error())
+ err = errGSO.RetryErr
}
-
- peer.keepKeyFreshSending()
}
+ if err != nil {
+ device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
+ continue
+ }
+
+ peer.keepKeyFreshSending()
}
}
diff --git a/device/sticky_default.go b/device/sticky_default.go
new file mode 100644
index 0000000..1038256
--- /dev/null
+++ b/device/sticky_default.go
@@ -0,0 +1,12 @@
+//go:build !linux
+
+package device
+
+import (
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/rwcancel"
+)
+
+func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
+ return nil, nil
+}
diff --git a/device/sticky_linux.go b/device/sticky_linux.go
new file mode 100644
index 0000000..6057ff1
--- /dev/null
+++ b/device/sticky_linux.go
@@ -0,0 +1,224 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ *
+ * This implements userspace semantics of "sticky sockets", modeled after
+ * WireGuard's kernelspace implementation. This is more or less a straight port
+ * of the sticky-sockets.c example code:
+ * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
+ *
+ * Currently there is no way to achieve this within the net package:
+ * See e.g. https://github.com/golang/go/issues/17930
+ * So this code is remains platform dependent.
+ */
+
+package device
+
+import (
+ "sync"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/rwcancel"
+)
+
+func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
+ if !conn.StdNetSupportsStickySockets {
+ return nil, nil
+ }
+ if _, ok := bind.(*conn.StdNetBind); !ok {
+ return nil, nil
+ }
+
+ netlinkSock, err := createNetlinkRouteSocket()
+ if err != nil {
+ return nil, err
+ }
+ netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
+ if err != nil {
+ unix.Close(netlinkSock)
+ return nil, err
+ }
+
+ go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
+
+ return netlinkCancel, nil
+}
+
+func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
+ type peerEndpointPtr struct {
+ peer *Peer
+ endpoint *conn.Endpoint
+ }
+ var reqPeer map[uint32]peerEndpointPtr
+ var reqPeerLock sync.Mutex
+
+ defer netlinkCancel.Close()
+ defer unix.Close(netlinkSock)
+
+ for msg := make([]byte, 1<<16); ; {
+ var err error
+ var msgn int
+ for {
+ msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
+ if err == nil || !rwcancel.RetryAfterError(err) {
+ break
+ }
+ if !netlinkCancel.ReadyRead() {
+ return
+ }
+ }
+ if err != nil {
+ return
+ }
+
+ for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
+
+ hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
+
+ if uint(hdr.Len) > uint(len(remain)) {
+ break
+ }
+
+ switch hdr.Type {
+ case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
+ if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
+ if uint(len(remain)) < uint(hdr.Len) {
+ break
+ }
+ if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
+ attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
+ for {
+ if uint(len(attr)) < uint(unix.SizeofRtAttr) {
+ break
+ }
+ attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
+ if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
+ break
+ }
+ if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
+ ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
+ reqPeerLock.Lock()
+ if reqPeer == nil {
+ reqPeerLock.Unlock()
+ break
+ }
+ pePtr, ok := reqPeer[hdr.Seq]
+ reqPeerLock.Unlock()
+ if !ok {
+ break
+ }
+ pePtr.peer.endpoint.Lock()
+ if &pePtr.peer.endpoint.val != pePtr.endpoint {
+ pePtr.peer.endpoint.Unlock()
+ break
+ }
+ if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
+ pePtr.peer.endpoint.Unlock()
+ break
+ }
+ pePtr.peer.endpoint.clearSrcOnTx = true
+ pePtr.peer.endpoint.Unlock()
+ }
+ attr = attr[attrhdr.Len:]
+ }
+ }
+ break
+ }
+ reqPeerLock.Lock()
+ reqPeer = make(map[uint32]peerEndpointPtr)
+ reqPeerLock.Unlock()
+ go func() {
+ device.peers.RLock()
+ i := uint32(1)
+ for _, peer := range device.peers.keyMap {
+ peer.endpoint.Lock()
+ if peer.endpoint.val == nil {
+ peer.endpoint.Unlock()
+ continue
+ }
+ nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
+ if nativeEP == nil {
+ peer.endpoint.Unlock()
+ continue
+ }
+ if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
+ peer.endpoint.Unlock()
+ break
+ }
+ nlmsg := struct {
+ hdr unix.NlMsghdr
+ msg unix.RtMsg
+ dsthdr unix.RtAttr
+ dst [4]byte
+ srchdr unix.RtAttr
+ src [4]byte
+ markhdr unix.RtAttr
+ mark uint32
+ }{
+ unix.NlMsghdr{
+ Type: uint16(unix.RTM_GETROUTE),
+ Flags: unix.NLM_F_REQUEST,
+ Seq: i,
+ },
+ unix.RtMsg{
+ Family: unix.AF_INET,
+ Dst_len: 32,
+ Src_len: 32,
+ },
+ unix.RtAttr{
+ Len: 8,
+ Type: unix.RTA_DST,
+ },
+ nativeEP.DstIP().As4(),
+ unix.RtAttr{
+ Len: 8,
+ Type: unix.RTA_SRC,
+ },
+ nativeEP.SrcIP().As4(),
+ unix.RtAttr{
+ Len: 8,
+ Type: unix.RTA_MARK,
+ },
+ device.net.fwmark,
+ }
+ nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
+ reqPeerLock.Lock()
+ reqPeer[i] = peerEndpointPtr{
+ peer: peer,
+ endpoint: &peer.endpoint.val,
+ }
+ reqPeerLock.Unlock()
+ peer.endpoint.Unlock()
+ i++
+ _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
+ if err != nil {
+ break
+ }
+ }
+ device.peers.RUnlock()
+ }()
+ }
+ remain = remain[hdr.Len:]
+ }
+ }
+}
+
+func createNetlinkRouteSocket() (int, error) {
+ sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
+ if err != nil {
+ return -1, err
+ }
+ saddr := &unix.SockaddrNetlink{
+ Family: unix.AF_NETLINK,
+ Groups: unix.RTMGRP_IPV4_ROUTE,
+ }
+ err = unix.Bind(sock, saddr)
+ if err != nil {
+ unix.Close(sock)
+ return -1, err
+ }
+ return sock, nil
+}
diff --git a/device/timers.go b/device/timers.go
index 18ee736..d4a4ed4 100644
--- a/device/timers.go
+++ b/device/timers.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*
* This is based heavily on timers.c from the kernel implementation.
*/
@@ -8,16 +8,16 @@
package device
import (
- "math/rand"
"sync"
- "sync/atomic"
"time"
+ _ "unsafe"
)
-/* This Timer structure and related functions should roughly copy the interface of
- * the Linux kernel's struct timer_list.
- */
+//go:linkname fastrandn runtime.fastrandn
+func fastrandn(n uint32) uint32
+// A Timer manages time-based aspects of the WireGuard protocol.
+// Timer roughly copies the interface of the Linux kernel's struct timer_list.
type Timer struct {
*time.Timer
modifyingLock sync.RWMutex
@@ -29,18 +29,17 @@ func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer {
timer := &Timer{}
timer.Timer = time.AfterFunc(time.Hour, func() {
timer.runningLock.Lock()
+ defer timer.runningLock.Unlock()
timer.modifyingLock.Lock()
if !timer.isPending {
timer.modifyingLock.Unlock()
- timer.runningLock.Unlock()
return
}
timer.isPending = false
timer.modifyingLock.Unlock()
expirationFunction(peer)
- timer.runningLock.Unlock()
})
timer.Stop()
return timer
@@ -74,12 +73,12 @@ func (timer *Timer) IsPending() bool {
}
func (peer *Peer) timersActive() bool {
- return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && len(peer.device.peers.keyMap) > 0
+ return peer.isRunning.Load() && peer.device != nil && peer.device.isUp()
}
func expiredRetransmitHandshake(peer *Peer) {
- if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes {
- peer.device.log.Debug.Printf("%s - Handshake did not complete after %d attempts, giving up\n", peer, MaxTimerHandshakes+2)
+ if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
+ peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
if peer.timersActive() {
peer.timers.sendKeepalive.Del()
@@ -88,7 +87,7 @@ func expiredRetransmitHandshake(peer *Peer) {
/* We drop all packets without a keypair and don't try again,
* if we try unsuccessfully for too long to make a handshake.
*/
- peer.FlushNonceQueue()
+ peer.FlushStagedPackets()
/* We set a timer for destroying any residue that might be left
* of a partial exchange.
@@ -97,15 +96,11 @@ func expiredRetransmitHandshake(peer *Peer) {
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
}
} else {
- atomic.AddUint32(&peer.timers.handshakeAttempts, 1)
- peer.device.log.Debug.Printf("%s - Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1)
+ peer.timers.handshakeAttempts.Add(1)
+ peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
/* We clear the endpoint address src address, in case this is the cause of trouble. */
- peer.Lock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
- peer.Unlock()
+ peer.markEndpointSrcForClearing()
peer.SendHandshakeInitiation(true)
}
@@ -113,8 +108,8 @@ func expiredRetransmitHandshake(peer *Peer) {
func expiredSendKeepalive(peer *Peer) {
peer.SendKeepalive()
- if peer.timers.needAnotherKeepalive.Get() {
- peer.timers.needAnotherKeepalive.Set(false)
+ if peer.timers.needAnotherKeepalive.Load() {
+ peer.timers.needAnotherKeepalive.Store(false)
if peer.timersActive() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
}
@@ -122,24 +117,19 @@ func expiredSendKeepalive(peer *Peer) {
}
func expiredNewHandshake(peer *Peer) {
- peer.device.log.Debug.Printf("%s - Retrying handshake because we stopped hearing back after %d seconds\n", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
+ peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
/* We clear the endpoint address src address, in case this is the cause of trouble. */
- peer.Lock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
- peer.Unlock()
+ peer.markEndpointSrcForClearing()
peer.SendHandshakeInitiation(false)
-
}
func expiredZeroKeyMaterial(peer *Peer) {
- peer.device.log.Debug.Printf("%s - Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds()))
+ peer.device.log.Verbosef("%s - Removing all keys, since we haven't received a new one in %d seconds", peer, int((RejectAfterTime * 3).Seconds()))
peer.ZeroAndFlushAll()
}
func expiredPersistentKeepalive(peer *Peer) {
- if peer.persistentKeepaliveInterval > 0 {
+ if peer.persistentKeepaliveInterval.Load() > 0 {
peer.SendKeepalive()
}
}
@@ -147,7 +137,7 @@ func expiredPersistentKeepalive(peer *Peer) {
/* Should be called after an authenticated data packet is sent. */
func (peer *Peer) timersDataSent() {
if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
- peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs)))
+ peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
}
}
@@ -157,7 +147,7 @@ func (peer *Peer) timersDataReceived() {
if !peer.timers.sendKeepalive.IsPending() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
} else {
- peer.timers.needAnotherKeepalive.Set(true)
+ peer.timers.needAnotherKeepalive.Store(true)
}
}
}
@@ -179,7 +169,7 @@ func (peer *Peer) timersAnyAuthenticatedPacketReceived() {
/* Should be called after a handshake initiation message is sent. */
func (peer *Peer) timersHandshakeInitiated() {
if peer.timersActive() {
- peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs)))
+ peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
}
}
@@ -188,9 +178,9 @@ func (peer *Peer) timersHandshakeComplete() {
if peer.timersActive() {
peer.timers.retransmitHandshake.Del()
}
- atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
- peer.timers.sentLastMinuteHandshake.Set(false)
- atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano())
+ peer.timers.handshakeAttempts.Store(0)
+ peer.timers.sentLastMinuteHandshake.Store(false)
+ peer.lastHandshakeNano.Store(time.Now().UnixNano())
}
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
@@ -202,8 +192,9 @@ func (peer *Peer) timersSessionDerived() {
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
- if peer.persistentKeepaliveInterval > 0 && peer.timersActive() {
- peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second)
+ keepalive := peer.persistentKeepaliveInterval.Load()
+ if keepalive > 0 && peer.timersActive() {
+ peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
}
}
@@ -213,9 +204,12 @@ func (peer *Peer) timersInit() {
peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake)
peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial)
peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive)
- atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
- peer.timers.sentLastMinuteHandshake.Set(false)
- peer.timers.needAnotherKeepalive.Set(false)
+}
+
+func (peer *Peer) timersStart() {
+ peer.timers.handshakeAttempts.Store(0)
+ peer.timers.sentLastMinuteHandshake.Store(false)
+ peer.timers.needAnotherKeepalive.Store(false)
}
func (peer *Peer) timersStop() {
diff --git a/device/tun.go b/device/tun.go
index 0a3fc79..2a2ace9 100644
--- a/device/tun.go
+++ b/device/tun.go
@@ -1,12 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
- "sync/atomic"
+ "fmt"
"golang.zx2c4.com/wireguard/tun"
)
@@ -14,43 +14,40 @@ import (
const DefaultMTU = 1420
func (device *Device) RoutineTUNEventReader() {
- setUp := false
- logDebug := device.log.Debug
- logInfo := device.log.Info
- logError := device.log.Error
-
- logDebug.Println("Routine: event worker - started")
- device.state.starting.Done()
+ device.log.Verbosef("Routine: event worker - started")
for event := range device.tun.device.Events() {
if event&tun.EventMTUUpdate != 0 {
mtu, err := device.tun.device.MTU()
- old := atomic.LoadInt32(&device.tun.mtu)
if err != nil {
- logError.Println("Failed to load updated MTU of device:", err)
- } else if int(old) != mtu {
- if mtu+MessageTransportSize > MaxMessageSize {
- logInfo.Println("MTU updated:", mtu, "(too large)")
- } else {
- logInfo.Println("MTU updated:", mtu)
- }
- atomic.StoreInt32(&device.tun.mtu, int32(mtu))
+ device.log.Errorf("Failed to load updated MTU of device: %v", err)
+ continue
+ }
+ if mtu < 0 {
+ device.log.Errorf("MTU not updated to negative value: %v", mtu)
+ continue
+ }
+ var tooLarge string
+ if mtu > MaxContentSize {
+ tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
+ mtu = MaxContentSize
+ }
+ old := device.tun.mtu.Swap(int32(mtu))
+ if int(old) != mtu {
+ device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
}
}
- if event&tun.EventUp != 0 && !setUp {
- logInfo.Println("Interface set up")
- setUp = true
+ if event&tun.EventUp != 0 {
+ device.log.Verbosef("Interface up requested")
device.Up()
}
- if event&tun.EventDown != 0 && setUp {
- logInfo.Println("Interface set down")
- setUp = false
+ if event&tun.EventDown != 0 {
+ device.log.Verbosef("Interface down requested")
device.Down()
}
}
- logDebug.Println("Routine: event worker - stopped")
- device.state.stopping.Done()
+ device.log.Verbosef("Routine: event worker - stopped")
}
diff --git a/device/tun_test.go b/device/tun_test.go
deleted file mode 100644
index 5614771..0000000
--- a/device/tun_test.go
+++ /dev/null
@@ -1,56 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import (
- "errors"
- "os"
-
- "golang.zx2c4.com/wireguard/tun"
-)
-
-// newDummyTUN creates a dummy TUN device with the specified name.
-func newDummyTUN(name string) tun.Device {
- return &dummyTUN{
- name: name,
- packets: make(chan []byte, 100),
- events: make(chan tun.Event, 10),
- }
-}
-
-// A dummyTUN is a tun.Device which is used in unit tests.
-type dummyTUN struct {
- name string
- mtu int
- packets chan []byte
- events chan tun.Event
-}
-
-func (d *dummyTUN) Events() chan tun.Event { return d.events }
-func (*dummyTUN) File() *os.File { return nil }
-func (*dummyTUN) Flush() error { return nil }
-func (d *dummyTUN) MTU() (int, error) { return d.mtu, nil }
-func (d *dummyTUN) Name() (string, error) { return d.name, nil }
-
-func (d *dummyTUN) Close() error {
- close(d.events)
- close(d.packets)
- return nil
-}
-
-func (d *dummyTUN) Read(b []byte, offset int) (int, error) {
- buf, ok := <-d.packets
- if !ok {
- return 0, errors.New("device closed")
- }
- copy(b[offset:], buf)
- return len(buf), nil
-}
-
-func (d *dummyTUN) Write(b []byte, offset int) (int, error) {
- d.packets <- b[offset:]
- return len(b), nil
-}
diff --git a/device/uapi.go b/device/uapi.go
index 999eeb5..d81dae3 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -1,43 +1,77 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"bufio"
+ "bytes"
+ "errors"
"fmt"
"io"
"net"
+ "net/netip"
"strconv"
"strings"
- "sync/atomic"
+ "sync"
"time"
"golang.zx2c4.com/wireguard/ipc"
)
type IPCError struct {
- int64
+ code int64 // error code
+ err error // underlying/wrapped error
}
func (s IPCError) Error() string {
- return fmt.Sprintf("IPC error: %d", s.int64)
+ return fmt.Sprintf("IPC error %d: %v", s.code, s.err)
+}
+
+func (s IPCError) Unwrap() error {
+ return s.err
}
func (s IPCError) ErrorCode() int64 {
- return s.int64
+ return s.code
+}
+
+func ipcErrorf(code int64, msg string, args ...any) *IPCError {
+ return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
}
-func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError {
- lines := make([]string, 0, 100)
- send := func(line string) {
- lines = append(lines, line)
+var byteBufferPool = &sync.Pool{
+ New: func() any { return new(bytes.Buffer) },
+}
+
+// IpcGetOperation implements the WireGuard configuration protocol "get" operation.
+// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
+func (device *Device) IpcGetOperation(w io.Writer) error {
+ device.ipcMutex.RLock()
+ defer device.ipcMutex.RUnlock()
+
+ buf := byteBufferPool.Get().(*bytes.Buffer)
+ buf.Reset()
+ defer byteBufferPool.Put(buf)
+ sendf := func(format string, args ...any) {
+ fmt.Fprintf(buf, format, args...)
+ buf.WriteByte('\n')
+ }
+ keyf := func(prefix string, key *[32]byte) {
+ buf.Grow(len(key)*2 + 2 + len(prefix))
+ buf.WriteString(prefix)
+ buf.WriteByte('=')
+ const hex = "0123456789abcdef"
+ for i := 0; i < len(key); i++ {
+ buf.WriteByte(hex[key[i]>>4])
+ buf.WriteByte(hex[key[i]&0xf])
+ }
+ buf.WriteByte('\n')
}
func() {
-
// lock required resources
device.net.RLock()
@@ -52,353 +86,326 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError {
// serialize device related values
if !device.staticIdentity.privateKey.IsZero() {
- send("private_key=" + device.staticIdentity.privateKey.ToHex())
+ keyf("private_key", (*[32]byte)(&device.staticIdentity.privateKey))
}
if device.net.port != 0 {
- send(fmt.Sprintf("listen_port=%d", device.net.port))
+ sendf("listen_port=%d", device.net.port)
}
if device.net.fwmark != 0 {
- send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
+ sendf("fwmark=%d", device.net.fwmark)
}
- // serialize each peer state
-
for _, peer := range device.peers.keyMap {
- peer.RLock()
- defer peer.RUnlock()
-
- send("public_key=" + peer.handshake.remoteStatic.ToHex())
- send("preshared_key=" + peer.handshake.presharedKey.ToHex())
- send("protocol_version=1")
- if peer.endpoint != nil {
- send("endpoint=" + peer.endpoint.DstToString())
+ // Serialize peer state.
+ peer.handshake.mutex.RLock()
+ keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
+ keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
+ peer.handshake.mutex.RUnlock()
+ sendf("protocol_version=1")
+ peer.endpoint.Lock()
+ if peer.endpoint.val != nil {
+ sendf("endpoint=%s", peer.endpoint.val.DstToString())
}
+ peer.endpoint.Unlock()
- nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
+ nano := peer.lastHandshakeNano.Load()
secs := nano / time.Second.Nanoseconds()
nano %= time.Second.Nanoseconds()
- send(fmt.Sprintf("last_handshake_time_sec=%d", secs))
- send(fmt.Sprintf("last_handshake_time_nsec=%d", nano))
- send(fmt.Sprintf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes)))
- send(fmt.Sprintf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)))
- send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
-
- for _, ip := range device.allowedips.EntriesForPeer(peer) {
- send("allowed_ip=" + ip.String())
- }
+ sendf("last_handshake_time_sec=%d", secs)
+ sendf("last_handshake_time_nsec=%d", nano)
+ sendf("tx_bytes=%d", peer.txBytes.Load())
+ sendf("rx_bytes=%d", peer.rxBytes.Load())
+ sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
+ device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
+ sendf("allowed_ip=%s", prefix.String())
+ return true
+ })
}
}()
// send lines (does not require resource locks)
-
- for _, line := range lines {
- _, err := socket.WriteString(line + "\n")
- if err != nil {
- return &IPCError{ipc.IpcErrorIO}
- }
+ if _, err := w.Write(buf.Bytes()); err != nil {
+ return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err)
}
return nil
}
-func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
- scanner := bufio.NewScanner(socket)
- logError := device.log.Error
- logDebug := device.log.Debug
+// IpcSetOperation implements the WireGuard configuration protocol "set" operation.
+// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
+func (device *Device) IpcSetOperation(r io.Reader) (err error) {
+ device.ipcMutex.Lock()
+ defer device.ipcMutex.Unlock()
- var peer *Peer
+ defer func() {
+ if err != nil {
+ device.log.Errorf("%v", err)
+ }
+ }()
- dummy := false
- createdNewPeer := false
+ peer := new(ipcSetPeer)
deviceConfig := true
+ scanner := bufio.NewScanner(r)
for scanner.Scan() {
-
- // parse line
-
line := scanner.Text()
if line == "" {
+ // Blank line means terminate operation.
+ peer.handlePostConfig()
return nil
}
- parts := strings.Split(line, "=")
- if len(parts) != 2 {
- return &IPCError{ipc.IpcErrorProtocol}
+ key, value, ok := strings.Cut(line, "=")
+ if !ok {
+ return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line)
}
- key := parts[0]
- value := parts[1]
-
- /* device configuration */
-
- if deviceConfig {
-
- switch key {
- case "private_key":
- var sk NoisePrivateKey
- err := sk.FromHex(value)
- if err != nil {
- logError.Println("Failed to set private_key:", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
- logDebug.Println("UAPI: Updating private key")
- device.SetPrivateKey(sk)
-
- case "listen_port":
-
- // parse port number
-
- port, err := strconv.ParseUint(value, 10, 16)
- if err != nil {
- logError.Println("Failed to parse listen_port:", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
-
- // update port and rebind
-
- logDebug.Println("UAPI: Updating listen port")
-
- device.net.Lock()
- device.net.port = uint16(port)
- device.net.Unlock()
-
- if err := device.BindUpdate(); err != nil {
- logError.Println("Failed to set listen_port:", err)
- return &IPCError{ipc.IpcErrorPortInUse}
- }
- case "fwmark":
-
- // parse fwmark field
-
- fwmark, err := func() (uint32, error) {
- if value == "" {
- return 0, nil
- }
- mark, err := strconv.ParseUint(value, 10, 32)
- return uint32(mark), err
- }()
-
- if err != nil {
- logError.Println("Invalid fwmark", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
-
- logDebug.Println("UAPI: Updating fwmark")
-
- if err := device.BindSetMark(uint32(fwmark)); err != nil {
- logError.Println("Failed to update fwmark:", err)
- return &IPCError{ipc.IpcErrorPortInUse}
- }
-
- case "public_key":
- // switch to peer configuration
- logDebug.Println("UAPI: Transition to peer configuration")
+ if key == "public_key" {
+ if deviceConfig {
deviceConfig = false
-
- case "replace_peers":
- if value != "true" {
- logError.Println("Failed to set replace_peers, invalid value:", value)
- return &IPCError{ipc.IpcErrorInvalid}
- }
- logDebug.Println("UAPI: Removing all peers")
- device.RemoveAllPeers()
-
- default:
- logError.Println("Invalid UAPI device key:", key)
- return &IPCError{ipc.IpcErrorInvalid}
}
+ peer.handlePostConfig()
+ // Load/create the peer we are now configuring.
+ err := device.handlePublicKeyLine(peer, value)
+ if err != nil {
+ return err
+ }
+ continue
}
- /* peer configuration */
-
- if !deviceConfig {
-
- switch key {
-
- case "public_key":
- var publicKey NoisePublicKey
- err := publicKey.FromHex(value)
- if err != nil {
- logError.Println("Failed to get peer by public key:", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
-
- // ignore peer with public key of device
-
- device.staticIdentity.RLock()
- dummy = device.staticIdentity.publicKey.Equals(publicKey)
- device.staticIdentity.RUnlock()
-
- if dummy {
- peer = &Peer{}
- } else {
- peer = device.LookupPeer(publicKey)
- }
+ var err error
+ if deviceConfig {
+ err = device.handleDeviceLine(key, value)
+ } else {
+ err = device.handlePeerLine(peer, key, value)
+ }
+ if err != nil {
+ return err
+ }
+ }
+ peer.handlePostConfig()
- createdNewPeer = peer == nil
- if createdNewPeer {
- peer, err = device.NewPeer(publicKey)
- if err != nil {
- logError.Println("Failed to create new peer:", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
- if peer == nil {
- dummy = true
- peer = &Peer{}
- } else {
- logDebug.Println(peer, "- UAPI: Created")
- }
- }
+ if err := scanner.Err(); err != nil {
+ return ipcErrorf(ipc.IpcErrorIO, "failed to read input: %w", err)
+ }
+ return nil
+}
- case "update_only":
+func (device *Device) handleDeviceLine(key, value string) error {
+ switch key {
+ case "private_key":
+ var sk NoisePrivateKey
+ err := sk.FromMaybeZeroHex(value)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
+ }
+ device.log.Verbosef("UAPI: Updating private key")
+ device.SetPrivateKey(sk)
- // allow disabling of creation
+ case "listen_port":
+ port, err := strconv.ParseUint(value, 10, 16)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
+ }
- if value != "true" {
- logError.Println("Failed to set update only, invalid value:", value)
- return &IPCError{ipc.IpcErrorInvalid}
- }
- if createdNewPeer && !dummy {
- device.RemovePeer(peer.handshake.remoteStatic)
- peer = &Peer{}
- dummy = true
- }
+ // update port and rebind
+ device.log.Verbosef("UAPI: Updating listen port")
- case "remove":
+ device.net.Lock()
+ device.net.port = uint16(port)
+ device.net.Unlock()
- // remove currently selected peer from device
+ if err := device.BindUpdate(); err != nil {
+ return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err)
+ }
- if value != "true" {
- logError.Println("Failed to set remove, invalid value:", value)
- return &IPCError{ipc.IpcErrorInvalid}
- }
- if !dummy {
- logDebug.Println(peer, "- UAPI: Removing")
- device.RemovePeer(peer.handshake.remoteStatic)
- }
- peer = &Peer{}
- dummy = true
+ case "fwmark":
+ mark, err := strconv.ParseUint(value, 10, 32)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err)
+ }
- case "preshared_key":
+ device.log.Verbosef("UAPI: Updating fwmark")
+ if err := device.BindSetMark(uint32(mark)); err != nil {
+ return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
+ }
- // update PSK
+ case "replace_peers":
+ if value != "true" {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
+ }
+ device.log.Verbosef("UAPI: Removing all peers")
+ device.RemoveAllPeers()
- logDebug.Println(peer, "- UAPI: Updating preshared key")
+ default:
+ return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
+ }
- peer.handshake.mutex.Lock()
- err := peer.handshake.presharedKey.FromHex(value)
- peer.handshake.mutex.Unlock()
+ return nil
+}
- if err != nil {
- logError.Println("Failed to set preshared key:", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
+// An ipcSetPeer is the current state of an IPC set operation on a peer.
+type ipcSetPeer struct {
+ *Peer // Peer is the current peer being operated on
+ dummy bool // dummy reports whether this peer is a temporary, placeholder peer
+ created bool // new reports whether this is a newly created peer
+ pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on
+}
- case "endpoint":
+func (peer *ipcSetPeer) handlePostConfig() {
+ if peer.Peer == nil || peer.dummy {
+ return
+ }
+ if peer.created {
+ peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
+ }
+ if peer.device.isUp() {
+ peer.Start()
+ if peer.pkaOn {
+ peer.SendKeepalive()
+ }
+ peer.SendStagedPackets()
+ }
+}
- // set endpoint destination
+func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error {
+ // Load/create the peer we are configuring.
+ var publicKey NoisePublicKey
+ err := publicKey.FromHex(value)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
+ }
- logDebug.Println(peer, "- UAPI: Updating endpoint")
+ // Ignore peer with the same public key as this device.
+ device.staticIdentity.RLock()
+ peer.dummy = device.staticIdentity.publicKey.Equals(publicKey)
+ device.staticIdentity.RUnlock()
- err := func() error {
- peer.Lock()
- defer peer.Unlock()
- endpoint, err := CreateEndpoint(value)
- if err != nil {
- return err
- }
- peer.endpoint = endpoint
- return nil
- }()
+ if peer.dummy {
+ peer.Peer = &Peer{}
+ } else {
+ peer.Peer = device.LookupPeer(publicKey)
+ }
- if err != nil {
- logError.Println("Failed to set endpoint:", err, ":", value)
- return &IPCError{ipc.IpcErrorInvalid}
- }
-
- case "persistent_keepalive_interval":
-
- // update persistent keepalive interval
-
- logDebug.Println(peer, "- UAPI: Updating persistent keepalive interval")
-
- secs, err := strconv.ParseUint(value, 10, 16)
- if err != nil {
- logError.Println("Failed to set persistent keepalive interval:", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
-
- old := peer.persistentKeepaliveInterval
- peer.persistentKeepaliveInterval = uint16(secs)
-
- // send immediate keepalive if we're turning it on and before it wasn't on
-
- if old == 0 && secs != 0 {
- if err != nil {
- logError.Println("Failed to get tun device status:", err)
- return &IPCError{ipc.IpcErrorIO}
- }
- if device.isUp.Get() && !dummy {
- peer.SendKeepalive()
- }
- }
+ peer.created = peer.Peer == nil
+ if peer.created {
+ peer.Peer, err = device.NewPeer(publicKey)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
+ }
+ device.log.Verbosef("%v - UAPI: Created", peer.Peer)
+ }
+ return nil
+}
- case "replace_allowed_ips":
+func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error {
+ switch key {
+ case "update_only":
+ // allow disabling of creation
+ if value != "true" {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
+ }
+ if peer.created && !peer.dummy {
+ device.RemovePeer(peer.handshake.remoteStatic)
+ peer.Peer = &Peer{}
+ peer.dummy = true
+ }
- logDebug.Println(peer, "- UAPI: Removing all allowedips")
+ case "remove":
+ // remove currently selected peer from device
+ if value != "true" {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
+ }
+ if !peer.dummy {
+ device.log.Verbosef("%v - UAPI: Removing", peer.Peer)
+ device.RemovePeer(peer.handshake.remoteStatic)
+ }
+ peer.Peer = &Peer{}
+ peer.dummy = true
- if value != "true" {
- logError.Println("Failed to replace allowedips, invalid value:", value)
- return &IPCError{ipc.IpcErrorInvalid}
- }
+ case "preshared_key":
+ device.log.Verbosef("%v - UAPI: Updating preshared key", peer.Peer)
- if dummy {
- continue
- }
+ peer.handshake.mutex.Lock()
+ err := peer.handshake.presharedKey.FromHex(value)
+ peer.handshake.mutex.Unlock()
- device.allowedips.RemoveByPeer(peer)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err)
+ }
- case "allowed_ip":
+ case "endpoint":
+ device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
+ endpoint, err := device.net.bind.ParseEndpoint(value)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
+ }
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ peer.endpoint.val = endpoint
- logDebug.Println(peer, "- UAPI: Adding allowedip")
+ case "persistent_keepalive_interval":
+ device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
- _, network, err := net.ParseCIDR(value)
- if err != nil {
- logError.Println("Failed to set allowed ip:", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
+ secs, err := strconv.ParseUint(value, 10, 16)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
+ }
- if dummy {
- continue
- }
+ old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
- ones, _ := network.Mask.Size()
- device.allowedips.Insert(network.IP, uint(ones), peer)
+ // Send immediate keepalive if we're turning it on and before it wasn't on.
+ peer.pkaOn = old == 0 && secs != 0
- case "protocol_version":
+ case "replace_allowed_ips":
+ device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
+ if value != "true" {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
+ }
+ if peer.dummy {
+ return nil
+ }
+ device.allowedips.RemoveByPeer(peer.Peer)
- if value != "1" {
- logError.Println("Invalid protocol version:", value)
- return &IPCError{ipc.IpcErrorInvalid}
- }
+ case "allowed_ip":
+ device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
+ prefix, err := netip.ParsePrefix(value)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
+ }
+ if peer.dummy {
+ return nil
+ }
+ device.allowedips.Insert(prefix, peer.Peer)
- default:
- logError.Println("Invalid UAPI peer key:", key)
- return &IPCError{ipc.IpcErrorInvalid}
- }
+ case "protocol_version":
+ if value != "1" {
+ return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
}
+
+ default:
+ return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key)
}
return nil
}
-func (device *Device) IpcHandle(socket net.Conn) {
+func (device *Device) IpcGet() (string, error) {
+ buf := new(strings.Builder)
+ if err := device.IpcGetOperation(buf); err != nil {
+ return "", err
+ }
+ return buf.String(), nil
+}
- // create buffered read/writer
+func (device *Device) IpcSet(uapiConf string) error {
+ return device.IpcSetOperation(strings.NewReader(uapiConf))
+}
+func (device *Device) IpcHandle(socket net.Conn) {
defer socket.Close()
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
@@ -407,35 +414,44 @@ func (device *Device) IpcHandle(socket net.Conn) {
return bufio.NewReadWriter(reader, writer)
}(socket)
- defer buffered.Flush()
-
- op, err := buffered.ReadString('\n')
- if err != nil {
- return
- }
-
- // handle operation
-
- var status *IPCError
-
- switch op {
- case "set=1\n":
- status = device.IpcSetOperation(buffered.Reader)
-
- case "get=1\n":
- status = device.IpcGetOperation(buffered.Writer)
-
- default:
- device.log.Error.Println("Invalid UAPI operation:", op)
- return
- }
+ for {
+ op, err := buffered.ReadString('\n')
+ if err != nil {
+ return
+ }
- // write status
+ // handle operation
+ switch op {
+ case "set=1\n":
+ err = device.IpcSetOperation(buffered.Reader)
+ case "get=1\n":
+ var nextByte byte
+ nextByte, err = buffered.ReadByte()
+ if err != nil {
+ return
+ }
+ if nextByte != '\n' {
+ err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
+ break
+ }
+ err = device.IpcGetOperation(buffered.Writer)
+ default:
+ device.log.Errorf("invalid UAPI operation: %v", op)
+ return
+ }
- if status != nil {
- device.log.Error.Println(status)
- fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
- } else {
- fmt.Fprintf(buffered, "errno=0\n\n")
+ // write status
+ var status *IPCError
+ if err != nil && !errors.As(err, &status) {
+ // shouldn't happen
+ status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err)
+ }
+ if status != nil {
+ device.log.Errorf("%v", status)
+ fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
+ } else {
+ fmt.Fprintf(buffered, "errno=0\n\n")
+ }
+ buffered.Flush()
}
}
diff --git a/device/version.go b/device/version.go
deleted file mode 100644
index 326b9a9..0000000
--- a/device/version.go
+++ /dev/null
@@ -1,3 +0,0 @@
-package device
-
-const WireGuardGoVersion = "0.0.20191012"
diff --git a/format_test.go b/format_test.go
new file mode 100644
index 0000000..6f6cab7
--- /dev/null
+++ b/format_test.go
@@ -0,0 +1,51 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+package main
+
+import (
+ "bytes"
+ "go/format"
+ "io/fs"
+ "os"
+ "path/filepath"
+ "runtime"
+ "sync"
+ "testing"
+)
+
+func TestFormatting(t *testing.T) {
+ var wg sync.WaitGroup
+ filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error {
+ if err != nil {
+ t.Errorf("unable to walk %s: %v", path, err)
+ return nil
+ }
+ if d.IsDir() || filepath.Ext(path) != ".go" {
+ return nil
+ }
+ wg.Add(1)
+ go func(path string) {
+ defer wg.Done()
+ src, err := os.ReadFile(path)
+ if err != nil {
+ t.Errorf("unable to read %s: %v", path, err)
+ return
+ }
+ if runtime.GOOS == "windows" {
+ src = bytes.ReplaceAll(src, []byte{'\r', '\n'}, []byte{'\n'})
+ }
+ formatted, err := format.Source(src)
+ if err != nil {
+ t.Errorf("unable to format %s: %v", path, err)
+ return
+ }
+ if !bytes.Equal(src, formatted) {
+ t.Errorf("unformatted code: %s", path)
+ }
+ }(path)
+ return nil
+ })
+ wg.Wait()
+}
diff --git a/go.mod b/go.mod
index 34b1e72..919dc49 100644
--- a/go.mod
+++ b/go.mod
@@ -1,10 +1,16 @@
module golang.zx2c4.com/wireguard
-go 1.12
+go 1.20
require (
- golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc
- golang.org/x/net v0.0.0-20191003171128-d98b1b443823
- golang.org/x/sys v0.0.0-20191003212358-c178f38b412c
- golang.org/x/text v0.3.2
+ golang.org/x/crypto v0.13.0
+ golang.org/x/net v0.15.0
+ golang.org/x/sys v0.12.0
+ golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
+ gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259
+)
+
+require (
+ github.com/google/btree v1.0.1 // indirect
+ golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
)
diff --git a/go.sum b/go.sum
index 970f4cb..6bcecea 100644
--- a/go.sum
+++ b/go.sum
@@ -1,14 +1,14 @@
-golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
-golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc h1:c0o/qxkaO2LF5t6fQrT4b5hzyggAkLLlCUjqfRxd8Q4=
-golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
-golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
-golang.org/x/net v0.0.0-20191003171128-d98b1b443823 h1:Ypyv6BNJh07T1pUSrehkLemqPKXhus2MkfktJ91kRh4=
-golang.org/x/net v0.0.0-20191003171128-d98b1b443823/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20191003212358-c178f38b412c h1:6Zx7DRlKXf79yfxuQ/7GqV3w2y7aDsk6bGg0MzF5RVU=
-golang.org/x/sys v0.0.0-20191003212358-c178f38b412c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
-golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
-golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
+github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
+golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck=
+golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
+golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8=
+golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
+golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
+golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44=
+golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
+golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
+gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
+gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
diff --git a/ipc/winpipe/file.go b/ipc/namedpipe/file.go
index 29d02a7..ab1e13d 100644
--- a/ipc/winpipe/file.go
+++ b/ipc/namedpipe/file.go
@@ -1,63 +1,31 @@
-// +build windows
+// Copyright 2021 The Go Authors. All rights reserved.
+// Copyright 2015 Microsoft
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2005 Microsoft
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
- */
-package winpipe
+//go:build windows
+
+package namedpipe
import (
- "errors"
"io"
+ "os"
"runtime"
"sync"
"sync/atomic"
"time"
+ "unsafe"
"golang.org/x/sys/windows"
)
-//sys cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) = CancelIoEx
-//sys createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) = CreateIoCompletionPort
-//sys getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
-//sys setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
-//sys wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
-
-type atomicBool int32
-
-func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
-func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }
-func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
-func (b *atomicBool) swap(new bool) bool {
- var newInt int32
- if new {
- newInt = 1
- }
- return atomic.SwapInt32((*int32)(b), newInt) == 1
-}
-
-const (
- cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 1
- cFILE_SKIP_SET_EVENT_ON_HANDLE = 2
-)
+type timeoutChan chan struct{}
var (
- ErrFileClosed = errors.New("file has already been closed")
- ErrTimeout = &timeoutError{}
+ ioInitOnce sync.Once
+ ioCompletionPort windows.Handle
)
-type timeoutError struct{}
-
-func (e *timeoutError) Error() string { return "i/o timeout" }
-func (e *timeoutError) Timeout() bool { return true }
-func (e *timeoutError) Temporary() bool { return true }
-
-type timeoutChan chan struct{}
-
-var ioInitOnce sync.Once
-var ioCompletionPort windows.Handle
-
// ioResult contains the result of an asynchronous IO operation
type ioResult struct {
bytes uint32
@@ -71,7 +39,7 @@ type ioOperation struct {
}
func initIo() {
- h, err := createIoCompletionPort(windows.InvalidHandle, 0, 0, 0xffffffff)
+ h, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
if err != nil {
panic(err)
}
@@ -79,13 +47,13 @@ func initIo() {
go ioCompletionProcessor(h)
}
-// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
+// file implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
// It takes ownership of this handle and will close it if it is garbage collected.
-type win32File struct {
+type file struct {
handle windows.Handle
wg sync.WaitGroup
wgLock sync.RWMutex
- closing atomicBool
+ closing atomic.Bool
socket bool
readDeadline deadlineHandler
writeDeadline deadlineHandler
@@ -96,18 +64,18 @@ type deadlineHandler struct {
channel timeoutChan
channelLock sync.RWMutex
timer *time.Timer
- timedout atomicBool
+ timedout atomic.Bool
}
-// makeWin32File makes a new win32File from an existing file handle
-func makeWin32File(h windows.Handle) (*win32File, error) {
- f := &win32File{handle: h}
+// makeFile makes a new file from an existing file handle
+func makeFile(h windows.Handle) (*file, error) {
+ f := &file{handle: h}
ioInitOnce.Do(initIo)
- _, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff)
+ _, err := windows.CreateIoCompletionPort(h, ioCompletionPort, 0, 0)
if err != nil {
return nil, err
}
- err = setFileCompletionNotificationModes(h, cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS|cFILE_SKIP_SET_EVENT_ON_HANDLE)
+ err = windows.SetFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE)
if err != nil {
return nil, err
}
@@ -116,18 +84,14 @@ func makeWin32File(h windows.Handle) (*win32File, error) {
return f, nil
}
-func MakeOpenFile(h windows.Handle) (io.ReadWriteCloser, error) {
- return makeWin32File(h)
-}
-
// closeHandle closes the resources associated with a Win32 handle
-func (f *win32File) closeHandle() {
+func (f *file) closeHandle() {
f.wgLock.Lock()
// Atomically set that we are closing, releasing the resources only once.
- if !f.closing.swap(true) {
+ if f.closing.Swap(true) == false {
f.wgLock.Unlock()
// cancel all IO and wait for it to complete
- cancelIoEx(f.handle, nil)
+ windows.CancelIoEx(f.handle, nil)
f.wg.Wait()
// at this point, no new IO can start
windows.Close(f.handle)
@@ -137,19 +101,19 @@ func (f *win32File) closeHandle() {
}
}
-// Close closes a win32File.
-func (f *win32File) Close() error {
+// Close closes a file.
+func (f *file) Close() error {
f.closeHandle()
return nil
}
// prepareIo prepares for a new IO operation.
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
-func (f *win32File) prepareIo() (*ioOperation, error) {
+func (f *file) prepareIo() (*ioOperation, error) {
f.wgLock.RLock()
- if f.closing.isSet() {
+ if f.closing.Load() {
f.wgLock.RUnlock()
- return nil, ErrFileClosed
+ return nil, os.ErrClosed
}
f.wg.Add(1)
f.wgLock.RUnlock()
@@ -164,7 +128,7 @@ func ioCompletionProcessor(h windows.Handle) {
var bytes uint32
var key uintptr
var op *ioOperation
- err := getQueuedCompletionStatus(h, &bytes, &key, &op, windows.INFINITE)
+ err := windows.GetQueuedCompletionStatus(h, &bytes, &key, (**windows.Overlapped)(unsafe.Pointer(&op)), windows.INFINITE)
if op == nil {
panic(err)
}
@@ -174,13 +138,13 @@ func ioCompletionProcessor(h windows.Handle) {
// asyncIo processes the return value from ReadFile or WriteFile, blocking until
// the operation has actually completed.
-func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
+func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
if err != windows.ERROR_IO_PENDING {
return int(bytes), err
}
- if f.closing.isSet() {
- cancelIoEx(f.handle, &c.o)
+ if f.closing.Load() {
+ windows.CancelIoEx(f.handle, &c.o)
}
var timeout timeoutChan
@@ -195,20 +159,20 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
case r = <-c.ch:
err = r.err
if err == windows.ERROR_OPERATION_ABORTED {
- if f.closing.isSet() {
- err = ErrFileClosed
+ if f.closing.Load() {
+ err = os.ErrClosed
}
} else if err != nil && f.socket {
// err is from Win32. Query the overlapped structure to get the winsock error.
var bytes, flags uint32
- err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
+ err = windows.WSAGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
}
case <-timeout:
- cancelIoEx(f.handle, &c.o)
+ windows.CancelIoEx(f.handle, &c.o)
r = <-c.ch
err = r.err
if err == windows.ERROR_OPERATION_ABORTED {
- err = ErrTimeout
+ err = os.ErrDeadlineExceeded
}
}
@@ -220,15 +184,15 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
}
// Read reads from a file handle.
-func (f *win32File) Read(b []byte) (int, error) {
+func (f *file) Read(b []byte) (int, error) {
c, err := f.prepareIo()
if err != nil {
return 0, err
}
defer f.wg.Done()
- if f.readDeadline.timedout.isSet() {
- return 0, ErrTimeout
+ if f.readDeadline.timedout.Load() {
+ return 0, os.ErrDeadlineExceeded
}
var bytes uint32
@@ -247,15 +211,15 @@ func (f *win32File) Read(b []byte) (int, error) {
}
// Write writes to a file handle.
-func (f *win32File) Write(b []byte) (int, error) {
+func (f *file) Write(b []byte) (int, error) {
c, err := f.prepareIo()
if err != nil {
return 0, err
}
defer f.wg.Done()
- if f.writeDeadline.timedout.isSet() {
- return 0, ErrTimeout
+ if f.writeDeadline.timedout.Load() {
+ return 0, os.ErrDeadlineExceeded
}
var bytes uint32
@@ -265,19 +229,19 @@ func (f *win32File) Write(b []byte) (int, error) {
return n, err
}
-func (f *win32File) SetReadDeadline(deadline time.Time) error {
+func (f *file) SetReadDeadline(deadline time.Time) error {
return f.readDeadline.set(deadline)
}
-func (f *win32File) SetWriteDeadline(deadline time.Time) error {
+func (f *file) SetWriteDeadline(deadline time.Time) error {
return f.writeDeadline.set(deadline)
}
-func (f *win32File) Flush() error {
+func (f *file) Flush() error {
return windows.FlushFileBuffers(f.handle)
}
-func (f *win32File) Fd() uintptr {
+func (f *file) Fd() uintptr {
return uintptr(f.handle)
}
@@ -291,7 +255,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
}
d.timer = nil
}
- d.timedout.setFalse()
+ d.timedout.Store(false)
select {
case <-d.channel:
@@ -306,7 +270,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
}
timeoutIO := func() {
- d.timedout.setTrue()
+ d.timedout.Store(true)
close(d.channel)
}
diff --git a/ipc/namedpipe/namedpipe.go b/ipc/namedpipe/namedpipe.go
new file mode 100644
index 0000000..ef3dea1
--- /dev/null
+++ b/ipc/namedpipe/namedpipe.go
@@ -0,0 +1,485 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Copyright 2015 Microsoft
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build windows
+
+// Package namedpipe implements a net.Conn and net.Listener around Windows named pipes.
+package namedpipe
+
+import (
+ "context"
+ "io"
+ "net"
+ "os"
+ "runtime"
+ "sync/atomic"
+ "time"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+type pipe struct {
+ *file
+ path string
+}
+
+type messageBytePipe struct {
+ pipe
+ writeClosed atomic.Bool
+ readEOF bool
+}
+
+type pipeAddress string
+
+func (f *pipe) LocalAddr() net.Addr {
+ return pipeAddress(f.path)
+}
+
+func (f *pipe) RemoteAddr() net.Addr {
+ return pipeAddress(f.path)
+}
+
+func (f *pipe) SetDeadline(t time.Time) error {
+ f.SetReadDeadline(t)
+ f.SetWriteDeadline(t)
+ return nil
+}
+
+// CloseWrite closes the write side of a message pipe in byte mode.
+func (f *messageBytePipe) CloseWrite() error {
+ if !f.writeClosed.CompareAndSwap(false, true) {
+ return io.ErrClosedPipe
+ }
+ err := f.file.Flush()
+ if err != nil {
+ f.writeClosed.Store(false)
+ return err
+ }
+ _, err = f.file.Write(nil)
+ if err != nil {
+ f.writeClosed.Store(false)
+ return err
+ }
+ return nil
+}
+
+// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
+// they are used to implement CloseWrite.
+func (f *messageBytePipe) Write(b []byte) (int, error) {
+ if f.writeClosed.Load() {
+ return 0, io.ErrClosedPipe
+ }
+ if len(b) == 0 {
+ return 0, nil
+ }
+ return f.file.Write(b)
+}
+
+// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message
+// mode pipe will return io.EOF, as will all subsequent reads.
+func (f *messageBytePipe) Read(b []byte) (int, error) {
+ if f.readEOF {
+ return 0, io.EOF
+ }
+ n, err := f.file.Read(b)
+ if err == io.EOF {
+ // If this was the result of a zero-byte read, then
+ // it is possible that the read was due to a zero-size
+ // message. Since we are simulating CloseWrite with a
+ // zero-byte message, ensure that all future Read calls
+ // also return EOF.
+ f.readEOF = true
+ } else if err == windows.ERROR_MORE_DATA {
+ // ERROR_MORE_DATA indicates that the pipe's read mode is message mode
+ // and the message still has more bytes. Treat this as a success, since
+ // this package presents all named pipes as byte streams.
+ err = nil
+ }
+ return n, err
+}
+
+func (f *pipe) Handle() windows.Handle {
+ return f.handle
+}
+
+func (s pipeAddress) Network() string {
+ return "pipe"
+}
+
+func (s pipeAddress) String() string {
+ return string(s)
+}
+
+// tryDialPipe attempts to dial the specified pipe until cancellation or timeout.
+func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) {
+ for {
+ select {
+ case <-ctx.Done():
+ return 0, ctx.Err()
+ default:
+ path16, err := windows.UTF16PtrFromString(*path)
+ if err != nil {
+ return 0, err
+ }
+ h, err := windows.CreateFile(path16, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
+ if err == nil {
+ return h, nil
+ }
+ if err != windows.ERROR_PIPE_BUSY {
+ return h, &os.PathError{Err: err, Op: "open", Path: *path}
+ }
+ // Wait 10 msec and try again. This is a rather simplistic
+ // view, as we always try each 10 milliseconds.
+ time.Sleep(10 * time.Millisecond)
+ }
+ }
+}
+
+// DialConfig exposes various options for use in Dial and DialContext.
+type DialConfig struct {
+ ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID.
+}
+
+// DialTimeout connects to the specified named pipe by path, timing out if the
+// connection takes longer than the specified duration. If timeout is zero, then
+// we use a default timeout of 2 seconds.
+func (config *DialConfig) DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
+ if timeout == 0 {
+ timeout = time.Second * 2
+ }
+ absTimeout := time.Now().Add(timeout)
+ ctx, _ := context.WithDeadline(context.Background(), absTimeout)
+ conn, err := config.DialContext(ctx, path)
+ if err == context.DeadlineExceeded {
+ return nil, os.ErrDeadlineExceeded
+ }
+ return conn, err
+}
+
+// DialContext attempts to connect to the specified named pipe by path.
+func (config *DialConfig) DialContext(ctx context.Context, path string) (net.Conn, error) {
+ var err error
+ var h windows.Handle
+ h, err = tryDialPipe(ctx, &path)
+ if err != nil {
+ return nil, err
+ }
+
+ if config.ExpectedOwner != nil {
+ sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION)
+ if err != nil {
+ windows.Close(h)
+ return nil, err
+ }
+ realOwner, _, err := sd.Owner()
+ if err != nil {
+ windows.Close(h)
+ return nil, err
+ }
+ if !realOwner.Equals(config.ExpectedOwner) {
+ windows.Close(h)
+ return nil, windows.ERROR_ACCESS_DENIED
+ }
+ }
+
+ var flags uint32
+ err = windows.GetNamedPipeInfo(h, &flags, nil, nil, nil)
+ if err != nil {
+ windows.Close(h)
+ return nil, err
+ }
+
+ f, err := makeFile(h)
+ if err != nil {
+ windows.Close(h)
+ return nil, err
+ }
+
+ // If the pipe is in message mode, return a message byte pipe, which
+ // supports CloseWrite.
+ if flags&windows.PIPE_TYPE_MESSAGE != 0 {
+ return &messageBytePipe{
+ pipe: pipe{file: f, path: path},
+ }, nil
+ }
+ return &pipe{file: f, path: path}, nil
+}
+
+var defaultDialer DialConfig
+
+// DialTimeout calls DialConfig.DialTimeout using an empty configuration.
+func DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
+ return defaultDialer.DialTimeout(path, timeout)
+}
+
+// DialContext calls DialConfig.DialContext using an empty configuration.
+func DialContext(ctx context.Context, path string) (net.Conn, error) {
+ return defaultDialer.DialContext(ctx, path)
+}
+
+type acceptResponse struct {
+ f *file
+ err error
+}
+
+type pipeListener struct {
+ firstHandle windows.Handle
+ path string
+ config ListenConfig
+ acceptCh chan chan acceptResponse
+ closeCh chan int
+ doneCh chan int
+}
+
+func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, isFirstPipe bool) (windows.Handle, error) {
+ path16, err := windows.UTF16PtrFromString(path)
+ if err != nil {
+ return 0, &os.PathError{Op: "open", Path: path, Err: err}
+ }
+
+ var oa windows.OBJECT_ATTRIBUTES
+ oa.Length = uint32(unsafe.Sizeof(oa))
+
+ var ntPath windows.NTUnicodeString
+ if err := windows.RtlDosPathNameToNtPathName(path16, &ntPath, nil, nil); err != nil {
+ if ntstatus, ok := err.(windows.NTStatus); ok {
+ err = ntstatus.Errno()
+ }
+ return 0, &os.PathError{Op: "open", Path: path, Err: err}
+ }
+ defer windows.LocalFree(windows.Handle(unsafe.Pointer(ntPath.Buffer)))
+ oa.ObjectName = &ntPath
+
+ // The security descriptor is only needed for the first pipe.
+ if isFirstPipe {
+ if sd != nil {
+ oa.SecurityDescriptor = sd
+ } else {
+ // Construct the default named pipe security descriptor.
+ var acl *windows.ACL
+ if err := windows.RtlDefaultNpAcl(&acl); err != nil {
+ return 0, err
+ }
+ defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl)))
+ sd, err = windows.NewSecurityDescriptor()
+ if err != nil {
+ return 0, err
+ }
+ if err = sd.SetDACL(acl, true, false); err != nil {
+ return 0, err
+ }
+ oa.SecurityDescriptor = sd
+ }
+ }
+
+ typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS)
+ if c.MessageMode {
+ typ |= windows.FILE_PIPE_MESSAGE_TYPE
+ }
+
+ disposition := uint32(windows.FILE_OPEN)
+ access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
+ if isFirstPipe {
+ disposition = windows.FILE_CREATE
+ // By not asking for read or write access, the named pipe file system
+ // will put this pipe into an initially disconnected state, blocking
+ // client connections until the next call with isFirstPipe == false.
+ access = windows.SYNCHRONIZE
+ }
+
+ timeout := int64(-50 * 10000) // 50ms
+
+ var (
+ h windows.Handle
+ iosb windows.IO_STATUS_BLOCK
+ )
+ err = windows.NtCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout)
+ if err != nil {
+ if ntstatus, ok := err.(windows.NTStatus); ok {
+ err = ntstatus.Errno()
+ }
+ return 0, &os.PathError{Op: "open", Path: path, Err: err}
+ }
+
+ runtime.KeepAlive(ntPath)
+ return h, nil
+}
+
+func (l *pipeListener) makeServerPipe() (*file, error) {
+ h, err := makeServerPipeHandle(l.path, nil, &l.config, false)
+ if err != nil {
+ return nil, err
+ }
+ f, err := makeFile(h)
+ if err != nil {
+ windows.Close(h)
+ return nil, err
+ }
+ return f, nil
+}
+
+func (l *pipeListener) makeConnectedServerPipe() (*file, error) {
+ p, err := l.makeServerPipe()
+ if err != nil {
+ return nil, err
+ }
+
+ // Wait for the client to connect.
+ ch := make(chan error)
+ go func(p *file) {
+ ch <- connectPipe(p)
+ }(p)
+
+ select {
+ case err = <-ch:
+ if err != nil {
+ p.Close()
+ p = nil
+ }
+ case <-l.closeCh:
+ // Abort the connect request by closing the handle.
+ p.Close()
+ p = nil
+ err = <-ch
+ if err == nil || err == os.ErrClosed {
+ err = net.ErrClosed
+ }
+ }
+ return p, err
+}
+
+func (l *pipeListener) listenerRoutine() {
+ closed := false
+ for !closed {
+ select {
+ case <-l.closeCh:
+ closed = true
+ case responseCh := <-l.acceptCh:
+ var (
+ p *file
+ err error
+ )
+ for {
+ p, err = l.makeConnectedServerPipe()
+ // If the connection was immediately closed by the client, try
+ // again.
+ if err != windows.ERROR_NO_DATA {
+ break
+ }
+ }
+ responseCh <- acceptResponse{p, err}
+ closed = err == net.ErrClosed
+ }
+ }
+ windows.Close(l.firstHandle)
+ l.firstHandle = 0
+ // Notify Close and Accept callers that the handle has been closed.
+ close(l.doneCh)
+}
+
+// ListenConfig contains configuration for the pipe listener.
+type ListenConfig struct {
+ // SecurityDescriptor contains a Windows security descriptor. If nil, the default from RtlDefaultNpAcl is used.
+ SecurityDescriptor *windows.SECURITY_DESCRIPTOR
+
+ // MessageMode determines whether the pipe is in byte or message mode. In either
+ // case the pipe is read in byte mode by default. The only practical difference in
+ // this implementation is that CloseWrite is only supported for message mode pipes;
+ // CloseWrite is implemented as a zero-byte write, but zero-byte writes are only
+ // transferred to the reader (and returned as io.EOF in this implementation)
+ // when the pipe is in message mode.
+ MessageMode bool
+
+ // InputBufferSize specifies the initial size of the input buffer, in bytes, which the OS will grow as needed.
+ InputBufferSize int32
+
+ // OutputBufferSize specifies the initial size of the output buffer, in bytes, which the OS will grow as needed.
+ OutputBufferSize int32
+}
+
+// Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe.
+// The pipe must not already exist.
+func (c *ListenConfig) Listen(path string) (net.Listener, error) {
+ h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
+ if err != nil {
+ return nil, err
+ }
+ l := &pipeListener{
+ firstHandle: h,
+ path: path,
+ config: *c,
+ acceptCh: make(chan chan acceptResponse),
+ closeCh: make(chan int),
+ doneCh: make(chan int),
+ }
+ // The first connection is swallowed on Windows 7 & 8, so synthesize it.
+ if maj, min, _ := windows.RtlGetNtVersionNumbers(); maj < 6 || (maj == 6 && min < 4) {
+ path16, err := windows.UTF16PtrFromString(path)
+ if err == nil {
+ h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
+ if err == nil {
+ windows.CloseHandle(h)
+ }
+ }
+ }
+ go l.listenerRoutine()
+ return l, nil
+}
+
+var defaultListener ListenConfig
+
+// Listen calls ListenConfig.Listen using an empty configuration.
+func Listen(path string) (net.Listener, error) {
+ return defaultListener.Listen(path)
+}
+
+func connectPipe(p *file) error {
+ c, err := p.prepareIo()
+ if err != nil {
+ return err
+ }
+ defer p.wg.Done()
+
+ err = windows.ConnectNamedPipe(p.handle, &c.o)
+ _, err = p.asyncIo(c, nil, 0, err)
+ if err != nil && err != windows.ERROR_PIPE_CONNECTED {
+ return err
+ }
+ return nil
+}
+
+func (l *pipeListener) Accept() (net.Conn, error) {
+ ch := make(chan acceptResponse)
+ select {
+ case l.acceptCh <- ch:
+ response := <-ch
+ err := response.err
+ if err != nil {
+ return nil, err
+ }
+ if l.config.MessageMode {
+ return &messageBytePipe{
+ pipe: pipe{file: response.f, path: l.path},
+ }, nil
+ }
+ return &pipe{file: response.f, path: l.path}, nil
+ case <-l.doneCh:
+ return nil, net.ErrClosed
+ }
+}
+
+func (l *pipeListener) Close() error {
+ select {
+ case l.closeCh <- 1:
+ <-l.doneCh
+ case <-l.doneCh:
+ }
+ return nil
+}
+
+func (l *pipeListener) Addr() net.Addr {
+ return pipeAddress(l.path)
+}
diff --git a/ipc/namedpipe/namedpipe_test.go b/ipc/namedpipe/namedpipe_test.go
new file mode 100644
index 0000000..998453b
--- /dev/null
+++ b/ipc/namedpipe/namedpipe_test.go
@@ -0,0 +1,674 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Copyright 2015 Microsoft
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build windows
+
+package namedpipe_test
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "errors"
+ "io"
+ "net"
+ "os"
+ "sync"
+ "syscall"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/windows"
+ "golang.zx2c4.com/wireguard/ipc/namedpipe"
+)
+
+func randomPipePath() string {
+ guid, err := windows.GenerateGUID()
+ if err != nil {
+ panic(err)
+ }
+ return `\\.\PIPE\go-namedpipe-test-` + guid.String()
+}
+
+func TestPingPong(t *testing.T) {
+ const (
+ ping = 42
+ pong = 24
+ )
+ pipePath := randomPipePath()
+ listener, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatalf("unable to listen on pipe: %v", err)
+ }
+ defer listener.Close()
+ go func() {
+ incoming, err := listener.Accept()
+ if err != nil {
+ t.Fatalf("unable to accept pipe connection: %v", err)
+ }
+ defer incoming.Close()
+ var data [1]byte
+ _, err = incoming.Read(data[:])
+ if err != nil {
+ t.Fatalf("unable to read ping from pipe: %v", err)
+ }
+ if data[0] != ping {
+ t.Fatalf("expected ping, got %d", data[0])
+ }
+ data[0] = pong
+ _, err = incoming.Write(data[:])
+ if err != nil {
+ t.Fatalf("unable to write pong to pipe: %v", err)
+ }
+ }()
+ client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
+ if err != nil {
+ t.Fatalf("unable to dial pipe: %v", err)
+ }
+ defer client.Close()
+ client.SetDeadline(time.Now().Add(time.Second * 5))
+ var data [1]byte
+ data[0] = ping
+ _, err = client.Write(data[:])
+ if err != nil {
+ t.Fatalf("unable to write ping to pipe: %v", err)
+ }
+ _, err = client.Read(data[:])
+ if err != nil {
+ t.Fatalf("unable to read pong from pipe: %v", err)
+ }
+ if data[0] != pong {
+ t.Fatalf("expected pong, got %d", data[0])
+ }
+}
+
+func TestDialUnknownFailsImmediately(t *testing.T) {
+ _, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0))
+ if !errors.Is(err, syscall.ENOENT) {
+ t.Fatalf("expected ENOENT got %v", err)
+ }
+}
+
+func TestDialListenerTimesOut(t *testing.T) {
+ pipePath := randomPipePath()
+ l, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+ pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond)
+ if err == nil {
+ pipe.Close()
+ }
+ if err != os.ErrDeadlineExceeded {
+ t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
+ }
+}
+
+func TestDialContextListenerTimesOut(t *testing.T) {
+ pipePath := randomPipePath()
+ l, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+ d := 10 * time.Millisecond
+ ctx, _ := context.WithTimeout(context.Background(), d)
+ pipe, err := namedpipe.DialContext(ctx, pipePath)
+ if err == nil {
+ pipe.Close()
+ }
+ if err != context.DeadlineExceeded {
+ t.Fatalf("expected context.DeadlineExceeded, got %v", err)
+ }
+}
+
+func TestDialListenerGetsCancelled(t *testing.T) {
+ pipePath := randomPipePath()
+ ctx, cancel := context.WithCancel(context.Background())
+ l, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+ ch := make(chan error)
+ go func(ctx context.Context, ch chan error) {
+ _, err := namedpipe.DialContext(ctx, pipePath)
+ ch <- err
+ }(ctx, ch)
+ time.Sleep(time.Millisecond * 30)
+ cancel()
+ err = <-ch
+ if err != context.Canceled {
+ t.Fatalf("expected context.Canceled, got %v", err)
+ }
+}
+
+func TestDialAccessDeniedWithRestrictedSD(t *testing.T) {
+ if windows.NewLazySystemDLL("ntdll.dll").NewProc("wine_get_version").Find() == nil {
+ t.Skip("dacls on named pipes are broken on wine")
+ }
+ pipePath := randomPipePath()
+ sd, _ := windows.SecurityDescriptorFromString("D:")
+ l, err := (&namedpipe.ListenConfig{
+ SecurityDescriptor: sd,
+ }).Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+ pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
+ if err == nil {
+ pipe.Close()
+ }
+ if !errors.Is(err, windows.ERROR_ACCESS_DENIED) {
+ t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
+ }
+}
+
+func getConnection(cfg *namedpipe.ListenConfig) (client, server net.Conn, err error) {
+ pipePath := randomPipePath()
+ if cfg == nil {
+ cfg = &namedpipe.ListenConfig{}
+ }
+ l, err := cfg.Listen(pipePath)
+ if err != nil {
+ return
+ }
+ defer l.Close()
+
+ type response struct {
+ c net.Conn
+ err error
+ }
+ ch := make(chan response)
+ go func() {
+ c, err := l.Accept()
+ ch <- response{c, err}
+ }()
+
+ c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
+ if err != nil {
+ return
+ }
+
+ r := <-ch
+ if err = r.err; err != nil {
+ c.Close()
+ return
+ }
+
+ client = c
+ server = r.c
+ return
+}
+
+func TestReadTimeout(t *testing.T) {
+ c, s, err := getConnection(nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ defer s.Close()
+
+ c.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
+
+ buf := make([]byte, 10)
+ _, err = c.Read(buf)
+ if err != os.ErrDeadlineExceeded {
+ t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
+ }
+}
+
+func server(l net.Listener, ch chan int) {
+ c, err := l.Accept()
+ if err != nil {
+ panic(err)
+ }
+ rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
+ s, err := rw.ReadString('\n')
+ if err != nil {
+ panic(err)
+ }
+ _, err = rw.WriteString("got " + s)
+ if err != nil {
+ panic(err)
+ }
+ err = rw.Flush()
+ if err != nil {
+ panic(err)
+ }
+ c.Close()
+ ch <- 1
+}
+
+func TestFullListenDialReadWrite(t *testing.T) {
+ pipePath := randomPipePath()
+ l, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+
+ ch := make(chan int)
+ go server(l, ch)
+
+ c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
+ _, err = rw.WriteString("hello world\n")
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = rw.Flush()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ s, err := rw.ReadString('\n')
+ if err != nil {
+ t.Fatal(err)
+ }
+ ms := "got hello world\n"
+ if s != ms {
+ t.Errorf("expected '%s', got '%s'", ms, s)
+ }
+
+ <-ch
+}
+
+func TestCloseAbortsListen(t *testing.T) {
+ pipePath := randomPipePath()
+ l, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ch := make(chan error)
+ go func() {
+ _, err := l.Accept()
+ ch <- err
+ }()
+
+ time.Sleep(30 * time.Millisecond)
+ l.Close()
+
+ err = <-ch
+ if err != net.ErrClosed {
+ t.Fatalf("expected net.ErrClosed, got %v", err)
+ }
+}
+
+func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) {
+ b := make([]byte, 10)
+ w.Close()
+ n, err := r.Read(b)
+ if n > 0 {
+ t.Errorf("unexpected byte count %d", n)
+ }
+ if err != io.EOF {
+ t.Errorf("expected EOF: %v", err)
+ }
+}
+
+func TestCloseClientEOFServer(t *testing.T) {
+ c, s, err := getConnection(nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ defer s.Close()
+ ensureEOFOnClose(t, c, s)
+}
+
+func TestCloseServerEOFClient(t *testing.T) {
+ c, s, err := getConnection(nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ defer s.Close()
+ ensureEOFOnClose(t, s, c)
+}
+
+func TestCloseWriteEOF(t *testing.T) {
+ cfg := &namedpipe.ListenConfig{
+ MessageMode: true,
+ }
+ c, s, err := getConnection(cfg)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ defer s.Close()
+
+ type closeWriter interface {
+ CloseWrite() error
+ }
+
+ err = c.(closeWriter).CloseWrite()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ b := make([]byte, 10)
+ _, err = s.Read(b)
+ if err != io.EOF {
+ t.Fatal(err)
+ }
+}
+
+func TestAcceptAfterCloseFails(t *testing.T) {
+ pipePath := randomPipePath()
+ l, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ l.Close()
+ _, err = l.Accept()
+ if err != net.ErrClosed {
+ t.Fatalf("expected net.ErrClosed, got %v", err)
+ }
+}
+
+func TestDialTimesOutByDefault(t *testing.T) {
+ pipePath := randomPipePath()
+ l, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+ pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds.
+ if err == nil {
+ pipe.Close()
+ }
+ if err != os.ErrDeadlineExceeded {
+ t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
+ }
+}
+
+func TestTimeoutPendingRead(t *testing.T) {
+ pipePath := randomPipePath()
+ l, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+
+ serverDone := make(chan struct{})
+
+ go func() {
+ s, err := l.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ time.Sleep(1 * time.Second)
+ s.Close()
+ close(serverDone)
+ }()
+
+ client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer client.Close()
+
+ clientErr := make(chan error)
+ go func() {
+ buf := make([]byte, 10)
+ _, err = client.Read(buf)
+ clientErr <- err
+ }()
+
+ time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline
+ client.SetReadDeadline(time.Unix(1, 0))
+
+ select {
+ case err = <-clientErr:
+ if err != os.ErrDeadlineExceeded {
+ t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
+ }
+ case <-time.After(100 * time.Millisecond):
+ t.Fatalf("timed out while waiting for read to cancel")
+ <-clientErr
+ }
+ <-serverDone
+}
+
+func TestTimeoutPendingWrite(t *testing.T) {
+ pipePath := randomPipePath()
+ l, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+
+ serverDone := make(chan struct{})
+
+ go func() {
+ s, err := l.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ time.Sleep(1 * time.Second)
+ s.Close()
+ close(serverDone)
+ }()
+
+ client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer client.Close()
+
+ clientErr := make(chan error)
+ go func() {
+ _, err = client.Write([]byte("this should timeout"))
+ clientErr <- err
+ }()
+
+ time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline
+ client.SetWriteDeadline(time.Unix(1, 0))
+
+ select {
+ case err = <-clientErr:
+ if err != os.ErrDeadlineExceeded {
+ t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
+ }
+ case <-time.After(100 * time.Millisecond):
+ t.Fatalf("timed out while waiting for write to cancel")
+ <-clientErr
+ }
+ <-serverDone
+}
+
+type CloseWriter interface {
+ CloseWrite() error
+}
+
+func TestEchoWithMessaging(t *testing.T) {
+ pipePath := randomPipePath()
+ l, err := (&namedpipe.ListenConfig{
+ MessageMode: true, // Use message mode so that CloseWrite() is supported
+ InputBufferSize: 65536, // Use 64KB buffers to improve performance
+ OutputBufferSize: 65536,
+ }).Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+
+ listenerDone := make(chan bool)
+ clientDone := make(chan bool)
+ go func() {
+ // server echo
+ conn, err := l.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+
+ time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent
+ _, err = io.Copy(conn, conn)
+ if err != nil {
+ t.Fatal(err)
+ }
+ conn.(CloseWriter).CloseWrite()
+ close(listenerDone)
+ }()
+ client, err := namedpipe.DialTimeout(pipePath, time.Second)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer client.Close()
+
+ go func() {
+ // client read back
+ bytes := make([]byte, 2)
+ n, e := client.Read(bytes)
+ if e != nil {
+ t.Fatal(e)
+ }
+ if n != 2 || bytes[0] != 0 || bytes[1] != 1 {
+ t.Fatalf("expected 2 bytes, got %v", n)
+ }
+ close(clientDone)
+ }()
+
+ payload := make([]byte, 2)
+ payload[0] = 0
+ payload[1] = 1
+
+ n, err := client.Write(payload)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != 2 {
+ t.Fatalf("expected 2 bytes, got %v", n)
+ }
+ client.(CloseWriter).CloseWrite()
+ <-listenerDone
+ <-clientDone
+}
+
+func TestConnectRace(t *testing.T) {
+ pipePath := randomPipePath()
+ l, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+ go func() {
+ for {
+ s, err := l.Accept()
+ if err == net.ErrClosed {
+ return
+ }
+
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.Close()
+ }
+ }()
+
+ for i := 0; i < 1000; i++ {
+ c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+ }
+}
+
+func TestMessageReadMode(t *testing.T) {
+ if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 {
+ t.Skipf("Skipping on Windows %d", maj)
+ }
+ var wg sync.WaitGroup
+ defer wg.Wait()
+ pipePath := randomPipePath()
+ l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+
+ msg := ([]byte)("hello world")
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ s, err := l.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = s.Write(msg)
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.Close()
+ }()
+
+ c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ mode := uint32(windows.PIPE_READMODE_MESSAGE)
+ err = windows.SetNamedPipeHandleState(c.(interface{ Handle() windows.Handle }).Handle(), &mode, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ch := make([]byte, 1)
+ var vmsg []byte
+ for {
+ n, err := c.Read(ch)
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != 1 {
+ t.Fatalf("expected 1, got %d", n)
+ }
+ vmsg = append(vmsg, ch[0])
+ }
+ if !bytes.Equal(msg, vmsg) {
+ t.Fatalf("expected %s, got %s", msg, vmsg)
+ }
+}
+
+func TestListenConnectRace(t *testing.T) {
+ if testing.Short() {
+ t.Skip("Skipping long race test")
+ }
+ pipePath := randomPipePath()
+ for i := 0; i < 50 && !t.Failed(); i++ {
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
+ if err == nil {
+ c.Close()
+ }
+ wg.Done()
+ }()
+ s, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Error(i, err)
+ } else {
+ s.Close()
+ }
+ wg.Wait()
+ }
+}
diff --git a/ipc/uapi_bsd.go b/ipc/uapi_bsd.go
index 75cc0e3..ddcaf27 100644
--- a/ipc/uapi_bsd.go
+++ b/ipc/uapi_bsd.go
@@ -1,33 +1,21 @@
-// +build darwin freebsd openbsd
+//go:build darwin || freebsd || openbsd
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc
import (
"errors"
- "fmt"
"net"
"os"
- "path"
"unsafe"
"golang.org/x/sys/unix"
)
-var socketDirectory = "/var/run/wireguard"
-
-const (
- IpcErrorIO = -int64(unix.EIO)
- IpcErrorProtocol = -int64(unix.EPROTO)
- IpcErrorInvalid = -int64(unix.EINVAL)
- IpcErrorPortInUse = -int64(unix.EADDRINUSE)
- socketName = "%s.sock"
-)
-
type UAPIListener struct {
listener net.Listener // unix socket listener
connNew chan net.Conn
@@ -66,7 +54,6 @@ func (l *UAPIListener) Addr() net.Addr {
}
func UAPIListen(name string, file *os.File) (net.Listener, error) {
-
// wrap file in listener
listener, err := net.FileListener(file)
@@ -84,10 +71,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
unixListener.SetUnlinkOnClose(true)
}
- socketPath := path.Join(
- socketDirectory,
- fmt.Sprintf(socketName, name),
- )
+ socketPath := sockPath(name)
// watch for deletion of socket
@@ -119,7 +103,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
l.connErr <- err
return
}
- if kerr != nil || n != 1 {
+ if (kerr != nil || n != 1) && kerr != unix.EINTR {
if kerr != nil {
l.connErr <- kerr
} else {
@@ -146,58 +130,3 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
return uapi, nil
}
-
-func UAPIOpen(name string) (*os.File, error) {
-
- // check if path exist
-
- err := os.MkdirAll(socketDirectory, 0755)
- if err != nil && !os.IsExist(err) {
- return nil, err
- }
-
- // open UNIX socket
-
- socketPath := path.Join(
- socketDirectory,
- fmt.Sprintf(socketName, name),
- )
-
- addr, err := net.ResolveUnixAddr("unix", socketPath)
- if err != nil {
- return nil, err
- }
-
- oldUmask := unix.Umask(0077)
- listener, err := func() (*net.UnixListener, error) {
-
- // initial connection attempt
-
- listener, err := net.ListenUnix("unix", addr)
- if err == nil {
- return listener, nil
- }
-
- // check if socket already active
-
- _, err = net.Dial("unix", socketPath)
- if err == nil {
- return nil, errors.New("unix socket in use")
- }
-
- // cleanup & attempt again
-
- err = os.Remove(socketPath)
- if err != nil {
- return nil, err
- }
- return net.ListenUnix("unix", addr)
- }()
- unix.Umask(oldUmask)
-
- if err != nil {
- return nil, err
- }
-
- return listener.File()
-}
diff --git a/ipc/uapi_linux.go b/ipc/uapi_linux.go
index a3c95ca..1562a18 100644
--- a/ipc/uapi_linux.go
+++ b/ipc/uapi_linux.go
@@ -1,31 +1,18 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc
import (
- "errors"
- "fmt"
"net"
"os"
- "path"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
)
-var socketDirectory = "/var/run/wireguard"
-
-const (
- IpcErrorIO = -int64(unix.EIO)
- IpcErrorProtocol = -int64(unix.EPROTO)
- IpcErrorInvalid = -int64(unix.EINVAL)
- IpcErrorPortInUse = -int64(unix.EADDRINUSE)
- socketName = "%s.sock"
-)
-
type UAPIListener struct {
listener net.Listener // unix socket listener
connNew chan net.Conn
@@ -64,7 +51,6 @@ func (l *UAPIListener) Addr() net.Addr {
}
func UAPIListen(name string, file *os.File) (net.Listener, error) {
-
// wrap file in listener
listener, err := net.FileListener(file)
@@ -84,10 +70,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
// watch for deletion of socket
- socketPath := path.Join(
- socketDirectory,
- fmt.Sprintf(socketName, name),
- )
+ socketPath := sockPath(name)
uapi.inotifyFd, err = unix.InotifyInit()
if err != nil {
@@ -113,14 +96,15 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
}
go func(l *UAPIListener) {
- var buff [0]byte
+ var buf [0]byte
for {
+ defer uapi.inotifyRWCancel.Close()
// start with lstat to avoid race condition
if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
l.connErr <- err
return
}
- _, err := uapi.inotifyRWCancel.Read(buff[:])
+ _, err := uapi.inotifyRWCancel.Read(buf[:])
if err != nil {
l.connErr <- err
return
@@ -143,58 +127,3 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
return uapi, nil
}
-
-func UAPIOpen(name string) (*os.File, error) {
-
- // check if path exist
-
- err := os.MkdirAll(socketDirectory, 0755)
- if err != nil && !os.IsExist(err) {
- return nil, err
- }
-
- // open UNIX socket
-
- socketPath := path.Join(
- socketDirectory,
- fmt.Sprintf(socketName, name),
- )
-
- addr, err := net.ResolveUnixAddr("unix", socketPath)
- if err != nil {
- return nil, err
- }
-
- oldUmask := unix.Umask(0077)
- listener, err := func() (*net.UnixListener, error) {
-
- // initial connection attempt
-
- listener, err := net.ListenUnix("unix", addr)
- if err == nil {
- return listener, nil
- }
-
- // check if socket already active
-
- _, err = net.Dial("unix", socketPath)
- if err == nil {
- return nil, errors.New("unix socket in use")
- }
-
- // cleanup & attempt again
-
- err = os.Remove(socketPath)
- if err != nil {
- return nil, err
- }
- return net.ListenUnix("unix", addr)
- }()
- unix.Umask(oldUmask)
-
- if err != nil {
- return nil, err
- }
-
- return listener.File()
-}
diff --git a/ipc/uapi_unix.go b/ipc/uapi_unix.go
new file mode 100644
index 0000000..e67be26
--- /dev/null
+++ b/ipc/uapi_unix.go
@@ -0,0 +1,66 @@
+//go:build linux || darwin || freebsd || openbsd
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package ipc
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "os"
+
+ "golang.org/x/sys/unix"
+)
+
+const (
+ IpcErrorIO = -int64(unix.EIO)
+ IpcErrorProtocol = -int64(unix.EPROTO)
+ IpcErrorInvalid = -int64(unix.EINVAL)
+ IpcErrorPortInUse = -int64(unix.EADDRINUSE)
+ IpcErrorUnknown = -55 // ENOANO
+)
+
+// socketDirectory is variable because it is modified by a linker
+// flag in wireguard-android.
+var socketDirectory = "/var/run/wireguard"
+
+func sockPath(iface string) string {
+ return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)
+}
+
+func UAPIOpen(name string) (*os.File, error) {
+ if err := os.MkdirAll(socketDirectory, 0o755); err != nil {
+ return nil, err
+ }
+
+ socketPath := sockPath(name)
+ addr, err := net.ResolveUnixAddr("unix", socketPath)
+ if err != nil {
+ return nil, err
+ }
+
+ oldUmask := unix.Umask(0o077)
+ defer unix.Umask(oldUmask)
+
+ listener, err := net.ListenUnix("unix", addr)
+ if err == nil {
+ return listener.File()
+ }
+
+ // Test socket, if not in use cleanup and try again.
+ if _, err := net.Dial("unix", socketPath); err == nil {
+ return nil, errors.New("unix socket in use")
+ }
+ if err := os.Remove(socketPath); err != nil {
+ return nil, err
+ }
+ listener, err = net.ListenUnix("unix", addr)
+ if err != nil {
+ return nil, err
+ }
+ return listener.File()
+}
diff --git a/ipc/uapi_wasm.go b/ipc/uapi_wasm.go
new file mode 100644
index 0000000..fa84684
--- /dev/null
+++ b/ipc/uapi_wasm.go
@@ -0,0 +1,15 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package ipc
+
+// Made up sentinel error codes for {js,wasip1}/wasm.
+const (
+ IpcErrorIO = 1
+ IpcErrorInvalid = 2
+ IpcErrorPortInUse = 3
+ IpcErrorUnknown = 4
+ IpcErrorProtocol = 5
+)
diff --git a/ipc/uapi_windows.go b/ipc/uapi_windows.go
index ead0dc5..aa023c9 100644
--- a/ipc/uapi_windows.go
+++ b/ipc/uapi_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc
@@ -9,8 +9,7 @@ import (
"net"
"golang.org/x/sys/windows"
-
- "golang.zx2c4.com/wireguard/ipc/winpipe"
+ "golang.zx2c4.com/wireguard/ipc/namedpipe"
)
// TODO: replace these with actual standard windows error numbers from the win package
@@ -19,6 +18,7 @@ const (
IpcErrorProtocol = -int64(71)
IpcErrorInvalid = -int64(22)
IpcErrorPortInUse = -int64(98)
+ IpcErrorUnknown = -int64(55)
)
type UAPIListener struct {
@@ -53,18 +53,16 @@ var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR
func init() {
var err error
- /* SDDL_DEVOBJ_SYS_ALL from the WDK */
- UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)")
+ UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)(A;;GA;;;BA)S:(ML;;NWNRNX;;;HI)")
if err != nil {
panic(err)
}
}
func UAPIListen(name string) (net.Listener, error) {
- config := winpipe.PipeConfig{
+ listener, err := (&namedpipe.ListenConfig{
SecurityDescriptor: UAPISecurityDescriptor,
- }
- listener, err := winpipe.ListenPipe(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config)
+ }).Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\` + name)
if err != nil {
return nil, err
}
diff --git a/ipc/winpipe/mksyscall.go b/ipc/winpipe/mksyscall.go
deleted file mode 100644
index 19ac03a..0000000
--- a/ipc/winpipe/mksyscall.go
+++ /dev/null
@@ -1,9 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2005 Microsoft
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
- */
-
-package winpipe
-
-//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go pipe.go file.go
diff --git a/ipc/winpipe/pipe.go b/ipc/winpipe/pipe.go
deleted file mode 100644
index 06b3037..0000000
--- a/ipc/winpipe/pipe.go
+++ /dev/null
@@ -1,509 +0,0 @@
-// +build windows
-
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2005 Microsoft
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
- */
-
-package winpipe
-
-import (
- "context"
- "errors"
- "fmt"
- "io"
- "net"
- "os"
- "runtime"
- "time"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-//sys connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) = ConnectNamedPipe
-//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateNamedPipeW
-//sys createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateFileW
-//sys getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
-//sys getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
-//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc
-//sys ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile
-//sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb
-//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U
-//sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl
-
-type ioStatusBlock struct {
- Status, Information uintptr
-}
-
-type objectAttributes struct {
- Length uintptr
- RootDirectory uintptr
- ObjectName *unicodeString
- Attributes uintptr
- SecurityDescriptor *windows.SECURITY_DESCRIPTOR
- SecurityQoS uintptr
-}
-
-type unicodeString struct {
- Length uint16
- MaximumLength uint16
- Buffer uintptr
-}
-
-type ntstatus int32
-
-func (status ntstatus) Err() error {
- if status >= 0 {
- return nil
- }
- return rtlNtStatusToDosError(status)
-}
-
-const (
- cSECURITY_SQOS_PRESENT = 0x100000
- cSECURITY_ANONYMOUS = 0
-
- cPIPE_TYPE_MESSAGE = 4
-
- cPIPE_READMODE_MESSAGE = 2
-
- cFILE_OPEN = 1
- cFILE_CREATE = 2
-
- cFILE_PIPE_MESSAGE_TYPE = 1
- cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2
-)
-
-var (
- // ErrPipeListenerClosed is returned for pipe operations on listeners that have been closed.
- // This error should match net.errClosing since docker takes a dependency on its text.
- ErrPipeListenerClosed = errors.New("use of closed network connection")
-
- errPipeWriteClosed = errors.New("pipe has been closed for write")
-)
-
-type win32Pipe struct {
- *win32File
- path string
-}
-
-type win32MessageBytePipe struct {
- win32Pipe
- writeClosed bool
- readEOF bool
-}
-
-type pipeAddress string
-
-func (f *win32Pipe) LocalAddr() net.Addr {
- return pipeAddress(f.path)
-}
-
-func (f *win32Pipe) RemoteAddr() net.Addr {
- return pipeAddress(f.path)
-}
-
-func (f *win32Pipe) SetDeadline(t time.Time) error {
- f.SetReadDeadline(t)
- f.SetWriteDeadline(t)
- return nil
-}
-
-// CloseWrite closes the write side of a message pipe in byte mode.
-func (f *win32MessageBytePipe) CloseWrite() error {
- if f.writeClosed {
- return errPipeWriteClosed
- }
- err := f.win32File.Flush()
- if err != nil {
- return err
- }
- _, err = f.win32File.Write(nil)
- if err != nil {
- return err
- }
- f.writeClosed = true
- return nil
-}
-
-// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
-// they are used to implement CloseWrite().
-func (f *win32MessageBytePipe) Write(b []byte) (int, error) {
- if f.writeClosed {
- return 0, errPipeWriteClosed
- }
- if len(b) == 0 {
- return 0, nil
- }
- return f.win32File.Write(b)
-}
-
-// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message
-// mode pipe will return io.EOF, as will all subsequent reads.
-func (f *win32MessageBytePipe) Read(b []byte) (int, error) {
- if f.readEOF {
- return 0, io.EOF
- }
- n, err := f.win32File.Read(b)
- if err == io.EOF {
- // If this was the result of a zero-byte read, then
- // it is possible that the read was due to a zero-size
- // message. Since we are simulating CloseWrite with a
- // zero-byte message, ensure that all future Read() calls
- // also return EOF.
- f.readEOF = true
- } else if err == windows.ERROR_MORE_DATA {
- // ERROR_MORE_DATA indicates that the pipe's read mode is message mode
- // and the message still has more bytes. Treat this as a success, since
- // this package presents all named pipes as byte streams.
- err = nil
- }
- return n, err
-}
-
-func (s pipeAddress) Network() string {
- return "pipe"
-}
-
-func (s pipeAddress) String() string {
- return string(s)
-}
-
-// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout.
-func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) {
- for {
- select {
- case <-ctx.Done():
- return windows.Handle(0), ctx.Err()
- default:
- h, err := createFile(*path, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0)
- if err == nil {
- return h, nil
- }
- if err != windows.ERROR_PIPE_BUSY {
- return h, &os.PathError{Err: err, Op: "open", Path: *path}
- }
- // Wait 10 msec and try again. This is a rather simplistic
- // view, as we always try each 10 milliseconds.
- time.Sleep(time.Millisecond * 10)
- }
- }
-}
-
-// DialPipe connects to a named pipe by path, timing out if the connection
-// takes longer than the specified duration. If timeout is nil, then we use
-// a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
-func DialPipe(path string, timeout *time.Duration, expectedOwner *windows.SID) (net.Conn, error) {
- var absTimeout time.Time
- if timeout != nil {
- absTimeout = time.Now().Add(*timeout)
- } else {
- absTimeout = time.Now().Add(time.Second * 2)
- }
- ctx, _ := context.WithDeadline(context.Background(), absTimeout)
- conn, err := DialPipeContext(ctx, path, expectedOwner)
- if err == context.DeadlineExceeded {
- return nil, ErrTimeout
- }
- return conn, err
-}
-
-// DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
-// cancellation or timeout.
-func DialPipeContext(ctx context.Context, path string, expectedOwner *windows.SID) (net.Conn, error) {
- var err error
- var h windows.Handle
- h, err = tryDialPipe(ctx, &path)
- if err != nil {
- return nil, err
- }
-
- if expectedOwner != nil {
- sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION)
- if err != nil {
- windows.Close(h)
- return nil, err
- }
- realOwner, _, err := sd.Owner()
- if err != nil {
- windows.Close(h)
- return nil, err
- }
- if !realOwner.Equals(expectedOwner) {
- windows.Close(h)
- return nil, windows.ERROR_ACCESS_DENIED
- }
- }
-
- var flags uint32
- err = getNamedPipeInfo(h, &flags, nil, nil, nil)
- if err != nil {
- windows.Close(h)
- return nil, err
- }
-
- f, err := makeWin32File(h)
- if err != nil {
- windows.Close(h)
- return nil, err
- }
-
- // If the pipe is in message mode, return a message byte pipe, which
- // supports CloseWrite().
- if flags&cPIPE_TYPE_MESSAGE != 0 {
- return &win32MessageBytePipe{
- win32Pipe: win32Pipe{win32File: f, path: path},
- }, nil
- }
- return &win32Pipe{win32File: f, path: path}, nil
-}
-
-type acceptResponse struct {
- f *win32File
- err error
-}
-
-type win32PipeListener struct {
- firstHandle windows.Handle
- path string
- config PipeConfig
- acceptCh chan (chan acceptResponse)
- closeCh chan int
- doneCh chan int
-}
-
-func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *PipeConfig, first bool) (windows.Handle, error) {
- path16, err := windows.UTF16FromString(path)
- if err != nil {
- return 0, &os.PathError{Op: "open", Path: path, Err: err}
- }
-
- var oa objectAttributes
- oa.Length = unsafe.Sizeof(oa)
-
- var ntPath unicodeString
- if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil {
- return 0, &os.PathError{Op: "open", Path: path, Err: err}
- }
- defer windows.LocalFree(windows.Handle(ntPath.Buffer))
- oa.ObjectName = &ntPath
-
- // The security descriptor is only needed for the first pipe.
- if first {
- if sd != nil {
- oa.SecurityDescriptor = sd
- } else {
- // Construct the default named pipe security descriptor.
- var dacl uintptr
- if err := rtlDefaultNpAcl(&dacl).Err(); err != nil {
- return 0, fmt.Errorf("getting default named pipe ACL: %s", err)
- }
- defer windows.LocalFree(windows.Handle(dacl))
- sd, err := windows.NewSecurityDescriptor()
- if err != nil {
- return 0, fmt.Errorf("creating new security descriptor: %s", err)
- }
- if err = sd.SetDACL((*windows.ACL)(unsafe.Pointer(dacl)), true, false); err != nil {
- return 0, fmt.Errorf("assigning dacl: %s", err)
- }
- sd, err = sd.ToSelfRelative()
- if err != nil {
- return 0, fmt.Errorf("converting to self-relative: %s", err)
- }
- oa.SecurityDescriptor = sd
- }
- }
-
- typ := uint32(cFILE_PIPE_REJECT_REMOTE_CLIENTS)
- if c.MessageMode {
- typ |= cFILE_PIPE_MESSAGE_TYPE
- }
-
- disposition := uint32(cFILE_OPEN)
- access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
- if first {
- disposition = cFILE_CREATE
- // By not asking for read or write access, the named pipe file system
- // will put this pipe into an initially disconnected state, blocking
- // client connections until the next call with first == false.
- access = windows.SYNCHRONIZE
- }
-
- timeout := int64(-50 * 10000) // 50ms
-
- var (
- h windows.Handle
- iosb ioStatusBlock
- )
- err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err()
- if err != nil {
- return 0, &os.PathError{Op: "open", Path: path, Err: err}
- }
-
- runtime.KeepAlive(ntPath)
- return h, nil
-}
-
-func (l *win32PipeListener) makeServerPipe() (*win32File, error) {
- h, err := makeServerPipeHandle(l.path, nil, &l.config, false)
- if err != nil {
- return nil, err
- }
- f, err := makeWin32File(h)
- if err != nil {
- windows.Close(h)
- return nil, err
- }
- return f, nil
-}
-
-func (l *win32PipeListener) makeConnectedServerPipe() (*win32File, error) {
- p, err := l.makeServerPipe()
- if err != nil {
- return nil, err
- }
-
- // Wait for the client to connect.
- ch := make(chan error)
- go func(p *win32File) {
- ch <- connectPipe(p)
- }(p)
-
- select {
- case err = <-ch:
- if err != nil {
- p.Close()
- p = nil
- }
- case <-l.closeCh:
- // Abort the connect request by closing the handle.
- p.Close()
- p = nil
- err = <-ch
- if err == nil || err == ErrFileClosed {
- err = ErrPipeListenerClosed
- }
- }
- return p, err
-}
-
-func (l *win32PipeListener) listenerRoutine() {
- closed := false
- for !closed {
- select {
- case <-l.closeCh:
- closed = true
- case responseCh := <-l.acceptCh:
- var (
- p *win32File
- err error
- )
- for {
- p, err = l.makeConnectedServerPipe()
- // If the connection was immediately closed by the client, try
- // again.
- if err != windows.ERROR_NO_DATA {
- break
- }
- }
- responseCh <- acceptResponse{p, err}
- closed = err == ErrPipeListenerClosed
- }
- }
- windows.Close(l.firstHandle)
- l.firstHandle = 0
- // Notify Close() and Accept() callers that the handle has been closed.
- close(l.doneCh)
-}
-
-// PipeConfig contain configuration for the pipe listener.
-type PipeConfig struct {
- // SecurityDescriptor contains a Windows security descriptor.
- SecurityDescriptor *windows.SECURITY_DESCRIPTOR
-
- // MessageMode determines whether the pipe is in byte or message mode. In either
- // case the pipe is read in byte mode by default. The only practical difference in
- // this implementation is that CloseWrite() is only supported for message mode pipes;
- // CloseWrite() is implemented as a zero-byte write, but zero-byte writes are only
- // transferred to the reader (and returned as io.EOF in this implementation)
- // when the pipe is in message mode.
- MessageMode bool
-
- // InputBufferSize specifies the size the input buffer, in bytes.
- InputBufferSize int32
-
- // OutputBufferSize specifies the size the input buffer, in bytes.
- OutputBufferSize int32
-}
-
-// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe.
-// The pipe must not already exist.
-func ListenPipe(path string, c *PipeConfig) (net.Listener, error) {
- if c == nil {
- c = &PipeConfig{}
- }
- h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
- if err != nil {
- return nil, err
- }
- l := &win32PipeListener{
- firstHandle: h,
- path: path,
- config: *c,
- acceptCh: make(chan (chan acceptResponse)),
- closeCh: make(chan int),
- doneCh: make(chan int),
- }
- go l.listenerRoutine()
- return l, nil
-}
-
-func connectPipe(p *win32File) error {
- c, err := p.prepareIo()
- if err != nil {
- return err
- }
- defer p.wg.Done()
-
- err = connectNamedPipe(p.handle, &c.o)
- _, err = p.asyncIo(c, nil, 0, err)
- if err != nil && err != windows.ERROR_PIPE_CONNECTED {
- return err
- }
- return nil
-}
-
-func (l *win32PipeListener) Accept() (net.Conn, error) {
- ch := make(chan acceptResponse)
- select {
- case l.acceptCh <- ch:
- response := <-ch
- err := response.err
- if err != nil {
- return nil, err
- }
- if l.config.MessageMode {
- return &win32MessageBytePipe{
- win32Pipe: win32Pipe{win32File: response.f, path: l.path},
- }, nil
- }
- return &win32Pipe{win32File: response.f, path: l.path}, nil
- case <-l.doneCh:
- return nil, ErrPipeListenerClosed
- }
-}
-
-func (l *win32PipeListener) Close() error {
- select {
- case l.closeCh <- 1:
- <-l.doneCh
- case <-l.doneCh:
- }
- return nil
-}
-
-func (l *win32PipeListener) Addr() net.Addr {
- return pipeAddress(l.path)
-}
diff --git a/ipc/winpipe/zsyscall_windows.go b/ipc/winpipe/zsyscall_windows.go
deleted file mode 100644
index 9954329..0000000
--- a/ipc/winpipe/zsyscall_windows.go
+++ /dev/null
@@ -1,238 +0,0 @@
-// Code generated by 'go generate'; DO NOT EDIT.
-
-package winpipe
-
-import (
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-var _ unsafe.Pointer
-
-// Do the interface allocations only once for common
-// Errno values.
-const (
- errnoERROR_IO_PENDING = 997
-)
-
-var (
- errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
-)
-
-// errnoErr returns common boxed Errno values, to prevent
-// allocations at runtime.
-func errnoErr(e syscall.Errno) error {
- switch e {
- case 0:
- return nil
- case errnoERROR_IO_PENDING:
- return errERROR_IO_PENDING
- }
- // TODO: add more here, after collecting data on the common
- // error values see on Windows. (perhaps when running
- // all.bat?)
- return e
-}
-
-var (
- modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
- modntdll = windows.NewLazySystemDLL("ntdll.dll")
- modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
-
- procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
- procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW")
- procCreateFileW = modkernel32.NewProc("CreateFileW")
- procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo")
- procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW")
- procLocalAlloc = modkernel32.NewProc("LocalAlloc")
- procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile")
- procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
- procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
- procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl")
- procCancelIoEx = modkernel32.NewProc("CancelIoEx")
- procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
- procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
- procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes")
- procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
-)
-
-func connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) {
- r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
- var _p0 *uint16
- _p0, err = syscall.UTF16PtrFromString(name)
- if err != nil {
- return
- }
- return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa)
-}
-
-func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
- r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0)
- handle = windows.Handle(r0)
- if handle == windows.InvalidHandle {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) {
- var _p0 *uint16
- _p0, err = syscall.UTF16PtrFromString(name)
- if err != nil {
- return
- }
- return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile)
-}
-
-func _createFile(name *uint16, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) {
- r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0)
- handle = windows.Handle(r0)
- if handle == windows.InvalidHandle {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
- r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
- r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func localAlloc(uFlags uint32, length uint32) (ptr uintptr) {
- r0, _, _ := syscall.Syscall(procLocalAlloc.Addr(), 2, uintptr(uFlags), uintptr(length), 0)
- ptr = uintptr(r0)
- return
-}
-
-func ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) {
- r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0)
- status = ntstatus(r0)
- return
-}
-
-func rtlNtStatusToDosError(status ntstatus) (winerr error) {
- r0, _, _ := syscall.Syscall(procRtlNtStatusToDosErrorNoTeb.Addr(), 1, uintptr(status), 0, 0)
- if r0 != 0 {
- winerr = syscall.Errno(r0)
- }
- return
-}
-
-func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) {
- r0, _, _ := syscall.Syscall6(procRtlDosPathNameToNtPathName_U.Addr(), 4, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved), 0, 0)
- status = ntstatus(r0)
- return
-}
-
-func rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) {
- r0, _, _ := syscall.Syscall(procRtlDefaultNpAcl.Addr(), 1, uintptr(unsafe.Pointer(dacl)), 0, 0)
- status = ntstatus(r0)
- return
-}
-
-func cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) {
- r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) {
- r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0)
- newport = windows.Handle(r0)
- if newport == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) {
- r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) {
- r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
- var _p0 uint32
- if wait {
- _p0 = 1
- } else {
- _p0 = 0
- }
- r1, _, e1 := syscall.Syscall6(procWSAGetOverlappedResult.Addr(), 5, uintptr(h), uintptr(unsafe.Pointer(o)), uintptr(unsafe.Pointer(bytes)), uintptr(_p0), uintptr(unsafe.Pointer(flags)), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
diff --git a/main.go b/main.go
index 053f488..e016116 100644
--- a/main.go
+++ b/main.go
@@ -1,8 +1,8 @@
-// +build !windows
+//go:build !windows
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main
@@ -13,8 +13,9 @@ import (
"os/signal"
"runtime"
"strconv"
- "syscall"
+ "golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
@@ -32,32 +33,33 @@ const (
)
func printUsage() {
- fmt.Printf("usage:\n")
- fmt.Printf("%s [-f/--foreground] INTERFACE-NAME\n", os.Args[0])
+ fmt.Printf("Usage: %s [-f/--foreground] INTERFACE-NAME\n", os.Args[0])
}
func warning() {
- if runtime.GOOS != "linux" || os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" {
+ switch runtime.GOOS {
+ case "linux", "freebsd", "openbsd":
+ if os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" {
+ return
+ }
+ default:
return
}
- fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
- fmt.Fprintln(os.Stderr, "W G")
- fmt.Fprintln(os.Stderr, "W You are running this software on a Linux kernel, G")
- fmt.Fprintln(os.Stderr, "W which is probably unnecessary and misguided. This G")
- fmt.Fprintln(os.Stderr, "W is because the Linux kernel has built-in first G")
- fmt.Fprintln(os.Stderr, "W class support for WireGuard, and this support is G")
- fmt.Fprintln(os.Stderr, "W much more refined than this slower userspace G")
- fmt.Fprintln(os.Stderr, "W implementation. For more information on G")
- fmt.Fprintln(os.Stderr, "W installing the kernel module, please visit: G")
- fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G")
- fmt.Fprintln(os.Stderr, "W G")
- fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
+ fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────┐")
+ fmt.Fprintln(os.Stderr, "│ │")
+ fmt.Fprintln(os.Stderr, "│ Running wireguard-go is not required because this │")
+ fmt.Fprintln(os.Stderr, "│ kernel has first class support for WireGuard. For │")
+ fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
+ fmt.Fprintln(os.Stderr, "│ please visit: │")
+ fmt.Fprintln(os.Stderr, "│ https://www.wireguard.com/install/ │")
+ fmt.Fprintln(os.Stderr, "│ │")
+ fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────┘")
}
func main() {
if len(os.Args) == 2 && os.Args[1] == "--version" {
- fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld <Jason@zx2c4.com>.\n", device.WireGuardGoVersion, runtime.GOOS, runtime.GOARCH)
+ fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld <Jason@zx2c4.com>.\n", Version, runtime.GOOS, runtime.GOARCH)
return
}
@@ -97,21 +99,19 @@ func main() {
logLevel := func() int {
switch os.Getenv("LOG_LEVEL") {
- case "debug":
- return device.LogLevelDebug
- case "info":
- return device.LogLevelInfo
+ case "verbose", "debug":
+ return device.LogLevelVerbose
case "error":
return device.LogLevelError
case "silent":
return device.LogLevelSilent
}
- return device.LogLevelInfo
+ return device.LogLevelError
}()
// open TUN device (or use supplied fd)
- tun, err := func() (tun.Device, error) {
+ tdev, err := func() (tun.Device, error) {
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
if tunFdStr == "" {
return tun.CreateTUN(interfaceName, device.DefaultMTU)
@@ -124,7 +124,7 @@ func main() {
return nil, err
}
- err = syscall.SetNonblock(int(fd), true)
+ err = unix.SetNonblock(int(fd), true)
if err != nil {
return nil, err
}
@@ -134,7 +134,7 @@ func main() {
}()
if err == nil {
- realInterfaceName, err2 := tun.Name()
+ realInterfaceName, err2 := tdev.Name()
if err2 == nil {
interfaceName = realInterfaceName
}
@@ -145,12 +145,10 @@ func main() {
fmt.Sprintf("(%s) ", interfaceName),
)
- logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion)
-
- logger.Debug.Println("Debug log enabled")
+ logger.Verbosef("Starting wireguard-go version %s", Version)
if err != nil {
- logger.Error.Println("Failed to create TUN device:", err)
+ logger.Errorf("Failed to create TUN device: %v", err)
os.Exit(ExitSetupFailed)
}
@@ -171,9 +169,8 @@ func main() {
return os.NewFile(uintptr(fd), ""), nil
}()
-
if err != nil {
- logger.Error.Println("UAPI listen error:", err)
+ logger.Errorf("UAPI listen error: %v", err)
os.Exit(ExitSetupFailed)
return
}
@@ -199,7 +196,7 @@ func main() {
files[0], // stdin
files[1], // stdout
files[2], // stderr
- tun.File(),
+ tdev.File(),
fileUAPI,
},
Dir: ".",
@@ -208,7 +205,7 @@ func main() {
path, err := os.Executable()
if err != nil {
- logger.Error.Println("Failed to determine executable:", err)
+ logger.Errorf("Failed to determine executable: %v", err)
os.Exit(ExitSetupFailed)
}
@@ -218,23 +215,23 @@ func main() {
attr,
)
if err != nil {
- logger.Error.Println("Failed to daemonize:", err)
+ logger.Errorf("Failed to daemonize: %v", err)
os.Exit(ExitSetupFailed)
}
process.Release()
return
}
- device := device.NewDevice(tun, logger)
+ device := device.NewDevice(tdev, conn.NewDefaultBind(), logger)
- logger.Info.Println("Device started")
+ logger.Verbosef("Device started")
errs := make(chan error)
term := make(chan os.Signal, 1)
uapi, err := ipc.UAPIListen(interfaceName, fileUAPI)
if err != nil {
- logger.Error.Println("Failed to listen on uapi socket:", err)
+ logger.Errorf("Failed to listen on uapi socket: %v", err)
os.Exit(ExitSetupFailed)
}
@@ -249,11 +246,11 @@ func main() {
}
}()
- logger.Info.Println("UAPI listener started")
+ logger.Verbosef("UAPI listener started")
// wait for program to terminate
- signal.Notify(term, syscall.SIGTERM)
+ signal.Notify(term, unix.SIGTERM)
signal.Notify(term, os.Interrupt)
select {
@@ -267,5 +264,5 @@ func main() {
uapi.Close()
device.Close()
- logger.Info.Println("Shutting down")
+ logger.Verbosef("Shutting down")
}
diff --git a/main_windows.go b/main_windows.go
index f57bc8d..a4dc46f 100644
--- a/main_windows.go
+++ b/main_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main
@@ -9,8 +9,10 @@ import (
"fmt"
"os"
"os/signal"
- "syscall"
+ "golang.org/x/sys/windows"
+
+ "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
@@ -31,11 +33,10 @@ func main() {
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real WireGuard for Windows client, the repo you want is <https://git.zx2c4.com/wireguard-windows/>, which includes this code as a module.")
logger := device.NewLogger(
- device.LogLevelDebug,
+ device.LogLevelVerbose,
fmt.Sprintf("(%s) ", interfaceName),
)
- logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion)
- logger.Debug.Println("Debug log enabled")
+ logger.Verbosef("Starting wireguard-go version %s", Version)
tun, err := tun.CreateTUN(interfaceName, 0)
if err == nil {
@@ -44,17 +45,21 @@ func main() {
interfaceName = realInterfaceName
}
} else {
- logger.Error.Println("Failed to create TUN device:", err)
+ logger.Errorf("Failed to create TUN device: %v", err)
os.Exit(ExitSetupFailed)
}
- device := device.NewDevice(tun, logger)
- device.Up()
- logger.Info.Println("Device started")
+ device := device.NewDevice(tun, conn.NewDefaultBind(), logger)
+ err = device.Up()
+ if err != nil {
+ logger.Errorf("Failed to bring up device: %v", err)
+ os.Exit(ExitSetupFailed)
+ }
+ logger.Verbosef("Device started")
uapi, err := ipc.UAPIListen(interfaceName)
if err != nil {
- logger.Error.Println("Failed to listen on uapi socket:", err)
+ logger.Errorf("Failed to listen on uapi socket: %v", err)
os.Exit(ExitSetupFailed)
}
@@ -71,13 +76,13 @@ func main() {
go device.IpcHandle(conn)
}
}()
- logger.Info.Println("UAPI listener started")
+ logger.Verbosef("UAPI listener started")
// wait for program to terminate
signal.Notify(term, os.Interrupt)
signal.Notify(term, os.Kill)
- signal.Notify(term, syscall.SIGTERM)
+ signal.Notify(term, windows.SIGTERM)
select {
case <-term:
@@ -90,5 +95,5 @@ func main() {
uapi.Close()
device.Close()
- logger.Info.Println("Shutting down")
+ logger.Verbosef("Shutting down")
}
diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go
index 772c45a..f7d05ef 100644
--- a/ratelimiter/ratelimiter.go
+++ b/ratelimiter/ratelimiter.go
@@ -1,12 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ratelimiter
import (
- "net"
+ "net/netip"
"sync"
"time"
)
@@ -20,21 +20,22 @@ const (
)
type RatelimiterEntry struct {
- sync.Mutex
+ mu sync.Mutex
lastTime time.Time
tokens int64
}
type Ratelimiter struct {
- sync.RWMutex
- stopReset chan struct{}
- tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
- tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
+ mu sync.RWMutex
+ timeNow func() time.Time
+
+ stopReset chan struct{} // send to reset, close to stop
+ table map[netip.Addr]*RatelimiterEntry
}
func (rate *Ratelimiter) Close() {
- rate.Lock()
- defer rate.Unlock()
+ rate.mu.Lock()
+ defer rate.mu.Unlock()
if rate.stopReset != nil {
close(rate.stopReset)
@@ -42,111 +43,83 @@ func (rate *Ratelimiter) Close() {
}
func (rate *Ratelimiter) Init() {
- rate.Lock()
- defer rate.Unlock()
+ rate.mu.Lock()
+ defer rate.mu.Unlock()
- // stop any ongoing garbage collection routine
+ if rate.timeNow == nil {
+ rate.timeNow = time.Now
+ }
+ // stop any ongoing garbage collection routine
if rate.stopReset != nil {
close(rate.stopReset)
}
rate.stopReset = make(chan struct{})
- rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
- rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
+ rate.table = make(map[netip.Addr]*RatelimiterEntry)
- // start garbage collection routine
+ stopReset := rate.stopReset // store in case Init is called again.
+ // Start garbage collection routine.
go func() {
ticker := time.NewTicker(time.Second)
ticker.Stop()
for {
select {
- case _, ok := <-rate.stopReset:
+ case _, ok := <-stopReset:
ticker.Stop()
- if ok {
- ticker = time.NewTicker(time.Second)
- } else {
+ if !ok {
return
}
+ ticker = time.NewTicker(time.Second)
case <-ticker.C:
- func() {
- rate.Lock()
- defer rate.Unlock()
-
- for key, entry := range rate.tableIPv4 {
- entry.Lock()
- if time.Since(entry.lastTime) > garbageCollectTime {
- delete(rate.tableIPv4, key)
- }
- entry.Unlock()
- }
-
- for key, entry := range rate.tableIPv6 {
- entry.Lock()
- if time.Since(entry.lastTime) > garbageCollectTime {
- delete(rate.tableIPv6, key)
- }
- entry.Unlock()
- }
-
- if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 {
- ticker.Stop()
- }
- }()
+ if rate.cleanup() {
+ ticker.Stop()
+ }
}
}
}()
}
-func (rate *Ratelimiter) Allow(ip net.IP) bool {
- var entry *RatelimiterEntry
- var keyIPv4 [net.IPv4len]byte
- var keyIPv6 [net.IPv6len]byte
-
- // lookup entry
-
- IPv4 := ip.To4()
- IPv6 := ip.To16()
-
- rate.RLock()
+func (rate *Ratelimiter) cleanup() (empty bool) {
+ rate.mu.Lock()
+ defer rate.mu.Unlock()
- if IPv4 != nil {
- copy(keyIPv4[:], IPv4)
- entry = rate.tableIPv4[keyIPv4]
- } else {
- copy(keyIPv6[:], IPv6)
- entry = rate.tableIPv6[keyIPv6]
+ for key, entry := range rate.table {
+ entry.mu.Lock()
+ if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
+ delete(rate.table, key)
+ }
+ entry.mu.Unlock()
}
- rate.RUnlock()
+ return len(rate.table) == 0
+}
- // make new entry if not found
+func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
+ var entry *RatelimiterEntry
+ // lookup entry
+ rate.mu.RLock()
+ entry = rate.table[ip]
+ rate.mu.RUnlock()
+ // make new entry if not found
if entry == nil {
entry = new(RatelimiterEntry)
entry.tokens = maxTokens - packetCost
- entry.lastTime = time.Now()
- rate.Lock()
- if IPv4 != nil {
- rate.tableIPv4[keyIPv4] = entry
- if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
- rate.stopReset <- struct{}{}
- }
- } else {
- rate.tableIPv6[keyIPv6] = entry
- if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
- rate.stopReset <- struct{}{}
- }
+ entry.lastTime = rate.timeNow()
+ rate.mu.Lock()
+ rate.table[ip] = entry
+ if len(rate.table) == 1 {
+ rate.stopReset <- struct{}{}
}
- rate.Unlock()
+ rate.mu.Unlock()
return true
}
// add tokens to entry
-
- entry.Lock()
- now := time.Now()
+ entry.mu.Lock()
+ now := rate.timeNow()
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
entry.lastTime = now
if entry.tokens > maxTokens {
@@ -154,12 +127,11 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
}
// subtract cost of packet
-
if entry.tokens > packetCost {
entry.tokens -= packetCost
- entry.Unlock()
+ entry.mu.Unlock()
return true
}
- entry.Unlock()
+ entry.mu.Unlock()
return false
}
diff --git a/ratelimiter/ratelimiter_test.go b/ratelimiter/ratelimiter_test.go
index a18a097..0bfa3af 100644
--- a/ratelimiter/ratelimiter_test.go
+++ b/ratelimiter/ratelimiter_test.go
@@ -1,32 +1,31 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ratelimiter
import (
- "net"
+ "net/netip"
"testing"
"time"
)
-type RatelimiterResult struct {
+type result struct {
allowed bool
text string
wait time.Duration
}
func TestRatelimiter(t *testing.T) {
+ var rate Ratelimiter
+ var expectedResults []result
- var ratelimiter Ratelimiter
- var expectedResults []RatelimiterResult
-
- Nano := func(nano int64) time.Duration {
+ nano := func(nano int64) time.Duration {
return time.Nanosecond * time.Duration(nano)
}
- Add := func(res RatelimiterResult) {
+ add := func(res result) {
expectedResults = append(
expectedResults,
res,
@@ -34,69 +33,86 @@ func TestRatelimiter(t *testing.T) {
}
for i := 0; i < packetsBurstable; i++ {
- Add(RatelimiterResult{
+ add(result{
allowed: true,
- text: "inital burst",
+ text: "initial burst",
})
}
- Add(RatelimiterResult{
+ add(result{
allowed: false,
text: "after burst",
})
- Add(RatelimiterResult{
+ add(result{
allowed: true,
- wait: Nano(time.Second.Nanoseconds() / packetsPerSecond),
+ wait: nano(time.Second.Nanoseconds() / packetsPerSecond),
text: "filling tokens for single packet",
})
- Add(RatelimiterResult{
+ add(result{
allowed: false,
text: "not having refilled enough",
})
- Add(RatelimiterResult{
+ add(result{
allowed: true,
- wait: 2 * (Nano(time.Second.Nanoseconds() / packetsPerSecond)),
+ wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)),
text: "filling tokens for two packet burst",
})
- Add(RatelimiterResult{
+ add(result{
allowed: true,
text: "second packet in 2 packet burst",
})
- Add(RatelimiterResult{
+ add(result{
allowed: false,
text: "packet following 2 packet burst",
})
- ips := []net.IP{
- net.ParseIP("127.0.0.1"),
- net.ParseIP("192.168.1.1"),
- net.ParseIP("172.167.2.3"),
- net.ParseIP("97.231.252.215"),
- net.ParseIP("248.97.91.167"),
- net.ParseIP("188.208.233.47"),
- net.ParseIP("104.2.183.179"),
- net.ParseIP("72.129.46.120"),
- net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
- net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
- net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
- net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
- net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
- net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
+ ips := []netip.Addr{
+ netip.MustParseAddr("127.0.0.1"),
+ netip.MustParseAddr("192.168.1.1"),
+ netip.MustParseAddr("172.167.2.3"),
+ netip.MustParseAddr("97.231.252.215"),
+ netip.MustParseAddr("248.97.91.167"),
+ netip.MustParseAddr("188.208.233.47"),
+ netip.MustParseAddr("104.2.183.179"),
+ netip.MustParseAddr("72.129.46.120"),
+ netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
+ netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
+ netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
+ netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
+ netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
+ netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
+ }
+
+ now := time.Now()
+ rate.timeNow = func() time.Time {
+ return now
+ }
+ defer func() {
+ // Lock to avoid data race with cleanup goroutine from Init.
+ rate.mu.Lock()
+ defer rate.mu.Unlock()
+
+ rate.timeNow = time.Now
+ }()
+ timeSleep := func(d time.Duration) {
+ now = now.Add(d + 1)
+ rate.cleanup()
}
- ratelimiter.Init()
+ rate.Init()
+ defer rate.Close()
for i, res := range expectedResults {
- time.Sleep(res.wait)
+ timeSleep(res.wait)
for _, ip := range ips {
- allowed := ratelimiter.Allow(ip)
+ allowed := rate.Allow(ip)
if allowed != res.allowed {
- t.Fatal("Test failed for", ip.String(), ", on:", i, "(", res.text, ")", "expected:", res.allowed, "got:", allowed)
+ t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed)
}
}
}
diff --git a/replay/replay.go b/replay/replay.go
index 0f6b6c9..8b99e23 100644
--- a/replay/replay.go
+++ b/replay/replay.go
@@ -1,83 +1,62 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
+// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
package replay
-/* Implementation of RFC6479
- * https://tools.ietf.org/html/rfc6479
- *
- * The implementation is not safe for concurrent use!
- */
-
-const (
- // See: https://golang.org/src/math/big/arith.go
- _Wordm = ^uintptr(0)
- _WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1
- _WordSize = 1 << _WordLogSize
-)
+type block uint64
const (
- CounterRedundantBitsLog = _WordLogSize + 3
- CounterRedundantBits = _WordSize * 8
- CounterBitsTotal = 2048
- CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits)
+ blockBitLog = 6 // 1<<6 == 64 bits
+ blockBits = 1 << blockBitLog // must be power of 2
+ ringBlocks = 1 << 7 // must be power of 2
+ windowSize = (ringBlocks - 1) * blockBits
+ blockMask = ringBlocks - 1
+ bitMask = blockBits - 1
)
-const (
- BacktrackWords = CounterBitsTotal / _WordSize
-)
-
-func minUint64(a uint64, b uint64) uint64 {
- if a > b {
- return b
- }
- return a
-}
-
-type ReplayFilter struct {
- counter uint64
- backtrack [BacktrackWords]uintptr
+// A Filter rejects replayed messages by checking if message counter value is
+// within a sliding window of previously received messages.
+// The zero value for Filter is an empty filter ready to use.
+// Filters are unsafe for concurrent use.
+type Filter struct {
+ last uint64
+ ring [ringBlocks]block
}
-func (filter *ReplayFilter) Init() {
- filter.counter = 0
- filter.backtrack[0] = 0
+// Reset resets the filter to empty state.
+func (f *Filter) Reset() {
+ f.last = 0
+ f.ring[0] = 0
}
-func (filter *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool {
+// ValidateCounter checks if the counter should be accepted.
+// Overlimit counters (>= limit) are always rejected.
+func (f *Filter) ValidateCounter(counter, limit uint64) bool {
if counter >= limit {
return false
}
-
- indexWord := counter >> CounterRedundantBitsLog
-
- if counter > filter.counter {
-
- // move window forward
-
- current := filter.counter >> CounterRedundantBitsLog
- diff := minUint64(indexWord-current, BacktrackWords)
- for i := uint64(1); i <= diff; i++ {
- filter.backtrack[(current+i)%BacktrackWords] = 0
+ indexBlock := counter >> blockBitLog
+ if counter > f.last { // move window forward
+ current := f.last >> blockBitLog
+ diff := indexBlock - current
+ if diff > ringBlocks {
+ diff = ringBlocks // cap diff to clear the whole ring
}
- filter.counter = counter
-
- } else if filter.counter-counter > CounterWindowSize {
-
- // behind current window
-
+ for i := current + 1; i <= current+diff; i++ {
+ f.ring[i&blockMask] = 0
+ }
+ f.last = counter
+ } else if f.last-counter > windowSize { // behind current window
return false
}
-
- indexWord %= BacktrackWords
- indexBit := counter & uint64(CounterRedundantBits-1)
-
// check and set bit
-
- oldValue := filter.backtrack[indexWord]
- newValue := oldValue | (1 << indexBit)
- filter.backtrack[indexWord] = newValue
- return oldValue != newValue
+ indexBlock &= blockMask
+ indexBit := counter & bitMask
+ old := f.ring[indexBlock]
+ new := old | 1<<indexBit
+ f.ring[indexBlock] = new
+ return old != new
}
diff --git a/replay/replay_test.go b/replay/replay_test.go
index 5365f10..9a9e4a8 100644
--- a/replay/replay_test.go
+++ b/replay/replay_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package replay
@@ -14,22 +14,22 @@ import (
*
*/
-const RejectAfterMessages = (1 << 64) - (1 << 4) - 1
+const RejectAfterMessages = 1<<64 - 1<<13 - 1
func TestReplay(t *testing.T) {
- var filter ReplayFilter
+ var filter Filter
- T_LIM := CounterWindowSize + 1
+ const T_LIM = windowSize + 1
testNumber := 0
- T := func(n uint64, v bool) {
+ T := func(n uint64, expected bool) {
testNumber++
- if filter.ValidateCounter(n, RejectAfterMessages) != v {
- t.Fatal("Test", testNumber, "failed", n, v)
+ if filter.ValidateCounter(n, RejectAfterMessages) != expected {
+ t.Fatal("Test", testNumber, "failed", n, expected)
}
}
- filter.Init()
+ filter.Reset()
T(0, true) /* 1 */
T(1, true) /* 2 */
@@ -67,53 +67,53 @@ func TestReplay(t *testing.T) {
T(0, false) /* 34 */
t.Log("Bulk test 1")
- filter.Init()
+ filter.Reset()
testNumber = 0
- for i := uint64(1); i <= CounterWindowSize; i++ {
+ for i := uint64(1); i <= windowSize; i++ {
T(i, true)
}
T(0, true)
T(0, false)
t.Log("Bulk test 2")
- filter.Init()
+ filter.Reset()
testNumber = 0
- for i := uint64(2); i <= CounterWindowSize+1; i++ {
+ for i := uint64(2); i <= windowSize+1; i++ {
T(i, true)
}
T(1, true)
T(0, false)
t.Log("Bulk test 3")
- filter.Init()
+ filter.Reset()
testNumber = 0
- for i := CounterWindowSize + 1; i > 0; i-- {
+ for i := uint64(windowSize + 1); i > 0; i-- {
T(i, true)
}
t.Log("Bulk test 4")
- filter.Init()
+ filter.Reset()
testNumber = 0
- for i := CounterWindowSize + 2; i > 1; i-- {
+ for i := uint64(windowSize + 2); i > 1; i-- {
T(i, true)
}
T(0, false)
t.Log("Bulk test 5")
- filter.Init()
+ filter.Reset()
testNumber = 0
- for i := CounterWindowSize; i > 0; i-- {
+ for i := uint64(windowSize); i > 0; i-- {
T(i, true)
}
- T(CounterWindowSize+1, true)
+ T(windowSize+1, true)
T(0, false)
t.Log("Bulk test 6")
- filter.Init()
+ filter.Reset()
testNumber = 0
- for i := CounterWindowSize; i > 0; i-- {
+ for i := uint64(windowSize); i > 0; i-- {
T(i, true)
}
T(0, true)
- T(CounterWindowSize+1, true)
+ T(windowSize+1, true)
}
diff --git a/rwcancel/fdset.go b/rwcancel/fdset.go
deleted file mode 100644
index 28746e6..0000000
--- a/rwcancel/fdset.go
+++ /dev/null
@@ -1,22 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
- */
-
-package rwcancel
-
-import "golang.org/x/sys/unix"
-
-type fdSet struct {
- unix.FdSet
-}
-
-func (fdset *fdSet) set(i int) {
- bits := 32 << (^uint(0) >> 63)
- fdset.Bits[i/bits] |= 1 << uint(i%bits)
-}
-
-func (fdset *fdSet) check(i int) bool {
- bits := 32 << (^uint(0) >> 63)
- return (fdset.Bits[i/bits] & (1 << uint(i%bits))) != 0
-}
diff --git a/rwcancel/rwcancel.go b/rwcancel/rwcancel.go
index 808e691..e397c0e 100644
--- a/rwcancel/rwcancel.go
+++ b/rwcancel/rwcancel.go
@@ -1,8 +1,12 @@
+//go:build !windows && !wasm
+
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
+// Package rwcancel implements cancelable read/write operations on
+// a file descriptor.
package rwcancel
import (
@@ -13,13 +17,6 @@ import (
"golang.org/x/sys/unix"
)
-func max(a, b int) int {
- if a > b {
- return a
- }
- return b
-}
-
type RWCancel struct {
fd int
closingReader *os.File
@@ -42,27 +39,16 @@ func NewRWCancel(fd int) (*RWCancel, error) {
}
func RetryAfterError(err error) bool {
- if pe, ok := err.(*os.PathError); ok {
- err = pe.Err
- }
- if errno, ok := err.(syscall.Errno); ok {
- switch errno {
- case syscall.EAGAIN, syscall.EINTR:
- return true
- }
-
- }
- return false
+ return errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EINTR)
}
func (rw *RWCancel) ReadyRead() bool {
- closeFd := int(rw.closingReader.Fd())
- fdset := fdSet{}
- fdset.set(rw.fd)
- fdset.set(closeFd)
+ closeFd := int32(rw.closingReader.Fd())
+
+ pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLIN}, {Fd: closeFd, Events: unix.POLLIN}}
var err error
for {
- err = unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil)
+ _, err = unix.Poll(pollFds, -1)
if err == nil || !RetryAfterError(err) {
break
}
@@ -70,20 +56,18 @@ func (rw *RWCancel) ReadyRead() bool {
if err != nil {
return false
}
- if fdset.check(closeFd) {
+ if pollFds[1].Revents != 0 {
return false
}
- return fdset.check(rw.fd)
+ return pollFds[0].Revents != 0
}
func (rw *RWCancel) ReadyWrite() bool {
- closeFd := int(rw.closingReader.Fd())
- fdset := fdSet{}
- fdset.set(rw.fd)
- fdset.set(closeFd)
+ closeFd := int32(rw.closingReader.Fd())
+ pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}}
var err error
for {
- err = unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil)
+ _, err = unix.Poll(pollFds, -1)
if err == nil || !RetryAfterError(err) {
break
}
@@ -91,10 +75,11 @@ func (rw *RWCancel) ReadyWrite() bool {
if err != nil {
return false
}
- if fdset.check(closeFd) {
+
+ if pollFds[1].Revents != 0 {
return false
}
- return fdset.check(rw.fd)
+ return pollFds[0].Revents != 0
}
func (rw *RWCancel) Read(p []byte) (n int, err error) {
@@ -104,7 +89,7 @@ func (rw *RWCancel) Read(p []byte) (n int, err error) {
return n, err
}
if !rw.ReadyRead() {
- return 0, errors.New("fd closed")
+ return 0, os.ErrClosed
}
}
}
@@ -116,7 +101,7 @@ func (rw *RWCancel) Write(p []byte) (n int, err error) {
return n, err
}
if !rw.ReadyWrite() {
- return 0, errors.New("fd closed")
+ return 0, os.ErrClosed
}
}
}
@@ -125,3 +110,8 @@ func (rw *RWCancel) Cancel() (err error) {
_, err = rw.closingWriter.Write([]byte{0})
return
}
+
+func (rw *RWCancel) Close() {
+ rw.closingReader.Close()
+ rw.closingWriter.Close()
+}
diff --git a/rwcancel/rwcancel_stub.go b/rwcancel/rwcancel_stub.go
new file mode 100644
index 0000000..2a98b2b
--- /dev/null
+++ b/rwcancel/rwcancel_stub.go
@@ -0,0 +1,9 @@
+//go:build windows || wasm
+
+// SPDX-License-Identifier: MIT
+
+package rwcancel
+
+type RWCancel struct{}
+
+func (*RWCancel) Cancel() {}
diff --git a/rwcancel/select_default.go b/rwcancel/select_default.go
deleted file mode 100644
index dd23cda..0000000
--- a/rwcancel/select_default.go
+++ /dev/null
@@ -1,14 +0,0 @@
-// +build !linux
-
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
- */
-
-package rwcancel
-
-import "golang.org/x/sys/unix"
-
-func unixSelect(nfd int, r *unix.FdSet, w *unix.FdSet, e *unix.FdSet, timeout *unix.Timeval) error {
- return unix.Select(nfd, r, w, e, timeout)
-}
diff --git a/rwcancel/select_linux.go b/rwcancel/select_linux.go
deleted file mode 100644
index 1a72e0a..0000000
--- a/rwcancel/select_linux.go
+++ /dev/null
@@ -1,13 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
- */
-
-package rwcancel
-
-import "golang.org/x/sys/unix"
-
-func unixSelect(nfd int, r *unix.FdSet, w *unix.FdSet, e *unix.FdSet, timeout *unix.Timeval) (err error) {
- _, err = unix.Select(nfd, r, w, e, timeout)
- return
-}
diff --git a/tai64n/tai64n.go b/tai64n/tai64n.go
index 565aaa4..8f10b39 100644
--- a/tai64n/tai64n.go
+++ b/tai64n/tai64n.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tai64n
@@ -11,22 +11,31 @@ import (
"time"
)
-const TimestampSize = 12
-const base = uint64(0x400000000000000a)
-const whitenerMask = uint32(0x1000000 - 1)
+const (
+ TimestampSize = 12
+ base = uint64(0x400000000000000a)
+ whitenerMask = uint32(0x1000000 - 1)
+)
type Timestamp [TimestampSize]byte
-func Now() Timestamp {
+func stamp(t time.Time) Timestamp {
var tai64n Timestamp
- now := time.Now()
- secs := base + uint64(now.Unix())
- nano := uint32(now.Nanosecond()) &^ whitenerMask
+ secs := base + uint64(t.Unix())
+ nano := uint32(t.Nanosecond()) &^ whitenerMask
binary.BigEndian.PutUint64(tai64n[:], secs)
binary.BigEndian.PutUint32(tai64n[8:], nano)
return tai64n
}
+func Now() Timestamp {
+ return stamp(time.Now())
+}
+
func (t1 Timestamp) After(t2 Timestamp) bool {
return bytes.Compare(t1[:], t2[:]) > 0
}
+
+func (t Timestamp) String() string {
+ return time.Unix(int64(binary.BigEndian.Uint64(t[:8])-base), int64(binary.BigEndian.Uint32(t[8:12]))).String()
+}
diff --git a/tai64n/tai64n_test.go b/tai64n/tai64n_test.go
index 859660f..c70fc1a 100644
--- a/tai64n/tai64n_test.go
+++ b/tai64n/tai64n_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tai64n
@@ -10,21 +10,31 @@ import (
"time"
)
-/* Testing the essential property of the timestamp
- * as used by WireGuard.
- */
+// Test that timestamps are monotonic as required by Wireguard and that
+// nanosecond-level information is whitened to prevent side channel attacks.
func TestMonotonic(t *testing.T) {
- old := Now()
- for i := 0; i < 50; i++ {
- next := Now()
- if next.After(old) {
- t.Error("Whitening insufficient")
- }
- time.Sleep(time.Duration(whitenerMask)/time.Nanosecond + 1)
- next = Now()
- if !next.After(old) {
- t.Error("Not monotonically increasing on whitened nano-second scale")
- }
- old = next
+ startTime := time.Unix(0, 123456789) // a nontrivial bit pattern
+ // Whitening should reduce timestamp granularity
+ // to more than 10 but fewer than 20 milliseconds.
+ tests := []struct {
+ name string
+ t1, t2 time.Time
+ wantAfter bool
+ }{
+ {"after_10_ns", startTime, startTime.Add(10 * time.Nanosecond), false},
+ {"after_10_us", startTime, startTime.Add(10 * time.Microsecond), false},
+ {"after_1_ms", startTime, startTime.Add(time.Millisecond), false},
+ {"after_10_ms", startTime, startTime.Add(10 * time.Millisecond), false},
+ {"after_20_ms", startTime, startTime.Add(20 * time.Millisecond), true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ts1, ts2 := stamp(tt.t1), stamp(tt.t2)
+ got := ts2.After(ts1)
+ if got != tt.wantAfter {
+ t.Errorf("after = %v; want %v", got, tt.wantAfter)
+ }
+ })
}
}
diff --git a/tests/netns.sh b/tests/netns.sh
index 02d428b..2f2a2cd 100755
--- a/tests/netns.sh
+++ b/tests/netns.sh
@@ -36,7 +36,7 @@ netns0="wg-test-$$-0"
netns1="wg-test-$$-1"
netns2="wg-test-$$-2"
program=$1
-export LOG_LEVEL="info"
+export LOG_LEVEL="verbose"
pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; }
pp() { pretty "" "$*"; "$@"; }
diff --git a/tun/alignment_windows_test.go b/tun/alignment_windows_test.go
new file mode 100644
index 0000000..67a785e
--- /dev/null
+++ b/tun/alignment_windows_test.go
@@ -0,0 +1,67 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+import (
+ "reflect"
+ "testing"
+ "unsafe"
+)
+
+func checkAlignment(t *testing.T, name string, offset uintptr) {
+ t.Helper()
+ if offset%8 != 0 {
+ t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8))
+ }
+}
+
+// TestRateJugglerAlignment checks that atomically-accessed fields are
+// aligned to 64-bit boundaries, as required by the atomic package.
+//
+// Unfortunately, violating this rule on 32-bit platforms results in a
+// hard segfault at runtime.
+func TestRateJugglerAlignment(t *testing.T) {
+ var r rateJuggler
+
+ typ := reflect.TypeOf(&r).Elem()
+ t.Logf("Peer type size: %d, with fields:", typ.Size())
+ for i := 0; i < typ.NumField(); i++ {
+ field := typ.Field(i)
+ t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
+ field.Name,
+ field.Offset,
+ field.Type.Size(),
+ field.Type.Align(),
+ )
+ }
+
+ checkAlignment(t, "rateJuggler.current", unsafe.Offsetof(r.current))
+ checkAlignment(t, "rateJuggler.nextByteCount", unsafe.Offsetof(r.nextByteCount))
+ checkAlignment(t, "rateJuggler.nextStartTime", unsafe.Offsetof(r.nextStartTime))
+}
+
+// TestNativeTunAlignment checks that atomically-accessed fields are
+// aligned to 64-bit boundaries, as required by the atomic package.
+//
+// Unfortunately, violating this rule on 32-bit platforms results in a
+// hard segfault at runtime.
+func TestNativeTunAlignment(t *testing.T) {
+ var tun NativeTun
+
+ typ := reflect.TypeOf(&tun).Elem()
+ t.Logf("Peer type size: %d, with fields:", typ.Size())
+ for i := 0; i < typ.NumField(); i++ {
+ field := typ.Field(i)
+ t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
+ field.Name,
+ field.Offset,
+ field.Type.Size(),
+ field.Type.Align(),
+ )
+ }
+
+ checkAlignment(t, "NativeTun.rate", unsafe.Offsetof(tun.rate))
+}
diff --git a/tun/checksum.go b/tun/checksum.go
new file mode 100644
index 0000000..29a8fc8
--- /dev/null
+++ b/tun/checksum.go
@@ -0,0 +1,118 @@
+package tun
+
+import "encoding/binary"
+
+// TODO: Explore SIMD and/or other assembly optimizations.
+// TODO: Test native endian loads. See RFC 1071 section 2 part B.
+func checksumNoFold(b []byte, initial uint64) uint64 {
+ ac := initial
+
+ for len(b) >= 128 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+ ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+ ac += uint64(binary.BigEndian.Uint32(b[16:20]))
+ ac += uint64(binary.BigEndian.Uint32(b[20:24]))
+ ac += uint64(binary.BigEndian.Uint32(b[24:28]))
+ ac += uint64(binary.BigEndian.Uint32(b[28:32]))
+ ac += uint64(binary.BigEndian.Uint32(b[32:36]))
+ ac += uint64(binary.BigEndian.Uint32(b[36:40]))
+ ac += uint64(binary.BigEndian.Uint32(b[40:44]))
+ ac += uint64(binary.BigEndian.Uint32(b[44:48]))
+ ac += uint64(binary.BigEndian.Uint32(b[48:52]))
+ ac += uint64(binary.BigEndian.Uint32(b[52:56]))
+ ac += uint64(binary.BigEndian.Uint32(b[56:60]))
+ ac += uint64(binary.BigEndian.Uint32(b[60:64]))
+ ac += uint64(binary.BigEndian.Uint32(b[64:68]))
+ ac += uint64(binary.BigEndian.Uint32(b[68:72]))
+ ac += uint64(binary.BigEndian.Uint32(b[72:76]))
+ ac += uint64(binary.BigEndian.Uint32(b[76:80]))
+ ac += uint64(binary.BigEndian.Uint32(b[80:84]))
+ ac += uint64(binary.BigEndian.Uint32(b[84:88]))
+ ac += uint64(binary.BigEndian.Uint32(b[88:92]))
+ ac += uint64(binary.BigEndian.Uint32(b[92:96]))
+ ac += uint64(binary.BigEndian.Uint32(b[96:100]))
+ ac += uint64(binary.BigEndian.Uint32(b[100:104]))
+ ac += uint64(binary.BigEndian.Uint32(b[104:108]))
+ ac += uint64(binary.BigEndian.Uint32(b[108:112]))
+ ac += uint64(binary.BigEndian.Uint32(b[112:116]))
+ ac += uint64(binary.BigEndian.Uint32(b[116:120]))
+ ac += uint64(binary.BigEndian.Uint32(b[120:124]))
+ ac += uint64(binary.BigEndian.Uint32(b[124:128]))
+ b = b[128:]
+ }
+ if len(b) >= 64 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+ ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+ ac += uint64(binary.BigEndian.Uint32(b[16:20]))
+ ac += uint64(binary.BigEndian.Uint32(b[20:24]))
+ ac += uint64(binary.BigEndian.Uint32(b[24:28]))
+ ac += uint64(binary.BigEndian.Uint32(b[28:32]))
+ ac += uint64(binary.BigEndian.Uint32(b[32:36]))
+ ac += uint64(binary.BigEndian.Uint32(b[36:40]))
+ ac += uint64(binary.BigEndian.Uint32(b[40:44]))
+ ac += uint64(binary.BigEndian.Uint32(b[44:48]))
+ ac += uint64(binary.BigEndian.Uint32(b[48:52]))
+ ac += uint64(binary.BigEndian.Uint32(b[52:56]))
+ ac += uint64(binary.BigEndian.Uint32(b[56:60]))
+ ac += uint64(binary.BigEndian.Uint32(b[60:64]))
+ b = b[64:]
+ }
+ if len(b) >= 32 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+ ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+ ac += uint64(binary.BigEndian.Uint32(b[16:20]))
+ ac += uint64(binary.BigEndian.Uint32(b[20:24]))
+ ac += uint64(binary.BigEndian.Uint32(b[24:28]))
+ ac += uint64(binary.BigEndian.Uint32(b[28:32]))
+ b = b[32:]
+ }
+ if len(b) >= 16 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+ ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+ b = b[16:]
+ }
+ if len(b) >= 8 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ b = b[8:]
+ }
+ if len(b) >= 4 {
+ ac += uint64(binary.BigEndian.Uint32(b))
+ b = b[4:]
+ }
+ if len(b) >= 2 {
+ ac += uint64(binary.BigEndian.Uint16(b))
+ b = b[2:]
+ }
+ if len(b) == 1 {
+ ac += uint64(b[0]) << 8
+ }
+
+ return ac
+}
+
+func checksum(b []byte, initial uint64) uint16 {
+ ac := checksumNoFold(b, initial)
+ ac = (ac >> 16) + (ac & 0xffff)
+ ac = (ac >> 16) + (ac & 0xffff)
+ ac = (ac >> 16) + (ac & 0xffff)
+ ac = (ac >> 16) + (ac & 0xffff)
+ return uint16(ac)
+}
+
+func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
+ sum := checksumNoFold(srcAddr, 0)
+ sum = checksumNoFold(dstAddr, sum)
+ sum = checksumNoFold([]byte{0, protocol}, sum)
+ tmp := make([]byte, 2)
+ binary.BigEndian.PutUint16(tmp, totalLen)
+ return checksumNoFold(tmp, sum)
+}
diff --git a/tun/checksum_test.go b/tun/checksum_test.go
new file mode 100644
index 0000000..c1ccff5
--- /dev/null
+++ b/tun/checksum_test.go
@@ -0,0 +1,35 @@
+package tun
+
+import (
+ "fmt"
+ "math/rand"
+ "testing"
+)
+
+func BenchmarkChecksum(b *testing.B) {
+ lengths := []int{
+ 64,
+ 128,
+ 256,
+ 512,
+ 1024,
+ 1500,
+ 2048,
+ 4096,
+ 8192,
+ 9000,
+ 9001,
+ }
+
+ for _, length := range lengths {
+ b.Run(fmt.Sprintf("%d", length), func(b *testing.B) {
+ buf := make([]byte, length)
+ rng := rand.New(rand.NewSource(1))
+ rng.Read(buf)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ checksum(buf, 0)
+ }
+ })
+ }
+}
diff --git a/tun/errors.go b/tun/errors.go
new file mode 100644
index 0000000..75ae3a4
--- /dev/null
+++ b/tun/errors.go
@@ -0,0 +1,12 @@
+package tun
+
+import (
+ "errors"
+)
+
+var (
+ // ErrTooManySegments is returned by Device.Read() when segmentation
+ // overflows the length of supplied buffers. This error should not cause
+ // reads to cease.
+ ErrTooManySegments = errors.New("too many segments")
+)
diff --git a/tun/netstack/examples/http_client.go b/tun/netstack/examples/http_client.go
new file mode 100644
index 0000000..ccd32ed
--- /dev/null
+++ b/tun/netstack/examples/http_client.go
@@ -0,0 +1,54 @@
+//go:build ignore
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package main
+
+import (
+ "io"
+ "log"
+ "net/http"
+ "net/netip"
+
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/device"
+ "golang.zx2c4.com/wireguard/tun/netstack"
+)
+
+func main() {
+ tun, tnet, err := netstack.CreateNetTUN(
+ []netip.Addr{netip.MustParseAddr("192.168.4.28")},
+ []netip.Addr{netip.MustParseAddr("8.8.8.8")},
+ 1420)
+ if err != nil {
+ log.Panic(err)
+ }
+ dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
+ err = dev.IpcSet(`private_key=087ec6e14bbed210e7215cdc73468dfa23f080a1bfb8665b2fd809bd99d28379
+public_key=c4c8e984c5322c8184c72265b92b250fdb63688705f504ba003c88f03393cf28
+allowed_ip=0.0.0.0/0
+endpoint=127.0.0.1:58120
+`)
+ err = dev.Up()
+ if err != nil {
+ log.Panic(err)
+ }
+
+ client := http.Client{
+ Transport: &http.Transport{
+ DialContext: tnet.DialContext,
+ },
+ }
+ resp, err := client.Get("http://192.168.4.29/")
+ if err != nil {
+ log.Panic(err)
+ }
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ log.Panic(err)
+ }
+ log.Println(string(body))
+}
diff --git a/tun/netstack/examples/http_server.go b/tun/netstack/examples/http_server.go
new file mode 100644
index 0000000..f5b7a8f
--- /dev/null
+++ b/tun/netstack/examples/http_server.go
@@ -0,0 +1,51 @@
+//go:build ignore
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package main
+
+import (
+ "io"
+ "log"
+ "net"
+ "net/http"
+ "net/netip"
+
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/device"
+ "golang.zx2c4.com/wireguard/tun/netstack"
+)
+
+func main() {
+ tun, tnet, err := netstack.CreateNetTUN(
+ []netip.Addr{netip.MustParseAddr("192.168.4.29")},
+ []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")},
+ 1420,
+ )
+ if err != nil {
+ log.Panic(err)
+ }
+ dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
+ dev.IpcSet(`private_key=003ed5d73b55806c30de3f8a7bdab38af13539220533055e635690b8b87ad641
+listen_port=58120
+public_key=f928d4f6c1b86c12f2562c10b07c555c5c57fd00f59e90c8d8d88767271cbf7c
+allowed_ip=192.168.4.28/32
+persistent_keepalive_interval=25
+`)
+ dev.Up()
+ listener, err := tnet.ListenTCP(&net.TCPAddr{Port: 80})
+ if err != nil {
+ log.Panicln(err)
+ }
+ http.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {
+ log.Printf("> %s - %s - %s", request.RemoteAddr, request.URL.String(), request.UserAgent())
+ io.WriteString(writer, "Hello from userspace TCP!")
+ })
+ err = http.Serve(listener, nil)
+ if err != nil {
+ log.Panicln(err)
+ }
+}
diff --git a/tun/netstack/examples/ping_client.go b/tun/netstack/examples/ping_client.go
new file mode 100644
index 0000000..2eef0fb
--- /dev/null
+++ b/tun/netstack/examples/ping_client.go
@@ -0,0 +1,75 @@
+//go:build ignore
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package main
+
+import (
+ "bytes"
+ "log"
+ "math/rand"
+ "net/netip"
+ "time"
+
+ "golang.org/x/net/icmp"
+ "golang.org/x/net/ipv4"
+
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/device"
+ "golang.zx2c4.com/wireguard/tun/netstack"
+)
+
+func main() {
+ tun, tnet, err := netstack.CreateNetTUN(
+ []netip.Addr{netip.MustParseAddr("192.168.4.29")},
+ []netip.Addr{netip.MustParseAddr("8.8.8.8")},
+ 1420)
+ if err != nil {
+ log.Panic(err)
+ }
+ dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
+ dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f
+public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b
+endpoint=163.172.161.0:12912
+allowed_ip=0.0.0.0/0
+`)
+ err = dev.Up()
+ if err != nil {
+ log.Panic(err)
+ }
+
+ socket, err := tnet.Dial("ping4", "zx2c4.com")
+ if err != nil {
+ log.Panic(err)
+ }
+ requestPing := icmp.Echo{
+ Seq: rand.Intn(1 << 16),
+ Data: []byte("gopher burrow"),
+ }
+ icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
+ socket.SetReadDeadline(time.Now().Add(time.Second * 10))
+ start := time.Now()
+ _, err = socket.Write(icmpBytes)
+ if err != nil {
+ log.Panic(err)
+ }
+ n, err := socket.Read(icmpBytes[:])
+ if err != nil {
+ log.Panic(err)
+ }
+ replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
+ if err != nil {
+ log.Panic(err)
+ }
+ replyPing, ok := replyPacket.Body.(*icmp.Echo)
+ if !ok {
+ log.Panicf("invalid reply type: %v", replyPacket)
+ }
+ if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
+ log.Panicf("invalid ping reply: %v", replyPing)
+ }
+ log.Printf("Ping latency: %v", time.Since(start))
+}
diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go
new file mode 100644
index 0000000..2b73054
--- /dev/null
+++ b/tun/netstack/tun.go
@@ -0,0 +1,1055 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package netstack
+
+import (
+ "bytes"
+ "context"
+ "crypto/rand"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/netip"
+ "os"
+ "regexp"
+ "strconv"
+ "strings"
+ "syscall"
+ "time"
+
+ "golang.zx2c4.com/wireguard/tun"
+
+ "golang.org/x/net/dns/dnsmessage"
+ "gvisor.dev/gvisor/pkg/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+type netTun struct {
+ ep *channel.Endpoint
+ stack *stack.Stack
+ events chan tun.Event
+ incomingPacket chan *buffer.View
+ mtu int
+ dnsServers []netip.Addr
+ hasV4, hasV6 bool
+}
+
+type Net netTun
+
+func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
+ HandleLocal: true,
+ }
+ dev := &netTun{
+ ep: channel.New(1024, uint32(mtu), ""),
+ stack: stack.New(opts),
+ events: make(chan tun.Event, 10),
+ incomingPacket: make(chan *buffer.View),
+ dnsServers: dnsServers,
+ mtu: mtu,
+ }
+ sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
+ tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
+ if tcpipErr != nil {
+ return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
+ }
+ dev.ep.AddNotify(dev)
+ tcpipErr = dev.stack.CreateNIC(1, dev.ep)
+ if tcpipErr != nil {
+ return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
+ }
+ for _, ip := range localAddresses {
+ var protoNumber tcpip.NetworkProtocolNumber
+ if ip.Is4() {
+ protoNumber = ipv4.ProtocolNumber
+ } else if ip.Is6() {
+ protoNumber = ipv6.ProtocolNumber
+ }
+ protoAddr := tcpip.ProtocolAddress{
+ Protocol: protoNumber,
+ AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
+ }
+ tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
+ if tcpipErr != nil {
+ return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
+ }
+ if ip.Is4() {
+ dev.hasV4 = true
+ } else if ip.Is6() {
+ dev.hasV6 = true
+ }
+ }
+ if dev.hasV4 {
+ dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
+ }
+ if dev.hasV6 {
+ dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
+ }
+
+ dev.events <- tun.EventUp
+ return dev, (*Net)(dev), nil
+}
+
+func (tun *netTun) Name() (string, error) {
+ return "go", nil
+}
+
+func (tun *netTun) File() *os.File {
+ return nil
+}
+
+func (tun *netTun) Events() <-chan tun.Event {
+ return tun.events
+}
+
+func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
+ view, ok := <-tun.incomingPacket
+ if !ok {
+ return 0, os.ErrClosed
+ }
+
+ n, err := view.Read(buf[0][offset:])
+ if err != nil {
+ return 0, err
+ }
+ sizes[0] = n
+ return 1, nil
+}
+
+func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
+ for _, buf := range buf {
+ packet := buf[offset:]
+ if len(packet) == 0 {
+ continue
+ }
+
+ pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
+ switch packet[0] >> 4 {
+ case 4:
+ tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
+ case 6:
+ tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
+ default:
+ return 0, syscall.EAFNOSUPPORT
+ }
+ }
+ return len(buf), nil
+}
+
+func (tun *netTun) WriteNotify() {
+ pkt := tun.ep.Read()
+ if pkt.IsNil() {
+ return
+ }
+
+ view := pkt.ToView()
+ pkt.DecRef()
+
+ tun.incomingPacket <- view
+}
+
+func (tun *netTun) Close() error {
+ tun.stack.RemoveNIC(1)
+
+ if tun.events != nil {
+ close(tun.events)
+ }
+
+ tun.ep.Close()
+
+ if tun.incomingPacket != nil {
+ close(tun.incomingPacket)
+ }
+
+ return nil
+}
+
+func (tun *netTun) MTU() (int, error) {
+ return tun.mtu, nil
+}
+
+func (tun *netTun) BatchSize() int {
+ return 1
+}
+
+func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
+ var protoNumber tcpip.NetworkProtocolNumber
+ if endpoint.Addr().Is4() {
+ protoNumber = ipv4.ProtocolNumber
+ } else {
+ protoNumber = ipv6.ProtocolNumber
+ }
+ return tcpip.FullAddress{
+ NIC: 1,
+ Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
+ Port: endpoint.Port(),
+ }, protoNumber
+}
+
+func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.DialContextTCP(ctx, net.stack, fa, pn)
+}
+
+func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
+ if addr == nil {
+ return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
+ }
+ ip, _ := netip.AddrFromSlice(addr.IP)
+ return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port)))
+}
+
+func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.DialTCP(net.stack, fa, pn)
+}
+
+func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
+ if addr == nil {
+ return net.DialTCPAddrPort(netip.AddrPort{})
+ }
+ ip, _ := netip.AddrFromSlice(addr.IP)
+ return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
+}
+
+func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.ListenTCP(net.stack, fa, pn)
+}
+
+func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
+ if addr == nil {
+ return net.ListenTCPAddrPort(netip.AddrPort{})
+ }
+ ip, _ := netip.AddrFromSlice(addr.IP)
+ return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
+}
+
+func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
+ var lfa, rfa *tcpip.FullAddress
+ var pn tcpip.NetworkProtocolNumber
+ if laddr.IsValid() || laddr.Port() > 0 {
+ var addr tcpip.FullAddress
+ addr, pn = convertToFullAddr(laddr)
+ lfa = &addr
+ }
+ if raddr.IsValid() || raddr.Port() > 0 {
+ var addr tcpip.FullAddress
+ addr, pn = convertToFullAddr(raddr)
+ rfa = &addr
+ }
+ return gonet.DialUDP(net.stack, lfa, rfa, pn)
+}
+
+func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) {
+ return net.DialUDPAddrPort(laddr, netip.AddrPort{})
+}
+
+func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
+ var la, ra netip.AddrPort
+ if laddr != nil {
+ ip, _ := netip.AddrFromSlice(laddr.IP)
+ la = netip.AddrPortFrom(ip, uint16(laddr.Port))
+ }
+ if raddr != nil {
+ ip, _ := netip.AddrFromSlice(raddr.IP)
+ ra = netip.AddrPortFrom(ip, uint16(raddr.Port))
+ }
+ return net.DialUDPAddrPort(la, ra)
+}
+
+func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
+ return net.DialUDP(laddr, nil)
+}
+
+type PingConn struct {
+ laddr PingAddr
+ raddr PingAddr
+ wq waiter.Queue
+ ep tcpip.Endpoint
+ deadline *time.Timer
+}
+
+type PingAddr struct{ addr netip.Addr }
+
+func (ia PingAddr) String() string {
+ return ia.addr.String()
+}
+
+func (ia PingAddr) Network() string {
+ if ia.addr.Is4() {
+ return "ping4"
+ } else if ia.addr.Is6() {
+ return "ping6"
+ }
+ return "ping"
+}
+
+func (ia PingAddr) Addr() netip.Addr {
+ return ia.addr
+}
+
+func PingAddrFromAddr(addr netip.Addr) *PingAddr {
+ return &PingAddr{addr}
+}
+
+func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
+ if !laddr.IsValid() && !raddr.IsValid() {
+ return nil, errors.New("ping dial: invalid address")
+ }
+ v6 := laddr.Is6() || raddr.Is6()
+ bind := laddr.IsValid()
+ if !bind {
+ if v6 {
+ laddr = netip.IPv6Unspecified()
+ } else {
+ laddr = netip.IPv4Unspecified()
+ }
+ }
+
+ tn := icmp.ProtocolNumber4
+ pn := ipv4.ProtocolNumber
+ if v6 {
+ tn = icmp.ProtocolNumber6
+ pn = ipv6.ProtocolNumber
+ }
+
+ pc := &PingConn{
+ laddr: PingAddr{laddr},
+ deadline: time.NewTimer(time.Hour << 10),
+ }
+ pc.deadline.Stop()
+
+ ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
+ if tcpipErr != nil {
+ return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr)
+ }
+ pc.ep = ep
+
+ if bind {
+ fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0))
+ if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil {
+ return nil, fmt.Errorf("ping bind: %s", tcpipErr)
+ }
+ }
+
+ if raddr.IsValid() {
+ pc.raddr = PingAddr{raddr}
+ fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0))
+ if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil {
+ return nil, fmt.Errorf("ping connect: %s", tcpipErr)
+ }
+ }
+
+ return pc, nil
+}
+
+func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) {
+ return net.DialPingAddr(laddr, netip.Addr{})
+}
+
+func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) {
+ var la, ra netip.Addr
+ if laddr != nil {
+ la = laddr.addr
+ }
+ if raddr != nil {
+ ra = raddr.addr
+ }
+ return net.DialPingAddr(la, ra)
+}
+
+func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) {
+ var la netip.Addr
+ if laddr != nil {
+ la = laddr.addr
+ }
+ return net.ListenPingAddr(la)
+}
+
+func (pc *PingConn) LocalAddr() net.Addr {
+ return pc.laddr
+}
+
+func (pc *PingConn) RemoteAddr() net.Addr {
+ return pc.raddr
+}
+
+func (pc *PingConn) Close() error {
+ pc.deadline.Reset(0)
+ pc.ep.Close()
+ return nil
+}
+
+func (pc *PingConn) SetWriteDeadline(t time.Time) error {
+ return errors.New("not implemented")
+}
+
+func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
+ var na netip.Addr
+ switch v := addr.(type) {
+ case *PingAddr:
+ na = v.addr
+ case *net.IPAddr:
+ na, _ = netip.AddrFromSlice(v.IP)
+ default:
+ return 0, fmt.Errorf("ping write: wrong net.Addr type")
+ }
+ if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) {
+ return 0, fmt.Errorf("ping write: mismatched protocols")
+ }
+
+ buf := bytes.NewReader(p)
+ rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0))
+ // won't block, no deadlines
+ n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{
+ To: &rfa,
+ })
+ if tcpipErr != nil {
+ return int(n64), fmt.Errorf("ping write: %s", tcpipErr)
+ }
+
+ return int(n64), nil
+}
+
+func (pc *PingConn) Write(p []byte) (n int, err error) {
+ return pc.WriteTo(p, &pc.raddr)
+}
+
+func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
+ e, notifyCh := waiter.NewChannelEntry(waiter.EventIn)
+ pc.wq.EventRegister(&e)
+ defer pc.wq.EventUnregister(&e)
+
+ select {
+ case <-pc.deadline.C:
+ return 0, nil, os.ErrDeadlineExceeded
+ case <-notifyCh:
+ }
+
+ w := tcpip.SliceWriter(p)
+
+ res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{
+ NeedRemoteAddr: true,
+ })
+ if tcpipErr != nil {
+ return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
+ }
+
+ remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice())
+ return res.Count, &PingAddr{remoteAddr}, nil
+}
+
+func (pc *PingConn) Read(p []byte) (n int, err error) {
+ n, _, err = pc.ReadFrom(p)
+ return
+}
+
+func (pc *PingConn) SetDeadline(t time.Time) error {
+ // pc.SetWriteDeadline is unimplemented
+
+ return pc.SetReadDeadline(t)
+}
+
+func (pc *PingConn) SetReadDeadline(t time.Time) error {
+ pc.deadline.Reset(time.Until(t))
+ return nil
+}
+
+var (
+ errNoSuchHost = errors.New("no such host")
+ errLameReferral = errors.New("lame referral")
+ errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message")
+ errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message")
+ errServerMisbehaving = errors.New("server misbehaving")
+ errInvalidDNSResponse = errors.New("invalid DNS response")
+ errNoAnswerFromDNSServer = errors.New("no answer from DNS server")
+ errServerTemporarilyMisbehaving = errors.New("server misbehaving")
+ errCanceled = errors.New("operation was canceled")
+ errTimeout = errors.New("i/o timeout")
+ errNumericPort = errors.New("port must be numeric")
+ errNoSuitableAddress = errors.New("no suitable address found")
+ errMissingAddress = errors.New("missing address")
+)
+
+func (net *Net) LookupHost(host string) (addrs []string, err error) {
+ return net.LookupContextHost(context.Background(), host)
+}
+
+func isDomainName(s string) bool {
+ l := len(s)
+ if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
+ return false
+ }
+ last := byte('.')
+ nonNumeric := false
+ partlen := 0
+ for i := 0; i < len(s); i++ {
+ c := s[i]
+ switch {
+ default:
+ return false
+ case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
+ nonNumeric = true
+ partlen++
+ case '0' <= c && c <= '9':
+ partlen++
+ case c == '-':
+ if last == '.' {
+ return false
+ }
+ partlen++
+ nonNumeric = true
+ case c == '.':
+ if last == '.' || last == '-' {
+ return false
+ }
+ if partlen > 63 || partlen == 0 {
+ return false
+ }
+ partlen = 0
+ }
+ last = c
+ }
+ if last == '-' || partlen > 63 {
+ return false
+ }
+ return nonNumeric
+}
+
+func randU16() uint16 {
+ var b [2]byte
+ _, err := rand.Read(b[:])
+ if err != nil {
+ panic(err)
+ }
+ return binary.LittleEndian.Uint16(b[:])
+}
+
+func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) {
+ id = randU16()
+ b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true})
+ b.EnableCompression()
+ if err := b.StartQuestions(); err != nil {
+ return 0, nil, nil, err
+ }
+ if err := b.Question(q); err != nil {
+ return 0, nil, nil, err
+ }
+ tcpReq, err = b.Finish()
+ udpReq = tcpReq[2:]
+ l := len(tcpReq) - 2
+ tcpReq[0] = byte(l >> 8)
+ tcpReq[1] = byte(l)
+ return id, udpReq, tcpReq, err
+}
+
+func equalASCIIName(x, y dnsmessage.Name) bool {
+ if x.Length != y.Length {
+ return false
+ }
+ for i := 0; i < int(x.Length); i++ {
+ a := x.Data[i]
+ b := y.Data[i]
+ if 'A' <= a && a <= 'Z' {
+ a += 0x20
+ }
+ if 'A' <= b && b <= 'Z' {
+ b += 0x20
+ }
+ if a != b {
+ return false
+ }
+ }
+ return true
+}
+
+func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool {
+ if !respHdr.Response {
+ return false
+ }
+ if reqID != respHdr.ID {
+ return false
+ }
+ if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) {
+ return false
+ }
+ return true
+}
+
+func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
+ if _, err := c.Write(b); err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ b = make([]byte, 512)
+ for {
+ n, err := c.Read(b)
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ var p dnsmessage.Parser
+ h, err := p.Start(b[:n])
+ if err != nil {
+ continue
+ }
+ q, err := p.Question()
+ if err != nil || !checkResponse(id, query, h, q) {
+ continue
+ }
+ return p, h, nil
+ }
+}
+
+func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
+ if _, err := c.Write(b); err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ b = make([]byte, 1280)
+ if _, err := io.ReadFull(c, b[:2]); err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ l := int(b[0])<<8 | int(b[1])
+ if l > len(b) {
+ b = make([]byte, l)
+ }
+ n, err := io.ReadFull(c, b[:l])
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ var p dnsmessage.Parser
+ h, err := p.Start(b[:n])
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
+ }
+ q, err := p.Question()
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
+ }
+ if !checkResponse(id, query, h, q) {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
+ }
+ return p, h, nil
+}
+
+func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
+ q.Class = dnsmessage.ClassINET
+ id, udpReq, tcpReq, err := newRequest(q)
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage
+ }
+
+ for _, useUDP := range []bool{true, false} {
+ ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
+ defer cancel()
+
+ var c net.Conn
+ var err error
+ if useUDP {
+ c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
+ } else {
+ c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
+ }
+
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ if d, ok := ctx.Deadline(); ok && !d.IsZero() {
+ err := c.SetDeadline(d)
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ }
+ var p dnsmessage.Parser
+ var h dnsmessage.Header
+ if useUDP {
+ p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
+ } else {
+ p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
+ }
+ c.Close()
+ if err != nil {
+ if err == context.Canceled {
+ err = errCanceled
+ } else if err == context.DeadlineExceeded {
+ err = errTimeout
+ }
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
+ }
+ if h.Truncated {
+ continue
+ }
+ return p, h, nil
+ }
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer
+}
+
+func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
+ if h.RCode == dnsmessage.RCodeNameError {
+ return errNoSuchHost
+ }
+ _, err := p.AnswerHeader()
+ if err != nil && err != dnsmessage.ErrSectionDone {
+ return errCannotUnmarshalDNSMessage
+ }
+ if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
+ return errLameReferral
+ }
+ if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError {
+ if h.RCode == dnsmessage.RCodeServerFailure {
+ return errServerTemporarilyMisbehaving
+ }
+ return errServerMisbehaving
+ }
+ return nil
+}
+
+func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error {
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ return errNoSuchHost
+ }
+ if err != nil {
+ return errCannotUnmarshalDNSMessage
+ }
+ if h.Type == qtype {
+ return nil
+ }
+ if err := p.SkipAnswer(); err != nil {
+ return errCannotUnmarshalDNSMessage
+ }
+ }
+}
+
+func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
+ var lastErr error
+
+ n, err := dnsmessage.NewName(name)
+ if err != nil {
+ return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage
+ }
+ q := dnsmessage.Question{
+ Name: n,
+ Type: qtype,
+ Class: dnsmessage.ClassINET,
+ }
+
+ for i := 0; i < 2; i++ {
+ for _, server := range tnet.dnsServers {
+ p, h, err := tnet.exchange(ctx, server, q, time.Second*5)
+ if err != nil {
+ dnsErr := &net.DNSError{
+ Err: err.Error(),
+ Name: name,
+ Server: server.String(),
+ }
+ if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
+ dnsErr.IsTimeout = true
+ }
+ if _, ok := err.(*net.OpError); ok {
+ dnsErr.IsTemporary = true
+ }
+ lastErr = dnsErr
+ continue
+ }
+
+ if err := checkHeader(&p, h); err != nil {
+ dnsErr := &net.DNSError{
+ Err: err.Error(),
+ Name: name,
+ Server: server.String(),
+ }
+ if err == errServerTemporarilyMisbehaving {
+ dnsErr.IsTemporary = true
+ }
+ if err == errNoSuchHost {
+ dnsErr.IsNotFound = true
+ return p, server.String(), dnsErr
+ }
+ lastErr = dnsErr
+ continue
+ }
+
+ err = skipToAnswer(&p, qtype)
+ if err == nil {
+ return p, server.String(), nil
+ }
+ lastErr = &net.DNSError{
+ Err: err.Error(),
+ Name: name,
+ Server: server.String(),
+ }
+ if err == errNoSuchHost {
+ lastErr.(*net.DNSError).IsNotFound = true
+ return p, server.String(), lastErr
+ }
+ }
+ }
+ return dnsmessage.Parser{}, "", lastErr
+}
+
+func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) {
+ if host == "" || (!tnet.hasV6 && !tnet.hasV4) {
+ return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
+ }
+ zlen := len(host)
+ if strings.IndexByte(host, ':') != -1 {
+ if zidx := strings.LastIndexByte(host, '%'); zidx != -1 {
+ zlen = zidx
+ }
+ }
+ if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
+ return []string{ip.String()}, nil
+ }
+
+ if !isDomainName(host) {
+ return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
+ }
+ type result struct {
+ p dnsmessage.Parser
+ server string
+ error
+ }
+ var addrsV4, addrsV6 []netip.Addr
+ lanes := 0
+ if tnet.hasV4 {
+ lanes++
+ }
+ if tnet.hasV6 {
+ lanes++
+ }
+ lane := make(chan result, lanes)
+ var lastErr error
+ if tnet.hasV4 {
+ go func() {
+ p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA)
+ lane <- result{p, server, err}
+ }()
+ }
+ if tnet.hasV6 {
+ go func() {
+ p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA)
+ lane <- result{p, server, err}
+ }()
+ }
+ for l := 0; l < lanes; l++ {
+ result := <-lane
+ if result.error != nil {
+ if lastErr == nil {
+ lastErr = result.error
+ }
+ continue
+ }
+
+ loop:
+ for {
+ h, err := result.p.AnswerHeader()
+ if err != nil && err != dnsmessage.ErrSectionDone {
+ lastErr = &net.DNSError{
+ Err: errCannotMarshalDNSMessage.Error(),
+ Name: host,
+ Server: result.server,
+ }
+ }
+ if err != nil {
+ break
+ }
+ switch h.Type {
+ case dnsmessage.TypeA:
+ a, err := result.p.AResource()
+ if err != nil {
+ lastErr = &net.DNSError{
+ Err: errCannotMarshalDNSMessage.Error(),
+ Name: host,
+ Server: result.server,
+ }
+ break loop
+ }
+ addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
+
+ case dnsmessage.TypeAAAA:
+ aaaa, err := result.p.AAAAResource()
+ if err != nil {
+ lastErr = &net.DNSError{
+ Err: errCannotMarshalDNSMessage.Error(),
+ Name: host,
+ Server: result.server,
+ }
+ break loop
+ }
+ addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
+
+ default:
+ if err := result.p.SkipAnswer(); err != nil {
+ lastErr = &net.DNSError{
+ Err: errCannotMarshalDNSMessage.Error(),
+ Name: host,
+ Server: result.server,
+ }
+ break loop
+ }
+ continue
+ }
+ }
+ }
+ // We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled
+ var addrs []netip.Addr
+ if tnet.hasV6 {
+ addrs = append(addrsV6, addrsV4...)
+ } else {
+ addrs = append(addrsV4, addrsV6...)
+ }
+
+ if len(addrs) == 0 && lastErr != nil {
+ return nil, lastErr
+ }
+ saddrs := make([]string, 0, len(addrs))
+ for _, ip := range addrs {
+ saddrs = append(saddrs, ip.String())
+ }
+ return saddrs, nil
+}
+
+func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
+ if deadline.IsZero() {
+ return deadline, nil
+ }
+ timeRemaining := deadline.Sub(now)
+ if timeRemaining <= 0 {
+ return time.Time{}, errTimeout
+ }
+ timeout := timeRemaining / time.Duration(addrsRemaining)
+ const saneMinimum = 2 * time.Second
+ if timeout < saneMinimum {
+ if timeRemaining < saneMinimum {
+ timeout = timeRemaining
+ } else {
+ timeout = saneMinimum
+ }
+ }
+ return now.Add(timeout), nil
+}
+
+var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`)
+
+func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
+ if ctx == nil {
+ panic("nil context")
+ }
+ var acceptV4, acceptV6 bool
+ matches := protoSplitter.FindStringSubmatch(network)
+ if matches == nil {
+ return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
+ } else if len(matches[2]) == 0 {
+ acceptV4 = true
+ acceptV6 = true
+ } else {
+ acceptV4 = matches[2][0] == '4'
+ acceptV6 = !acceptV4
+ }
+ var host string
+ var port int
+ if matches[1] == "ping" {
+ host = address
+ } else {
+ var sport string
+ var err error
+ host, sport, err = net.SplitHostPort(address)
+ if err != nil {
+ return nil, &net.OpError{Op: "dial", Err: err}
+ }
+ port, err = strconv.Atoi(sport)
+ if err != nil || port < 0 || port > 65535 {
+ return nil, &net.OpError{Op: "dial", Err: errNumericPort}
+ }
+ }
+ allAddr, err := tnet.LookupContextHost(ctx, host)
+ if err != nil {
+ return nil, &net.OpError{Op: "dial", Err: err}
+ }
+ var addrs []netip.AddrPort
+ for _, addr := range allAddr {
+ ip, err := netip.ParseAddr(addr)
+ if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
+ addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
+ }
+ }
+ if len(addrs) == 0 && len(allAddr) != 0 {
+ return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress}
+ }
+
+ var firstErr error
+ for i, addr := range addrs {
+ select {
+ case <-ctx.Done():
+ err := ctx.Err()
+ if err == context.Canceled {
+ err = errCanceled
+ } else if err == context.DeadlineExceeded {
+ err = errTimeout
+ }
+ return nil, &net.OpError{Op: "dial", Err: err}
+ default:
+ }
+
+ dialCtx := ctx
+ if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
+ partialDeadline, err := partialDeadline(time.Now(), deadline, len(addrs)-i)
+ if err != nil {
+ if firstErr == nil {
+ firstErr = &net.OpError{Op: "dial", Err: err}
+ }
+ break
+ }
+ if partialDeadline.Before(deadline) {
+ var cancel context.CancelFunc
+ dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
+ defer cancel()
+ }
+ }
+
+ var c net.Conn
+ switch matches[1] {
+ case "tcp":
+ c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
+ case "udp":
+ c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
+ case "ping":
+ c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr())
+ }
+ if err == nil {
+ return c, nil
+ }
+ if firstErr == nil {
+ firstErr = err
+ }
+ }
+ if firstErr == nil {
+ firstErr = &net.OpError{Op: "dial", Err: errMissingAddress}
+ }
+ return nil, firstErr
+}
+
+func (tnet *Net) Dial(network, address string) (net.Conn, error) {
+ return tnet.DialContext(context.Background(), network, address)
+}
diff --git a/tun/offload_linux.go b/tun/offload_linux.go
new file mode 100644
index 0000000..9ff7fea
--- /dev/null
+++ b/tun/offload_linux.go
@@ -0,0 +1,993 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "io"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/conn"
+)
+
+const tcpFlagsOffset = 13
+
+const (
+ tcpFlagFIN uint8 = 0x01
+ tcpFlagPSH uint8 = 0x08
+ tcpFlagACK uint8 = 0x10
+)
+
+// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
+// kernel symbol is virtio_net_hdr.
+type virtioNetHdr struct {
+ flags uint8
+ gsoType uint8
+ hdrLen uint16
+ gsoSize uint16
+ csumStart uint16
+ csumOffset uint16
+}
+
+func (v *virtioNetHdr) decode(b []byte) error {
+ if len(b) < virtioNetHdrLen {
+ return io.ErrShortBuffer
+ }
+ copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
+ return nil
+}
+
+func (v *virtioNetHdr) encode(b []byte) error {
+ if len(b) < virtioNetHdrLen {
+ return io.ErrShortBuffer
+ }
+ copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
+ return nil
+}
+
+const (
+ // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
+ // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
+ virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
+)
+
+// tcpFlowKey represents the key for a TCP flow.
+type tcpFlowKey struct {
+ srcAddr, dstAddr [16]byte
+ srcPort, dstPort uint16
+ rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
+ isV6 bool
+}
+
+// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO.
+type tcpGROTable struct {
+ itemsByFlow map[tcpFlowKey][]tcpGROItem
+ itemsPool [][]tcpGROItem
+}
+
+func newTCPGROTable() *tcpGROTable {
+ t := &tcpGROTable{
+ itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, conn.IdealBatchSize),
+ itemsPool: make([][]tcpGROItem, conn.IdealBatchSize),
+ }
+ for i := range t.itemsPool {
+ t.itemsPool[i] = make([]tcpGROItem, 0, conn.IdealBatchSize)
+ }
+ return t
+}
+
+func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey {
+ key := tcpFlowKey{}
+ addrSize := dstAddrOffset - srcAddrOffset
+ copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
+ copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
+ key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
+ key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
+ key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
+ key.isV6 = addrSize == 16
+ return key
+}
+
+// lookupOrInsert looks up a flow for the provided packet and metadata,
+// returning the packets found for the flow, or inserting a new one if none
+// is found.
+func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
+ key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
+ items, ok := t.itemsByFlow[key]
+ if ok {
+ return items, ok
+ }
+ // TODO: insert() performs another map lookup. This could be rearranged to avoid.
+ t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
+ return nil, false
+}
+
+// insert an item in the table for the provided packet and packet metadata.
+func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
+ key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
+ item := tcpGROItem{
+ key: key,
+ bufsIndex: uint16(bufsIndex),
+ gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
+ iphLen: uint8(tcphOffset),
+ tcphLen: uint8(tcphLen),
+ sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
+ pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
+ }
+ items, ok := t.itemsByFlow[key]
+ if !ok {
+ items = t.newItems()
+ }
+ items = append(items, item)
+ t.itemsByFlow[key] = items
+}
+
+func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
+ items, _ := t.itemsByFlow[item.key]
+ items[i] = item
+}
+
+func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) {
+ items, _ := t.itemsByFlow[key]
+ items = append(items[:i], items[i+1:]...)
+ t.itemsByFlow[key] = items
+}
+
+// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
+// of a GRO evaluation across a vector of packets.
+type tcpGROItem struct {
+ key tcpFlowKey
+ sentSeq uint32 // the sequence number
+ bufsIndex uint16 // the index into the original bufs slice
+ numMerged uint16 // the number of packets merged into this item
+ gsoSize uint16 // payload size
+ iphLen uint8 // ip header len
+ tcphLen uint8 // tcp header len
+ pshSet bool // psh flag is set
+}
+
+func (t *tcpGROTable) newItems() []tcpGROItem {
+ var items []tcpGROItem
+ items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
+ return items
+}
+
+func (t *tcpGROTable) reset() {
+ for k, items := range t.itemsByFlow {
+ items = items[:0]
+ t.itemsPool = append(t.itemsPool, items)
+ delete(t.itemsByFlow, k)
+ }
+}
+
+// udpFlowKey represents the key for a UDP flow.
+type udpFlowKey struct {
+ srcAddr, dstAddr [16]byte
+ srcPort, dstPort uint16
+ isV6 bool
+}
+
+// udpGROTable holds flow and coalescing information for the purposes of UDP GRO.
+type udpGROTable struct {
+ itemsByFlow map[udpFlowKey][]udpGROItem
+ itemsPool [][]udpGROItem
+}
+
+func newUDPGROTable() *udpGROTable {
+ u := &udpGROTable{
+ itemsByFlow: make(map[udpFlowKey][]udpGROItem, conn.IdealBatchSize),
+ itemsPool: make([][]udpGROItem, conn.IdealBatchSize),
+ }
+ for i := range u.itemsPool {
+ u.itemsPool[i] = make([]udpGROItem, 0, conn.IdealBatchSize)
+ }
+ return u
+}
+
+func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey {
+ key := udpFlowKey{}
+ addrSize := dstAddrOffset - srcAddrOffset
+ copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
+ copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
+ key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:])
+ key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:])
+ key.isV6 = addrSize == 16
+ return key
+}
+
+// lookupOrInsert looks up a flow for the provided packet and metadata,
+// returning the packets found for the flow, or inserting a new one if none
+// is found.
+func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) {
+ key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
+ items, ok := u.itemsByFlow[key]
+ if ok {
+ return items, ok
+ }
+ // TODO: insert() performs another map lookup. This could be rearranged to avoid.
+ u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false)
+ return nil, false
+}
+
+// insert an item in the table for the provided packet and packet metadata.
+func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) {
+ key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
+ item := udpGROItem{
+ key: key,
+ bufsIndex: uint16(bufsIndex),
+ gsoSize: uint16(len(pkt[udphOffset+udphLen:])),
+ iphLen: uint8(udphOffset),
+ cSumKnownInvalid: cSumKnownInvalid,
+ }
+ items, ok := u.itemsByFlow[key]
+ if !ok {
+ items = u.newItems()
+ }
+ items = append(items, item)
+ u.itemsByFlow[key] = items
+}
+
+func (u *udpGROTable) updateAt(item udpGROItem, i int) {
+ items, _ := u.itemsByFlow[item.key]
+ items[i] = item
+}
+
+// udpGROItem represents bookkeeping data for a UDP packet during the lifetime
+// of a GRO evaluation across a vector of packets.
+type udpGROItem struct {
+ key udpFlowKey
+ bufsIndex uint16 // the index into the original bufs slice
+ numMerged uint16 // the number of packets merged into this item
+ gsoSize uint16 // payload size
+ iphLen uint8 // ip header len
+ cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown.
+}
+
+func (u *udpGROTable) newItems() []udpGROItem {
+ var items []udpGROItem
+ items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1]
+ return items
+}
+
+func (u *udpGROTable) reset() {
+ for k, items := range u.itemsByFlow {
+ items = items[:0]
+ u.itemsPool = append(u.itemsPool, items)
+ delete(u.itemsByFlow, k)
+ }
+}
+
+// canCoalesce represents the outcome of checking if two TCP packets are
+// candidates for coalescing.
+type canCoalesce int
+
+const (
+ coalescePrepend canCoalesce = -1
+ coalesceUnavailable canCoalesce = 0
+ coalesceAppend canCoalesce = 1
+)
+
+// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB
+// meet all requirements to be merged as part of a GRO operation, otherwise it
+// returns false.
+func ipHeadersCanCoalesce(pktA, pktB []byte) bool {
+ if len(pktA) < 9 || len(pktB) < 9 {
+ return false
+ }
+ if pktA[0]>>4 == 6 {
+ if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 {
+ // cannot coalesce with unequal Traffic class values
+ return false
+ }
+ if pktA[7] != pktB[7] {
+ // cannot coalesce with unequal Hop limit values
+ return false
+ }
+ } else {
+ if pktA[1] != pktB[1] {
+ // cannot coalesce with unequal ToS values
+ return false
+ }
+ if pktA[6]>>5 != pktB[6]>>5 {
+ // cannot coalesce with unequal DF or reserved bits. MF is checked
+ // further up the stack.
+ return false
+ }
+ if pktA[8] != pktB[8] {
+ // cannot coalesce with unequal TTL values
+ return false
+ }
+ }
+ return true
+}
+
+// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
+// described by item. iphLen and gsoSize describe pkt. bufs is the vector of
+// packets involved in the current GRO evaluation. bufsOffset is the offset at
+// which packet data begins within bufs.
+func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
+ pktTarget := bufs[item.bufsIndex][bufsOffset:]
+ if !ipHeadersCanCoalesce(pkt, pktTarget) {
+ return coalesceUnavailable
+ }
+ if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 {
+ // A smaller than gsoSize packet has been appended previously.
+ // Nothing can come after a smaller packet on the end.
+ return coalesceUnavailable
+ }
+ if gsoSize > item.gsoSize {
+ // We cannot have a larger packet following a smaller one.
+ return coalesceUnavailable
+ }
+ return coalesceAppend
+}
+
+// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
+// described by item. This function makes considerations that match the kernel's
+// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
+func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
+ pktTarget := bufs[item.bufsIndex][bufsOffset:]
+ if tcphLen != item.tcphLen {
+ // cannot coalesce with unequal tcp options len
+ return coalesceUnavailable
+ }
+ if tcphLen > 20 {
+ if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
+ // cannot coalesce with unequal tcp options
+ return coalesceUnavailable
+ }
+ }
+ if !ipHeadersCanCoalesce(pkt, pktTarget) {
+ return coalesceUnavailable
+ }
+ // seq adjacency
+ lhsLen := item.gsoSize
+ lhsLen += item.numMerged * item.gsoSize
+ if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
+ if item.pshSet {
+ // We cannot append to a segment that has the PSH flag set, PSH
+ // can only be set on the final segment in a reassembled group.
+ return coalesceUnavailable
+ }
+ if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
+ // A smaller than gsoSize packet has been appended previously.
+ // Nothing can come after a smaller packet on the end.
+ return coalesceUnavailable
+ }
+ if gsoSize > item.gsoSize {
+ // We cannot have a larger packet following a smaller one.
+ return coalesceUnavailable
+ }
+ return coalesceAppend
+ } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
+ if pshSet {
+ // We cannot prepend with a segment that has the PSH flag set, PSH
+ // can only be set on the final segment in a reassembled group.
+ return coalesceUnavailable
+ }
+ if gsoSize < item.gsoSize {
+ // We cannot have a larger packet following a smaller one.
+ return coalesceUnavailable
+ }
+ if gsoSize > item.gsoSize && item.numMerged > 0 {
+ // There's at least one previous merge, and we're larger than all
+ // previous. This would put multiple smaller packets on the end.
+ return coalesceUnavailable
+ }
+ return coalescePrepend
+ }
+ return coalesceUnavailable
+}
+
+func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool {
+ srcAddrAt := ipv4SrcAddrOffset
+ addrSize := 4
+ if isV6 {
+ srcAddrAt = ipv6SrcAddrOffset
+ addrSize = 16
+ }
+ lenForPseudo := uint16(len(pkt) - int(iphLen))
+ cSum := pseudoHeaderChecksumNoFold(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo)
+ return ^checksum(pkt[iphLen:], cSum) == 0
+}
+
+// coalesceResult represents the result of attempting to coalesce two TCP
+// packets.
+type coalesceResult int
+
+const (
+ coalesceInsufficientCap coalesceResult = iota
+ coalescePSHEnding
+ coalesceItemInvalidCSum
+ coalescePktInvalidCSum
+ coalesceSuccess
+)
+
+// coalesceUDPPackets attempts to coalesce pkt with the packet described by
+// item, and returns the outcome.
+func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
+ pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front
+ headersLen := item.iphLen + udphLen
+ coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
+
+ if cap(pktHead)-bufsOffset < coalescedLen {
+ // We don't want to allocate a new underlying array if capacity is
+ // too small.
+ return coalesceInsufficientCap
+ }
+ if item.numMerged == 0 {
+ if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) {
+ return coalesceItemInvalidCSum
+ }
+ }
+ if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) {
+ return coalescePktInvalidCSum
+ }
+ extendBy := len(pkt) - int(headersLen)
+ bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
+ copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
+
+ item.numMerged++
+ return coalesceSuccess
+}
+
+// coalesceTCPPackets attempts to coalesce pkt with the packet described by
+// item, and returns the outcome. This function may swap bufs elements in the
+// event of a prepend as item's bufs index is already being tracked for writing
+// to a Device.
+func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
+ var pktHead []byte // the packet that will end up at the front
+ headersLen := item.iphLen + item.tcphLen
+ coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
+
+ // Copy data
+ if mode == coalescePrepend {
+ pktHead = pkt
+ if cap(pkt)-bufsOffset < coalescedLen {
+ // We don't want to allocate a new underlying array if capacity is
+ // too small.
+ return coalesceInsufficientCap
+ }
+ if pshSet {
+ return coalescePSHEnding
+ }
+ if item.numMerged == 0 {
+ if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
+ return coalesceItemInvalidCSum
+ }
+ }
+ if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
+ return coalescePktInvalidCSum
+ }
+ item.sentSeq = seq
+ extendBy := coalescedLen - len(pktHead)
+ bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
+ copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
+ // Flip the slice headers in bufs as part of prepend. The index of item
+ // is already being tracked for writing.
+ bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
+ } else {
+ pktHead = bufs[item.bufsIndex][bufsOffset:]
+ if cap(pktHead)-bufsOffset < coalescedLen {
+ // We don't want to allocate a new underlying array if capacity is
+ // too small.
+ return coalesceInsufficientCap
+ }
+ if item.numMerged == 0 {
+ if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
+ return coalesceItemInvalidCSum
+ }
+ }
+ if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
+ return coalescePktInvalidCSum
+ }
+ if pshSet {
+ // We are appending a segment with PSH set.
+ item.pshSet = pshSet
+ pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
+ }
+ extendBy := len(pkt) - int(headersLen)
+ bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
+ copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
+ }
+
+ if gsoSize > item.gsoSize {
+ item.gsoSize = gsoSize
+ }
+
+ item.numMerged++
+ return coalesceSuccess
+}
+
+const (
+ ipv4FlagMoreFragments uint8 = 0x20
+)
+
+const (
+ ipv4SrcAddrOffset = 12
+ ipv6SrcAddrOffset = 8
+ maxUint16 = 1<<16 - 1
+)
+
+type groResult int
+
+const (
+ groResultNoop groResult = iota
+ groResultTableInsert
+ groResultCoalesced
+)
+
+// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
+// existing packets tracked in table. It returns a groResultNoop when no
+// action was taken, groResultTableInsert when the evaluated packet was
+// inserted into table, and groResultCoalesced when the evaluated packet was
+// coalesced with another packet in table.
+func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult {
+ pkt := bufs[pktI][offset:]
+ if len(pkt) > maxUint16 {
+ // A valid IPv4 or IPv6 packet will never exceed this.
+ return groResultNoop
+ }
+ iphLen := int((pkt[0] & 0x0F) * 4)
+ if isV6 {
+ iphLen = 40
+ ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
+ if ipv6HPayloadLen != len(pkt)-iphLen {
+ return groResultNoop
+ }
+ } else {
+ totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
+ if totalLen != len(pkt) {
+ return groResultNoop
+ }
+ }
+ if len(pkt) < iphLen {
+ return groResultNoop
+ }
+ tcphLen := int((pkt[iphLen+12] >> 4) * 4)
+ if tcphLen < 20 || tcphLen > 60 {
+ return groResultNoop
+ }
+ if len(pkt) < iphLen+tcphLen {
+ return groResultNoop
+ }
+ if !isV6 {
+ if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
+ // no GRO support for fragmented segments for now
+ return groResultNoop
+ }
+ }
+ tcpFlags := pkt[iphLen+tcpFlagsOffset]
+ var pshSet bool
+ // not a candidate if any non-ACK flags (except PSH+ACK) are set
+ if tcpFlags != tcpFlagACK {
+ if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
+ return groResultNoop
+ }
+ pshSet = true
+ }
+ gsoSize := uint16(len(pkt) - tcphLen - iphLen)
+ // not a candidate if payload len is 0
+ if gsoSize < 1 {
+ return groResultNoop
+ }
+ seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
+ srcAddrOffset := ipv4SrcAddrOffset
+ addrLen := 4
+ if isV6 {
+ srcAddrOffset = ipv6SrcAddrOffset
+ addrLen = 16
+ }
+ items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
+ if !existing {
+ return groResultTableInsert
+ }
+ for i := len(items) - 1; i >= 0; i-- {
+ // In the best case of packets arriving in order iterating in reverse is
+ // more efficient if there are multiple items for a given flow. This
+ // also enables a natural table.deleteAt() in the
+ // coalesceItemInvalidCSum case without the need for index tracking.
+ // This algorithm makes a best effort to coalesce in the event of
+ // unordered packets, where pkt may land anywhere in items from a
+ // sequence number perspective, however once an item is inserted into
+ // the table it is never compared across other items later.
+ item := items[i]
+ can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
+ if can != coalesceUnavailable {
+ result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
+ switch result {
+ case coalesceSuccess:
+ table.updateAt(item, i)
+ return groResultCoalesced
+ case coalesceItemInvalidCSum:
+ // delete the item with an invalid csum
+ table.deleteAt(item.key, i)
+ case coalescePktInvalidCSum:
+ // no point in inserting an item that we can't coalesce
+ return groResultNoop
+ default:
+ }
+ }
+ }
+ // failed to coalesce with any other packets; store the item in the flow
+ table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
+ return groResultTableInsert
+}
+
+// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the
+// metadata found in table.
+func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error {
+ for _, items := range table.itemsByFlow {
+ for _, item := range items {
+ if item.numMerged > 0 {
+ hdr := virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
+ hdrLen: uint16(item.iphLen + item.tcphLen),
+ gsoSize: item.gsoSize,
+ csumStart: uint16(item.iphLen),
+ csumOffset: 16,
+ }
+ pkt := bufs[item.bufsIndex][offset:]
+
+ // Recalculate the total len (IPv4) or payload len (IPv6).
+ // Recalculate the (IPv4) header checksum.
+ if item.key.isV6 {
+ hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
+ binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
+ } else {
+ hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
+ pkt[10], pkt[11] = 0, 0
+ binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
+ iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
+ binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
+ }
+ err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+
+ // Calculate the pseudo header checksum and place it at the TCP
+ // checksum offset. Downstream checksum offloading will combine
+ // this with computation of the tcp header and payload checksum.
+ addrLen := 4
+ addrOffset := ipv4SrcAddrOffset
+ if item.key.isV6 {
+ addrLen = 16
+ addrOffset = ipv6SrcAddrOffset
+ }
+ srcAddrAt := offset + addrOffset
+ srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
+ dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
+ psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
+ binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
+ } else {
+ hdr := virtioNetHdr{}
+ err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+ }
+ }
+ }
+ return nil
+}
+
+// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the
+// metadata found in table.
+func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error {
+ for _, items := range table.itemsByFlow {
+ for _, item := range items {
+ if item.numMerged > 0 {
+ hdr := virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
+ hdrLen: uint16(item.iphLen + udphLen),
+ gsoSize: item.gsoSize,
+ csumStart: uint16(item.iphLen),
+ csumOffset: 6,
+ }
+ pkt := bufs[item.bufsIndex][offset:]
+
+ // Recalculate the total len (IPv4) or payload len (IPv6).
+ // Recalculate the (IPv4) header checksum.
+ hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4
+ if item.key.isV6 {
+ binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
+ } else {
+ pkt[10], pkt[11] = 0, 0
+ binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
+ iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
+ binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
+ }
+ err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+
+ // Recalculate the UDP len field value
+ binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:])))
+
+ // Calculate the pseudo header checksum and place it at the UDP
+ // checksum offset. Downstream checksum offloading will combine
+ // this with computation of the udp header and payload checksum.
+ addrLen := 4
+ addrOffset := ipv4SrcAddrOffset
+ if item.key.isV6 {
+ addrLen = 16
+ addrOffset = ipv6SrcAddrOffset
+ }
+ srcAddrAt := offset + addrOffset
+ srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
+ dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
+ psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
+ binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
+ } else {
+ hdr := virtioNetHdr{}
+ err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+ }
+ }
+ }
+ return nil
+}
+
+type groCandidateType uint8
+
+const (
+ notGROCandidate groCandidateType = iota
+ tcp4GROCandidate
+ tcp6GROCandidate
+ udp4GROCandidate
+ udp6GROCandidate
+)
+
+func packetIsGROCandidate(b []byte, canUDPGRO bool) groCandidateType {
+ if len(b) < 28 {
+ return notGROCandidate
+ }
+ if b[0]>>4 == 4 {
+ if b[0]&0x0F != 5 {
+ // IPv4 packets w/IP options do not coalesce
+ return notGROCandidate
+ }
+ if b[9] == unix.IPPROTO_TCP && len(b) >= 40 {
+ return tcp4GROCandidate
+ }
+ if b[9] == unix.IPPROTO_UDP && canUDPGRO {
+ return udp4GROCandidate
+ }
+ } else if b[0]>>4 == 6 {
+ if b[6] == unix.IPPROTO_TCP && len(b) >= 60 {
+ return tcp6GROCandidate
+ }
+ if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && canUDPGRO {
+ return udp6GROCandidate
+ }
+ }
+ return notGROCandidate
+}
+
+const (
+ udphLen = 8
+)
+
+// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with
+// existing packets tracked in table. It returns a groResultNoop when no
+// action was taken, groResultTableInsert when the evaluated packet was
+// inserted into table, and groResultCoalesced when the evaluated packet was
+// coalesced with another packet in table.
+func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult {
+ pkt := bufs[pktI][offset:]
+ if len(pkt) > maxUint16 {
+ // A valid IPv4 or IPv6 packet will never exceed this.
+ return groResultNoop
+ }
+ iphLen := int((pkt[0] & 0x0F) * 4)
+ if isV6 {
+ iphLen = 40
+ ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
+ if ipv6HPayloadLen != len(pkt)-iphLen {
+ return groResultNoop
+ }
+ } else {
+ totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
+ if totalLen != len(pkt) {
+ return groResultNoop
+ }
+ }
+ if len(pkt) < iphLen {
+ return groResultNoop
+ }
+ if len(pkt) < iphLen+udphLen {
+ return groResultNoop
+ }
+ if !isV6 {
+ if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
+ // no GRO support for fragmented segments for now
+ return groResultNoop
+ }
+ }
+ gsoSize := uint16(len(pkt) - udphLen - iphLen)
+ // not a candidate if payload len is 0
+ if gsoSize < 1 {
+ return groResultNoop
+ }
+ srcAddrOffset := ipv4SrcAddrOffset
+ addrLen := 4
+ if isV6 {
+ srcAddrOffset = ipv6SrcAddrOffset
+ addrLen = 16
+ }
+ items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI)
+ if !existing {
+ return groResultTableInsert
+ }
+ // With UDP we only check the last item, otherwise we could reorder packets
+ // for a given flow. We must also always insert a new item, or successfully
+ // coalesce with an existing item, for the same reason.
+ item := items[len(items)-1]
+ can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset)
+ var pktCSumKnownInvalid bool
+ if can == coalesceAppend {
+ result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6)
+ switch result {
+ case coalesceSuccess:
+ table.updateAt(item, len(items)-1)
+ return groResultCoalesced
+ case coalesceItemInvalidCSum:
+ // If the existing item has an invalid csum we take no action. A new
+ // item will be stored after it, and the existing item will never be
+ // revisited as part of future coalescing candidacy checks.
+ case coalescePktInvalidCSum:
+ // We must insert a new item, but we also mark it as invalid csum
+ // to prevent a repeat checksum validation.
+ pktCSumKnownInvalid = true
+ default:
+ }
+ }
+ // failed to coalesce with any other packets; store the item in the flow
+ table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid)
+ return groResultTableInsert
+}
+
+// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
+// packets into toWrite. toWrite, tcpTable, and udpTable should initially be
+// empty (but non-nil), and are passed in to save allocs as the caller may reset
+// and recycle them across vectors of packets. canUDPGRO indicates if UDP GRO is
+// supported.
+func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, canUDPGRO bool, toWrite *[]int) error {
+ for i := range bufs {
+ if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
+ return errors.New("invalid offset")
+ }
+ var result groResult
+ switch packetIsGROCandidate(bufs[i][offset:], canUDPGRO) {
+ case tcp4GROCandidate:
+ result = tcpGRO(bufs, offset, i, tcpTable, false)
+ case tcp6GROCandidate:
+ result = tcpGRO(bufs, offset, i, tcpTable, true)
+ case udp4GROCandidate:
+ result = udpGRO(bufs, offset, i, udpTable, false)
+ case udp6GROCandidate:
+ result = udpGRO(bufs, offset, i, udpTable, true)
+ }
+ switch result {
+ case groResultNoop:
+ hdr := virtioNetHdr{}
+ err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+ fallthrough
+ case groResultTableInsert:
+ *toWrite = append(*toWrite, i)
+ }
+ }
+ errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable)
+ errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable)
+ return errors.Join(errTCP, errUDP)
+}
+
+// gsoSplit splits packets from in into outBuffs, writing the size of each
+// element into sizes. It returns the number of buffers populated, and/or an
+// error.
+func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int, isV6 bool) (int, error) {
+ iphLen := int(hdr.csumStart)
+ srcAddrOffset := ipv6SrcAddrOffset
+ addrLen := 16
+ if !isV6 {
+ in[10], in[11] = 0, 0 // clear ipv4 header checksum
+ srcAddrOffset = ipv4SrcAddrOffset
+ addrLen = 4
+ }
+ transportCsumAt := int(hdr.csumStart + hdr.csumOffset)
+ in[transportCsumAt], in[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum
+ var firstTCPSeqNum uint32
+ var protocol uint8
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV6 {
+ protocol = unix.IPPROTO_TCP
+ firstTCPSeqNum = binary.BigEndian.Uint32(in[hdr.csumStart+4:])
+ } else {
+ protocol = unix.IPPROTO_UDP
+ }
+ nextSegmentDataAt := int(hdr.hdrLen)
+ i := 0
+ for ; nextSegmentDataAt < len(in); i++ {
+ if i == len(outBuffs) {
+ return i - 1, ErrTooManySegments
+ }
+ nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
+ if nextSegmentEnd > len(in) {
+ nextSegmentEnd = len(in)
+ }
+ segmentDataLen := nextSegmentEnd - nextSegmentDataAt
+ totalLen := int(hdr.hdrLen) + segmentDataLen
+ sizes[i] = totalLen
+ out := outBuffs[i][outOffset:]
+
+ copy(out, in[:iphLen])
+ if !isV6 {
+ // For IPv4 we are responsible for incrementing the ID field,
+ // updating the total len field, and recalculating the header
+ // checksum.
+ if i > 0 {
+ id := binary.BigEndian.Uint16(out[4:])
+ id += uint16(i)
+ binary.BigEndian.PutUint16(out[4:], id)
+ }
+ binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
+ ipv4CSum := ^checksum(out[:iphLen], 0)
+ binary.BigEndian.PutUint16(out[10:], ipv4CSum)
+ } else {
+ // For IPv6 we are responsible for updating the payload length field.
+ binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
+ }
+
+ // copy transport header
+ copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
+
+ if protocol == unix.IPPROTO_TCP {
+ // set TCP seq and adjust TCP flags
+ tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
+ binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
+ if nextSegmentEnd != len(in) {
+ // FIN and PSH should only be set on last segment
+ clearFlags := tcpFlagFIN | tcpFlagPSH
+ out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
+ }
+ } else {
+ // set UDP header len
+ binary.BigEndian.PutUint16(out[hdr.csumStart+4:], uint16(segmentDataLen)+(hdr.hdrLen-hdr.csumStart))
+ }
+
+ // payload
+ copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
+
+ // transport checksum
+ transportHeaderLen := int(hdr.hdrLen - hdr.csumStart)
+ lenForPseudo := uint16(transportHeaderLen + segmentDataLen)
+ transportCSumNoFold := pseudoHeaderChecksumNoFold(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo)
+ transportCSum := ^checksum(out[hdr.csumStart:totalLen], transportCSumNoFold)
+ binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], transportCSum)
+
+ nextSegmentDataAt += int(hdr.gsoSize)
+ }
+ return i, nil
+}
+
+func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
+ cSumAt := cSumStart + cSumOffset
+ // The initial value at the checksum offset should be summed with the
+ // checksum we compute. This is typically the pseudo-header checksum.
+ initial := binary.BigEndian.Uint16(in[cSumAt:])
+ in[cSumAt], in[cSumAt+1] = 0, 0
+ binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial)))
+ return nil
+}
diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go
new file mode 100644
index 0000000..ae55c8c
--- /dev/null
+++ b/tun/offload_linux_test.go
@@ -0,0 +1,752 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+import (
+ "net/netip"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/conn"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+const (
+ offset = virtioNetHdrLen
+)
+
+var (
+ ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
+ ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
+ ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
+ ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
+ ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
+ ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
+)
+
+func udp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv4Fields)) []byte {
+ totalLen := 28 + payloadLen
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv4H := header.IPv4(b[offset:])
+ srcAs4 := srcIPPort.Addr().As4()
+ dstAs4 := dstIPPort.Addr().As4()
+ ipFields := &header.IPv4Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
+ Protocol: unix.IPPROTO_UDP,
+ TTL: 64,
+ TotalLength: uint16(totalLen),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv4H.Encode(ipFields)
+ udpH := header.UDP(b[offset+20:])
+ udpH.Encode(&header.UDPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ Length: uint16(payloadLen + udphLen),
+ })
+ ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udphLen+payloadLen))
+ udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func udp6Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
+ return udp6PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
+}
+
+func udp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv6Fields)) []byte {
+ totalLen := 48 + payloadLen
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv6H := header.IPv6(b[offset:])
+ srcAs16 := srcIPPort.Addr().As16()
+ dstAs16 := dstIPPort.Addr().As16()
+ ipFields := &header.IPv6Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
+ TransportProtocol: unix.IPPROTO_UDP,
+ HopLimit: 64,
+ PayloadLength: uint16(payloadLen + udphLen),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv6H.Encode(ipFields)
+ udpH := header.UDP(b[offset+40:])
+ udpH.Encode(&header.UDPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ Length: uint16(payloadLen + udphLen),
+ })
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udphLen+payloadLen))
+ udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func udp4Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
+ return udp4PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
+}
+
+func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte {
+ totalLen := 40 + segmentSize
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv4H := header.IPv4(b[offset:])
+ srcAs4 := srcIPPort.Addr().As4()
+ dstAs4 := dstIPPort.Addr().As4()
+ ipFields := &header.IPv4Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
+ Protocol: unix.IPPROTO_TCP,
+ TTL: 64,
+ TotalLength: uint16(totalLen),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv4H.Encode(ipFields)
+ tcpH := header.TCP(b[offset+20:])
+ tcpH.Encode(&header.TCPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ SeqNum: seq,
+ AckNum: 1,
+ DataOffset: 20,
+ Flags: flags,
+ WindowSize: 3000,
+ })
+ ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
+ tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+ return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
+}
+
+func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte {
+ totalLen := 60 + segmentSize
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv6H := header.IPv6(b[offset:])
+ srcAs16 := srcIPPort.Addr().As16()
+ dstAs16 := dstIPPort.Addr().As16()
+ ipFields := &header.IPv6Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
+ TransportProtocol: unix.IPPROTO_TCP,
+ HopLimit: 64,
+ PayloadLength: uint16(segmentSize + 20),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv6H.Encode(ipFields)
+ tcpH := header.TCP(b[offset+40:])
+ tcpH.Encode(&header.TCPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ SeqNum: seq,
+ AckNum: 1,
+ DataOffset: 20,
+ Flags: flags,
+ WindowSize: 3000,
+ })
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
+ tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+ return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
+}
+
+func Test_handleVirtioRead(t *testing.T) {
+ tests := []struct {
+ name string
+ hdr virtioNetHdr
+ pktIn []byte
+ wantLens []int
+ wantErr bool
+ }{
+ {
+ "tcp4",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
+ gsoSize: 100,
+ hdrLen: 40,
+ csumStart: 20,
+ csumOffset: 16,
+ },
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
+ []int{140, 140},
+ false,
+ },
+ {
+ "tcp6",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
+ gsoSize: 100,
+ hdrLen: 60,
+ csumStart: 40,
+ csumOffset: 16,
+ },
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
+ []int{160, 160},
+ false,
+ },
+ {
+ "udp4",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
+ gsoSize: 100,
+ hdrLen: 28,
+ csumStart: 20,
+ csumOffset: 6,
+ },
+ udp4Packet(ip4PortA, ip4PortB, 200),
+ []int{128, 128},
+ false,
+ },
+ {
+ "udp6",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
+ gsoSize: 100,
+ hdrLen: 48,
+ csumStart: 40,
+ csumOffset: 6,
+ },
+ udp6Packet(ip6PortA, ip6PortB, 200),
+ []int{148, 148},
+ false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ out := make([][]byte, conn.IdealBatchSize)
+ sizes := make([]int, conn.IdealBatchSize)
+ for i := range out {
+ out[i] = make([]byte, 65535)
+ }
+ tt.hdr.encode(tt.pktIn)
+ n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
+ if err != nil {
+ if tt.wantErr {
+ return
+ }
+ t.Fatalf("got err: %v", err)
+ }
+ if n != len(tt.wantLens) {
+ t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
+ }
+ for i := range tt.wantLens {
+ if tt.wantLens[i] != sizes[i] {
+ t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
+ }
+ }
+ })
+ }
+}
+
+func flipTCP4Checksum(b []byte) []byte {
+ at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
+ b[at] ^= 0xFF
+ b[at+1] ^= 0xFF
+ return b
+}
+
+func flipUDP4Checksum(b []byte) []byte {
+ at := virtioNetHdrLen + 20 + 6 // 20 byte ipv4 header; udp csum offset is 6
+ b[at] ^= 0xFF
+ b[at+1] ^= 0xFF
+ return b
+}
+
+func Fuzz_handleGRO(f *testing.F) {
+ pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
+ pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
+ pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
+ pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
+ pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
+ pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
+ pkt6 := udp4Packet(ip4PortA, ip4PortB, 100)
+ pkt7 := udp4Packet(ip4PortA, ip4PortB, 100)
+ pkt8 := udp4Packet(ip4PortA, ip4PortC, 100)
+ pkt9 := udp6Packet(ip6PortA, ip6PortB, 100)
+ pkt10 := udp6Packet(ip6PortA, ip6PortB, 100)
+ pkt11 := udp6Packet(ip6PortA, ip6PortC, 100)
+ f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, true, offset)
+ f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, canUDPGRO bool, offset int) {
+ pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11}
+ toWrite := make([]int, 0, len(pkts))
+ handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, &toWrite)
+ if len(toWrite) > len(pkts) {
+ t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
+ }
+ seenWriteI := make(map[int]bool)
+ for _, writeI := range toWrite {
+ if writeI < 0 || writeI > len(pkts)-1 {
+ t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
+ }
+ if seenWriteI[writeI] {
+ t.Errorf("duplicate toWrite value: %d", writeI)
+ }
+ seenWriteI[writeI] = true
+ }
+ })
+}
+
+func Test_handleGRO(t *testing.T) {
+ tests := []struct {
+ name string
+ pktsIn [][]byte
+ canUDPGRO bool
+ wantToWrite []int
+ wantLens []int
+ wantErr bool
+ }{
+ {
+ "multiple protocols and flows",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
+ tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ },
+ true,
+ []int{0, 1, 2, 4, 5, 7, 9},
+ []int{240, 228, 128, 140, 260, 160, 248},
+ false,
+ },
+ {
+ "multiple protocols and flows no UDP GRO",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
+ tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ },
+ false,
+ []int{0, 1, 2, 4, 5, 7, 8, 9, 10},
+ []int{240, 128, 128, 140, 260, 160, 128, 148, 148},
+ false,
+ },
+ {
+ "PSH interleaved",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1
+ },
+ true,
+ []int{0, 2, 4, 6},
+ []int{240, 240, 260, 260},
+ false,
+ },
+ {
+ "coalesceItemInvalidCSum",
+ [][]byte{
+ flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
+ flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ },
+ true,
+ []int{0, 1, 3, 4},
+ []int{140, 240, 128, 228},
+ false,
+ },
+ {
+ "out of order",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
+ },
+ true,
+ []int{0},
+ []int{340},
+ false,
+ },
+ {
+ "unequal TTL",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.TTL++
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.TTL++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "unequal ToS",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.TOS++
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.TOS++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "unequal flags more fragments set",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.Flags = 1
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.Flags = 1
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "unequal flags DF set",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.Flags = 2
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.Flags = 2
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "ipv6 unequal hop limit",
+ [][]byte{
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
+ tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
+ fields.HopLimit++
+ }),
+ udp6Packet(ip6PortA, ip6PortB, 100),
+ udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
+ fields.HopLimit++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{160, 160, 148, 148},
+ false,
+ },
+ {
+ "ipv6 unequal traffic class",
+ [][]byte{
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
+ tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
+ fields.TrafficClass++
+ }),
+ udp6Packet(ip6PortA, ip6PortB, 100),
+ udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
+ fields.TrafficClass++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{160, 160, 148, 148},
+ false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ toWrite := make([]int, 0, len(tt.pktsIn))
+ err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, &toWrite)
+ if err != nil {
+ if tt.wantErr {
+ return
+ }
+ t.Fatalf("got err: %v", err)
+ }
+ if len(toWrite) != len(tt.wantToWrite) {
+ t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
+ }
+ for i, pktI := range tt.wantToWrite {
+ if tt.wantToWrite[i] != toWrite[i] {
+ t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
+ }
+ if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
+ t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
+ }
+ }
+ })
+ }
+}
+
+func Test_packetIsGROCandidate(t *testing.T) {
+ tcp4 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
+ tcp4TooShort := tcp4[:39]
+ ip4InvalidHeaderLen := make([]byte, len(tcp4))
+ copy(ip4InvalidHeaderLen, tcp4)
+ ip4InvalidHeaderLen[0] = 0x46
+ ip4InvalidProtocol := make([]byte, len(tcp4))
+ copy(ip4InvalidProtocol, tcp4)
+ ip4InvalidProtocol[9] = unix.IPPROTO_GRE
+
+ tcp6 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
+ tcp6TooShort := tcp6[:59]
+ ip6InvalidProtocol := make([]byte, len(tcp6))
+ copy(ip6InvalidProtocol, tcp6)
+ ip6InvalidProtocol[6] = unix.IPPROTO_GRE
+
+ udp4 := udp4Packet(ip4PortA, ip4PortB, 100)[virtioNetHdrLen:]
+ udp4TooShort := udp4[:27]
+
+ udp6 := udp6Packet(ip6PortA, ip6PortB, 100)[virtioNetHdrLen:]
+ udp6TooShort := udp6[:47]
+
+ tests := []struct {
+ name string
+ b []byte
+ canUDPGRO bool
+ want groCandidateType
+ }{
+ {
+ "tcp4",
+ tcp4,
+ true,
+ tcp4GROCandidate,
+ },
+ {
+ "tcp6",
+ tcp6,
+ true,
+ tcp6GROCandidate,
+ },
+ {
+ "udp4",
+ udp4,
+ true,
+ udp4GROCandidate,
+ },
+ {
+ "udp4 no support",
+ udp4,
+ false,
+ notGROCandidate,
+ },
+ {
+ "udp6",
+ udp6,
+ true,
+ udp6GROCandidate,
+ },
+ {
+ "udp6 no support",
+ udp6,
+ false,
+ notGROCandidate,
+ },
+ {
+ "udp4 too short",
+ udp4TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "udp6 too short",
+ udp6TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "tcp4 too short",
+ tcp4TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "tcp6 too short",
+ tcp6TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "invalid IP version",
+ []byte{0x00},
+ true,
+ notGROCandidate,
+ },
+ {
+ "invalid IP header len",
+ ip4InvalidHeaderLen,
+ true,
+ notGROCandidate,
+ },
+ {
+ "ip4 invalid protocol",
+ ip4InvalidProtocol,
+ true,
+ notGROCandidate,
+ },
+ {
+ "ip6 invalid protocol",
+ ip6InvalidProtocol,
+ true,
+ notGROCandidate,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := packetIsGROCandidate(tt.b, tt.canUDPGRO); got != tt.want {
+ t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func Test_udpPacketsCanCoalesce(t *testing.T) {
+ udp4a := udp4Packet(ip4PortA, ip4PortB, 100)
+ udp4b := udp4Packet(ip4PortA, ip4PortB, 100)
+ udp4c := udp4Packet(ip4PortA, ip4PortB, 110)
+
+ type args struct {
+ pkt []byte
+ iphLen uint8
+ gsoSize uint16
+ item udpGROItem
+ bufs [][]byte
+ bufsOffset int
+ }
+ tests := []struct {
+ name string
+ args args
+ want canCoalesce
+ }{
+ {
+ "coalesceAppend equal gso",
+ args{
+ pkt: udp4a[offset:],
+ iphLen: 20,
+ gsoSize: 100,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4a,
+ udp4b,
+ },
+ bufsOffset: offset,
+ },
+ coalesceAppend,
+ },
+ {
+ "coalesceAppend smaller gso",
+ args{
+ pkt: udp4a[offset : len(udp4a)-90],
+ iphLen: 20,
+ gsoSize: 10,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4a,
+ udp4b,
+ },
+ bufsOffset: offset,
+ },
+ coalesceAppend,
+ },
+ {
+ "coalesceUnavailable smaller gso previously appended",
+ args{
+ pkt: udp4a[offset:],
+ iphLen: 20,
+ gsoSize: 100,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4c,
+ udp4b,
+ },
+ bufsOffset: offset,
+ },
+ coalesceUnavailable,
+ },
+ {
+ "coalesceUnavailable larger following smaller",
+ args{
+ pkt: udp4c[offset:],
+ iphLen: 20,
+ gsoSize: 110,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4a,
+ udp4c,
+ },
+ bufsOffset: offset,
+ },
+ coalesceUnavailable,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want {
+ t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/tun/operateonfd.go b/tun/operateonfd.go
index 31747a2..f1beb6d 100644
--- a/tun/operateonfd.go
+++ b/tun/operateonfd.go
@@ -1,8 +1,8 @@
-// +build !windows
+//go:build darwin || freebsd
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
diff --git a/tun/tun.go b/tun/tun.go
index 5395bdb..0ae53d0 100644
--- a/tun/tun.go
+++ b/tun/tun.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -18,12 +18,36 @@ const (
)
type Device interface {
- File() *os.File // returns the file descriptor of the device
- Read([]byte, int) (int, error) // read a packet from the device (without any additional headers)
- Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers)
- Flush() error // flush all previous writes to the device
- MTU() (int, error) // returns the MTU of the device
- Name() (string, error) // fetches and returns the current name
- Events() chan Event // returns a constant channel of events related to the device
- Close() error // stops the device and closes the event channel
+ // File returns the file descriptor of the device.
+ File() *os.File
+
+ // Read one or more packets from the Device (without any additional headers).
+ // On a successful read it returns the number of packets read, and sets
+ // packet lengths within the sizes slice. len(sizes) must be >= len(bufs).
+ // A nonzero offset can be used to instruct the Device on where to begin
+ // reading into each element of the bufs slice.
+ Read(bufs [][]byte, sizes []int, offset int) (n int, err error)
+
+ // Write one or more packets to the device (without any additional headers).
+ // On a successful write it returns the number of packets written. A nonzero
+ // offset can be used to instruct the Device on where to begin writing from
+ // each packet contained within the bufs slice.
+ Write(bufs [][]byte, offset int) (int, error)
+
+ // MTU returns the MTU of the Device.
+ MTU() (int, error)
+
+ // Name returns the current name of the Device.
+ Name() (string, error)
+
+ // Events returns a channel of type Event, which is fed Device events.
+ Events() <-chan Event
+
+ // Close stops the Device and closes the Event channel.
+ Close() error
+
+ // BatchSize returns the preferred/max number of packets that can be read or
+ // written in a single read/write call. BatchSize must not change over the
+ // lifetime of a Device.
+ BatchSize() int
}
diff --git a/tun/tun_darwin.go b/tun/tun_darwin.go
index 6d2e6dd..c9a6c0b 100644
--- a/tun/tun_darwin.go
+++ b/tun/tun_darwin.go
@@ -1,46 +1,46 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
+ "errors"
"fmt"
- "io/ioutil"
+ "io"
"net"
"os"
+ "sync"
"syscall"
+ "time"
"unsafe"
- "golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
)
const utunControlName = "com.apple.net.utun_control"
-// _CTLIOCGINFO value derived from /usr/include/sys/{kern_control,ioccom}.h
-const _CTLIOCGINFO = (0x40000000 | 0x80000000) | ((100 & 0x1fff) << 16) | uint32(byte('N'))<<8 | 3
-
-// sockaddr_ctl specifeid in /usr/include/sys/kern_control.h
-type sockaddrCtl struct {
- scLen uint8
- scFamily uint8
- ssSysaddr uint16
- scID uint32
- scUnit uint32
- scReserved [5]uint32
-}
-
type NativeTun struct {
name string
tunFile *os.File
events chan Event
errors chan error
routeSocket int
+ closeOnce sync.Once
}
-var sockaddrCtlSize uintptr = 32
+func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
+ for i := 0; i < 20; i++ {
+ iface, err = net.InterfaceByIndex(index)
+ if err != nil && errors.Is(err, unix.ENOMEM) {
+ time.Sleep(time.Duration(i) * time.Second / 3)
+ continue
+ }
+ return iface, err
+ }
+ return nil, err
+}
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
var (
@@ -55,7 +55,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
retry:
n, err := unix.Read(tun.routeSocket, data)
if err != nil {
- if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
+ if errno, ok := err.(unix.Errno); ok && errno == unix.EINTR {
goto retry
}
tun.errors <- err
@@ -74,7 +74,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
continue
}
- iface, err := net.InterfaceByIndex(ifindex)
+ iface, err := retryInterfaceByIndex(ifindex)
if err != nil {
tun.errors <- err
return
@@ -107,53 +107,33 @@ func CreateTUN(name string, mtu int) (Device, error) {
}
}
- fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2)
-
+ fd, err := socketCloexec(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2)
if err != nil {
return nil, err
}
- var ctlInfo = &struct {
- ctlID uint32
- ctlName [96]byte
- }{}
-
- copy(ctlInfo.ctlName[:], []byte(utunControlName))
-
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(fd),
- uintptr(_CTLIOCGINFO),
- uintptr(unsafe.Pointer(ctlInfo)),
- )
-
- if errno != 0 {
- return nil, fmt.Errorf("_CTLIOCGINFO: %v", errno)
+ ctlInfo := &unix.CtlInfo{}
+ copy(ctlInfo.Name[:], []byte(utunControlName))
+ err = unix.IoctlCtlInfo(fd, ctlInfo)
+ if err != nil {
+ unix.Close(fd)
+ return nil, fmt.Errorf("IoctlGetCtlInfo: %w", err)
}
- sc := sockaddrCtl{
- scLen: uint8(sockaddrCtlSize),
- scFamily: unix.AF_SYSTEM,
- ssSysaddr: 2,
- scID: ctlInfo.ctlID,
- scUnit: uint32(ifIndex) + 1,
+ sc := &unix.SockaddrCtl{
+ ID: ctlInfo.Id,
+ Unit: uint32(ifIndex) + 1,
}
- scPointer := unsafe.Pointer(&sc)
-
- _, _, errno = unix.RawSyscall(
- unix.SYS_CONNECT,
- uintptr(fd),
- uintptr(scPointer),
- uintptr(sockaddrCtlSize),
- )
-
- if errno != 0 {
- return nil, fmt.Errorf("SYS_CONNECT: %v", errno)
+ err = unix.Connect(fd, sc)
+ if err != nil {
+ unix.Close(fd)
+ return nil, err
}
- err = syscall.SetNonblock(fd, true)
+ err = unix.SetNonblock(fd, true)
if err != nil {
+ unix.Close(fd)
return nil, err
}
tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), ""), mtu)
@@ -161,7 +141,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
if err == nil && name == "utun" {
fname := os.Getenv("WG_TUN_NAME_FILE")
if fname != "" {
- ioutil.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0400)
+ os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400)
}
}
@@ -193,7 +173,7 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return nil, err
}
- tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+ tun.routeSocket, err = socketCloexec(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
tun.tunFile.Close()
return nil, err
@@ -213,27 +193,19 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
}
func (tun *NativeTun) Name() (string, error) {
- var ifName struct {
- name [16]byte
- }
- ifNameSize := uintptr(16)
-
- var errno syscall.Errno
+ var err error
tun.operateOnFd(func(fd uintptr) {
- _, _, errno = unix.Syscall6(
- unix.SYS_GETSOCKOPT,
- fd,
+ tun.name, err = unix.GetsockoptString(
+ int(fd),
2, /* #define SYSPROTO_CONTROL 2 */
2, /* #define UTUN_OPT_IFNAME 2 */
- uintptr(unsafe.Pointer(&ifName)),
- uintptr(unsafe.Pointer(&ifNameSize)), 0)
+ )
})
- if errno != 0 {
- return "", fmt.Errorf("SYS_GETSOCKOPT: %v", errno)
+ if err != nil {
+ return "", fmt.Errorf("GetSockoptString: %w", err)
}
- tun.name = string(ifName.name[:ifNameSize-1])
return tun.name, nil
}
@@ -241,61 +213,63 @@ func (tun *NativeTun) File() *os.File {
return tun.tunFile
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
+ // TODO: the BSDs look very similar in Read() and Write(). They should be
+ // collapsed, with platform-specific files containing the varying parts of
+ // their implementations.
select {
case err := <-tun.errors:
return 0, err
default:
- buff := buff[offset-4:]
- n, err := tun.tunFile.Read(buff[:])
+ buf := bufs[0][offset-4:]
+ n, err := tun.tunFile.Read(buf[:])
if n < 4 {
return 0, err
}
- return n - 4, err
+ sizes[0] = n - 4
+ return 1, err
}
}
-func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
-
- // reserve space for header
-
- buff = buff[offset-4:]
-
- // add packet information header
-
- buff[0] = 0x00
- buff[1] = 0x00
- buff[2] = 0x00
-
- if buff[4]>>4 == ipv6.Version {
- buff[3] = unix.AF_INET6
- } else {
- buff[3] = unix.AF_INET
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
+ if offset < 4 {
+ return 0, io.ErrShortBuffer
}
-
- // write
-
- return tun.tunFile.Write(buff)
-}
-
-func (tun *NativeTun) Flush() error {
- // TODO: can flushing be implemented by buffering and using sendmmsg?
- return nil
+ for i, buf := range bufs {
+ buf = buf[offset-4:]
+ buf[0] = 0x00
+ buf[1] = 0x00
+ buf[2] = 0x00
+ switch buf[4] >> 4 {
+ case 4:
+ buf[3] = unix.AF_INET
+ case 6:
+ buf[3] = unix.AF_INET6
+ default:
+ return i, unix.EAFNOSUPPORT
+ }
+ if _, err := tun.tunFile.Write(buf); err != nil {
+ return i, err
+ }
+ }
+ return len(bufs), nil
}
func (tun *NativeTun) Close() error {
- var err2 error
- err1 := tun.tunFile.Close()
- if tun.routeSocket != -1 {
- unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
- err2 = unix.Close(tun.routeSocket)
- } else if tun.events != nil {
- close(tun.events)
- }
+ var err1, err2 error
+ tun.closeOnce.Do(func() {
+ err1 = tun.tunFile.Close()
+ if tun.routeSocket != -1 {
+ unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
+ err2 = unix.Close(tun.routeSocket)
+ } else if tun.events != nil {
+ close(tun.events)
+ }
+ })
if err1 != nil {
return err1
}
@@ -303,71 +277,60 @@ func (tun *NativeTun) Close() error {
}
func (tun *NativeTun) setMTU(n int) error {
-
- // open datagram socket
-
- var fd int
-
- fd, err := unix.Socket(
+ fd, err := socketCloexec(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
-
if err != nil {
return err
}
defer unix.Close(fd)
- // do ioctl call
-
- var ifr [32]byte
- copy(ifr[:], tun.name)
- *(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(fd),
- uintptr(unix.SIOCSIFMTU),
- uintptr(unsafe.Pointer(&ifr[0])),
- )
-
- if errno != 0 {
- return fmt.Errorf("failed to set MTU on %s", tun.name)
+ var ifr unix.IfreqMTU
+ copy(ifr.Name[:], tun.name)
+ ifr.MTU = int32(n)
+ err = unix.IoctlSetIfreqMTU(fd, &ifr)
+ if err != nil {
+ return fmt.Errorf("failed to set MTU on %s: %w", tun.name, err)
}
return nil
}
func (tun *NativeTun) MTU() (int, error) {
-
- // open datagram socket
-
- fd, err := unix.Socket(
+ fd, err := socketCloexec(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
-
if err != nil {
return 0, err
}
defer unix.Close(fd)
- // do ioctl call
-
- var ifr [64]byte
- copy(ifr[:], tun.name)
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(fd),
- uintptr(unix.SIOCGIFMTU),
- uintptr(unsafe.Pointer(&ifr[0])),
- )
- if errno != 0 {
- return 0, fmt.Errorf("failed to get MTU on %s", tun.name)
+ ifr, err := unix.IoctlGetIfreqMTU(fd, tun.name)
+ if err != nil {
+ return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, err)
}
- return int(*(*int32)(unsafe.Pointer(&ifr[16]))), nil
+ return int(ifr.MTU), nil
+}
+
+func (tun *NativeTun) BatchSize() int {
+ return 1
+}
+
+func socketCloexec(family, sotype, proto int) (fd int, err error) {
+ // See go/src/net/sys_cloexec.go for background.
+ syscall.ForkLock.RLock()
+ defer syscall.ForkLock.RUnlock()
+
+ fd, err = unix.Socket(family, sotype, proto)
+ if err == nil {
+ unix.CloseOnExec(fd)
+ }
+ return
}
diff --git a/tun/tun_freebsd.go b/tun/tun_freebsd.go
index 6cf9313..7c65fd9 100644
--- a/tun/tun_freebsd.go
+++ b/tun/tun_freebsd.go
@@ -1,66 +1,57 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
- "bytes"
"errors"
"fmt"
+ "io"
"net"
"os"
+ "sync"
"syscall"
"unsafe"
- "golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
)
-// _TUNSIFHEAD, value derived from sys/net/{if_tun,ioccom}.h
-// const _TUNSIFHEAD = ((0x80000000) | (((4) & ((1 << 13) - 1) ) << 16) | (uint32(byte('t')) << 8) | (96))
const (
_TUNSIFHEAD = 0x80047460
_TUNSIFMODE = 0x8004745e
+ _TUNGIFNAME = 0x4020745d
_TUNSIFPID = 0x2000745f
-)
-// TODO: move into x/sys/unix
-const (
- SIOCGIFINFO_IN6 = 0xc048696c
- SIOCSIFINFO_IN6 = 0xc048696d
- ND6_IFF_AUTO_LINKLOCAL = 0x20
- ND6_IFF_NO_DAD = 0x100
+ _SIOCGIFINFO_IN6 = 0xc048696c
+ _SIOCSIFINFO_IN6 = 0xc048696d
+ _ND6_IFF_AUTO_LINKLOCAL = 0x20
+ _ND6_IFF_NO_DAD = 0x100
)
-// Iface status string max len
-const _IFSTATMAX = 800
-
-const SIZEOF_UINTPTR = 4 << (^uintptr(0) >> 32 & 1)
+// Iface requests with just the name
+type ifreqName struct {
+ Name [unix.IFNAMSIZ]byte
+ _ [16]byte
+}
-// structure for iface requests with a pointer
-type ifreq_ptr struct {
+// Iface requests with a pointer
+type ifreqPtr struct {
Name [unix.IFNAMSIZ]byte
Data uintptr
- Pad0 [16 - SIZEOF_UINTPTR]byte
+ _ [16 - unsafe.Sizeof(uintptr(0))]byte
}
-// Structure for iface mtu get/set ioctls
-type ifreq_mtu struct {
+// Iface requests with MTU
+type ifreqMtu struct {
Name [unix.IFNAMSIZ]byte
MTU uint32
- Pad0 [12]byte
-}
-
-// Structure for interface status request ioctl
-type ifstat struct {
- IfsName [unix.IFNAMSIZ]byte
- Ascii [_IFSTATMAX]byte
+ _ [12]byte
}
-// Structures for nd6 flag manipulation
-type in6_ndireq struct {
+// ND6 flag manipulation
+type nd6Req struct {
Name [unix.IFNAMSIZ]byte
Linkmtu uint32
Maxmtu uint32
@@ -82,6 +73,7 @@ type NativeTun struct {
events chan Event
errors chan error
routeSocket int
+ closeOnce sync.Once
}
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
@@ -97,7 +89,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
retry:
n, err := unix.Read(tun.routeSocket, data)
if err != nil {
- if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
+ if errors.Is(err, syscall.EINTR) {
goto retry
}
tun.errors <- err
@@ -141,91 +133,17 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
}
func tunName(fd uintptr) (string, error) {
- //Terrible hack to make up for freebsd not having a TUNGIFNAME
-
- //First, make sure the tun pid matches this proc's pid
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(fd),
- uintptr(_TUNSIFPID),
- uintptr(0),
- )
-
- if errno != 0 {
- return "", fmt.Errorf("failed to set tun device PID: %s", errno.Error())
- }
-
- // Open iface control socket
-
- confd, err := unix.Socket(
- unix.AF_INET,
- unix.SOCK_DGRAM,
- 0,
- )
-
- if err != nil {
+ var ifreq ifreqName
+ _, _, err := unix.Syscall(unix.SYS_IOCTL, fd, _TUNGIFNAME, uintptr(unsafe.Pointer(&ifreq)))
+ if err != 0 {
return "", err
}
-
- defer unix.Close(confd)
-
- procPid := os.Getpid()
-
- //Try to find interface with matching PID
- for i := 1; ; i++ {
- iface, _ := net.InterfaceByIndex(i)
- if err != nil || iface == nil {
- break
- }
-
- // Structs for getting data in and out of SIOCGIFSTATUS ioctl
- var ifstatus ifstat
- copy(ifstatus.IfsName[:], iface.Name)
-
- // Make the syscall to get the status string
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(confd),
- uintptr(unix.SIOCGIFSTATUS),
- uintptr(unsafe.Pointer(&ifstatus)),
- )
-
- if errno != 0 {
- continue
- }
-
- nullStr := ifstatus.Ascii[:]
- i := bytes.IndexByte(nullStr, 0)
- if i < 1 {
- continue
- }
- statStr := string(nullStr[:i])
- var pidNum int = 0
-
- // Finally get the owning PID
- // Format string taken from sys/net/if_tun.c
- _, err := fmt.Sscanf(statStr, "\tOpened by PID %d\n", &pidNum)
- if err != nil {
- continue
- }
-
- if pidNum == procPid {
- return iface.Name, nil
- }
- }
-
- return "", nil
+ return unix.ByteSliceToString(ifreq.Name[:]), nil
}
// Destroy a named system interface
func tunDestroy(name string) error {
- // Open control socket.
- var fd int
- fd, err := unix.Socket(
- unix.AF_INET,
- unix.SOCK_DGRAM,
- 0,
- )
+ fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
if err != nil {
return err
}
@@ -233,14 +151,9 @@ func tunDestroy(name string) error {
var ifr [32]byte
copy(ifr[:], name)
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(fd),
- uintptr(unix.SIOCIFDESTROY),
- uintptr(unsafe.Pointer(&ifr[0])),
- )
+ _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCIFDESTROY), uintptr(unsafe.Pointer(&ifr[0])))
if errno != 0 {
- return fmt.Errorf("failed to destroy interface %s: %s", name, errno.Error())
+ return fmt.Errorf("failed to destroy interface %s: %w", name, errno)
}
return nil
@@ -257,7 +170,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
return nil, fmt.Errorf("interface %s already exists", name)
}
- tunFile, err := os.OpenFile("/dev/tun", unix.O_RDWR, 0)
+ tunFile, err := os.OpenFile("/dev/tun", unix.O_RDWR|unix.O_CLOEXEC, 0)
if err != nil {
return nil, err
}
@@ -276,103 +189,94 @@ func CreateTUN(name string, mtu int) (Device, error) {
ifheadmode := 1
var errno syscall.Errno
tun.operateOnFd(func(fd uintptr) {
- _, _, errno = unix.Syscall(
- unix.SYS_IOCTL,
- fd,
- uintptr(_TUNSIFHEAD),
- uintptr(unsafe.Pointer(&ifheadmode)),
- )
+ _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFHEAD, uintptr(unsafe.Pointer(&ifheadmode)))
})
if errno != 0 {
tunFile.Close()
tunDestroy(assignedName)
- return nil, fmt.Errorf("Unable to put into IFHEAD mode: %v", errno)
+ return nil, fmt.Errorf("unable to put into IFHEAD mode: %w", errno)
}
- // Open control sockets
- confd, err := unix.Socket(
- unix.AF_INET,
- unix.SOCK_DGRAM,
- 0,
- )
- if err != nil {
+ // Get out of PTP mode.
+ ifflags := syscall.IFF_BROADCAST | syscall.IFF_MULTICAST
+ tun.operateOnFd(func(fd uintptr) {
+ _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, uintptr(_TUNSIFMODE), uintptr(unsafe.Pointer(&ifflags)))
+ })
+
+ if errno != 0 {
tunFile.Close()
tunDestroy(assignedName)
- return nil, err
+ return nil, fmt.Errorf("unable to put into IFF_BROADCAST mode: %w", errno)
}
- defer unix.Close(confd)
- confd6, err := unix.Socket(
- unix.AF_INET6,
- unix.SOCK_DGRAM,
- 0,
- )
+
+ // Disable link-local v6, not just because WireGuard doesn't do that anyway, but
+ // also because there are serious races with attaching and detaching LLv6 addresses
+ // in relation to interface lifetime within the FreeBSD kernel.
+ confd6, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
if err != nil {
tunFile.Close()
tunDestroy(assignedName)
return nil, err
}
defer unix.Close(confd6)
-
- // Disable link-local v6, not just because WireGuard doesn't do that anyway, but
- // also because there are serious races with attaching and detaching LLv6 addresses
- // in relation to interface lifetime within the FreeBSD kernel.
- var ndireq in6_ndireq
+ var ndireq nd6Req
copy(ndireq.Name[:], assignedName)
- _, _, errno = unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(confd6),
- uintptr(SIOCGIFINFO_IN6),
- uintptr(unsafe.Pointer(&ndireq)),
- )
+ _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCGIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq)))
if errno != 0 {
tunFile.Close()
tunDestroy(assignedName)
- return nil, fmt.Errorf("Unable to get nd6 flags for %s: %v", assignedName, errno)
+ return nil, fmt.Errorf("unable to get nd6 flags for %s: %w", assignedName, errno)
}
- ndireq.Flags = ndireq.Flags &^ ND6_IFF_AUTO_LINKLOCAL
- ndireq.Flags = ndireq.Flags | ND6_IFF_NO_DAD
- _, _, errno = unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(confd6),
- uintptr(SIOCSIFINFO_IN6),
- uintptr(unsafe.Pointer(&ndireq)),
- )
+ ndireq.Flags = ndireq.Flags &^ _ND6_IFF_AUTO_LINKLOCAL
+ ndireq.Flags = ndireq.Flags | _ND6_IFF_NO_DAD
+ _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCSIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq)))
if errno != 0 {
tunFile.Close()
tunDestroy(assignedName)
- return nil, fmt.Errorf("Unable to set nd6 flags for %s: %v", assignedName, errno)
+ return nil, fmt.Errorf("unable to set nd6 flags for %s: %w", assignedName, errno)
}
- // Rename the interface
- var newnp [unix.IFNAMSIZ]byte
- copy(newnp[:], name)
- var ifr ifreq_ptr
- copy(ifr.Name[:], assignedName)
- ifr.Data = uintptr(unsafe.Pointer(&newnp[0]))
- _, _, errno = unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(confd),
- uintptr(unix.SIOCSIFNAME),
- uintptr(unsafe.Pointer(&ifr)),
- )
- if errno != 0 {
- tunFile.Close()
- tunDestroy(assignedName)
- return nil, fmt.Errorf("Failed to rename %s to %s: %v", assignedName, name, errno)
+ if name != "" {
+ confd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
+ if err != nil {
+ tunFile.Close()
+ tunDestroy(assignedName)
+ return nil, err
+ }
+ defer unix.Close(confd)
+ var newnp [unix.IFNAMSIZ]byte
+ copy(newnp[:], name)
+ var ifr ifreqPtr
+ copy(ifr.Name[:], assignedName)
+ ifr.Data = uintptr(unsafe.Pointer(&newnp[0]))
+ _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd), uintptr(unix.SIOCSIFNAME), uintptr(unsafe.Pointer(&ifr)))
+ if errno != 0 {
+ tunFile.Close()
+ tunDestroy(assignedName)
+ return nil, fmt.Errorf("Failed to rename %s to %s: %w", assignedName, name, errno)
+ }
}
return CreateTUNFromFile(tunFile, mtu)
}
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
-
tun := &NativeTun{
tunFile: file,
events: make(chan Event, 10),
errors: make(chan error, 1),
}
+ var errno syscall.Errno
+ tun.operateOnFd(func(fd uintptr) {
+ _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFPID, uintptr(0))
+ })
+ if errno != 0 {
+ tun.tunFile.Close()
+ return nil, fmt.Errorf("unable to become controlling TUN process: %w", errno)
+ }
+
name, err := tun.Name()
if err != nil {
tun.tunFile.Close()
@@ -391,7 +295,7 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return nil, err
}
- tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+ tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC)
if err != nil {
tun.tunFile.Close()
return nil, err
@@ -425,63 +329,65 @@ func (tun *NativeTun) File() *os.File {
return tun.tunFile
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
select {
case err := <-tun.errors:
return 0, err
default:
- buff := buff[offset-4:]
- n, err := tun.tunFile.Read(buff[:])
+ buf := bufs[0][offset-4:]
+ n, err := tun.tunFile.Read(buf[:])
if n < 4 {
return 0, err
}
- return n - 4, err
+ sizes[0] = n - 4
+ return 1, err
}
}
-func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
-
- // reserve space for header
-
- buff = buff[offset-4:]
-
- // add packet information header
-
- buff[0] = 0x00
- buff[1] = 0x00
- buff[2] = 0x00
-
- if buff[4]>>4 == ipv6.Version {
- buff[3] = unix.AF_INET6
- } else {
- buff[3] = unix.AF_INET
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
+ if offset < 4 {
+ return 0, io.ErrShortBuffer
}
-
- // write
-
- return tun.tunFile.Write(buff)
-}
-
-func (tun *NativeTun) Flush() error {
- // TODO: can flushing be implemented by buffering and using sendmmsg?
- return nil
+ for i, buf := range bufs {
+ buf = buf[offset-4:]
+ if len(buf) < 5 {
+ return i, io.ErrShortBuffer
+ }
+ buf[0] = 0x00
+ buf[1] = 0x00
+ buf[2] = 0x00
+ switch buf[4] >> 4 {
+ case 4:
+ buf[3] = unix.AF_INET
+ case 6:
+ buf[3] = unix.AF_INET6
+ default:
+ return i, unix.EAFNOSUPPORT
+ }
+ if _, err := tun.tunFile.Write(buf); err != nil {
+ return i, err
+ }
+ }
+ return len(bufs), nil
}
func (tun *NativeTun) Close() error {
- var err3 error
- err1 := tun.tunFile.Close()
- err2 := tunDestroy(tun.name)
- if tun.routeSocket != -1 {
- unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
- err3 = unix.Close(tun.routeSocket)
- tun.routeSocket = -1
- } else if tun.events != nil {
- close(tun.events)
- }
+ var err1, err2, err3 error
+ tun.closeOnce.Do(func() {
+ err1 = tun.tunFile.Close()
+ err2 = tunDestroy(tun.name)
+ if tun.routeSocket != -1 {
+ unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
+ err3 = unix.Close(tun.routeSocket)
+ tun.routeSocket = -1
+ } else if tun.events != nil {
+ close(tun.events)
+ }
+ })
if err1 != nil {
return err1
}
@@ -492,70 +398,38 @@ func (tun *NativeTun) Close() error {
}
func (tun *NativeTun) setMTU(n int) error {
- // open datagram socket
-
- var fd int
-
- fd, err := unix.Socket(
- unix.AF_INET,
- unix.SOCK_DGRAM,
- 0,
- )
-
+ fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
if err != nil {
return err
}
-
defer unix.Close(fd)
- // do ioctl call
-
- var ifr ifreq_mtu
+ var ifr ifreqMtu
copy(ifr.Name[:], tun.name)
ifr.MTU = uint32(n)
-
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(fd),
- uintptr(unix.SIOCSIFMTU),
- uintptr(unsafe.Pointer(&ifr)),
- )
-
+ _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCSIFMTU), uintptr(unsafe.Pointer(&ifr)))
if errno != 0 {
- return fmt.Errorf("failed to set MTU on %s", tun.name)
+ return fmt.Errorf("failed to set MTU on %s: %w", tun.name, errno)
}
-
return nil
}
func (tun *NativeTun) MTU() (int, error) {
- // open datagram socket
-
- fd, err := unix.Socket(
- unix.AF_INET,
- unix.SOCK_DGRAM,
- 0,
- )
-
+ fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
if err != nil {
return 0, err
}
-
defer unix.Close(fd)
- // do ioctl call
- var ifr ifreq_mtu
+ var ifr ifreqMtu
copy(ifr.Name[:], tun.name)
-
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(fd),
- uintptr(unix.SIOCGIFMTU),
- uintptr(unsafe.Pointer(&ifr)),
- )
+ _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCGIFMTU), uintptr(unsafe.Pointer(&ifr)))
if errno != 0 {
- return 0, fmt.Errorf("failed to get MTU on %s", tun.name)
+ return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, errno)
}
-
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
}
+
+func (tun *NativeTun) BatchSize() int {
+ return 1
+}
diff --git a/tun/tun_linux.go b/tun/tun_linux.go
index 61902e9..bd69cb5 100644
--- a/tun/tun_linux.go
+++ b/tun/tun_linux.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -9,18 +9,16 @@ package tun
*/
import (
- "bytes"
"errors"
"fmt"
- "net"
"os"
"sync"
"syscall"
"time"
"unsafe"
- "golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel"
)
@@ -32,14 +30,29 @@ const (
type NativeTun struct {
tunFile *os.File
index int32 // if index
- name string // name of interface
errors chan error // async error handling
events chan Event // device related events
- nopi bool // the device was pased IFF_NO_PI
netlinkSock int
netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex
statusListenersShutdown chan struct{}
+ batchSize int
+ vnetHdr bool
+ udpGSO bool
+
+ closeOnce sync.Once
+
+ nameOnce sync.Once // guards calling initNameCache, which sets following fields
+ nameCache string // name of interface
+ nameErr error
+
+ readOpMu sync.Mutex // readOpMu guards readBuff
+ readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
+
+ writeOpMu sync.Mutex // writeOpMu guards toWrite, tcpGROTable
+ toWrite []int
+ tcpGROTable *tcpGROTable
+ udpGROTable *udpGROTable
}
func (tun *NativeTun) File() *os.File {
@@ -51,6 +64,11 @@ func (tun *NativeTun) routineHackListener() {
/* This is needed for the detection to work across network namespaces
* If you are reading this and know a better method, please get in touch.
*/
+ last := 0
+ const (
+ up = 1
+ down = 2
+ )
for {
sysconn, err := tun.tunFile.SyscallConn()
if err != nil {
@@ -64,14 +82,25 @@ func (tun *NativeTun) routineHackListener() {
}
switch err {
case unix.EINVAL:
- tun.events <- EventUp
+ if last != up {
+ // If the tunnel is up, it reports that write() is
+ // allowed but we provided invalid data.
+ tun.events <- EventUp
+ last = up
+ }
case unix.EIO:
- tun.events <- EventDown
+ if last != down {
+ // If the tunnel is down, it reports that no I/O
+ // is possible, without checking our provided data.
+ tun.events <- EventDown
+ last = down
+ }
default:
return
}
select {
case <-time.After(time.Second):
+ // nothing
case <-tun.statusListenersShutdown:
return
}
@@ -79,13 +108,13 @@ func (tun *NativeTun) routineHackListener() {
}
func createNetlinkSocket() (int, error) {
- sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
+ sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
- Groups: uint32((1 << (unix.RTNLGRP_LINK - 1)) | (1 << (unix.RTNLGRP_IPV4_IFADDR - 1)) | (1 << (unix.RTNLGRP_IPV6_IFADDR - 1))),
+ Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
}
err = unix.Bind(sock, saddr)
if err != nil {
@@ -99,10 +128,10 @@ func (tun *NativeTun) routineNetlinkListener() {
unix.Close(tun.netlinkSock)
tun.hackListenerClosed.Lock()
close(tun.events)
+ tun.netlinkCancel.Close()
}()
for msg := make([]byte, 1<<16); ; {
-
var err error
var msgn int
for {
@@ -111,12 +140,12 @@ func (tun *NativeTun) routineNetlinkListener() {
break
}
if !tun.netlinkCancel.ReadyRead() {
- tun.errors <- fmt.Errorf("netlink socket closed: %s", err.Error())
+ tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
return
}
}
if err != nil {
- tun.errors <- fmt.Errorf("failed to receive netlink message: %s", err.Error())
+ tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
return
}
@@ -126,6 +155,7 @@ func (tun *NativeTun) routineNetlinkListener() {
default:
}
+ wasEverUp := false
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
@@ -149,10 +179,16 @@ func (tun *NativeTun) routineNetlinkListener() {
if info.Flags&unix.IFF_RUNNING != 0 {
tun.events <- EventUp
+ wasEverUp = true
}
if info.Flags&unix.IFF_RUNNING == 0 {
- tun.events <- EventDown
+ // Don't emit EventDown before we've ever emitted EventUp.
+ // This avoids a startup race with HackListener, which
+ // might detect Up before we have finished reporting Down.
+ if wasEverUp {
+ tun.events <- EventDown
+ }
}
tun.events <- EventMTUUpdate
@@ -164,15 +200,10 @@ func (tun *NativeTun) routineNetlinkListener() {
}
}
-func (tun *NativeTun) isUp() (bool, error) {
- inter, err := net.InterfaceByName(tun.name)
- return inter.Flags&net.FlagUp != 0, err
-}
-
func getIFIndex(name string) (int32, error) {
fd, err := unix.Socket(
unix.AF_INET,
- unix.SOCK_DGRAM,
+ unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
if err != nil {
@@ -198,13 +229,17 @@ func getIFIndex(name string) (int32, error) {
}
func (tun *NativeTun) setMTU(n int) error {
+ name, err := tun.Name()
+ if err != nil {
+ return err
+ }
+
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
- unix.SOCK_DGRAM,
+ unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
-
if err != nil {
return err
}
@@ -212,9 +247,8 @@ func (tun *NativeTun) setMTU(n int) error {
defer unix.Close(fd)
// do ioctl call
-
var ifr [ifReqSize]byte
- copy(ifr[:], tun.name)
+ copy(ifr[:], name)
*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
@@ -224,20 +258,24 @@ func (tun *NativeTun) setMTU(n int) error {
)
if errno != 0 {
- return errors.New("failed to set MTU of TUN device")
+ return fmt.Errorf("failed to set MTU of TUN device: %w", errno)
}
return nil
}
func (tun *NativeTun) MTU() (int, error) {
+ name, err := tun.Name()
+ if err != nil {
+ return 0, err
+ }
+
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
- unix.SOCK_DGRAM,
+ unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
-
if err != nil {
return 0, err
}
@@ -247,7 +285,7 @@ func (tun *NativeTun) MTU() (int, error) {
// do ioctl call
var ifr [ifReqSize]byte
- copy(ifr[:], tun.name)
+ copy(ifr[:], name)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
@@ -255,13 +293,22 @@ func (tun *NativeTun) MTU() (int, error) {
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
- return 0, errors.New("failed to get MTU of TUN device: " + errno.Error())
+ return 0, fmt.Errorf("failed to get MTU of TUN device: %w", errno)
}
return int(*(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil
}
func (tun *NativeTun) Name() (string, error) {
+ tun.nameOnce.Do(tun.initNameCache)
+ return tun.nameCache, tun.nameErr
+}
+
+func (tun *NativeTun) initNameCache() {
+ tun.nameCache, tun.nameErr = tun.nameSlow()
+}
+
+func (tun *NativeTun) nameSlow() (string, error) {
sysconn, err := tun.tunFile.SyscallConn()
if err != nil {
return "", err
@@ -277,147 +324,287 @@ func (tun *NativeTun) Name() (string, error) {
)
})
if err != nil {
- return "", errors.New("failed to get name of TUN device: " + err.Error())
+ return "", fmt.Errorf("failed to get name of TUN device: %w", err)
}
if errno != 0 {
- return "", errors.New("failed to get name of TUN device: " + errno.Error())
- }
- nullStr := ifr[:]
- i := bytes.IndexByte(nullStr, 0)
- if i != -1 {
- nullStr = nullStr[:i]
+ return "", fmt.Errorf("failed to get name of TUN device: %w", errno)
}
- tun.name = string(nullStr)
- return tun.name, nil
+ return unix.ByteSliceToString(ifr[:]), nil
}
-func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
-
- if tun.nopi {
- buff = buff[offset:]
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
+ tun.writeOpMu.Lock()
+ defer func() {
+ tun.tcpGROTable.reset()
+ tun.udpGROTable.reset()
+ tun.writeOpMu.Unlock()
+ }()
+ var (
+ errs error
+ total int
+ )
+ tun.toWrite = tun.toWrite[:0]
+ if tun.vnetHdr {
+ err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, &tun.toWrite)
+ if err != nil {
+ return 0, err
+ }
+ offset -= virtioNetHdrLen
} else {
- // reserve space for header
+ for i := range bufs {
+ tun.toWrite = append(tun.toWrite, i)
+ }
+ }
+ for _, bufsI := range tun.toWrite {
+ n, err := tun.tunFile.Write(bufs[bufsI][offset:])
+ if errors.Is(err, syscall.EBADFD) {
+ return total, os.ErrClosed
+ }
+ if err != nil {
+ errs = errors.Join(errs, err)
+ } else {
+ total += n
+ }
+ }
+ return total, errs
+}
- buff = buff[offset-4:]
+// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
+// each buffer. It mutates sizes to reflect the size of each element of bufs,
+// and returns the number of packets read.
+func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
+ var hdr virtioNetHdr
+ err := hdr.decode(in)
+ if err != nil {
+ return 0, err
+ }
+ in = in[virtioNetHdrLen:]
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
+ if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
+ // This means CHECKSUM_PARTIAL in skb context. We are responsible
+ // for computing the checksum starting at hdr.csumStart and placing
+ // at hdr.csumOffset.
+ err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset)
+ if err != nil {
+ return 0, err
+ }
+ }
+ if len(in) > len(bufs[0][offset:]) {
+ return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:]))
+ }
+ n := copy(bufs[0][offset:], in)
+ sizes[0] = n
+ return 1, nil
+ }
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+ return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
+ }
- // add packet information header
+ ipVersion := in[0] >> 4
+ switch ipVersion {
+ case 4:
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+ return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
+ }
+ case 6:
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+ return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
+ }
+ default:
+ return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
+ }
- buff[0] = 0x00
- buff[1] = 0x00
+ // Don't trust hdr.hdrLen from the kernel as it can be equal to the length
+ // of the entire first packet when the kernel is handling it as part of a
+ // FORWARD path. Instead, parse the transport header length and add it onto
+ // csumStart, which is synonymous for IP header length.
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+ hdr.hdrLen = hdr.csumStart + 8
+ } else {
+ if len(in) <= int(hdr.csumStart+12) {
+ return 0, errors.New("packet is too short")
+ }
- if buff[4]>>4 == ipv6.Version {
- buff[2] = 0x86
- buff[3] = 0xdd
- } else {
- buff[2] = 0x08
- buff[3] = 0x00
+ tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
+ if tcpHLen < 20 || tcpHLen > 60 {
+ // A TCP header must be between 20 and 60 bytes in length.
+ return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
}
+ hdr.hdrLen = hdr.csumStart + tcpHLen
}
- // write
+ if len(in) < int(hdr.hdrLen) {
+ return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
+ }
- return tun.tunFile.Write(buff)
-}
+ if hdr.hdrLen < hdr.csumStart {
+ return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
+ }
+ cSumAt := int(hdr.csumStart + hdr.csumOffset)
+ if cSumAt+1 >= len(in) {
+ return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
+ }
-func (tun *NativeTun) Flush() error {
- // TODO: can flushing be implemented by buffering and using sendmmsg?
- return nil
+ return gsoSplit(in, hdr, bufs, sizes, offset, ipVersion == 6)
}
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
+ tun.readOpMu.Lock()
+ defer tun.readOpMu.Unlock()
select {
case err := <-tun.errors:
return 0, err
default:
- if tun.nopi {
- return tun.tunFile.Read(buff[offset:])
+ readInto := bufs[0][offset:]
+ if tun.vnetHdr {
+ readInto = tun.readBuff[:]
+ }
+ n, err := tun.tunFile.Read(readInto)
+ if errors.Is(err, syscall.EBADFD) {
+ err = os.ErrClosed
+ }
+ if err != nil {
+ return 0, err
+ }
+ if tun.vnetHdr {
+ return handleVirtioRead(readInto[:n], bufs, sizes, offset)
} else {
- buff := buff[offset-4:]
- n, err := tun.tunFile.Read(buff[:])
- if n < 4 {
- return 0, err
- }
- return n - 4, err
+ sizes[0] = n
+ return 1, nil
}
}
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
func (tun *NativeTun) Close() error {
- var err1 error
- if tun.statusListenersShutdown != nil {
- close(tun.statusListenersShutdown)
- if tun.netlinkCancel != nil {
- err1 = tun.netlinkCancel.Cancel()
+ var err1, err2 error
+ tun.closeOnce.Do(func() {
+ if tun.statusListenersShutdown != nil {
+ close(tun.statusListenersShutdown)
+ if tun.netlinkCancel != nil {
+ err1 = tun.netlinkCancel.Cancel()
+ }
+ } else if tun.events != nil {
+ close(tun.events)
}
- } else if tun.events != nil {
- close(tun.events)
- }
- err2 := tun.tunFile.Close()
-
+ err2 = tun.tunFile.Close()
+ })
if err1 != nil {
return err1
}
return err2
}
+func (tun *NativeTun) BatchSize() int {
+ return tun.batchSize
+}
+
+const (
+ // TODO: support TSO with ECN bits
+ tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
+ tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
+)
+
+func (tun *NativeTun) initFromFlags(name string) error {
+ sc, err := tun.tunFile.SyscallConn()
+ if err != nil {
+ return err
+ }
+ if e := sc.Control(func(fd uintptr) {
+ var (
+ ifr *unix.Ifreq
+ )
+ ifr, err = unix.NewIfreq(name)
+ if err != nil {
+ return
+ }
+ err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
+ if err != nil {
+ return
+ }
+ got := ifr.Uint16()
+ if got&unix.IFF_VNET_HDR != 0 {
+ // tunTCPOffloads were added in Linux v2.6. We require their support
+ // if IFF_VNET_HDR is set.
+ err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads)
+ if err != nil {
+ return
+ }
+ tun.vnetHdr = true
+ tun.batchSize = conn.IdealBatchSize
+ // tunUDPOffloads were added in Linux v6.2. We do not return an
+ // error if they are unsupported at runtime.
+ tun.udpGSO = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) == nil
+ } else {
+ tun.batchSize = 1
+ }
+ }); e != nil {
+ return e
+ }
+ return err
+}
+
+// CreateTUN creates a Device with the provided name and MTU.
func CreateTUN(name string, mtu int) (Device, error) {
- nfd, err := unix.Open(cloneDevicePath, os.O_RDWR, 0)
+ nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
if err != nil {
+ if os.IsNotExist(err) {
+ return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath)
+ }
return nil, err
}
- var ifr [ifReqSize]byte
- var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack)
- nameBytes := []byte(name)
- if len(nameBytes) >= unix.IFNAMSIZ {
- return nil, errors.New("interface name too long")
+ ifr, err := unix.NewIfreq(name)
+ if err != nil {
+ return nil, err
}
- copy(ifr[:], nameBytes)
- *(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags
-
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(nfd),
- uintptr(unix.TUNSETIFF),
- uintptr(unsafe.Pointer(&ifr[0])),
- )
- if errno != 0 {
- return nil, errno
+ // IFF_VNET_HDR enables the "tun status hack" via routineHackListener()
+ // where a null write will return EINVAL indicating the TUN is up.
+ ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
+ err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr)
+ if err != nil {
+ return nil, err
}
- err = unix.SetNonblock(nfd, true)
- // Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line.
-
- fd := os.NewFile(uintptr(nfd), cloneDevicePath)
+ err = unix.SetNonblock(nfd, true)
if err != nil {
+ unix.Close(nfd)
return nil, err
}
+ // Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line.
+
+ fd := os.NewFile(uintptr(nfd), cloneDevicePath)
return CreateTUNFromFile(fd, mtu)
}
+// CreateTUNFromFile creates a Device from an os.File with the provided MTU.
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{
tunFile: file,
events: make(chan Event, 5),
errors: make(chan error, 5),
statusListenersShutdown: make(chan struct{}),
- nopi: false,
+ tcpGROTable: newTCPGROTable(),
+ udpGROTable: newUDPGROTable(),
+ toWrite: make([]int, 0, conn.IdealBatchSize),
}
- var err error
- _, err = tun.Name()
+ name, err := tun.Name()
if err != nil {
return nil, err
}
- // start event listener
+ err = tun.initFromFlags(name)
+ if err != nil {
+ return nil, err
+ }
- tun.index, err = getIFIndex(tun.name)
+ // start event listener
+ tun.index, err = getIFIndex(name)
if err != nil {
return nil, err
}
@@ -445,6 +632,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return tun, nil
}
+// CreateUnmonitoredTUNFromFD creates a Device from the provided file
+// descriptor.
func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
err := unix.SetNonblock(fd, true)
if err != nil {
@@ -452,14 +641,20 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
}
file := os.NewFile(uintptr(fd), "/dev/tun")
tun := &NativeTun{
- tunFile: file,
- events: make(chan Event, 5),
- errors: make(chan error, 5),
- nopi: true,
+ tunFile: file,
+ events: make(chan Event, 5),
+ errors: make(chan error, 5),
+ tcpGROTable: newTCPGROTable(),
+ udpGROTable: newUDPGROTable(),
+ toWrite: make([]int, 0, conn.IdealBatchSize),
}
name, err := tun.Name()
if err != nil {
return nil, "", err
}
- return tun, name, nil
+ err = tun.initFromFlags(name)
+ if err != nil {
+ return nil, "", err
+ }
+ return tun, name, err
}
diff --git a/tun/tun_openbsd.go b/tun/tun_openbsd.go
index 44cedaa..ae571b9 100644
--- a/tun/tun_openbsd.go
+++ b/tun/tun_openbsd.go
@@ -1,19 +1,20 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
+ "errors"
"fmt"
- "io/ioutil"
+ "io"
"net"
"os"
+ "sync"
"syscall"
"unsafe"
- "golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
)
@@ -32,6 +33,7 @@ type NativeTun struct {
events chan Event
errors chan error
routeSocket int
+ closeOnce sync.Once
}
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
@@ -99,16 +101,6 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
}
}
-func errorIsEBUSY(err error) bool {
- if pe, ok := err.(*os.PathError); ok {
- err = pe.Err
- }
- if errno, ok := err.(syscall.Errno); ok && errno == syscall.EBUSY {
- return true
- }
- return false
-}
-
func CreateTUN(name string, mtu int) (Device, error) {
ifIndex := -1
if name != "tun" {
@@ -122,11 +114,11 @@ func CreateTUN(name string, mtu int) (Device, error) {
var err error
if ifIndex != -1 {
- tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR, 0)
+ tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0)
} else {
- for ifIndex = 0; ifIndex < 256; ifIndex += 1 {
- tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR, 0)
- if err == nil || !errorIsEBUSY(err) {
+ for ifIndex = 0; ifIndex < 256; ifIndex++ {
+ tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0)
+ if err == nil || !errors.Is(err, syscall.EBUSY) {
break
}
}
@@ -141,7 +133,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
if err == nil && name == "tun" {
fname := os.Getenv("WG_TUN_NAME_FILE")
if fname != "" {
- ioutil.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0400)
+ os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400)
}
}
@@ -173,7 +165,7 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return nil, err
}
- tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+ tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC)
if err != nil {
tun.tunFile.Close()
return nil, err
@@ -208,62 +200,61 @@ func (tun *NativeTun) File() *os.File {
return tun.tunFile
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
select {
case err := <-tun.errors:
return 0, err
default:
- buff := buff[offset-4:]
- n, err := tun.tunFile.Read(buff[:])
+ buf := bufs[0][offset-4:]
+ n, err := tun.tunFile.Read(buf[:])
if n < 4 {
return 0, err
}
- return n - 4, err
+ sizes[0] = n - 4
+ return 1, err
}
}
-func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
-
- // reserve space for header
-
- buff = buff[offset-4:]
-
- // add packet information header
-
- buff[0] = 0x00
- buff[1] = 0x00
- buff[2] = 0x00
-
- if buff[4]>>4 == ipv6.Version {
- buff[3] = unix.AF_INET6
- } else {
- buff[3] = unix.AF_INET
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
+ if offset < 4 {
+ return 0, io.ErrShortBuffer
}
-
- // write
-
- return tun.tunFile.Write(buff)
-}
-
-func (tun *NativeTun) Flush() error {
- // TODO: can flushing be implemented by buffering and using sendmmsg?
- return nil
+ for i, buf := range bufs {
+ buf = buf[offset-4:]
+ buf[0] = 0x00
+ buf[1] = 0x00
+ buf[2] = 0x00
+ switch buf[4] >> 4 {
+ case 4:
+ buf[3] = unix.AF_INET
+ case 6:
+ buf[3] = unix.AF_INET6
+ default:
+ return i, unix.EAFNOSUPPORT
+ }
+ if _, err := tun.tunFile.Write(buf); err != nil {
+ return i, err
+ }
+ }
+ return len(bufs), nil
}
func (tun *NativeTun) Close() error {
- var err2 error
- err1 := tun.tunFile.Close()
- if tun.routeSocket != -1 {
- unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
- err2 = unix.Close(tun.routeSocket)
- tun.routeSocket = -1
- } else if tun.events != nil {
- close(tun.events)
- }
+ var err1, err2 error
+ tun.closeOnce.Do(func() {
+ err1 = tun.tunFile.Close()
+ if tun.routeSocket != -1 {
+ unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
+ err2 = unix.Close(tun.routeSocket)
+ tun.routeSocket = -1
+ } else if tun.events != nil {
+ close(tun.events)
+ }
+ })
if err1 != nil {
return err1
}
@@ -277,10 +268,9 @@ func (tun *NativeTun) setMTU(n int) error {
fd, err := unix.Socket(
unix.AF_INET,
- unix.SOCK_DGRAM,
+ unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
-
if err != nil {
return err
}
@@ -312,10 +302,9 @@ func (tun *NativeTun) MTU() (int, error) {
fd, err := unix.Socket(
unix.AF_INET,
- unix.SOCK_DGRAM,
+ unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
-
if err != nil {
return 0, err
}
@@ -338,3 +327,7 @@ func (tun *NativeTun) MTU() (int, error) {
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
}
+
+func (tun *NativeTun) BatchSize() int {
+ return 1
+}
diff --git a/tun/tun_windows.go b/tun/tun_windows.go
index daad4aa..2af8e3e 100644
--- a/tun/tun_windows.go
+++ b/tun/tun_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2018-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -9,13 +9,13 @@ import (
"errors"
"fmt"
"os"
+ "sync"
"sync/atomic"
"time"
- "unsafe"
+ _ "unsafe"
"golang.org/x/sys/windows"
-
- "golang.zx2c4.com/wireguard/tun/wintun"
+ "golang.zx2c4.com/wintun"
)
const (
@@ -25,24 +25,31 @@ const (
)
type rateJuggler struct {
- current uint64
- nextByteCount uint64
- nextStartTime int64
- changing int32
+ current atomic.Uint64
+ nextByteCount atomic.Uint64
+ nextStartTime atomic.Int64
+ changing atomic.Bool
}
type NativeTun struct {
- wt *wintun.Interface
+ wt *wintun.Adapter
+ name string
handle windows.Handle
- close bool
- rings wintun.RingDescriptor
+ rate rateJuggler
+ session wintun.Session
+ readWait windows.Handle
events chan Event
- errors chan error
+ running sync.WaitGroup
+ closeOnce sync.Once
+ close atomic.Bool
forcedMTU int
- rate rateJuggler
+ outSizes []int
}
-const WintunPool = wintun.Pool("WireGuard")
+var (
+ WintunTunnelType = "WireGuard"
+ WintunStaticRequestedGUID *windows.GUID
+)
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
@@ -50,34 +57,18 @@ func procyield(cycles uint32)
//go:linkname nanotime runtime.nanotime
func nanotime() int64
-//
// CreateTUN creates a Wintun interface with the given name. Should a Wintun
// interface with the same name exist, it is reused.
-//
func CreateTUN(ifname string, mtu int) (Device, error) {
- return CreateTUNWithRequestedGUID(ifname, nil, mtu)
+ return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
}
-//
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
// a requested GUID. Should a Wintun interface with the same name exist, it is reused.
-//
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
- var err error
- var wt *wintun.Interface
-
- // Does an interface with this name already exist?
- wt, err = WintunPool.GetInterface(ifname)
- if err == nil {
- // If so, we delete it, in case it has weird residual configuration.
- _, err = wt.DeleteInterface()
- if err != nil {
- return nil, fmt.Errorf("Error deleting already existing interface: %v", err)
- }
- }
- wt, _, err = WintunPool.CreateInterface(ifname, requestedGUID)
+ wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
if err != nil {
- return nil, fmt.Errorf("Error creating interface: %v", err)
+ return nil, fmt.Errorf("Error creating interface: %w", err)
}
forcedMTU := 1420
@@ -87,52 +78,46 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
tun := &NativeTun{
wt: wt,
+ name: ifname,
handle: windows.InvalidHandle,
events: make(chan Event, 10),
- errors: make(chan error, 1),
forcedMTU: forcedMTU,
}
- err = tun.rings.Init()
- if err != nil {
- tun.Close()
- return nil, fmt.Errorf("Error creating events: %v", err)
- }
-
- tun.handle, err = tun.wt.Register(&tun.rings)
+ tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
if err != nil {
- tun.Close()
- return nil, fmt.Errorf("Error registering rings: %v", err)
+ tun.wt.Close()
+ close(tun.events)
+ return nil, fmt.Errorf("Error starting session: %w", err)
}
+ tun.readWait = tun.session.ReadWaitEvent()
return tun, nil
}
func (tun *NativeTun) Name() (string, error) {
- return tun.wt.Name()
+ return tun.name, nil
}
func (tun *NativeTun) File() *os.File {
return nil
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
func (tun *NativeTun) Close() error {
- tun.close = true
- if tun.rings.Send.TailMoved != 0 {
- windows.SetEvent(tun.rings.Send.TailMoved) // wake the reader if it's sleeping
- }
- if tun.handle != windows.InvalidHandle {
- windows.CloseHandle(tun.handle)
- }
- tun.rings.Close()
var err error
- if tun.wt != nil {
- _, err = tun.wt.DeleteInterface()
- }
- close(tun.events)
+ tun.closeOnce.Do(func() {
+ tun.close.Store(true)
+ windows.SetEvent(tun.readWait)
+ tun.running.Wait()
+ tun.session.End()
+ if tun.wt != nil {
+ tun.wt.Close()
+ }
+ close(tun.events)
+ })
return err
}
@@ -142,129 +127,115 @@ func (tun *NativeTun) MTU() (int, error) {
// TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
func (tun *NativeTun) ForceMTU(mtu int) {
+ if tun.close.Load() {
+ return
+ }
+ update := tun.forcedMTU != mtu
tun.forcedMTU = mtu
+ if update {
+ tun.events <- EventMTUUpdate
+ }
+}
+
+func (tun *NativeTun) BatchSize() int {
+ // TODO: implement batching with wintun
+ return 1
}
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
+ tun.running.Add(1)
+ defer tun.running.Done()
retry:
- select {
- case err := <-tun.errors:
- return 0, err
- default:
- }
- if tun.close {
- return 0, os.ErrClosed
- }
-
- buffHead := atomic.LoadUint32(&tun.rings.Send.Ring.Head)
- if buffHead >= wintun.PacketCapacity {
+ if tun.close.Load() {
return 0, os.ErrClosed
}
-
start := nanotime()
- shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
- var buffTail uint32
+ shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
for {
- buffTail = atomic.LoadUint32(&tun.rings.Send.Ring.Tail)
- if buffHead != buffTail {
- break
- }
- if tun.close {
+ if tun.close.Load() {
return 0, os.ErrClosed
}
- if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
- windows.WaitForSingleObject(tun.rings.Send.TailMoved, windows.INFINITE)
- goto retry
+ packet, err := tun.session.ReceivePacket()
+ switch err {
+ case nil:
+ n := copy(bufs[0][offset:], packet)
+ sizes[0] = n
+ tun.session.ReleaseReceivePacket(packet)
+ tun.rate.update(uint64(n))
+ return 1, nil
+ case windows.ERROR_NO_MORE_ITEMS:
+ if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
+ windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
+ goto retry
+ }
+ procyield(1)
+ continue
+ case windows.ERROR_HANDLE_EOF:
+ return 0, os.ErrClosed
+ case windows.ERROR_INVALID_DATA:
+ return 0, errors.New("Send ring corrupt")
}
- procyield(1)
- }
- if buffTail >= wintun.PacketCapacity {
- return 0, os.ErrClosed
- }
-
- buffContent := tun.rings.Send.Ring.Wrap(buffTail - buffHead)
- if buffContent < uint32(unsafe.Sizeof(wintun.PacketHeader{})) {
- return 0, errors.New("incomplete packet header in send ring")
- }
-
- packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Send.Ring.Data[buffHead]))
- if packet.Size > wintun.PacketSizeMax {
- return 0, errors.New("packet too big in send ring")
- }
-
- alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packet.Size)
- if alignedPacketSize > buffContent {
- return 0, errors.New("incomplete packet in send ring")
+ return 0, fmt.Errorf("Read failed: %w", err)
}
-
- copy(buff[offset:], packet.Data[:packet.Size])
- buffHead = tun.rings.Send.Ring.Wrap(buffHead + alignedPacketSize)
- atomic.StoreUint32(&tun.rings.Send.Ring.Head, buffHead)
- tun.rate.update(uint64(packet.Size))
- return int(packet.Size), nil
-}
-
-func (tun *NativeTun) Flush() error {
- return nil
}
-func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
- if tun.close {
- return 0, os.ErrClosed
- }
-
- packetSize := uint32(len(buff) - offset)
- tun.rate.update(uint64(packetSize))
- alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize)
-
- buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head)
- if buffHead >= wintun.PacketCapacity {
- return 0, os.ErrClosed
- }
-
- buffTail := atomic.LoadUint32(&tun.rings.Receive.Ring.Tail)
- if buffTail >= wintun.PacketCapacity {
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
+ tun.running.Add(1)
+ defer tun.running.Done()
+ if tun.close.Load() {
return 0, os.ErrClosed
}
- buffSpace := tun.rings.Receive.Ring.Wrap(buffHead - buffTail - wintun.PacketAlignment)
- if alignedPacketSize > buffSpace {
- return 0, nil // Dropping when ring is full.
- }
-
- packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Receive.Ring.Data[buffTail]))
- packet.Size = packetSize
- copy(packet.Data[:packetSize], buff[offset:])
- atomic.StoreUint32(&tun.rings.Receive.Ring.Tail, tun.rings.Receive.Ring.Wrap(buffTail+alignedPacketSize))
- if atomic.LoadInt32(&tun.rings.Receive.Ring.Alertable) != 0 {
- windows.SetEvent(tun.rings.Receive.TailMoved)
+ for i, buf := range bufs {
+ packetSize := len(buf) - offset
+ tun.rate.update(uint64(packetSize))
+
+ packet, err := tun.session.AllocateSendPacket(packetSize)
+ switch err {
+ case nil:
+ // TODO: Explore options to eliminate this copy.
+ copy(packet, buf[offset:])
+ tun.session.SendPacket(packet)
+ continue
+ case windows.ERROR_HANDLE_EOF:
+ return i, os.ErrClosed
+ case windows.ERROR_BUFFER_OVERFLOW:
+ continue // Dropping when ring is full.
+ default:
+ return i, fmt.Errorf("Write failed: %w", err)
+ }
}
- return int(packetSize), nil
+ return len(bufs), nil
}
// LUID returns Windows interface instance ID.
func (tun *NativeTun) LUID() uint64 {
+ tun.running.Add(1)
+ defer tun.running.Done()
+ if tun.close.Load() {
+ return 0
+ }
return tun.wt.LUID()
}
-// Version returns the version of the Wintun driver and NDIS system currently loaded.
-func (tun *NativeTun) Version() (driverVersion string, ndisVersion string, err error) {
- return tun.wt.Version()
+// RunningVersion returns the running version of the Wintun driver.
+func (tun *NativeTun) RunningVersion() (version uint32, err error) {
+ return wintun.RunningVersion()
}
func (rate *rateJuggler) update(packetLen uint64) {
now := nanotime()
- total := atomic.AddUint64(&rate.nextByteCount, packetLen)
- period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
+ total := rate.nextByteCount.Add(packetLen)
+ period := uint64(now - rate.nextStartTime.Load())
if period >= rateMeasurementGranularity {
- if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
+ if !rate.changing.CompareAndSwap(false, true) {
return
}
- atomic.StoreInt64(&rate.nextStartTime, now)
- atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
- atomic.StoreUint64(&rate.nextByteCount, 0)
- atomic.StoreInt32(&rate.changing, 0)
+ rate.nextStartTime.Store(now)
+ rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period)
+ rate.nextByteCount.Store(0)
+ rate.changing.Store(false)
}
}
diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go
new file mode 100644
index 0000000..d07e860
--- /dev/null
+++ b/tun/tuntest/tuntest.go
@@ -0,0 +1,155 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package tuntest
+
+import (
+ "encoding/binary"
+ "io"
+ "net/netip"
+ "os"
+
+ "golang.zx2c4.com/wireguard/tun"
+)
+
+func Ping(dst, src netip.Addr) []byte {
+ localPort := uint16(1337)
+ seq := uint16(0)
+
+ payload := make([]byte, 4)
+ binary.BigEndian.PutUint16(payload[0:], localPort)
+ binary.BigEndian.PutUint16(payload[2:], seq)
+
+ return genICMPv4(payload, dst, src)
+}
+
+// Checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
+func checksum(buf []byte, initial uint16) uint16 {
+ v := uint32(initial)
+ for i := 0; i < len(buf)-1; i += 2 {
+ v += uint32(binary.BigEndian.Uint16(buf[i:]))
+ }
+ if len(buf)%2 == 1 {
+ v += uint32(buf[len(buf)-1]) << 8
+ }
+ for v > 0xffff {
+ v = (v >> 16) + (v & 0xffff)
+ }
+ return ^uint16(v)
+}
+
+func genICMPv4(payload []byte, dst, src netip.Addr) []byte {
+ const (
+ icmpv4ProtocolNumber = 1
+ icmpv4Echo = 8
+ icmpv4ChecksumOffset = 2
+ icmpv4Size = 8
+ ipv4Size = 20
+ ipv4TotalLenOffset = 2
+ ipv4ChecksumOffset = 10
+ ttl = 65
+ headerSize = ipv4Size + icmpv4Size
+ )
+
+ pkt := make([]byte, headerSize+len(payload))
+
+ ip := pkt[0:ipv4Size]
+ icmpv4 := pkt[ipv4Size : ipv4Size+icmpv4Size]
+
+ // https://tools.ietf.org/html/rfc792
+ icmpv4[0] = icmpv4Echo // type
+ icmpv4[1] = 0 // code
+ chksum := ^checksum(icmpv4, checksum(payload, 0))
+ binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
+
+ // https://tools.ietf.org/html/rfc760 section 3.1
+ length := uint16(len(pkt))
+ ip[0] = (4 << 4) | (ipv4Size / 4)
+ binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
+ ip[8] = ttl
+ ip[9] = icmpv4ProtocolNumber
+ copy(ip[12:], src.AsSlice())
+ copy(ip[16:], dst.AsSlice())
+ chksum = ^checksum(ip[:], 0)
+ binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
+
+ copy(pkt[headerSize:], payload)
+ return pkt
+}
+
+type ChannelTUN struct {
+ Inbound chan []byte // incoming packets, closed on TUN close
+ Outbound chan []byte // outbound packets, blocks forever on TUN close
+
+ closed chan struct{}
+ events chan tun.Event
+ tun chTun
+}
+
+func NewChannelTUN() *ChannelTUN {
+ c := &ChannelTUN{
+ Inbound: make(chan []byte),
+ Outbound: make(chan []byte),
+ closed: make(chan struct{}),
+ events: make(chan tun.Event, 1),
+ }
+ c.tun.c = c
+ c.events <- tun.EventUp
+ return c
+}
+
+func (c *ChannelTUN) TUN() tun.Device {
+ return &c.tun
+}
+
+type chTun struct {
+ c *ChannelTUN
+}
+
+func (t *chTun) File() *os.File { return nil }
+
+func (t *chTun) Read(packets [][]byte, sizes []int, offset int) (int, error) {
+ select {
+ case <-t.c.closed:
+ return 0, os.ErrClosed
+ case msg := <-t.c.Outbound:
+ n := copy(packets[0][offset:], msg)
+ sizes[0] = n
+ return 1, nil
+ }
+}
+
+// Write is called by the wireguard device to deliver a packet for routing.
+func (t *chTun) Write(packets [][]byte, offset int) (int, error) {
+ if offset == -1 {
+ close(t.c.closed)
+ close(t.c.events)
+ return 0, io.EOF
+ }
+ for i, data := range packets {
+ msg := make([]byte, len(data)-offset)
+ copy(msg, data[offset:])
+ select {
+ case <-t.c.closed:
+ return i, os.ErrClosed
+ case t.c.Inbound <- msg:
+ }
+ }
+ return len(packets), nil
+}
+
+func (t *chTun) BatchSize() int {
+ return 1
+}
+
+const DefaultMTU = 1420
+
+func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
+func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
+func (t *chTun) Events() <-chan tun.Event { return t.c.events }
+func (t *chTun) Close() error {
+ t.Write(nil, -1)
+ return nil
+}
diff --git a/tun/wintun/iphlpapi/conversion_windows.go b/tun/wintun/iphlpapi/conversion_windows.go
deleted file mode 100644
index a19e961..0000000
--- a/tun/wintun/iphlpapi/conversion_windows.go
+++ /dev/null
@@ -1,25 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package iphlpapi
-
-import "golang.org/x/sys/windows"
-
-//sys convertInterfaceLUIDToGUID(interfaceLUID *uint64, interfaceGUID *windows.GUID) (ret error) = iphlpapi.ConvertInterfaceLuidToGuid
-//sys convertInterfaceAliasToLUID(interfaceAlias *uint16, interfaceLUID *uint64) (ret error) = iphlpapi.ConvertInterfaceAliasToLuid
-
-func InterfaceGUIDFromAlias(alias string) (*windows.GUID, error) {
- var luid uint64
- var guid windows.GUID
- err := convertInterfaceAliasToLUID(windows.StringToUTF16Ptr(alias), &luid)
- if err != nil {
- return nil, err
- }
- err = convertInterfaceLUIDToGUID(&luid, &guid)
- if err != nil {
- return nil, err
- }
- return &guid, nil
-}
diff --git a/tun/wintun/iphlpapi/mksyscall.go b/tun/wintun/iphlpapi/mksyscall.go
deleted file mode 100644
index fc7dba4..0000000
--- a/tun/wintun/iphlpapi/mksyscall.go
+++ /dev/null
@@ -1,8 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package iphlpapi
-
-//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go conversion_windows.go
diff --git a/tun/wintun/iphlpapi/zsyscall_windows.go b/tun/wintun/iphlpapi/zsyscall_windows.go
deleted file mode 100644
index dc14294..0000000
--- a/tun/wintun/iphlpapi/zsyscall_windows.go
+++ /dev/null
@@ -1,60 +0,0 @@
-// Code generated by 'go generate'; DO NOT EDIT.
-
-package iphlpapi
-
-import (
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-var _ unsafe.Pointer
-
-// Do the interface allocations only once for common
-// Errno values.
-const (
- errnoERROR_IO_PENDING = 997
-)
-
-var (
- errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
-)
-
-// errnoErr returns common boxed Errno values, to prevent
-// allocations at runtime.
-func errnoErr(e syscall.Errno) error {
- switch e {
- case 0:
- return nil
- case errnoERROR_IO_PENDING:
- return errERROR_IO_PENDING
- }
- // TODO: add more here, after collecting data on the common
- // error values see on Windows. (perhaps when running
- // all.bat?)
- return e
-}
-
-var (
- modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
-
- procConvertInterfaceLuidToGuid = modiphlpapi.NewProc("ConvertInterfaceLuidToGuid")
- procConvertInterfaceAliasToLuid = modiphlpapi.NewProc("ConvertInterfaceAliasToLuid")
-)
-
-func convertInterfaceLUIDToGUID(interfaceLUID *uint64, interfaceGUID *windows.GUID) (ret error) {
- r0, _, _ := syscall.Syscall(procConvertInterfaceLuidToGuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceLUID)), uintptr(unsafe.Pointer(interfaceGUID)), 0)
- if r0 != 0 {
- ret = syscall.Errno(r0)
- }
- return
-}
-
-func convertInterfaceAliasToLUID(interfaceAlias *uint16, interfaceLUID *uint64) (ret error) {
- r0, _, _ := syscall.Syscall(procConvertInterfaceAliasToLuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceAlias)), uintptr(unsafe.Pointer(interfaceLUID)), 0)
- if r0 != 0 {
- ret = syscall.Errno(r0)
- }
- return
-}
diff --git a/tun/wintun/namespace_windows.go b/tun/wintun/namespace_windows.go
deleted file mode 100644
index f4316fe..0000000
--- a/tun/wintun/namespace_windows.go
+++ /dev/null
@@ -1,98 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package wintun
-
-import (
- "encoding/hex"
- "errors"
- "fmt"
- "sync"
- "unsafe"
-
- "golang.org/x/crypto/blake2s"
- "golang.org/x/sys/windows"
- "golang.org/x/text/unicode/norm"
-
- "golang.zx2c4.com/wireguard/tun/wintun/namespaceapi"
-)
-
-var (
- wintunObjectSecurityAttributes *windows.SecurityAttributes
- hasInitializedNamespace bool
- initializingNamespace sync.Mutex
-)
-
-func initializeNamespace() error {
- initializingNamespace.Lock()
- defer initializingNamespace.Unlock()
- if hasInitializedNamespace {
- return nil
- }
- sd, err := windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)")
- if err != nil {
- return fmt.Errorf("SddlToSecurityDescriptor failed: %v", err)
- }
- wintunObjectSecurityAttributes = &windows.SecurityAttributes{
- Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})),
- SecurityDescriptor: sd,
- }
- sid, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid)
- if err != nil {
- return fmt.Errorf("CreateWellKnownSid(LOCAL_SYSTEM) failed: %v", err)
- }
-
- boundary, err := namespaceapi.CreateBoundaryDescriptor("Wintun")
- if err != nil {
- return fmt.Errorf("CreateBoundaryDescriptor failed: %v", err)
- }
- err = boundary.AddSid(sid)
- if err != nil {
- return fmt.Errorf("AddSIDToBoundaryDescriptor failed: %v", err)
- }
- for {
- _, err = namespaceapi.CreatePrivateNamespace(wintunObjectSecurityAttributes, boundary, "Wintun")
- if err == windows.ERROR_ALREADY_EXISTS {
- _, err = namespaceapi.OpenPrivateNamespace(boundary, "Wintun")
- if err == windows.ERROR_PATH_NOT_FOUND {
- continue
- }
- }
- if err != nil {
- return fmt.Errorf("Create/OpenPrivateNamespace failed: %v", err)
- }
- break
- }
- hasInitializedNamespace = true
- return nil
-}
-
-func (pool Pool) takeNameMutex() (windows.Handle, error) {
- err := initializeNamespace()
- if err != nil {
- return 0, err
- }
-
- const mutexLabel = "WireGuard Adapter Name Mutex Stable Suffix v1 jason@zx2c4.com"
- b2, _ := blake2s.New256(nil)
- b2.Write([]byte(mutexLabel))
- b2.Write(norm.NFC.Bytes([]byte(string(pool))))
- mutexName := `Wintun\Wintun-Name-Mutex-` + hex.EncodeToString(b2.Sum(nil))
- mutex, err := windows.CreateMutex(wintunObjectSecurityAttributes, false, windows.StringToUTF16Ptr(mutexName))
- if err != nil {
- err = fmt.Errorf("Error creating name mutex: %v", err)
- return 0, err
- }
- event, err := windows.WaitForSingleObject(mutex, windows.INFINITE)
- if err != nil {
- windows.CloseHandle(mutex)
- return 0, fmt.Errorf("Error waiting on name mutex: %v", err)
- }
- if event != windows.WAIT_OBJECT_0 && event != windows.WAIT_ABANDONED {
- windows.CloseHandle(mutex)
- return 0, errors.New("Error with event trigger of name mutex")
- }
- return mutex, nil
-}
diff --git a/tun/wintun/namespaceapi/mksyscall.go b/tun/wintun/namespaceapi/mksyscall.go
deleted file mode 100644
index 93d43b0..0000000
--- a/tun/wintun/namespaceapi/mksyscall.go
+++ /dev/null
@@ -1,8 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package namespaceapi
-
-//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go namespaceapi_windows.go
diff --git a/tun/wintun/namespaceapi/namespaceapi_windows.go b/tun/wintun/namespaceapi/namespaceapi_windows.go
deleted file mode 100644
index a3a6274..0000000
--- a/tun/wintun/namespaceapi/namespaceapi_windows.go
+++ /dev/null
@@ -1,83 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package namespaceapi
-
-import "golang.org/x/sys/windows"
-
-//sys createBoundaryDescriptor(name *uint16, flags uint32) (handle windows.Handle, err error) = kernel32.CreateBoundaryDescriptorW
-//sys deleteBoundaryDescriptor(boundaryDescriptor windows.Handle) = kernel32.DeleteBoundaryDescriptor
-//sys addSIDToBoundaryDescriptor(boundaryDescriptor *windows.Handle, requiredSid *windows.SID) (err error) = kernel32.AddSIDToBoundaryDescriptor
-//sys createPrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) = kernel32.CreatePrivateNamespaceW
-//sys openPrivateNamespace(boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) = kernel32.OpenPrivateNamespaceW
-//sys closePrivateNamespace(handle windows.Handle, flags uint32) (err error) = kernel32.ClosePrivateNamespace
-
-// BoundaryDescriptor represents a boundary that defines how the objects in the namespace are to be isolated.
-type BoundaryDescriptor windows.Handle
-
-// CreateBoundaryDescriptor creates a boundary descriptor.
-func CreateBoundaryDescriptor(name string) (BoundaryDescriptor, error) {
- name16, err := windows.UTF16PtrFromString(name)
- if err != nil {
- return 0, err
- }
- handle, err := createBoundaryDescriptor(name16, 0)
- if err != nil {
- return 0, err
- }
- return BoundaryDescriptor(handle), nil
-}
-
-// Delete deletes the specified boundary descriptor.
-func (bd BoundaryDescriptor) Delete() {
- deleteBoundaryDescriptor(windows.Handle(bd))
-}
-
-// AddSid adds a security identifier (SID) to the specified boundary descriptor.
-func (bd *BoundaryDescriptor) AddSid(requiredSid *windows.SID) error {
- return addSIDToBoundaryDescriptor((*windows.Handle)(bd), requiredSid)
-}
-
-// PrivateNamespace represents a private namespace.
-type PrivateNamespace windows.Handle
-
-// CreatePrivateNamespace creates a private namespace.
-func CreatePrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor BoundaryDescriptor, aliasPrefix string) (PrivateNamespace, error) {
- aliasPrefix16, err := windows.UTF16PtrFromString(aliasPrefix)
- if err != nil {
- return 0, err
- }
- handle, err := createPrivateNamespace(privateNamespaceAttributes, windows.Handle(boundaryDescriptor), aliasPrefix16)
- if err != nil {
- return 0, err
- }
- return PrivateNamespace(handle), nil
-}
-
-// OpenPrivateNamespace opens a private namespace.
-func OpenPrivateNamespace(boundaryDescriptor BoundaryDescriptor, aliasPrefix string) (PrivateNamespace, error) {
- aliasPrefix16, err := windows.UTF16PtrFromString(aliasPrefix)
- if err != nil {
- return 0, err
- }
- handle, err := openPrivateNamespace(windows.Handle(boundaryDescriptor), aliasPrefix16)
- if err != nil {
- return 0, err
- }
- return PrivateNamespace(handle), nil
-}
-
-// ClosePrivateNamespaceFlags describes flags that are used by PrivateNamespace's Close() method.
-type ClosePrivateNamespaceFlags uint32
-
-const (
- // PrivateNamespaceFlagDestroy makes the close to destroy the namespace.
- PrivateNamespaceFlagDestroy = ClosePrivateNamespaceFlags(0x1)
-)
-
-// Close closes an open namespace handle.
-func (pns PrivateNamespace) Close(flags ClosePrivateNamespaceFlags) error {
- return closePrivateNamespace(windows.Handle(pns), uint32(flags))
-}
diff --git a/tun/wintun/namespaceapi/zsyscall_windows.go b/tun/wintun/namespaceapi/zsyscall_windows.go
deleted file mode 100644
index 508c223..0000000
--- a/tun/wintun/namespaceapi/zsyscall_windows.go
+++ /dev/null
@@ -1,116 +0,0 @@
-// Code generated by 'go generate'; DO NOT EDIT.
-
-package namespaceapi
-
-import (
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-var _ unsafe.Pointer
-
-// Do the interface allocations only once for common
-// Errno values.
-const (
- errnoERROR_IO_PENDING = 997
-)
-
-var (
- errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
-)
-
-// errnoErr returns common boxed Errno values, to prevent
-// allocations at runtime.
-func errnoErr(e syscall.Errno) error {
- switch e {
- case 0:
- return nil
- case errnoERROR_IO_PENDING:
- return errERROR_IO_PENDING
- }
- // TODO: add more here, after collecting data on the common
- // error values see on Windows. (perhaps when running
- // all.bat?)
- return e
-}
-
-var (
- modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
-
- procCreateBoundaryDescriptorW = modkernel32.NewProc("CreateBoundaryDescriptorW")
- procDeleteBoundaryDescriptor = modkernel32.NewProc("DeleteBoundaryDescriptor")
- procAddSIDToBoundaryDescriptor = modkernel32.NewProc("AddSIDToBoundaryDescriptor")
- procCreatePrivateNamespaceW = modkernel32.NewProc("CreatePrivateNamespaceW")
- procOpenPrivateNamespaceW = modkernel32.NewProc("OpenPrivateNamespaceW")
- procClosePrivateNamespace = modkernel32.NewProc("ClosePrivateNamespace")
-)
-
-func createBoundaryDescriptor(name *uint16, flags uint32) (handle windows.Handle, err error) {
- r0, _, e1 := syscall.Syscall(procCreateBoundaryDescriptorW.Addr(), 2, uintptr(unsafe.Pointer(name)), uintptr(flags), 0)
- handle = windows.Handle(r0)
- if handle == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func deleteBoundaryDescriptor(boundaryDescriptor windows.Handle) {
- syscall.Syscall(procDeleteBoundaryDescriptor.Addr(), 1, uintptr(boundaryDescriptor), 0, 0)
- return
-}
-
-func addSIDToBoundaryDescriptor(boundaryDescriptor *windows.Handle, requiredSid *windows.SID) (err error) {
- r1, _, e1 := syscall.Syscall(procAddSIDToBoundaryDescriptor.Addr(), 2, uintptr(unsafe.Pointer(boundaryDescriptor)), uintptr(unsafe.Pointer(requiredSid)), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func createPrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) {
- r0, _, e1 := syscall.Syscall(procCreatePrivateNamespaceW.Addr(), 3, uintptr(unsafe.Pointer(privateNamespaceAttributes)), uintptr(boundaryDescriptor), uintptr(unsafe.Pointer(aliasPrefix)))
- handle = windows.Handle(r0)
- if handle == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func openPrivateNamespace(boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) {
- r0, _, e1 := syscall.Syscall(procOpenPrivateNamespaceW.Addr(), 2, uintptr(boundaryDescriptor), uintptr(unsafe.Pointer(aliasPrefix)), 0)
- handle = windows.Handle(r0)
- if handle == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func closePrivateNamespace(handle windows.Handle, flags uint32) (err error) {
- r1, _, e1 := syscall.Syscall(procClosePrivateNamespace.Addr(), 2, uintptr(handle), uintptr(flags), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
diff --git a/tun/wintun/nci/mksyscall.go b/tun/wintun/nci/mksyscall.go
deleted file mode 100644
index 019da93..0000000
--- a/tun/wintun/nci/mksyscall.go
+++ /dev/null
@@ -1,8 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package nci
-
-//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go nci_windows.go
diff --git a/tun/wintun/nci/nci_windows.go b/tun/wintun/nci/nci_windows.go
deleted file mode 100644
index 9dc6699..0000000
--- a/tun/wintun/nci/nci_windows.go
+++ /dev/null
@@ -1,28 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package nci
-
-import "golang.org/x/sys/windows"
-
-//sys nciSetConnectionName(guid *windows.GUID, newName *uint16) (ret error) = nci.NciSetConnectionName
-//sys nciGetConnectionName(guid *windows.GUID, destName *uint16, inDestNameBytes uint32, outDestNameBytes *uint32) (ret error) = nci.NciGetConnectionName
-
-func SetConnectionName(guid *windows.GUID, newName string) error {
- newName16, err := windows.UTF16PtrFromString(newName)
- if err != nil {
- return err
- }
- return nciSetConnectionName(guid, newName16)
-}
-
-func ConnectionName(guid *windows.GUID) (string, error) {
- var name [0x400]uint16
- err := nciGetConnectionName(guid, &name[0], uint32(len(name)*2), nil)
- if err != nil {
- return "", err
- }
- return windows.UTF16ToString(name[:]), nil
-}
diff --git a/tun/wintun/nci/zsyscall_windows.go b/tun/wintun/nci/zsyscall_windows.go
deleted file mode 100644
index 2a7b79e..0000000
--- a/tun/wintun/nci/zsyscall_windows.go
+++ /dev/null
@@ -1,60 +0,0 @@
-// Code generated by 'go generate'; DO NOT EDIT.
-
-package nci
-
-import (
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-var _ unsafe.Pointer
-
-// Do the interface allocations only once for common
-// Errno values.
-const (
- errnoERROR_IO_PENDING = 997
-)
-
-var (
- errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
-)
-
-// errnoErr returns common boxed Errno values, to prevent
-// allocations at runtime.
-func errnoErr(e syscall.Errno) error {
- switch e {
- case 0:
- return nil
- case errnoERROR_IO_PENDING:
- return errERROR_IO_PENDING
- }
- // TODO: add more here, after collecting data on the common
- // error values see on Windows. (perhaps when running
- // all.bat?)
- return e
-}
-
-var (
- modnci = windows.NewLazySystemDLL("nci.dll")
-
- procNciSetConnectionName = modnci.NewProc("NciSetConnectionName")
- procNciGetConnectionName = modnci.NewProc("NciGetConnectionName")
-)
-
-func nciSetConnectionName(guid *windows.GUID, newName *uint16) (ret error) {
- r0, _, _ := syscall.Syscall(procNciSetConnectionName.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(newName)), 0)
- if r0 != 0 {
- ret = syscall.Errno(r0)
- }
- return
-}
-
-func nciGetConnectionName(guid *windows.GUID, destName *uint16, inDestNameBytes uint32, outDestNameBytes *uint32) (ret error) {
- r0, _, _ := syscall.Syscall6(procNciGetConnectionName.Addr(), 4, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(destName)), uintptr(inDestNameBytes), uintptr(unsafe.Pointer(outDestNameBytes)), 0, 0)
- if r0 != 0 {
- ret = syscall.Errno(r0)
- }
- return
-}
diff --git a/tun/wintun/registry/mksyscall.go b/tun/wintun/registry/mksyscall.go
deleted file mode 100644
index 6ad82d2..0000000
--- a/tun/wintun/registry/mksyscall.go
+++ /dev/null
@@ -1,8 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package registry
-
-//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zregistry_windows.go registry_windows.go
diff --git a/tun/wintun/registry/registry_windows.go b/tun/wintun/registry/registry_windows.go
deleted file mode 100644
index 12a0336..0000000
--- a/tun/wintun/registry/registry_windows.go
+++ /dev/null
@@ -1,272 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package registry
-
-import (
- "errors"
- "fmt"
- "runtime"
- "strings"
- "time"
- "unsafe"
-
- "golang.org/x/sys/windows"
- "golang.org/x/sys/windows/registry"
-)
-
-const (
- // REG_NOTIFY_CHANGE_NAME notifies the caller if a subkey is added or deleted.
- REG_NOTIFY_CHANGE_NAME uint32 = 0x00000001
-
- // REG_NOTIFY_CHANGE_ATTRIBUTES notifies the caller of changes to the attributes of the key, such as the security descriptor information.
- REG_NOTIFY_CHANGE_ATTRIBUTES uint32 = 0x00000002
-
- // REG_NOTIFY_CHANGE_LAST_SET notifies the caller of changes to a value of the key. This can include adding or deleting a value, or changing an existing value.
- REG_NOTIFY_CHANGE_LAST_SET uint32 = 0x00000004
-
- // REG_NOTIFY_CHANGE_SECURITY notifies the caller of changes to the security descriptor of the key.
- REG_NOTIFY_CHANGE_SECURITY uint32 = 0x00000008
-
- // REG_NOTIFY_THREAD_AGNOSTIC indicates that the lifetime of the registration must not be tied to the lifetime of the thread issuing the RegNotifyChangeKeyValue call. Note: This flag value is only supported in Windows 8 and later.
- REG_NOTIFY_THREAD_AGNOSTIC uint32 = 0x10000000
-)
-
-//sys regNotifyChangeKeyValue(key windows.Handle, watchSubtree bool, notifyFilter uint32, event windows.Handle, asynchronous bool) (regerrno error) = advapi32.RegNotifyChangeKeyValue
-
-func OpenKeyWait(k registry.Key, path string, access uint32, timeout time.Duration) (registry.Key, error) {
- runtime.LockOSThread()
- defer runtime.UnlockOSThread()
-
- deadline := time.Now().Add(timeout)
- pathSpl := strings.Split(path, "\\")
- for i := 0; ; i++ {
- keyName := pathSpl[i]
- isLast := i+1 == len(pathSpl)
-
- event, err := windows.CreateEvent(nil, 0, 0, nil)
- if err != nil {
- return 0, fmt.Errorf("Error creating event: %v", err)
- }
- defer windows.CloseHandle(event)
-
- var key registry.Key
- for {
- err = regNotifyChangeKeyValue(windows.Handle(k), false, REG_NOTIFY_CHANGE_NAME, windows.Handle(event), true)
- if err != nil {
- return 0, fmt.Errorf("Setting up change notification on registry key failed: %v", err)
- }
-
- var accessFlags uint32
- if isLast {
- accessFlags = access
- } else {
- accessFlags = registry.NOTIFY
- }
- key, err = registry.OpenKey(k, keyName, accessFlags)
- if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND {
- timeout := time.Until(deadline) / time.Millisecond
- if timeout < 0 {
- timeout = 0
- }
- s, err := windows.WaitForSingleObject(event, uint32(timeout))
- if err != nil {
- return 0, fmt.Errorf("Unable to wait on registry key: %v", err)
- }
- if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows
- return 0, errors.New("Timeout waiting for registry key")
- }
- } else if err != nil {
- return 0, fmt.Errorf("Error opening registry key %v: %v", path, err)
- } else {
- if isLast {
- return key, nil
- }
- defer key.Close()
- break
- }
- }
-
- k = key
- }
-}
-
-func WaitForKey(k registry.Key, path string, timeout time.Duration) error {
- key, err := OpenKeyWait(k, path, registry.NOTIFY, timeout)
- if err != nil {
- return err
- }
- key.Close()
- return nil
-}
-
-//
-// getValue is more or less the same as windows/registry's getValue.
-//
-func getValue(k registry.Key, name string, buf []byte) (value []byte, valueType uint32, err error) {
- var name16 *uint16
- name16, err = windows.UTF16PtrFromString(name)
- if err != nil {
- return
- }
- n := uint32(len(buf))
- for {
- err = windows.RegQueryValueEx(windows.Handle(k), name16, nil, &valueType, (*byte)(unsafe.Pointer(&buf[0])), &n)
- if err == nil {
- value = buf[:n]
- return
- }
- if err != windows.ERROR_MORE_DATA {
- return
- }
- if n <= uint32(len(buf)) {
- return
- }
- buf = make([]byte, n)
- }
-}
-
-//
-// getValueRetry function reads any value from registry. It waits for
-// the registry value to become available or returns error on timeout.
-//
-// Key must be opened with at least QUERY_VALUE|NOTIFY access.
-//
-func getValueRetry(key registry.Key, name string, buf []byte, timeout time.Duration) ([]byte, uint32, error) {
- runtime.LockOSThread()
- defer runtime.UnlockOSThread()
-
- event, err := windows.CreateEvent(nil, 0, 0, nil)
- if err != nil {
- return nil, 0, fmt.Errorf("Error creating event: %v", err)
- }
- defer windows.CloseHandle(event)
-
- deadline := time.Now().Add(timeout)
- for {
- err := regNotifyChangeKeyValue(windows.Handle(key), false, REG_NOTIFY_CHANGE_LAST_SET, windows.Handle(event), true)
- if err != nil {
- return nil, 0, fmt.Errorf("Setting up change notification on registry value failed: %v", err)
- }
-
- buf, valueType, err := getValue(key, name, buf)
- if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND {
- timeout := time.Until(deadline) / time.Millisecond
- if timeout < 0 {
- timeout = 0
- }
- s, err := windows.WaitForSingleObject(event, uint32(timeout))
- if err != nil {
- return nil, 0, fmt.Errorf("Unable to wait on registry value: %v", err)
- }
- if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows
- return nil, 0, errors.New("Timeout waiting for registry value")
- }
- } else if err != nil {
- return nil, 0, fmt.Errorf("Error reading registry value %v: %v", name, err)
- } else {
- return buf, valueType, nil
- }
- }
-}
-
-func toString(buf []byte, valueType uint32, err error) (string, error) {
- if err != nil {
- return "", err
- }
-
- var value string
- switch valueType {
- case registry.SZ, registry.EXPAND_SZ, registry.MULTI_SZ:
- if len(buf) == 0 {
- return "", nil
- }
- value = windows.UTF16ToString((*[(1 << 30) - 1]uint16)(unsafe.Pointer(&buf[0]))[:len(buf)/2])
-
- default:
- return "", registry.ErrUnexpectedType
- }
-
- if valueType != registry.EXPAND_SZ {
- // Value does not require expansion.
- return value, nil
- }
-
- valueExp, err := registry.ExpandString(value)
- if err != nil {
- // Expanding failed: return original sting value.
- return value, nil
- }
-
- // Return expanded value.
- return valueExp, nil
-}
-
-func toInteger(buf []byte, valueType uint32, err error) (uint64, error) {
- if err != nil {
- return 0, err
- }
-
- switch valueType {
- case registry.DWORD:
- if len(buf) != 4 {
- return 0, errors.New("DWORD value is not 4 bytes long")
- }
- var val uint32
- copy((*[4]byte)(unsafe.Pointer(&val))[:], buf)
- return uint64(val), nil
-
- case registry.QWORD:
- if len(buf) != 8 {
- return 0, errors.New("QWORD value is not 8 bytes long")
- }
- var val uint64
- copy((*[8]byte)(unsafe.Pointer(&val))[:], buf)
- return val, nil
-
- default:
- return 0, registry.ErrUnexpectedType
- }
-}
-
-//
-// GetStringValueWait function reads a string value from registry. It waits
-// for the registry value to become available or returns error on timeout.
-//
-// Key must be opened with at least QUERY_VALUE|NOTIFY access.
-//
-// If the value type is REG_EXPAND_SZ the environment variables are expanded.
-// Should expanding fail, original string value and nil error are returned.
-//
-// If the value type is REG_MULTI_SZ only the first string is returned.
-//
-func GetStringValueWait(key registry.Key, name string, timeout time.Duration) (string, error) {
- return toString(getValueRetry(key, name, make([]byte, 256), timeout))
-}
-
-//
-// GetStringValue function reads a string value from registry.
-//
-// Key must be opened with at least QUERY_VALUE access.
-//
-// If the value type is REG_EXPAND_SZ the environment variables are expanded.
-// Should expanding fail, original string value and nil error are returned.
-//
-// If the value type is REG_MULTI_SZ only the first string is returned.
-//
-func GetStringValue(key registry.Key, name string) (string, error) {
- return toString(getValue(key, name, make([]byte, 256)))
-}
-
-//
-// GetIntegerValueWait function reads a DWORD32 or QWORD value from registry.
-// It waits for the registry value to become available or returns error on
-// timeout.
-//
-// Key must be opened with at least QUERY_VALUE|NOTIFY access.
-//
-func GetIntegerValueWait(key registry.Key, name string, timeout time.Duration) (uint64, error) {
- return toInteger(getValueRetry(key, name, make([]byte, 8), timeout))
-}
diff --git a/tun/wintun/registry/registry_windows_test.go b/tun/wintun/registry/registry_windows_test.go
deleted file mode 100644
index c56b51b..0000000
--- a/tun/wintun/registry/registry_windows_test.go
+++ /dev/null
@@ -1,103 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package registry
-
-import (
- "testing"
- "time"
-
- "golang.org/x/sys/windows/registry"
-)
-
-const keyRoot = registry.CURRENT_USER
-const pathRoot = "Software\\WireGuardRegistryTest"
-const path = pathRoot + "\\foobar"
-const pathFake = pathRoot + "\\raboof"
-
-func Test_WaitForKey(t *testing.T) {
- registry.DeleteKey(keyRoot, path)
- registry.DeleteKey(keyRoot, pathRoot)
- go func() {
- time.Sleep(time.Second * 1)
- key, _, err := registry.CreateKey(keyRoot, pathFake, registry.QUERY_VALUE)
- if err != nil {
- t.Errorf("Error creating registry key: %v", err)
- }
- key.Close()
- registry.DeleteKey(keyRoot, pathFake)
-
- key, _, err = registry.CreateKey(keyRoot, path, registry.QUERY_VALUE)
- if err != nil {
- t.Errorf("Error creating registry key: %v", err)
- }
- key.Close()
- }()
- err := WaitForKey(keyRoot, path, time.Second*2)
- if err != nil {
- t.Errorf("Error waiting for registry key: %v", err)
- }
- registry.DeleteKey(keyRoot, path)
- registry.DeleteKey(keyRoot, pathRoot)
-
- err = WaitForKey(keyRoot, path, time.Second*1)
- if err == nil {
- t.Error("Registry key notification expected to timeout but it succeeded.")
- }
-}
-
-func Test_GetValueWait(t *testing.T) {
- registry.DeleteKey(keyRoot, path)
- registry.DeleteKey(keyRoot, pathRoot)
- go func() {
- time.Sleep(time.Second * 1)
- key, _, err := registry.CreateKey(keyRoot, path, registry.SET_VALUE)
- if err != nil {
- t.Errorf("Error creating registry key: %v", err)
- }
- time.Sleep(time.Second * 1)
- key.SetStringValue("name1", "eulav")
- key.SetExpandStringValue("name2", "value")
- time.Sleep(time.Second * 1)
- key.SetDWordValue("name3", ^uint32(123))
- key.SetDWordValue("name4", 123)
- key.Close()
- }()
-
- key, err := OpenKeyWait(keyRoot, path, registry.QUERY_VALUE|registry.NOTIFY, time.Second*2)
- if err != nil {
- t.Errorf("Error waiting for registry key: %v", err)
- }
-
- valueStr, err := GetStringValueWait(key, "name2", time.Second*2)
- if err != nil {
- t.Errorf("Error waiting for registry value: %v", err)
- }
- if valueStr != "value" {
- t.Errorf("Wrong value read: %v", valueStr)
- }
-
- _, err = GetStringValueWait(key, "nonexisting", time.Second*1)
- if err == nil {
- t.Error("Registry value notification expected to timeout but it succeeded.")
- }
-
- valueInt, err := GetIntegerValueWait(key, "name4", time.Second*2)
- if err != nil {
- t.Errorf("Error waiting for registry value: %v", err)
- }
- if valueInt != 123 {
- t.Errorf("Wrong value read: %v", valueInt)
- }
-
- _, err = GetIntegerValueWait(key, "nonexisting", time.Second*1)
- if err == nil {
- t.Error("Registry value notification expected to timeout but it succeeded.")
- }
-
- key.Close()
- registry.DeleteKey(keyRoot, path)
- registry.DeleteKey(keyRoot, pathRoot)
-}
diff --git a/tun/wintun/registry/zregistry_windows.go b/tun/wintun/registry/zregistry_windows.go
deleted file mode 100644
index f7ac33b..0000000
--- a/tun/wintun/registry/zregistry_windows.go
+++ /dev/null
@@ -1,63 +0,0 @@
-// Code generated by 'go generate'; DO NOT EDIT.
-
-package registry
-
-import (
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-var _ unsafe.Pointer
-
-// Do the interface allocations only once for common
-// Errno values.
-const (
- errnoERROR_IO_PENDING = 997
-)
-
-var (
- errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
-)
-
-// errnoErr returns common boxed Errno values, to prevent
-// allocations at runtime.
-func errnoErr(e syscall.Errno) error {
- switch e {
- case 0:
- return nil
- case errnoERROR_IO_PENDING:
- return errERROR_IO_PENDING
- }
- // TODO: add more here, after collecting data on the common
- // error values see on Windows. (perhaps when running
- // all.bat?)
- return e
-}
-
-var (
- modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
-
- procRegNotifyChangeKeyValue = modadvapi32.NewProc("RegNotifyChangeKeyValue")
-)
-
-func regNotifyChangeKeyValue(key windows.Handle, watchSubtree bool, notifyFilter uint32, event windows.Handle, asynchronous bool) (regerrno error) {
- var _p0 uint32
- if watchSubtree {
- _p0 = 1
- } else {
- _p0 = 0
- }
- var _p1 uint32
- if asynchronous {
- _p1 = 1
- } else {
- _p1 = 0
- }
- r0, _, _ := syscall.Syscall6(procRegNotifyChangeKeyValue.Addr(), 5, uintptr(key), uintptr(_p0), uintptr(notifyFilter), uintptr(event), uintptr(_p1), 0)
- if r0 != 0 {
- regerrno = syscall.Errno(r0)
- }
- return
-}
diff --git a/tun/wintun/ring_windows.go b/tun/wintun/ring_windows.go
deleted file mode 100644
index 8f46bc9..0000000
--- a/tun/wintun/ring_windows.go
+++ /dev/null
@@ -1,97 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package wintun
-
-import (
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-const (
- PacketAlignment = 4 // Number of bytes packets are aligned to in rings
- PacketSizeMax = 0xffff // Maximum packet size
- PacketCapacity = 0x800000 // Ring capacity, 8MiB
- PacketTrailingSize = uint32(unsafe.Sizeof(PacketHeader{})) + ((PacketSizeMax + (PacketAlignment - 1)) &^ (PacketAlignment - 1)) - PacketAlignment
- ioctlRegisterRings = (51820 << 16) | (0x970 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14)
-)
-
-type PacketHeader struct {
- Size uint32
-}
-
-type Packet struct {
- PacketHeader
- Data [PacketSizeMax]byte
-}
-
-type Ring struct {
- Head uint32
- Tail uint32
- Alertable int32
- Data [PacketCapacity + PacketTrailingSize]byte
-}
-
-type RingDescriptor struct {
- Send, Receive struct {
- Size uint32
- Ring *Ring
- TailMoved windows.Handle
- }
-}
-
-// Wrap returns value modulo ring capacity
-func (rb *Ring) Wrap(value uint32) uint32 {
- return value & (PacketCapacity - 1)
-}
-
-// Aligns a packet size to PacketAlignment
-func PacketAlign(size uint32) uint32 {
- return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1)
-}
-
-func (descriptor *RingDescriptor) Init() (err error) {
- descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{}))
- descriptor.Send.Ring = &Ring{}
- descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
- if err != nil {
- return
- }
-
- descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{}))
- descriptor.Receive.Ring = &Ring{}
- descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
- if err != nil {
- windows.CloseHandle(descriptor.Send.TailMoved)
- return
- }
-
- return
-}
-
-func (descriptor *RingDescriptor) Close() {
- if descriptor.Send.TailMoved != 0 {
- windows.CloseHandle(descriptor.Send.TailMoved)
- descriptor.Send.TailMoved = 0
- }
- if descriptor.Send.TailMoved != 0 {
- windows.CloseHandle(descriptor.Receive.TailMoved)
- descriptor.Receive.TailMoved = 0
- }
-}
-
-func (wintun *Interface) Register(descriptor *RingDescriptor) (windows.Handle, error) {
- handle, err := wintun.handle()
- if err != nil {
- return 0, err
- }
- var bytesReturned uint32
- err = windows.DeviceIoControl(handle, ioctlRegisterRings, (*byte)(unsafe.Pointer(descriptor)), uint32(unsafe.Sizeof(*descriptor)), nil, 0, &bytesReturned, nil)
- if err != nil {
- return 0, err
- }
- return handle, nil
-}
diff --git a/tun/wintun/setupapi/mksyscall.go b/tun/wintun/setupapi/mksyscall.go
deleted file mode 100644
index ac103a1..0000000
--- a/tun/wintun/setupapi/mksyscall.go
+++ /dev/null
@@ -1,8 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package setupapi
-
-//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsetupapi_windows.go setupapi_windows.go
diff --git a/tun/wintun/setupapi/setupapi_windows.go b/tun/wintun/setupapi/setupapi_windows.go
deleted file mode 100644
index 60a8eb7..0000000
--- a/tun/wintun/setupapi/setupapi_windows.go
+++ /dev/null
@@ -1,506 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package setupapi
-
-import (
- "encoding/binary"
- "fmt"
- "runtime"
- "unsafe"
-
- "golang.org/x/sys/windows"
- "golang.org/x/sys/windows/registry"
-)
-
-//sys setupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName *uint16, reserved uintptr) (handle DevInfo, err error) [failretval==DevInfo(windows.InvalidHandle)] = setupapi.SetupDiCreateDeviceInfoListExW
-
-// SetupDiCreateDeviceInfoListEx function creates an empty device information set on a remote or a local computer and optionally associates the set with a device setup class.
-func SetupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName string) (deviceInfoSet DevInfo, err error) {
- var machineNameUTF16 *uint16
- if machineName != "" {
- machineNameUTF16, err = windows.UTF16PtrFromString(machineName)
- if err != nil {
- return
- }
- }
- return setupDiCreateDeviceInfoListEx(classGUID, hwndParent, machineNameUTF16, 0)
-}
-
-//sys setupDiGetDeviceInfoListDetail(deviceInfoSet DevInfo, deviceInfoSetDetailData *DevInfoListDetailData) (err error) = setupapi.SetupDiGetDeviceInfoListDetailW
-
-// SetupDiGetDeviceInfoListDetail function retrieves information associated with a device information set including the class GUID, remote computer handle, and remote computer name.
-func SetupDiGetDeviceInfoListDetail(deviceInfoSet DevInfo) (deviceInfoSetDetailData *DevInfoListDetailData, err error) {
- data := &DevInfoListDetailData{}
- data.size = sizeofDevInfoListDetailData
-
- return data, setupDiGetDeviceInfoListDetail(deviceInfoSet, data)
-}
-
-// DeviceInfoListDetail method retrieves information associated with a device information set including the class GUID, remote computer handle, and remote computer name.
-func (deviceInfoSet DevInfo) DeviceInfoListDetail() (*DevInfoListDetailData, error) {
- return SetupDiGetDeviceInfoListDetail(deviceInfoSet)
-}
-
-//sys setupDiCreateDeviceInfo(deviceInfoSet DevInfo, DeviceName *uint16, classGUID *windows.GUID, DeviceDescription *uint16, hwndParent uintptr, CreationFlags DICD, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiCreateDeviceInfoW
-
-// SetupDiCreateDeviceInfo function creates a new device information element and adds it as a new member to the specified device information set.
-func SetupDiCreateDeviceInfo(deviceInfoSet DevInfo, deviceName string, classGUID *windows.GUID, deviceDescription string, hwndParent uintptr, creationFlags DICD) (deviceInfoData *DevInfoData, err error) {
- deviceNameUTF16, err := windows.UTF16PtrFromString(deviceName)
- if err != nil {
- return
- }
-
- var deviceDescriptionUTF16 *uint16
- if deviceDescription != "" {
- deviceDescriptionUTF16, err = windows.UTF16PtrFromString(deviceDescription)
- if err != nil {
- return
- }
- }
-
- data := &DevInfoData{}
- data.size = uint32(unsafe.Sizeof(*data))
-
- return data, setupDiCreateDeviceInfo(deviceInfoSet, deviceNameUTF16, classGUID, deviceDescriptionUTF16, hwndParent, creationFlags, data)
-}
-
-// CreateDeviceInfo method creates a new device information element and adds it as a new member to the specified device information set.
-func (deviceInfoSet DevInfo) CreateDeviceInfo(deviceName string, classGUID *windows.GUID, deviceDescription string, hwndParent uintptr, creationFlags DICD) (*DevInfoData, error) {
- return SetupDiCreateDeviceInfo(deviceInfoSet, deviceName, classGUID, deviceDescription, hwndParent, creationFlags)
-}
-
-//sys setupDiEnumDeviceInfo(deviceInfoSet DevInfo, memberIndex uint32, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiEnumDeviceInfo
-
-// SetupDiEnumDeviceInfo function returns a DevInfoData structure that specifies a device information element in a device information set.
-func SetupDiEnumDeviceInfo(deviceInfoSet DevInfo, memberIndex int) (*DevInfoData, error) {
- data := &DevInfoData{}
- data.size = uint32(unsafe.Sizeof(*data))
-
- return data, setupDiEnumDeviceInfo(deviceInfoSet, uint32(memberIndex), data)
-}
-
-// EnumDeviceInfo method returns a DevInfoData structure that specifies a device information element in a device information set.
-func (deviceInfoSet DevInfo) EnumDeviceInfo(memberIndex int) (*DevInfoData, error) {
- return SetupDiEnumDeviceInfo(deviceInfoSet, memberIndex)
-}
-
-// SetupDiDestroyDeviceInfoList function deletes a device information set and frees all associated memory.
-//sys SetupDiDestroyDeviceInfoList(deviceInfoSet DevInfo) (err error) = setupapi.SetupDiDestroyDeviceInfoList
-
-// Close method deletes a device information set and frees all associated memory.
-func (deviceInfoSet DevInfo) Close() error {
- return SetupDiDestroyDeviceInfoList(deviceInfoSet)
-}
-
-//sys SetupDiBuildDriverInfoList(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT) (err error) = setupapi.SetupDiBuildDriverInfoList
-
-// BuildDriverInfoList method builds a list of drivers that is associated with a specific device or with the global class driver list for a device information set.
-func (deviceInfoSet DevInfo) BuildDriverInfoList(deviceInfoData *DevInfoData, driverType SPDIT) error {
- return SetupDiBuildDriverInfoList(deviceInfoSet, deviceInfoData, driverType)
-}
-
-//sys SetupDiCancelDriverInfoSearch(deviceInfoSet DevInfo) (err error) = setupapi.SetupDiCancelDriverInfoSearch
-
-// CancelDriverInfoSearch method cancels a driver list search that is currently in progress in a different thread.
-func (deviceInfoSet DevInfo) CancelDriverInfoSearch() error {
- return SetupDiCancelDriverInfoSearch(deviceInfoSet)
-}
-
-//sys setupDiEnumDriverInfo(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT, memberIndex uint32, driverInfoData *DrvInfoData) (err error) = setupapi.SetupDiEnumDriverInfoW
-
-// SetupDiEnumDriverInfo function enumerates the members of a driver list.
-func SetupDiEnumDriverInfo(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT, memberIndex int) (*DrvInfoData, error) {
- data := &DrvInfoData{}
- data.size = uint32(unsafe.Sizeof(*data))
-
- return data, setupDiEnumDriverInfo(deviceInfoSet, deviceInfoData, driverType, uint32(memberIndex), data)
-}
-
-// EnumDriverInfo method enumerates the members of a driver list.
-func (deviceInfoSet DevInfo) EnumDriverInfo(deviceInfoData *DevInfoData, driverType SPDIT, memberIndex int) (*DrvInfoData, error) {
- return SetupDiEnumDriverInfo(deviceInfoSet, deviceInfoData, driverType, memberIndex)
-}
-
-//sys setupDiGetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (err error) = setupapi.SetupDiGetSelectedDriverW
-
-// SetupDiGetSelectedDriver function retrieves the selected driver for a device information set or a particular device information element.
-func SetupDiGetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (*DrvInfoData, error) {
- data := &DrvInfoData{}
- data.size = uint32(unsafe.Sizeof(*data))
-
- return data, setupDiGetSelectedDriver(deviceInfoSet, deviceInfoData, data)
-}
-
-// SelectedDriver method retrieves the selected driver for a device information set or a particular device information element.
-func (deviceInfoSet DevInfo) SelectedDriver(deviceInfoData *DevInfoData) (*DrvInfoData, error) {
- return SetupDiGetSelectedDriver(deviceInfoSet, deviceInfoData)
-}
-
-//sys SetupDiSetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (err error) = setupapi.SetupDiSetSelectedDriverW
-
-// SetSelectedDriver method sets, or resets, the selected driver for a device information element or the selected class driver for a device information set.
-func (deviceInfoSet DevInfo) SetSelectedDriver(deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) error {
- return SetupDiSetSelectedDriver(deviceInfoSet, deviceInfoData, driverInfoData)
-}
-
-//sys setupDiGetDriverInfoDetail(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData, driverInfoDetailData *DrvInfoDetailData, driverInfoDetailDataSize uint32, requiredSize *uint32) (err error) = setupapi.SetupDiGetDriverInfoDetailW
-
-// SetupDiGetDriverInfoDetail function retrieves driver information detail for a device information set or a particular device information element in the device information set.
-func SetupDiGetDriverInfoDetail(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (*DrvInfoDetailData, error) {
- reqSize := uint32(2048)
- for {
- buf := make([]byte, reqSize)
- data := (*DrvInfoDetailData)(unsafe.Pointer(&buf[0]))
- data.size = sizeofDrvInfoDetailData
- err := setupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData, data, uint32(len(buf)), &reqSize)
- if err == windows.ERROR_INSUFFICIENT_BUFFER {
- continue
- }
- if err != nil {
- return nil, err
- }
- data.size = reqSize
- return data, nil
- }
-}
-
-// DriverInfoDetail method retrieves driver information detail for a device information set or a particular device information element in the device information set.
-func (deviceInfoSet DevInfo) DriverInfoDetail(deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (*DrvInfoDetailData, error) {
- return SetupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData)
-}
-
-//sys SetupDiDestroyDriverInfoList(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT) (err error) = setupapi.SetupDiDestroyDriverInfoList
-
-// DestroyDriverInfoList method deletes a driver list.
-func (deviceInfoSet DevInfo) DestroyDriverInfoList(deviceInfoData *DevInfoData, driverType SPDIT) error {
- return SetupDiDestroyDriverInfoList(deviceInfoSet, deviceInfoData, driverType)
-}
-
-//sys setupDiGetClassDevsEx(classGUID *windows.GUID, Enumerator *uint16, hwndParent uintptr, Flags DIGCF, deviceInfoSet DevInfo, machineName *uint16, reserved uintptr) (handle DevInfo, err error) [failretval==DevInfo(windows.InvalidHandle)] = setupapi.SetupDiGetClassDevsExW
-
-// SetupDiGetClassDevsEx function returns a handle to a device information set that contains requested device information elements for a local or a remote computer.
-func SetupDiGetClassDevsEx(classGUID *windows.GUID, enumerator string, hwndParent uintptr, flags DIGCF, deviceInfoSet DevInfo, machineName string) (handle DevInfo, err error) {
- var enumeratorUTF16 *uint16
- if enumerator != "" {
- enumeratorUTF16, err = windows.UTF16PtrFromString(enumerator)
- if err != nil {
- return
- }
- }
- var machineNameUTF16 *uint16
- if machineName != "" {
- machineNameUTF16, err = windows.UTF16PtrFromString(machineName)
- if err != nil {
- return
- }
- }
- return setupDiGetClassDevsEx(classGUID, enumeratorUTF16, hwndParent, flags, deviceInfoSet, machineNameUTF16, 0)
-}
-
-// SetupDiCallClassInstaller function calls the appropriate class installer, and any registered co-installers, with the specified installation request (DIF code).
-//sys SetupDiCallClassInstaller(installFunction DI_FUNCTION, deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiCallClassInstaller
-
-// CallClassInstaller member calls the appropriate class installer, and any registered co-installers, with the specified installation request (DIF code).
-func (deviceInfoSet DevInfo) CallClassInstaller(installFunction DI_FUNCTION, deviceInfoData *DevInfoData) error {
- return SetupDiCallClassInstaller(installFunction, deviceInfoSet, deviceInfoData)
-}
-
-//sys setupDiOpenDevRegKey(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, Scope DICS_FLAG, HwProfile uint32, KeyType DIREG, samDesired uint32) (key windows.Handle, err error) [failretval==windows.InvalidHandle] = setupapi.SetupDiOpenDevRegKey
-
-// SetupDiOpenDevRegKey function opens a registry key for device-specific configuration information.
-func SetupDiOpenDevRegKey(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, scope DICS_FLAG, hwProfile uint32, keyType DIREG, samDesired uint32) (registry.Key, error) {
- handle, err := setupDiOpenDevRegKey(deviceInfoSet, deviceInfoData, scope, hwProfile, keyType, samDesired)
- return registry.Key(handle), err
-}
-
-// OpenDevRegKey method opens a registry key for device-specific configuration information.
-func (deviceInfoSet DevInfo) OpenDevRegKey(DeviceInfoData *DevInfoData, Scope DICS_FLAG, HwProfile uint32, KeyType DIREG, samDesired uint32) (registry.Key, error) {
- return SetupDiOpenDevRegKey(deviceInfoSet, DeviceInfoData, Scope, HwProfile, KeyType, samDesired)
-}
-
-//sys setupDiGetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyRegDataType *uint32, propertyBuffer *byte, propertyBufferSize uint32, requiredSize *uint32) (err error) = setupapi.SetupDiGetDeviceRegistryPropertyW
-
-// SetupDiGetDeviceRegistryProperty function retrieves a specified Plug and Play device property.
-func SetupDiGetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP) (value interface{}, err error) {
- reqSize := uint32(256)
- for {
- var dataType uint32
- buf := make([]byte, reqSize)
- err = setupDiGetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, &dataType, &buf[0], uint32(len(buf)), &reqSize)
- if err == windows.ERROR_INSUFFICIENT_BUFFER {
- continue
- }
- if err != nil {
- return
- }
- return getRegistryValue(buf[:reqSize], dataType)
- }
-}
-
-func getRegistryValue(buf []byte, dataType uint32) (interface{}, error) {
- switch dataType {
- case windows.REG_SZ:
- ret := windows.UTF16ToString(bufToUTF16(buf))
- runtime.KeepAlive(buf)
- return ret, nil
- case windows.REG_EXPAND_SZ:
- ret, err := registry.ExpandString(windows.UTF16ToString(bufToUTF16(buf)))
- runtime.KeepAlive(buf)
- return ret, err
- case windows.REG_BINARY:
- return buf, nil
- case windows.REG_DWORD_LITTLE_ENDIAN:
- return binary.LittleEndian.Uint32(buf), nil
- case windows.REG_DWORD_BIG_ENDIAN:
- return binary.BigEndian.Uint32(buf), nil
- case windows.REG_MULTI_SZ:
- bufW := bufToUTF16(buf)
- a := []string{}
- for i := 0; i < len(bufW); {
- j := i + wcslen(bufW[i:])
- if i < j {
- a = append(a, windows.UTF16ToString(bufW[i:j]))
- }
- i = j + 1
- }
- runtime.KeepAlive(buf)
- return a, nil
- case windows.REG_QWORD_LITTLE_ENDIAN:
- return binary.LittleEndian.Uint64(buf), nil
- default:
- return nil, fmt.Errorf("Unsupported registry value type: %v", dataType)
- }
-}
-
-// bufToUTF16 function reinterprets []byte buffer as []uint16
-func bufToUTF16(buf []byte) []uint16 {
- sl := struct {
- addr *uint16
- len int
- cap int
- }{(*uint16)(unsafe.Pointer(&buf[0])), len(buf) / 2, cap(buf) / 2}
- return *(*[]uint16)(unsafe.Pointer(&sl))
-}
-
-// utf16ToBuf function reinterprets []uint16 as []byte
-func utf16ToBuf(buf []uint16) []byte {
- sl := struct {
- addr *byte
- len int
- cap int
- }{(*byte)(unsafe.Pointer(&buf[0])), len(buf) * 2, cap(buf) * 2}
- return *(*[]byte)(unsafe.Pointer(&sl))
-}
-
-func wcslen(str []uint16) int {
- for i := 0; i < len(str); i++ {
- if str[i] == 0 {
- return i
- }
- }
- return len(str)
-}
-
-// DeviceRegistryProperty method retrieves a specified Plug and Play device property.
-func (deviceInfoSet DevInfo) DeviceRegistryProperty(deviceInfoData *DevInfoData, property SPDRP) (interface{}, error) {
- return SetupDiGetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property)
-}
-
-//sys setupDiSetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyBuffer *byte, propertyBufferSize uint32) (err error) = setupapi.SetupDiSetDeviceRegistryPropertyW
-
-// SetupDiSetDeviceRegistryProperty function sets a Plug and Play device property for a device.
-func SetupDiSetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyBuffers []byte) error {
- return setupDiSetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, &propertyBuffers[0], uint32(len(propertyBuffers)))
-}
-
-// SetDeviceRegistryProperty function sets a Plug and Play device property for a device.
-func (deviceInfoSet DevInfo) SetDeviceRegistryProperty(deviceInfoData *DevInfoData, property SPDRP, propertyBuffers []byte) error {
- return SetupDiSetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, propertyBuffers)
-}
-
-// SetDeviceRegistryPropertyString method sets a Plug and Play device property string for a device.
-func (deviceInfoSet DevInfo) SetDeviceRegistryPropertyString(deviceInfoData *DevInfoData, property SPDRP, str string) error {
- str16, err := windows.UTF16FromString(str)
- if err != nil {
- return err
- }
- err = SetupDiSetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, utf16ToBuf(append(str16, 0)))
- runtime.KeepAlive(str16)
- return err
-}
-
-//sys setupDiGetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) (err error) = setupapi.SetupDiGetDeviceInstallParamsW
-
-// SetupDiGetDeviceInstallParams function retrieves device installation parameters for a device information set or a particular device information element.
-func SetupDiGetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (*DevInstallParams, error) {
- params := &DevInstallParams{}
- params.size = uint32(unsafe.Sizeof(*params))
-
- return params, setupDiGetDeviceInstallParams(deviceInfoSet, deviceInfoData, params)
-}
-
-// DeviceInstallParams method retrieves device installation parameters for a device information set or a particular device information element.
-func (deviceInfoSet DevInfo) DeviceInstallParams(deviceInfoData *DevInfoData) (*DevInstallParams, error) {
- return SetupDiGetDeviceInstallParams(deviceInfoSet, deviceInfoData)
-}
-
-//sys setupDiGetDeviceInstanceId(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, instanceId *uint16, instanceIdSize uint32, instanceIdRequiredSize *uint32) (err error) = setupapi.SetupDiGetDeviceInstanceIdW
-
-// SetupDiGetDeviceInstanceId function retrieves the instance ID of the device.
-func SetupDiGetDeviceInstanceId(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (string, error) {
- reqSize := uint32(1024)
- for {
- buf := make([]uint16, reqSize)
- err := setupDiGetDeviceInstanceId(deviceInfoSet, deviceInfoData, &buf[0], uint32(len(buf)), &reqSize)
- if err == windows.ERROR_INSUFFICIENT_BUFFER {
- continue
- }
- if err != nil {
- return "", err
- }
- return windows.UTF16ToString(buf), nil
- }
-}
-
-// DeviceInstanceID method retrieves the instance ID of the device.
-func (deviceInfoSet DevInfo) DeviceInstanceID(deviceInfoData *DevInfoData) (string, error) {
- return SetupDiGetDeviceInstanceId(deviceInfoSet, deviceInfoData)
-}
-
-// SetupDiGetClassInstallParams function retrieves class installation parameters for a device information set or a particular device information element.
-//sys SetupDiGetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) (err error) = setupapi.SetupDiGetClassInstallParamsW
-
-// ClassInstallParams method retrieves class installation parameters for a device information set or a particular device information element.
-func (deviceInfoSet DevInfo) ClassInstallParams(deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) error {
- return SetupDiGetClassInstallParams(deviceInfoSet, deviceInfoData, classInstallParams, classInstallParamsSize, requiredSize)
-}
-
-//sys SetupDiSetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) (err error) = setupapi.SetupDiSetDeviceInstallParamsW
-
-// SetDeviceInstallParams member sets device installation parameters for a device information set or a particular device information element.
-func (deviceInfoSet DevInfo) SetDeviceInstallParams(deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) error {
- return SetupDiSetDeviceInstallParams(deviceInfoSet, deviceInfoData, deviceInstallParams)
-}
-
-// SetupDiSetClassInstallParams function sets or clears class install parameters for a device information set or a particular device information element.
-//sys SetupDiSetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32) (err error) = setupapi.SetupDiSetClassInstallParamsW
-
-// SetClassInstallParams method sets or clears class install parameters for a device information set or a particular device information element.
-func (deviceInfoSet DevInfo) SetClassInstallParams(deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32) error {
- return SetupDiSetClassInstallParams(deviceInfoSet, deviceInfoData, classInstallParams, classInstallParamsSize)
-}
-
-//sys setupDiClassNameFromGuidEx(classGUID *windows.GUID, className *uint16, classNameSize uint32, requiredSize *uint32, machineName *uint16, reserved uintptr) (err error) = setupapi.SetupDiClassNameFromGuidExW
-
-// SetupDiClassNameFromGuidEx function retrieves the class name associated with a class GUID. The class can be installed on a local or remote computer.
-func SetupDiClassNameFromGuidEx(classGUID *windows.GUID, machineName string) (className string, err error) {
- var classNameUTF16 [MAX_CLASS_NAME_LEN]uint16
-
- var machineNameUTF16 *uint16
- if machineName != "" {
- machineNameUTF16, err = windows.UTF16PtrFromString(machineName)
- if err != nil {
- return
- }
- }
-
- err = setupDiClassNameFromGuidEx(classGUID, &classNameUTF16[0], MAX_CLASS_NAME_LEN, nil, machineNameUTF16, 0)
- if err != nil {
- return
- }
-
- className = windows.UTF16ToString(classNameUTF16[:])
- return
-}
-
-//sys setupDiClassGuidsFromNameEx(className *uint16, classGuidList *windows.GUID, classGuidListSize uint32, requiredSize *uint32, machineName *uint16, reserved uintptr) (err error) = setupapi.SetupDiClassGuidsFromNameExW
-
-// SetupDiClassGuidsFromNameEx function retrieves the GUIDs associated with the specified class name. This resulting list contains the classes currently installed on a local or remote computer.
-func SetupDiClassGuidsFromNameEx(className string, machineName string) ([]windows.GUID, error) {
- classNameUTF16, err := windows.UTF16PtrFromString(className)
- if err != nil {
- return nil, err
- }
-
- var machineNameUTF16 *uint16
- if machineName != "" {
- machineNameUTF16, err = windows.UTF16PtrFromString(machineName)
- if err != nil {
- return nil, err
- }
- }
-
- reqSize := uint32(4)
- for {
- buf := make([]windows.GUID, reqSize)
- err = setupDiClassGuidsFromNameEx(classNameUTF16, &buf[0], uint32(len(buf)), &reqSize, machineNameUTF16, 0)
- if err == windows.ERROR_INSUFFICIENT_BUFFER {
- continue
- }
- if err != nil {
- return nil, err
- }
- return buf[:reqSize], nil
- }
-}
-
-//sys setupDiGetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiGetSelectedDevice
-
-// SetupDiGetSelectedDevice function retrieves the selected device information element in a device information set.
-func SetupDiGetSelectedDevice(deviceInfoSet DevInfo) (*DevInfoData, error) {
- data := &DevInfoData{}
- data.size = uint32(unsafe.Sizeof(*data))
-
- return data, setupDiGetSelectedDevice(deviceInfoSet, data)
-}
-
-// SelectedDevice method retrieves the selected device information element in a device information set.
-func (deviceInfoSet DevInfo) SelectedDevice() (*DevInfoData, error) {
- return SetupDiGetSelectedDevice(deviceInfoSet)
-}
-
-// SetupDiSetSelectedDevice function sets a device information element as the selected member of a device information set. This function is typically used by an installation wizard.
-//sys SetupDiSetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiSetSelectedDevice
-
-// SetSelectedDevice method sets a device information element as the selected member of a device information set. This function is typically used by an installation wizard.
-func (deviceInfoSet DevInfo) SetSelectedDevice(deviceInfoData *DevInfoData) error {
- return SetupDiSetSelectedDevice(deviceInfoSet, deviceInfoData)
-}
-
-//sys cm_Get_Device_Interface_List_Size(len *uint32, interfaceClass *windows.GUID, deviceID *uint16, flags uint32) (ret uint32) = CfgMgr32.CM_Get_Device_Interface_List_SizeW
-//sys cm_Get_Device_Interface_List(interfaceClass *windows.GUID, deviceID *uint16, buffer *uint16, bufferLen uint32, flags uint32) (ret uint32) = CfgMgr32.CM_Get_Device_Interface_ListW
-
-func CM_Get_Device_Interface_List(deviceID string, interfaceClass *windows.GUID, flags uint32) ([]string, error) {
- deviceID16, err := windows.UTF16PtrFromString(deviceID)
- if err != nil {
- return nil, err
- }
- var buf []uint16
- var buflen uint32
- for {
- if ret := cm_Get_Device_Interface_List_Size(&buflen, interfaceClass, deviceID16, flags); ret != CR_SUCCESS {
- return nil, fmt.Errorf("CfgMgr error: 0x%x", ret)
- }
- buf = make([]uint16, buflen)
- if ret := cm_Get_Device_Interface_List(interfaceClass, deviceID16, &buf[0], buflen, flags); ret == CR_SUCCESS {
- break
- } else if ret != CR_BUFFER_SMALL {
- return nil, fmt.Errorf("CfgMgr error: 0x%x", ret)
- }
- }
- var interfaces []string
- for i := 0; i < len(buf); {
- j := i + wcslen(buf[i:])
- if i < j {
- interfaces = append(interfaces, windows.UTF16ToString(buf[i:j]))
- }
- i = j + 1
- }
- if interfaces == nil {
- return nil, fmt.Errorf("no interfaces found")
- }
- return interfaces, nil
-}
diff --git a/tun/wintun/setupapi/setupapi_windows_test.go b/tun/wintun/setupapi/setupapi_windows_test.go
deleted file mode 100644
index a9e6b89..0000000
--- a/tun/wintun/setupapi/setupapi_windows_test.go
+++ /dev/null
@@ -1,488 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package setupapi
-
-import (
- "runtime"
- "strings"
- "testing"
-
- "golang.org/x/sys/windows"
-)
-
-var deviceClassNetGUID = windows.GUID{Data1: 0x4d36e972, Data2: 0xe325, Data3: 0x11ce, Data4: [8]byte{0xbf, 0xc1, 0x08, 0x00, 0x2b, 0xe1, 0x03, 0x18}}
-var computerName string
-
-func init() {
- computerName, _ = windows.ComputerName()
-}
-
-func TestSetupDiCreateDeviceInfoListEx(t *testing.T) {
- devInfoList, err := SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "")
- if err != nil {
- t.Errorf("Error calling SetupDiCreateDeviceInfoListEx: %s", err.Error())
- } else {
- devInfoList.Close()
- }
-
- devInfoList, err = SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, computerName)
- if err != nil {
- t.Errorf("Error calling SetupDiCreateDeviceInfoListEx: %s", err.Error())
- } else {
- devInfoList.Close()
- }
-
- devInfoList, err = SetupDiCreateDeviceInfoListEx(nil, 0, "")
- if err != nil {
- t.Errorf("Error calling SetupDiCreateDeviceInfoListEx(nil): %s", err.Error())
- } else {
- devInfoList.Close()
- }
-}
-
-func TestSetupDiGetDeviceInfoListDetail(t *testing.T) {
- devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "")
- if err != nil {
- t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error())
- }
- defer devInfoList.Close()
-
- data, err := devInfoList.DeviceInfoListDetail()
- if err != nil {
- t.Errorf("Error calling SetupDiGetDeviceInfoListDetail: %s", err.Error())
- } else {
- if data.ClassGUID != deviceClassNetGUID {
- t.Error("SetupDiGetDeviceInfoListDetail returned different class GUID")
- }
-
- if data.RemoteMachineHandle != windows.Handle(0) {
- t.Error("SetupDiGetDeviceInfoListDetail returned non-NULL remote machine handle")
- }
-
- if data.RemoteMachineName() != "" {
- t.Error("SetupDiGetDeviceInfoListDetail returned non-NULL remote machine name")
- }
- }
-
- devInfoList, err = SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), computerName)
- if err != nil {
- t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error())
- }
- defer devInfoList.Close()
-
- data, err = devInfoList.DeviceInfoListDetail()
- if err != nil {
- t.Errorf("Error calling SetupDiGetDeviceInfoListDetail: %s", err.Error())
- } else {
- if data.ClassGUID != deviceClassNetGUID {
- t.Error("SetupDiGetDeviceInfoListDetail returned different class GUID")
- }
-
- if data.RemoteMachineHandle == windows.Handle(0) {
- t.Error("SetupDiGetDeviceInfoListDetail returned NULL remote machine handle")
- }
-
- if data.RemoteMachineName() != computerName {
- t.Error("SetupDiGetDeviceInfoListDetail returned different remote machine name")
- }
- }
-
- data = &DevInfoListDetailData{}
- data.SetRemoteMachineName("foobar")
- if data.RemoteMachineName() != "foobar" {
- t.Error("DevInfoListDetailData.(Get|Set)RemoteMachineName() differ")
- }
-}
-
-func TestSetupDiCreateDeviceInfo(t *testing.T) {
- devInfoList, err := SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, computerName)
- if err != nil {
- t.Errorf("Error calling SetupDiCreateDeviceInfoListEx: %s", err.Error())
- }
- defer devInfoList.Close()
-
- deviceClassNetName, err := SetupDiClassNameFromGuidEx(&deviceClassNetGUID, computerName)
- if err != nil {
- t.Errorf("Error calling SetupDiClassNameFromGuidEx: %s", err.Error())
- }
-
- devInfoData, err := devInfoList.CreateDeviceInfo(deviceClassNetName, &deviceClassNetGUID, "This is a test device", 0, DICD_GENERATE_ID)
- if err != nil {
- // Access denied is expected, as the SetupDiCreateDeviceInfo() require elevation to succeed.
- if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_ACCESS_DENIED {
- t.Errorf("Error calling SetupDiCreateDeviceInfo: %s", err.Error())
- }
- } else if devInfoData.ClassGUID != deviceClassNetGUID {
- t.Error("SetupDiCreateDeviceInfo returned different class GUID")
- }
-}
-
-func TestSetupDiEnumDeviceInfo(t *testing.T) {
- devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "")
- if err != nil {
- t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error())
- }
- defer devInfoList.Close()
-
- for i := 0; true; i++ {
- data, err := devInfoList.EnumDeviceInfo(i)
- if err != nil {
- if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- if data.ClassGUID != deviceClassNetGUID {
- t.Error("SetupDiEnumDeviceInfo returned different class GUID")
- }
-
- _, err = devInfoList.DeviceInstanceID(data)
- if err != nil {
- t.Errorf("Error calling SetupDiGetDeviceInstanceId: %s", err.Error())
- }
- }
-}
-
-func TestDevInfo_BuildDriverInfoList(t *testing.T) {
- devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "")
- if err != nil {
- t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error())
- }
- defer devInfoList.Close()
-
- for i := 0; true; i++ {
- deviceData, err := devInfoList.EnumDeviceInfo(i)
- if err != nil {
- if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- const driverType SPDIT = SPDIT_COMPATDRIVER
- err = devInfoList.BuildDriverInfoList(deviceData, driverType)
- if err != nil {
- t.Errorf("Error calling SetupDiBuildDriverInfoList: %s", err.Error())
- }
- defer devInfoList.DestroyDriverInfoList(deviceData, driverType)
-
- var selectedDriverData *DrvInfoData
- for j := 0; true; j++ {
- driverData, err := devInfoList.EnumDriverInfo(deviceData, driverType, j)
- if err != nil {
- if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- if driverData.DriverType == 0 {
- continue
- }
-
- if !driverData.IsNewer(windows.Filetime{}, 0) {
- t.Error("Driver should have non-zero date and version")
- }
- if !driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime}, 0) {
- t.Error("Driver should have non-zero date and version")
- }
- if driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime + 1}, 0) {
- t.Error("Driver should report newer version on high-date-time")
- }
- if !driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime, LowDateTime: driverData.DriverDate.LowDateTime}, 0) {
- t.Error("Driver should have non-zero version")
- }
- if driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime, LowDateTime: driverData.DriverDate.LowDateTime + 1}, 0) {
- t.Error("Driver should report newer version on low-date-time")
- }
- if driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime, LowDateTime: driverData.DriverDate.LowDateTime}, driverData.DriverVersion) {
- t.Error("Driver should not be newer than itself")
- }
- if driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime, LowDateTime: driverData.DriverDate.LowDateTime}, driverData.DriverVersion+1) {
- t.Error("Driver should report newer version on version")
- }
-
- err = devInfoList.SetSelectedDriver(deviceData, driverData)
- if err != nil {
- t.Errorf("Error calling SetupDiSetSelectedDriver: %s", err.Error())
- } else {
- selectedDriverData = driverData
- }
-
- driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData)
- if err != nil {
- t.Errorf("Error calling SetupDiGetDriverInfoDetail: %s", err.Error())
- }
-
- if driverDetailData.IsCompatible("foobar-aab6e3a4-144e-4786-88d3-6cec361e1edd") {
- t.Error("Invalid HWID compatibitlity reported")
- }
- if !driverDetailData.IsCompatible(strings.ToUpper(driverDetailData.HardwareID())) {
- t.Error("HWID compatibitlity missed")
- }
- a := driverDetailData.CompatIDs()
- for k := range a {
- if !driverDetailData.IsCompatible(strings.ToUpper(a[k])) {
- t.Error("HWID compatibitlity missed")
- }
- }
- }
-
- selectedDriverData2, err := devInfoList.SelectedDriver(deviceData)
- if err != nil {
- t.Errorf("Error calling SetupDiGetSelectedDriver: %s", err.Error())
- } else if *selectedDriverData != *selectedDriverData2 {
- t.Error("SetupDiGetSelectedDriver should return driver selected with SetupDiSetSelectedDriver")
- }
- }
-
- data := &DrvInfoData{}
- data.SetDescription("foobar")
- if data.Description() != "foobar" {
- t.Error("DrvInfoData.(Get|Set)Description() differ")
- }
- data.SetMfgName("foobar")
- if data.MfgName() != "foobar" {
- t.Error("DrvInfoData.(Get|Set)MfgName() differ")
- }
- data.SetProviderName("foobar")
- if data.ProviderName() != "foobar" {
- t.Error("DrvInfoData.(Get|Set)ProviderName() differ")
- }
-}
-
-func TestSetupDiGetClassDevsEx(t *testing.T) {
- devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "PCI", 0, DIGCF_PRESENT, DevInfo(0), computerName)
- if err != nil {
- t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error())
- } else {
- devInfoList.Close()
- }
-
- devInfoList, err = SetupDiGetClassDevsEx(nil, "", 0, DIGCF_PRESENT, DevInfo(0), "")
- if err != nil {
- if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_PARAMETER {
- t.Errorf("SetupDiGetClassDevsEx(nil, ...) should fail with ERROR_INVALID_PARAMETER")
- }
- } else {
- devInfoList.Close()
- t.Errorf("SetupDiGetClassDevsEx(nil, ...) should fail")
- }
-}
-
-func TestSetupDiOpenDevRegKey(t *testing.T) {
- devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "")
- if err != nil {
- t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error())
- }
- defer devInfoList.Close()
-
- for i := 0; true; i++ {
- data, err := devInfoList.EnumDeviceInfo(i)
- if err != nil {
- if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- key, err := devInfoList.OpenDevRegKey(data, DICS_FLAG_GLOBAL, 0, DIREG_DRV, windows.KEY_READ)
- if err != nil {
- t.Errorf("Error calling SetupDiOpenDevRegKey: %s", err.Error())
- }
- defer key.Close()
- }
-}
-
-func TestSetupDiGetDeviceRegistryProperty(t *testing.T) {
- devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "")
- if err != nil {
- t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error())
- }
- defer devInfoList.Close()
-
- for i := 0; true; i++ {
- data, err := devInfoList.EnumDeviceInfo(i)
- if err != nil {
- if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- val, err := devInfoList.DeviceRegistryProperty(data, SPDRP_CLASS)
- if err != nil {
- t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CLASS): %s", err.Error())
- } else if class, ok := val.(string); !ok || strings.ToLower(class) != "net" {
- t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASS) should return \"Net\"")
- }
-
- val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_CLASSGUID)
- if err != nil {
- t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID): %s", err.Error())
- } else if valStr, ok := val.(string); !ok {
- t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID) should return string")
- } else {
- classGUID, err := windows.GUIDFromString(valStr)
- if err != nil {
- t.Errorf("Error parsing GUID returned by SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID): %s", err.Error())
- } else if classGUID != deviceClassNetGUID {
- t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID) should return %x", deviceClassNetGUID)
- }
- }
-
- val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_COMPATIBLEIDS)
- if err != nil {
- // Some devices have no SPDRP_COMPATIBLEIDS.
- if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_DATA {
- t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_COMPATIBLEIDS): %s", err.Error())
- }
- }
-
- val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_CONFIGFLAGS)
- if err != nil {
- t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CONFIGFLAGS): %s", err.Error())
- }
-
- val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_DEVICE_POWER_DATA)
- if err != nil {
- t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_DEVICE_POWER_DATA): %s", err.Error())
- }
- }
-}
-
-func TestSetupDiGetDeviceInstallParams(t *testing.T) {
- devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "")
- if err != nil {
- t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error())
- }
- defer devInfoList.Close()
-
- for i := 0; true; i++ {
- data, err := devInfoList.EnumDeviceInfo(i)
- if err != nil {
- if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- _, err = devInfoList.DeviceInstallParams(data)
- if err != nil {
- t.Errorf("Error calling SetupDiGetDeviceInstallParams: %s", err.Error())
- }
- }
-
- params := &DevInstallParams{}
- params.SetDriverPath("foobar")
- if params.DriverPath() != "foobar" {
- t.Error("DevInstallParams.(Get|Set)DriverPath() differ")
- }
-}
-
-func TestSetupDiClassNameFromGuidEx(t *testing.T) {
- deviceClassNetName, err := SetupDiClassNameFromGuidEx(&deviceClassNetGUID, "")
- if err != nil {
- t.Errorf("Error calling SetupDiClassNameFromGuidEx: %s", err.Error())
- } else if strings.ToLower(deviceClassNetName) != "net" {
- t.Errorf("SetupDiClassNameFromGuidEx(%x) should return \"Net\"", deviceClassNetGUID)
- }
-
- deviceClassNetName, err = SetupDiClassNameFromGuidEx(&deviceClassNetGUID, computerName)
- if err != nil {
- t.Errorf("Error calling SetupDiClassNameFromGuidEx: %s", err.Error())
- } else if strings.ToLower(deviceClassNetName) != "net" {
- t.Errorf("SetupDiClassNameFromGuidEx(%x) should return \"Net\"", deviceClassNetGUID)
- }
-
- _, err = SetupDiClassNameFromGuidEx(nil, "")
- if err != nil {
- if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_USER_BUFFER {
- t.Errorf("SetupDiClassNameFromGuidEx(nil) should fail with ERROR_INVALID_USER_BUFFER")
- }
- } else {
- t.Errorf("SetupDiClassNameFromGuidEx(nil) should fail")
- }
-}
-
-func TestSetupDiClassGuidsFromNameEx(t *testing.T) {
- ClassGUIDs, err := SetupDiClassGuidsFromNameEx("Net", "")
- if err != nil {
- t.Errorf("Error calling SetupDiClassGuidsFromNameEx: %s", err.Error())
- } else {
- found := false
- for i := range ClassGUIDs {
- if ClassGUIDs[i] == deviceClassNetGUID {
- found = true
- break
- }
- }
- if !found {
- t.Errorf("SetupDiClassGuidsFromNameEx(\"Net\") should return %x", deviceClassNetGUID)
- }
- }
-
- ClassGUIDs, err = SetupDiClassGuidsFromNameEx("foobar-34274a51-a6e6-45f0-80d6-c62be96dd5fe", computerName)
- if err != nil {
- t.Errorf("Error calling SetupDiClassGuidsFromNameEx: %s", err.Error())
- } else if len(ClassGUIDs) != 0 {
- t.Errorf("SetupDiClassGuidsFromNameEx(\"foobar-34274a51-a6e6-45f0-80d6-c62be96dd5fe\") should return an empty GUID set")
- }
-}
-
-func TestSetupDiGetSelectedDevice(t *testing.T) {
- devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "")
- if err != nil {
- t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error())
- }
- defer devInfoList.Close()
-
- for i := 0; true; i++ {
- data, err := devInfoList.EnumDeviceInfo(i)
- if err != nil {
- if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- err = devInfoList.SetSelectedDevice(data)
- if err != nil {
- t.Errorf("Error calling SetupDiSetSelectedDevice: %s", err.Error())
- }
-
- data2, err := devInfoList.SelectedDevice()
- if err != nil {
- t.Errorf("Error calling SetupDiGetSelectedDevice: %s", err.Error())
- } else if *data != *data2 {
- t.Error("SetupDiGetSelectedDevice returned different data than was set by SetupDiSetSelectedDevice")
- }
- }
-
- err = devInfoList.SetSelectedDevice(nil)
- if err != nil {
- if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_PARAMETER {
- t.Errorf("SetupDiSetSelectedDevice(nil) should fail with ERROR_INVALID_USER_BUFFER")
- }
- } else {
- t.Errorf("SetupDiSetSelectedDevice(nil) should fail")
- }
-}
-
-func TestUTF16ToBuf(t *testing.T) {
- buf := []uint16{0x0123, 0x4567, 0x89ab, 0xcdef}
- buf2 := utf16ToBuf(buf)
- if len(buf)*2 != len(buf2) ||
- cap(buf)*2 != cap(buf2) ||
- buf2[0] != 0x23 || buf2[1] != 0x01 ||
- buf2[2] != 0x67 || buf2[3] != 0x45 ||
- buf2[4] != 0xab || buf2[5] != 0x89 ||
- buf2[6] != 0xef || buf2[7] != 0xcd {
- t.Errorf("SetupDiSetSelectedDevice(nil) should fail with ERROR_INVALID_USER_BUFFER")
- }
- runtime.KeepAlive(buf)
-}
diff --git a/tun/wintun/setupapi/types_windows.go b/tun/wintun/setupapi/types_windows.go
deleted file mode 100644
index 136b4be..0000000
--- a/tun/wintun/setupapi/types_windows.go
+++ /dev/null
@@ -1,568 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package setupapi
-
-import (
- "strings"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-const (
- MAX_DEVICE_ID_LEN = 200
- MAX_DEVNODE_ID_LEN = MAX_DEVICE_ID_LEN
- MAX_GUID_STRING_LEN = 39 // 38 chars + terminator null
- MAX_CLASS_NAME_LEN = 32
- MAX_PROFILE_LEN = 80
- MAX_CONFIG_VALUE = 9999
- MAX_INSTANCE_VALUE = 9999
- CONFIGMG_VERSION = 0x0400
-)
-
-//
-// Define maximum string length constants
-//
-const (
- ANYSIZE_ARRAY = 1
- LINE_LEN = 256 // Windows 9x-compatible maximum for displayable strings coming from a device INF.
- MAX_INF_STRING_LENGTH = 4096 // Actual maximum size of an INF string (including string substitutions).
- MAX_INF_SECTION_NAME_LENGTH = 255 // For Windows 9x compatibility, INF section names should be constrained to 32 characters.
- MAX_TITLE_LEN = 60
- MAX_INSTRUCTION_LEN = 256
- MAX_LABEL_LEN = 30
- MAX_SERVICE_NAME_LEN = 256
- MAX_SUBTITLE_LEN = 256
-)
-
-const (
- // SP_MAX_MACHINENAME_LENGTH defines maximum length of a machine name in the format expected by ConfigMgr32 CM_Connect_Machine (i.e., "\\\\MachineName\0").
- SP_MAX_MACHINENAME_LENGTH = windows.MAX_PATH + 3
-)
-
-// HSPFILEQ is type for setup file queue
-type HSPFILEQ uintptr
-
-// DevInfo holds reference to device information set
-type DevInfo windows.Handle
-
-// DevInfoData is a device information structure (references a device instance that is a member of a device information set)
-type DevInfoData struct {
- size uint32
- ClassGUID windows.GUID
- DevInst uint32 // DEVINST handle
- _ uintptr
-}
-
-// DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supercedes the functionality of SetupDiGetDeviceInfoListClass).
-type DevInfoListDetailData struct {
- size uint32 // Warning: unsafe.Sizeof(DevInfoListDetailData) > sizeof(SP_DEVINFO_LIST_DETAIL_DATA) when GOARCH == 386 => use sizeofDevInfoListDetailData const.
- ClassGUID windows.GUID
- RemoteMachineHandle windows.Handle
- remoteMachineName [SP_MAX_MACHINENAME_LENGTH]uint16
-}
-
-func (data *DevInfoListDetailData) RemoteMachineName() string {
- return windows.UTF16ToString(data.remoteMachineName[:])
-}
-
-func (data *DevInfoListDetailData) SetRemoteMachineName(remoteMachineName string) error {
- str, err := windows.UTF16FromString(remoteMachineName)
- if err != nil {
- return err
- }
- copy(data.remoteMachineName[:], str)
- return nil
-}
-
-// DI_FUNCTION is function type for device installer
-type DI_FUNCTION uint32
-
-const (
- DIF_SELECTDEVICE DI_FUNCTION = 0x00000001
- DIF_INSTALLDEVICE DI_FUNCTION = 0x00000002
- DIF_ASSIGNRESOURCES DI_FUNCTION = 0x00000003
- DIF_PROPERTIES DI_FUNCTION = 0x00000004
- DIF_REMOVE DI_FUNCTION = 0x00000005
- DIF_FIRSTTIMESETUP DI_FUNCTION = 0x00000006
- DIF_FOUNDDEVICE DI_FUNCTION = 0x00000007
- DIF_SELECTCLASSDRIVERS DI_FUNCTION = 0x00000008
- DIF_VALIDATECLASSDRIVERS DI_FUNCTION = 0x00000009
- DIF_INSTALLCLASSDRIVERS DI_FUNCTION = 0x0000000A
- DIF_CALCDISKSPACE DI_FUNCTION = 0x0000000B
- DIF_DESTROYPRIVATEDATA DI_FUNCTION = 0x0000000C
- DIF_VALIDATEDRIVER DI_FUNCTION = 0x0000000D
- DIF_DETECT DI_FUNCTION = 0x0000000F
- DIF_INSTALLWIZARD DI_FUNCTION = 0x00000010
- DIF_DESTROYWIZARDDATA DI_FUNCTION = 0x00000011
- DIF_PROPERTYCHANGE DI_FUNCTION = 0x00000012
- DIF_ENABLECLASS DI_FUNCTION = 0x00000013
- DIF_DETECTVERIFY DI_FUNCTION = 0x00000014
- DIF_INSTALLDEVICEFILES DI_FUNCTION = 0x00000015
- DIF_UNREMOVE DI_FUNCTION = 0x00000016
- DIF_SELECTBESTCOMPATDRV DI_FUNCTION = 0x00000017
- DIF_ALLOW_INSTALL DI_FUNCTION = 0x00000018
- DIF_REGISTERDEVICE DI_FUNCTION = 0x00000019
- DIF_NEWDEVICEWIZARD_PRESELECT DI_FUNCTION = 0x0000001A
- DIF_NEWDEVICEWIZARD_SELECT DI_FUNCTION = 0x0000001B
- DIF_NEWDEVICEWIZARD_PREANALYZE DI_FUNCTION = 0x0000001C
- DIF_NEWDEVICEWIZARD_POSTANALYZE DI_FUNCTION = 0x0000001D
- DIF_NEWDEVICEWIZARD_FINISHINSTALL DI_FUNCTION = 0x0000001E
- DIF_INSTALLINTERFACES DI_FUNCTION = 0x00000020
- DIF_DETECTCANCEL DI_FUNCTION = 0x00000021
- DIF_REGISTER_COINSTALLERS DI_FUNCTION = 0x00000022
- DIF_ADDPROPERTYPAGE_ADVANCED DI_FUNCTION = 0x00000023
- DIF_ADDPROPERTYPAGE_BASIC DI_FUNCTION = 0x00000024
- DIF_TROUBLESHOOTER DI_FUNCTION = 0x00000026
- DIF_POWERMESSAGEWAKE DI_FUNCTION = 0x00000027
- DIF_ADDREMOTEPROPERTYPAGE_ADVANCED DI_FUNCTION = 0x00000028
- DIF_UPDATEDRIVER_UI DI_FUNCTION = 0x00000029
- DIF_FINISHINSTALL_ACTION DI_FUNCTION = 0x0000002A
-)
-
-// DevInstallParams is device installation parameters structure (associated with a particular device information element, or globally with a device information set)
-type DevInstallParams struct {
- size uint32
- Flags DI_FLAGS
- FlagsEx DI_FLAGSEX
- hwndParent uintptr
- InstallMsgHandler uintptr
- InstallMsgHandlerContext uintptr
- FileQueue HSPFILEQ
- _ uintptr
- _ uint32
- driverPath [windows.MAX_PATH]uint16
-}
-
-func (params *DevInstallParams) DriverPath() string {
- return windows.UTF16ToString(params.driverPath[:])
-}
-
-func (params *DevInstallParams) SetDriverPath(driverPath string) error {
- str, err := windows.UTF16FromString(driverPath)
- if err != nil {
- return err
- }
- copy(params.driverPath[:], str)
- return nil
-}
-
-// DI_FLAGS is SP_DEVINSTALL_PARAMS.Flags values
-type DI_FLAGS uint32
-
-const (
- // Flags for choosing a device
- DI_SHOWOEM DI_FLAGS = 0x00000001 // support Other... button
- DI_SHOWCOMPAT DI_FLAGS = 0x00000002 // show compatibility list
- DI_SHOWCLASS DI_FLAGS = 0x00000004 // show class list
- DI_SHOWALL DI_FLAGS = 0x00000007 // both class & compat list shown
- DI_NOVCP DI_FLAGS = 0x00000008 // don't create a new copy queue--use caller-supplied FileQueue
- DI_DIDCOMPAT DI_FLAGS = 0x00000010 // Searched for compatible devices
- DI_DIDCLASS DI_FLAGS = 0x00000020 // Searched for class devices
- DI_AUTOASSIGNRES DI_FLAGS = 0x00000040 // No UI for resources if possible
-
- // Flags returned by DiInstallDevice to indicate need to reboot/restart
- DI_NEEDRESTART DI_FLAGS = 0x00000080 // Reboot required to take effect
- DI_NEEDREBOOT DI_FLAGS = 0x00000100 // ""
-
- // Flags for device installation
- DI_NOBROWSE DI_FLAGS = 0x00000200 // no Browse... in InsertDisk
-
- // Flags set by DiBuildDriverInfoList
- DI_MULTMFGS DI_FLAGS = 0x00000400 // Set if multiple manufacturers in class driver list
-
- // Flag indicates that device is disabled
- DI_DISABLED DI_FLAGS = 0x00000800 // Set if device disabled
-
- // Flags for Device/Class Properties
- DI_GENERALPAGE_ADDED DI_FLAGS = 0x00001000
- DI_RESOURCEPAGE_ADDED DI_FLAGS = 0x00002000
-
- // Flag to indicate the setting properties for this Device (or class) caused a change so the Dev Mgr UI probably needs to be updated.
- DI_PROPERTIES_CHANGE DI_FLAGS = 0x00004000
-
- // Flag to indicate that the sorting from the INF file should be used.
- DI_INF_IS_SORTED DI_FLAGS = 0x00008000
-
- // Flag to indicate that only the the INF specified by SP_DEVINSTALL_PARAMS.DriverPath should be searched.
- DI_ENUMSINGLEINF DI_FLAGS = 0x00010000
-
- // Flag that prevents ConfigMgr from removing/re-enumerating devices during device
- // registration, installation, and deletion.
- DI_DONOTCALLCONFIGMG DI_FLAGS = 0x00020000
-
- // The following flag can be used to install a device disabled
- DI_INSTALLDISABLED DI_FLAGS = 0x00040000
-
- // Flag that causes SetupDiBuildDriverInfoList to build a device's compatible driver
- // list from its existing class driver list, instead of the normal INF search.
- DI_COMPAT_FROM_CLASS DI_FLAGS = 0x00080000
-
- // This flag is set if the Class Install params should be used.
- DI_CLASSINSTALLPARAMS DI_FLAGS = 0x00100000
-
- // This flag is set if the caller of DiCallClassInstaller does NOT want the internal default action performed if the Class installer returns ERROR_DI_DO_DEFAULT.
- DI_NODI_DEFAULTACTION DI_FLAGS = 0x00200000
-
- // Flags for device installation
- DI_QUIETINSTALL DI_FLAGS = 0x00800000 // don't confuse the user with questions or excess info
- DI_NOFILECOPY DI_FLAGS = 0x01000000 // No file Copy necessary
- DI_FORCECOPY DI_FLAGS = 0x02000000 // Force files to be copied from install path
- DI_DRIVERPAGE_ADDED DI_FLAGS = 0x04000000 // Prop provider added Driver page.
- DI_USECI_SELECTSTRINGS DI_FLAGS = 0x08000000 // Use Class Installer Provided strings in the Select Device Dlg
- DI_OVERRIDE_INFFLAGS DI_FLAGS = 0x10000000 // Override INF flags
- DI_PROPS_NOCHANGEUSAGE DI_FLAGS = 0x20000000 // No Enable/Disable in General Props
-
- DI_NOSELECTICONS DI_FLAGS = 0x40000000 // No small icons in select device dialogs
-
- DI_NOWRITE_IDS DI_FLAGS = 0x80000000 // Don't write HW & Compat IDs on install
-)
-
-// DI_FLAGSEX is SP_DEVINSTALL_PARAMS.FlagsEx values
-type DI_FLAGSEX uint32
-
-const (
- DI_FLAGSEX_CI_FAILED DI_FLAGSEX = 0x00000004 // Failed to Load/Call class installer
- DI_FLAGSEX_FINISHINSTALL_ACTION DI_FLAGSEX = 0x00000008 // Class/co-installer wants to get a DIF_FINISH_INSTALL action in client context.
- DI_FLAGSEX_DIDINFOLIST DI_FLAGSEX = 0x00000010 // Did the Class Info List
- DI_FLAGSEX_DIDCOMPATINFO DI_FLAGSEX = 0x00000020 // Did the Compat Info List
- DI_FLAGSEX_FILTERCLASSES DI_FLAGSEX = 0x00000040
- DI_FLAGSEX_SETFAILEDINSTALL DI_FLAGSEX = 0x00000080
- DI_FLAGSEX_DEVICECHANGE DI_FLAGSEX = 0x00000100
- DI_FLAGSEX_ALWAYSWRITEIDS DI_FLAGSEX = 0x00000200
- DI_FLAGSEX_PROPCHANGE_PENDING DI_FLAGSEX = 0x00000400 // One or more device property sheets have had changes made to them, and need to have a DIF_PROPERTYCHANGE occur.
- DI_FLAGSEX_ALLOWEXCLUDEDDRVS DI_FLAGSEX = 0x00000800
- DI_FLAGSEX_NOUIONQUERYREMOVE DI_FLAGSEX = 0x00001000
- DI_FLAGSEX_USECLASSFORCOMPAT DI_FLAGSEX = 0x00002000 // Use the device's class when building compat drv list. (Ignored if DI_COMPAT_FROM_CLASS flag is specified.)
- DI_FLAGSEX_NO_DRVREG_MODIFY DI_FLAGSEX = 0x00008000 // Don't run AddReg and DelReg for device's software (driver) key.
- DI_FLAGSEX_IN_SYSTEM_SETUP DI_FLAGSEX = 0x00010000 // Installation is occurring during initial system setup.
- DI_FLAGSEX_INET_DRIVER DI_FLAGSEX = 0x00020000 // Driver came from Windows Update
- DI_FLAGSEX_APPENDDRIVERLIST DI_FLAGSEX = 0x00040000 // Cause SetupDiBuildDriverInfoList to append a new driver list to an existing list.
- DI_FLAGSEX_PREINSTALLBACKUP DI_FLAGSEX = 0x00080000 // not used
- DI_FLAGSEX_BACKUPONREPLACE DI_FLAGSEX = 0x00100000 // not used
- DI_FLAGSEX_DRIVERLIST_FROM_URL DI_FLAGSEX = 0x00200000 // build driver list from INF(s) retrieved from URL specified in SP_DEVINSTALL_PARAMS.DriverPath (empty string means Windows Update website)
- DI_FLAGSEX_EXCLUDE_OLD_INET_DRIVERS DI_FLAGSEX = 0x00800000 // Don't include old Internet drivers when building a driver list. Ignored on Windows Vista and later.
- DI_FLAGSEX_POWERPAGE_ADDED DI_FLAGSEX = 0x01000000 // class installer added their own power page
- DI_FLAGSEX_FILTERSIMILARDRIVERS DI_FLAGSEX = 0x02000000 // only include similar drivers in class list
- DI_FLAGSEX_INSTALLEDDRIVER DI_FLAGSEX = 0x04000000 // only add the installed driver to the class or compat driver list. Used in calls to SetupDiBuildDriverInfoList
- DI_FLAGSEX_NO_CLASSLIST_NODE_MERGE DI_FLAGSEX = 0x08000000 // Don't remove identical driver nodes from the class list
- DI_FLAGSEX_ALTPLATFORM_DRVSEARCH DI_FLAGSEX = 0x10000000 // Build driver list based on alternate platform information specified in associated file queue
- DI_FLAGSEX_RESTART_DEVICE_ONLY DI_FLAGSEX = 0x20000000 // only restart the device drivers are being installed on as opposed to restarting all devices using those drivers.
- DI_FLAGSEX_RECURSIVESEARCH DI_FLAGSEX = 0x40000000 // Tell SetupDiBuildDriverInfoList to do a recursive search
- DI_FLAGSEX_SEARCH_PUBLISHED_INFS DI_FLAGSEX = 0x80000000 // Tell SetupDiBuildDriverInfoList to do a "published INF" search
-)
-
-// ClassInstallHeader is the first member of any class install parameters structure. It contains the device installation request code that defines the format of the rest of the install parameters structure.
-type ClassInstallHeader struct {
- size uint32
- InstallFunction DI_FUNCTION
-}
-
-func MakeClassInstallHeader(installFunction DI_FUNCTION) *ClassInstallHeader {
- hdr := &ClassInstallHeader{InstallFunction: installFunction}
- hdr.size = uint32(unsafe.Sizeof(*hdr))
- return hdr
-}
-
-// DICS_STATE specifies values indicating a change in a device's state
-type DICS_STATE uint32
-
-const (
- DICS_ENABLE DICS_STATE = 0x00000001 // The device is being enabled.
- DICS_DISABLE DICS_STATE = 0x00000002 // The device is being disabled.
- DICS_PROPCHANGE DICS_STATE = 0x00000003 // The properties of the device have changed.
- DICS_START DICS_STATE = 0x00000004 // The device is being started (if the request is for the currently active hardware profile).
- DICS_STOP DICS_STATE = 0x00000005 // The device is being stopped. The driver stack will be unloaded and the CSCONFIGFLAG_DO_NOT_START flag will be set for the device.
-)
-
-// DICS_FLAG specifies the scope of a device property change
-type DICS_FLAG uint32
-
-const (
- DICS_FLAG_GLOBAL DICS_FLAG = 0x00000001 // make change in all hardware profiles
- DICS_FLAG_CONFIGSPECIFIC DICS_FLAG = 0x00000002 // make change in specified profile only
- DICS_FLAG_CONFIGGENERAL DICS_FLAG = 0x00000004 // 1 or more hardware profile-specific changes to follow (obsolete)
-)
-
-// PropChangeParams is a structure corresponding to a DIF_PROPERTYCHANGE install function.
-type PropChangeParams struct {
- ClassInstallHeader ClassInstallHeader
- StateChange DICS_STATE
- Scope DICS_FLAG
- HwProfile uint32
-}
-
-// DI_REMOVEDEVICE specifies the scope of the device removal
-type DI_REMOVEDEVICE uint32
-
-const (
- DI_REMOVEDEVICE_GLOBAL DI_REMOVEDEVICE = 0x00000001 // Make this change in all hardware profiles. Remove information about the device from the registry.
- DI_REMOVEDEVICE_CONFIGSPECIFIC DI_REMOVEDEVICE = 0x00000002 // Make this change to only the hardware profile specified by HwProfile. this flag only applies to root-enumerated devices. When Windows removes the device from the last hardware profile in which it was configured, Windows performs a global removal.
-)
-
-// RemoveDeviceParams is a structure corresponding to a DIF_REMOVE install function.
-type RemoveDeviceParams struct {
- ClassInstallHeader ClassInstallHeader
- Scope DI_REMOVEDEVICE
- HwProfile uint32
-}
-
-// DrvInfoData is driver information structure (member of a driver info list that may be associated with a particular device instance, or (globally) with a device information set)
-type DrvInfoData struct {
- size uint32
- DriverType uint32
- _ uintptr
- description [LINE_LEN]uint16
- mfgName [LINE_LEN]uint16
- providerName [LINE_LEN]uint16
- DriverDate windows.Filetime
- DriverVersion uint64
-}
-
-func (data *DrvInfoData) Description() string {
- return windows.UTF16ToString(data.description[:])
-}
-
-func (data *DrvInfoData) SetDescription(description string) error {
- str, err := windows.UTF16FromString(description)
- if err != nil {
- return err
- }
- copy(data.description[:], str)
- return nil
-}
-
-func (data *DrvInfoData) MfgName() string {
- return windows.UTF16ToString(data.mfgName[:])
-}
-
-func (data *DrvInfoData) SetMfgName(mfgName string) error {
- str, err := windows.UTF16FromString(mfgName)
- if err != nil {
- return err
- }
- copy(data.mfgName[:], str)
- return nil
-}
-
-func (data *DrvInfoData) ProviderName() string {
- return windows.UTF16ToString(data.providerName[:])
-}
-
-func (data *DrvInfoData) SetProviderName(providerName string) error {
- str, err := windows.UTF16FromString(providerName)
- if err != nil {
- return err
- }
- copy(data.providerName[:], str)
- return nil
-}
-
-// IsNewer method returns true if DrvInfoData date and version is newer than supplied parameters.
-func (data *DrvInfoData) IsNewer(driverDate windows.Filetime, driverVersion uint64) bool {
- if data.DriverDate.HighDateTime > driverDate.HighDateTime {
- return true
- }
- if data.DriverDate.HighDateTime < driverDate.HighDateTime {
- return false
- }
-
- if data.DriverDate.LowDateTime > driverDate.LowDateTime {
- return true
- }
- if data.DriverDate.LowDateTime < driverDate.LowDateTime {
- return false
- }
-
- if data.DriverVersion > driverVersion {
- return true
- }
- if data.DriverVersion < driverVersion {
- return false
- }
-
- return false
-}
-
-// DrvInfoDetailData is driver information details structure (provides detailed information about a particular driver information structure)
-type DrvInfoDetailData struct {
- size uint32 // Warning: unsafe.Sizeof(DrvInfoDetailData) > sizeof(SP_DRVINFO_DETAIL_DATA) when GOARCH == 386 => use sizeofDrvInfoDetailData const.
- InfDate windows.Filetime
- compatIDsOffset uint32
- compatIDsLength uint32
- _ uintptr
- sectionName [LINE_LEN]uint16
- infFileName [windows.MAX_PATH]uint16
- drvDescription [LINE_LEN]uint16
- hardwareID [ANYSIZE_ARRAY]uint16
-}
-
-func (data *DrvInfoDetailData) SectionName() string {
- return windows.UTF16ToString(data.sectionName[:])
-}
-
-func (data *DrvInfoDetailData) InfFileName() string {
- return windows.UTF16ToString(data.infFileName[:])
-}
-
-func (data *DrvInfoDetailData) DrvDescription() string {
- return windows.UTF16ToString(data.drvDescription[:])
-}
-
-func (data *DrvInfoDetailData) HardwareID() string {
- if data.compatIDsOffset > 1 {
- bufW := data.getBuf()
- return windows.UTF16ToString(bufW[:wcslen(bufW)])
- }
-
- return ""
-}
-
-func (data *DrvInfoDetailData) CompatIDs() []string {
- a := make([]string, 0)
-
- if data.compatIDsLength > 0 {
- bufW := data.getBuf()
- bufW = bufW[data.compatIDsOffset : data.compatIDsOffset+data.compatIDsLength]
- for i := 0; i < len(bufW); {
- j := i + wcslen(bufW[i:])
- if i < j {
- a = append(a, windows.UTF16ToString(bufW[i:j]))
- }
- i = j + 1
- }
- }
-
- return a
-}
-
-func (data *DrvInfoDetailData) getBuf() []uint16 {
- len := (data.size - uint32(unsafe.Offsetof(data.hardwareID))) / 2
- sl := struct {
- addr *uint16
- len int
- cap int
- }{&data.hardwareID[0], int(len), int(len)}
- return *(*[]uint16)(unsafe.Pointer(&sl))
-}
-
-// IsCompatible method tests if given hardware ID matches the driver or is listed on the compatible ID list.
-func (data *DrvInfoDetailData) IsCompatible(hwid string) bool {
- hwidLC := strings.ToLower(hwid)
- if strings.ToLower(data.HardwareID()) == hwidLC {
- return true
- }
- a := data.CompatIDs()
- for i := range a {
- if strings.ToLower(a[i]) == hwidLC {
- return true
- }
- }
-
- return false
-}
-
-// DICD flags control SetupDiCreateDeviceInfo
-type DICD uint32
-
-const (
- DICD_GENERATE_ID DICD = 0x00000001
- DICD_INHERIT_CLASSDRVS DICD = 0x00000002
-)
-
-//
-// SPDIT flags to distinguish between class drivers and
-// device drivers.
-// (Passed in 'DriverType' parameter of driver information list APIs)
-//
-type SPDIT uint32
-
-const (
- SPDIT_NODRIVER SPDIT = 0x00000000
- SPDIT_CLASSDRIVER SPDIT = 0x00000001
- SPDIT_COMPATDRIVER SPDIT = 0x00000002
-)
-
-// DIGCF flags control what is included in the device information set built by SetupDiGetClassDevs
-type DIGCF uint32
-
-const (
- DIGCF_DEFAULT DIGCF = 0x00000001 // only valid with DIGCF_DEVICEINTERFACE
- DIGCF_PRESENT DIGCF = 0x00000002
- DIGCF_ALLCLASSES DIGCF = 0x00000004
- DIGCF_PROFILE DIGCF = 0x00000008
- DIGCF_DEVICEINTERFACE DIGCF = 0x00000010
-)
-
-// DIREG specifies values for SetupDiCreateDevRegKey, SetupDiOpenDevRegKey, and SetupDiDeleteDevRegKey.
-type DIREG uint32
-
-const (
- DIREG_DEV DIREG = 0x00000001 // Open/Create/Delete device key
- DIREG_DRV DIREG = 0x00000002 // Open/Create/Delete driver key
- DIREG_BOTH DIREG = 0x00000004 // Delete both driver and Device key
-)
-
-//
-// SPDRP specifies device registry property codes
-// (Codes marked as read-only (R) may only be used for
-// SetupDiGetDeviceRegistryProperty)
-//
-// These values should cover the same set of registry properties
-// as defined by the CM_DRP codes in cfgmgr32.h.
-//
-// Note that SPDRP codes are zero based while CM_DRP codes are one based!
-//
-type SPDRP uint32
-
-const (
- SPDRP_DEVICEDESC SPDRP = 0x00000000 // DeviceDesc (R/W)
- SPDRP_HARDWAREID SPDRP = 0x00000001 // HardwareID (R/W)
- SPDRP_COMPATIBLEIDS SPDRP = 0x00000002 // CompatibleIDs (R/W)
- SPDRP_SERVICE SPDRP = 0x00000004 // Service (R/W)
- SPDRP_CLASS SPDRP = 0x00000007 // Class (R--tied to ClassGUID)
- SPDRP_CLASSGUID SPDRP = 0x00000008 // ClassGUID (R/W)
- SPDRP_DRIVER SPDRP = 0x00000009 // Driver (R/W)
- SPDRP_CONFIGFLAGS SPDRP = 0x0000000A // ConfigFlags (R/W)
- SPDRP_MFG SPDRP = 0x0000000B // Mfg (R/W)
- SPDRP_FRIENDLYNAME SPDRP = 0x0000000C // FriendlyName (R/W)
- SPDRP_LOCATION_INFORMATION SPDRP = 0x0000000D // LocationInformation (R/W)
- SPDRP_PHYSICAL_DEVICE_OBJECT_NAME SPDRP = 0x0000000E // PhysicalDeviceObjectName (R)
- SPDRP_CAPABILITIES SPDRP = 0x0000000F // Capabilities (R)
- SPDRP_UI_NUMBER SPDRP = 0x00000010 // UiNumber (R)
- SPDRP_UPPERFILTERS SPDRP = 0x00000011 // UpperFilters (R/W)
- SPDRP_LOWERFILTERS SPDRP = 0x00000012 // LowerFilters (R/W)
- SPDRP_BUSTYPEGUID SPDRP = 0x00000013 // BusTypeGUID (R)
- SPDRP_LEGACYBUSTYPE SPDRP = 0x00000014 // LegacyBusType (R)
- SPDRP_BUSNUMBER SPDRP = 0x00000015 // BusNumber (R)
- SPDRP_ENUMERATOR_NAME SPDRP = 0x00000016 // Enumerator Name (R)
- SPDRP_SECURITY SPDRP = 0x00000017 // Security (R/W, binary form)
- SPDRP_SECURITY_SDS SPDRP = 0x00000018 // Security (W, SDS form)
- SPDRP_DEVTYPE SPDRP = 0x00000019 // Device Type (R/W)
- SPDRP_EXCLUSIVE SPDRP = 0x0000001A // Device is exclusive-access (R/W)
- SPDRP_CHARACTERISTICS SPDRP = 0x0000001B // Device Characteristics (R/W)
- SPDRP_ADDRESS SPDRP = 0x0000001C // Device Address (R)
- SPDRP_UI_NUMBER_DESC_FORMAT SPDRP = 0x0000001D // UiNumberDescFormat (R/W)
- SPDRP_DEVICE_POWER_DATA SPDRP = 0x0000001E // Device Power Data (R)
- SPDRP_REMOVAL_POLICY SPDRP = 0x0000001F // Removal Policy (R)
- SPDRP_REMOVAL_POLICY_HW_DEFAULT SPDRP = 0x00000020 // Hardware Removal Policy (R)
- SPDRP_REMOVAL_POLICY_OVERRIDE SPDRP = 0x00000021 // Removal Policy Override (RW)
- SPDRP_INSTALL_STATE SPDRP = 0x00000022 // Device Install State (R)
- SPDRP_LOCATION_PATHS SPDRP = 0x00000023 // Device Location Paths (R)
- SPDRP_BASE_CONTAINERID SPDRP = 0x00000024 // Base ContainerID (R)
-
- SPDRP_MAXIMUM_PROPERTY SPDRP = 0x00000025 // Upper bound on ordinals
-)
-
-const (
- CR_SUCCESS = 0x0
- CR_BUFFER_SMALL = 0x1a
-)
-
-const (
- CM_GET_DEVICE_INTERFACE_LIST_PRESENT = 0 // only currently 'live' device interfaces
- CM_GET_DEVICE_INTERFACE_LIST_ALL_DEVICES = 1 // all registered device interfaces, live or not
-)
diff --git a/tun/wintun/setupapi/types_windows_386.go b/tun/wintun/setupapi/types_windows_386.go
deleted file mode 100644
index 132f921..0000000
--- a/tun/wintun/setupapi/types_windows_386.go
+++ /dev/null
@@ -1,11 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package setupapi
-
-const (
- sizeofDevInfoListDetailData uint32 = 550
- sizeofDrvInfoDetailData uint32 = 1570
-)
diff --git a/tun/wintun/setupapi/types_windows_amd64.go b/tun/wintun/setupapi/types_windows_amd64.go
deleted file mode 100644
index d4dd65c..0000000
--- a/tun/wintun/setupapi/types_windows_amd64.go
+++ /dev/null
@@ -1,11 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package setupapi
-
-const (
- sizeofDevInfoListDetailData uint32 = 560
- sizeofDrvInfoDetailData uint32 = 1584
-)
diff --git a/tun/wintun/setupapi/zsetupapi_windows.go b/tun/wintun/setupapi/zsetupapi_windows.go
deleted file mode 100644
index 375862d..0000000
--- a/tun/wintun/setupapi/zsetupapi_windows.go
+++ /dev/null
@@ -1,398 +0,0 @@
-// Code generated by 'go generate'; DO NOT EDIT.
-
-package setupapi
-
-import (
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-var _ unsafe.Pointer
-
-// Do the interface allocations only once for common
-// Errno values.
-const (
- errnoERROR_IO_PENDING = 997
-)
-
-var (
- errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
-)
-
-// errnoErr returns common boxed Errno values, to prevent
-// allocations at runtime.
-func errnoErr(e syscall.Errno) error {
- switch e {
- case 0:
- return nil
- case errnoERROR_IO_PENDING:
- return errERROR_IO_PENDING
- }
- // TODO: add more here, after collecting data on the common
- // error values see on Windows. (perhaps when running
- // all.bat?)
- return e
-}
-
-var (
- modsetupapi = windows.NewLazySystemDLL("setupapi.dll")
- modCfgMgr32 = windows.NewLazySystemDLL("CfgMgr32.dll")
-
- procSetupDiCreateDeviceInfoListExW = modsetupapi.NewProc("SetupDiCreateDeviceInfoListExW")
- procSetupDiGetDeviceInfoListDetailW = modsetupapi.NewProc("SetupDiGetDeviceInfoListDetailW")
- procSetupDiCreateDeviceInfoW = modsetupapi.NewProc("SetupDiCreateDeviceInfoW")
- procSetupDiEnumDeviceInfo = modsetupapi.NewProc("SetupDiEnumDeviceInfo")
- procSetupDiDestroyDeviceInfoList = modsetupapi.NewProc("SetupDiDestroyDeviceInfoList")
- procSetupDiBuildDriverInfoList = modsetupapi.NewProc("SetupDiBuildDriverInfoList")
- procSetupDiCancelDriverInfoSearch = modsetupapi.NewProc("SetupDiCancelDriverInfoSearch")
- procSetupDiEnumDriverInfoW = modsetupapi.NewProc("SetupDiEnumDriverInfoW")
- procSetupDiGetSelectedDriverW = modsetupapi.NewProc("SetupDiGetSelectedDriverW")
- procSetupDiSetSelectedDriverW = modsetupapi.NewProc("SetupDiSetSelectedDriverW")
- procSetupDiGetDriverInfoDetailW = modsetupapi.NewProc("SetupDiGetDriverInfoDetailW")
- procSetupDiDestroyDriverInfoList = modsetupapi.NewProc("SetupDiDestroyDriverInfoList")
- procSetupDiGetClassDevsExW = modsetupapi.NewProc("SetupDiGetClassDevsExW")
- procSetupDiCallClassInstaller = modsetupapi.NewProc("SetupDiCallClassInstaller")
- procSetupDiOpenDevRegKey = modsetupapi.NewProc("SetupDiOpenDevRegKey")
- procSetupDiGetDeviceRegistryPropertyW = modsetupapi.NewProc("SetupDiGetDeviceRegistryPropertyW")
- procSetupDiSetDeviceRegistryPropertyW = modsetupapi.NewProc("SetupDiSetDeviceRegistryPropertyW")
- procSetupDiGetDeviceInstallParamsW = modsetupapi.NewProc("SetupDiGetDeviceInstallParamsW")
- procSetupDiGetDeviceInstanceIdW = modsetupapi.NewProc("SetupDiGetDeviceInstanceIdW")
- procSetupDiGetClassInstallParamsW = modsetupapi.NewProc("SetupDiGetClassInstallParamsW")
- procSetupDiSetDeviceInstallParamsW = modsetupapi.NewProc("SetupDiSetDeviceInstallParamsW")
- procSetupDiSetClassInstallParamsW = modsetupapi.NewProc("SetupDiSetClassInstallParamsW")
- procSetupDiClassNameFromGuidExW = modsetupapi.NewProc("SetupDiClassNameFromGuidExW")
- procSetupDiClassGuidsFromNameExW = modsetupapi.NewProc("SetupDiClassGuidsFromNameExW")
- procSetupDiGetSelectedDevice = modsetupapi.NewProc("SetupDiGetSelectedDevice")
- procSetupDiSetSelectedDevice = modsetupapi.NewProc("SetupDiSetSelectedDevice")
- procCM_Get_Device_Interface_List_SizeW = modCfgMgr32.NewProc("CM_Get_Device_Interface_List_SizeW")
- procCM_Get_Device_Interface_ListW = modCfgMgr32.NewProc("CM_Get_Device_Interface_ListW")
-)
-
-func setupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName *uint16, reserved uintptr) (handle DevInfo, err error) {
- r0, _, e1 := syscall.Syscall6(procSetupDiCreateDeviceInfoListExW.Addr(), 4, uintptr(unsafe.Pointer(classGUID)), uintptr(hwndParent), uintptr(unsafe.Pointer(machineName)), uintptr(reserved), 0, 0)
- handle = DevInfo(r0)
- if handle == DevInfo(windows.InvalidHandle) {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiGetDeviceInfoListDetail(deviceInfoSet DevInfo, deviceInfoSetDetailData *DevInfoListDetailData) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiGetDeviceInfoListDetailW.Addr(), 2, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoSetDetailData)), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiCreateDeviceInfo(deviceInfoSet DevInfo, DeviceName *uint16, classGUID *windows.GUID, DeviceDescription *uint16, hwndParent uintptr, CreationFlags DICD, deviceInfoData *DevInfoData) (err error) {
- r1, _, e1 := syscall.Syscall9(procSetupDiCreateDeviceInfoW.Addr(), 7, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(DeviceName)), uintptr(unsafe.Pointer(classGUID)), uintptr(unsafe.Pointer(DeviceDescription)), uintptr(hwndParent), uintptr(CreationFlags), uintptr(unsafe.Pointer(deviceInfoData)), 0, 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiEnumDeviceInfo(deviceInfoSet DevInfo, memberIndex uint32, deviceInfoData *DevInfoData) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiEnumDeviceInfo.Addr(), 3, uintptr(deviceInfoSet), uintptr(memberIndex), uintptr(unsafe.Pointer(deviceInfoData)))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func SetupDiDestroyDeviceInfoList(deviceInfoSet DevInfo) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiDestroyDeviceInfoList.Addr(), 1, uintptr(deviceInfoSet), 0, 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func SetupDiBuildDriverInfoList(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiBuildDriverInfoList.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(driverType))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func SetupDiCancelDriverInfoSearch(deviceInfoSet DevInfo) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiCancelDriverInfoSearch.Addr(), 1, uintptr(deviceInfoSet), 0, 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiEnumDriverInfo(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT, memberIndex uint32, driverInfoData *DrvInfoData) (err error) {
- r1, _, e1 := syscall.Syscall6(procSetupDiEnumDriverInfoW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(driverType), uintptr(memberIndex), uintptr(unsafe.Pointer(driverInfoData)), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiGetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiGetSelectedDriverW.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(driverInfoData)))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func SetupDiSetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiSetSelectedDriverW.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(driverInfoData)))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiGetDriverInfoDetail(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData, driverInfoDetailData *DrvInfoDetailData, driverInfoDetailDataSize uint32, requiredSize *uint32) (err error) {
- r1, _, e1 := syscall.Syscall6(procSetupDiGetDriverInfoDetailW.Addr(), 6, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(driverInfoData)), uintptr(unsafe.Pointer(driverInfoDetailData)), uintptr(driverInfoDetailDataSize), uintptr(unsafe.Pointer(requiredSize)))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func SetupDiDestroyDriverInfoList(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiDestroyDriverInfoList.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(driverType))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiGetClassDevsEx(classGUID *windows.GUID, Enumerator *uint16, hwndParent uintptr, Flags DIGCF, deviceInfoSet DevInfo, machineName *uint16, reserved uintptr) (handle DevInfo, err error) {
- r0, _, e1 := syscall.Syscall9(procSetupDiGetClassDevsExW.Addr(), 7, uintptr(unsafe.Pointer(classGUID)), uintptr(unsafe.Pointer(Enumerator)), uintptr(hwndParent), uintptr(Flags), uintptr(deviceInfoSet), uintptr(unsafe.Pointer(machineName)), uintptr(reserved), 0, 0)
- handle = DevInfo(r0)
- if handle == DevInfo(windows.InvalidHandle) {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func SetupDiCallClassInstaller(installFunction DI_FUNCTION, deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiCallClassInstaller.Addr(), 3, uintptr(installFunction), uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiOpenDevRegKey(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, Scope DICS_FLAG, HwProfile uint32, KeyType DIREG, samDesired uint32) (key windows.Handle, err error) {
- r0, _, e1 := syscall.Syscall6(procSetupDiOpenDevRegKey.Addr(), 6, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(Scope), uintptr(HwProfile), uintptr(KeyType), uintptr(samDesired))
- key = windows.Handle(r0)
- if key == windows.InvalidHandle {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiGetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyRegDataType *uint32, propertyBuffer *byte, propertyBufferSize uint32, requiredSize *uint32) (err error) {
- r1, _, e1 := syscall.Syscall9(procSetupDiGetDeviceRegistryPropertyW.Addr(), 7, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(property), uintptr(unsafe.Pointer(propertyRegDataType)), uintptr(unsafe.Pointer(propertyBuffer)), uintptr(propertyBufferSize), uintptr(unsafe.Pointer(requiredSize)), 0, 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiSetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyBuffer *byte, propertyBufferSize uint32) (err error) {
- r1, _, e1 := syscall.Syscall6(procSetupDiSetDeviceRegistryPropertyW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(property), uintptr(unsafe.Pointer(propertyBuffer)), uintptr(propertyBufferSize), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiGetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiGetDeviceInstallParamsW.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(deviceInstallParams)))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiGetDeviceInstanceId(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, instanceId *uint16, instanceIdSize uint32, instanceIdRequiredSize *uint32) (err error) {
- r1, _, e1 := syscall.Syscall6(procSetupDiGetDeviceInstanceIdW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(instanceId)), uintptr(instanceIdSize), uintptr(unsafe.Pointer(instanceIdRequiredSize)), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func SetupDiGetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) (err error) {
- r1, _, e1 := syscall.Syscall6(procSetupDiGetClassInstallParamsW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(classInstallParams)), uintptr(classInstallParamsSize), uintptr(unsafe.Pointer(requiredSize)), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func SetupDiSetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiSetDeviceInstallParamsW.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(deviceInstallParams)))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func SetupDiSetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32) (err error) {
- r1, _, e1 := syscall.Syscall6(procSetupDiSetClassInstallParamsW.Addr(), 4, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(classInstallParams)), uintptr(classInstallParamsSize), 0, 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiClassNameFromGuidEx(classGUID *windows.GUID, className *uint16, classNameSize uint32, requiredSize *uint32, machineName *uint16, reserved uintptr) (err error) {
- r1, _, e1 := syscall.Syscall6(procSetupDiClassNameFromGuidExW.Addr(), 6, uintptr(unsafe.Pointer(classGUID)), uintptr(unsafe.Pointer(className)), uintptr(classNameSize), uintptr(unsafe.Pointer(requiredSize)), uintptr(unsafe.Pointer(machineName)), uintptr(reserved))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiClassGuidsFromNameEx(className *uint16, classGuidList *windows.GUID, classGuidListSize uint32, requiredSize *uint32, machineName *uint16, reserved uintptr) (err error) {
- r1, _, e1 := syscall.Syscall6(procSetupDiClassGuidsFromNameExW.Addr(), 6, uintptr(unsafe.Pointer(className)), uintptr(unsafe.Pointer(classGuidList)), uintptr(classGuidListSize), uintptr(unsafe.Pointer(requiredSize)), uintptr(unsafe.Pointer(machineName)), uintptr(reserved))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiGetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiGetSelectedDevice.Addr(), 2, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func SetupDiSetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiSetSelectedDevice.Addr(), 2, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func cm_Get_Device_Interface_List_Size(len *uint32, interfaceClass *windows.GUID, deviceID *uint16, flags uint32) (ret uint32) {
- r0, _, _ := syscall.Syscall6(procCM_Get_Device_Interface_List_SizeW.Addr(), 4, uintptr(unsafe.Pointer(len)), uintptr(unsafe.Pointer(interfaceClass)), uintptr(unsafe.Pointer(deviceID)), uintptr(flags), 0, 0)
- ret = uint32(r0)
- return
-}
-
-func cm_Get_Device_Interface_List(interfaceClass *windows.GUID, deviceID *uint16, buffer *uint16, bufferLen uint32, flags uint32) (ret uint32) {
- r0, _, _ := syscall.Syscall6(procCM_Get_Device_Interface_ListW.Addr(), 5, uintptr(unsafe.Pointer(interfaceClass)), uintptr(unsafe.Pointer(deviceID)), uintptr(unsafe.Pointer(buffer)), uintptr(bufferLen), uintptr(flags), 0)
- ret = uint32(r0)
- return
-}
diff --git a/tun/wintun/setupapi/zsetupapi_windows_test.go b/tun/wintun/setupapi/zsetupapi_windows_test.go
deleted file mode 100644
index 915b427..0000000
--- a/tun/wintun/setupapi/zsetupapi_windows_test.go
+++ /dev/null
@@ -1,20 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package setupapi
-
-import (
- "syscall"
- "testing"
-
- "golang.org/x/sys/windows"
-)
-
-func TestSetupDiDestroyDeviceInfoList(t *testing.T) {
- err := SetupDiDestroyDeviceInfoList(DevInfo(windows.InvalidHandle))
- if errWin, ok := err.(syscall.Errno); !ok || errWin != windows.ERROR_INVALID_HANDLE {
- t.Errorf("SetupDiDestroyDeviceInfoList(nil, ...) should fail with ERROR_INVALID_HANDLE")
- }
-}
diff --git a/tun/wintun/wintun_windows.go b/tun/wintun/wintun_windows.go
deleted file mode 100644
index 4c12d97..0000000
--- a/tun/wintun/wintun_windows.go
+++ /dev/null
@@ -1,803 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package wintun
-
-import (
- "errors"
- "fmt"
- "strings"
- "time"
- "unsafe"
-
- "golang.org/x/sys/windows"
- "golang.org/x/sys/windows/registry"
-
- "golang.zx2c4.com/wireguard/tun/wintun/iphlpapi"
- "golang.zx2c4.com/wireguard/tun/wintun/nci"
- registryEx "golang.zx2c4.com/wireguard/tun/wintun/registry"
- "golang.zx2c4.com/wireguard/tun/wintun/setupapi"
-)
-
-type Pool string
-
-type Interface struct {
- cfgInstanceID windows.GUID
- devInstanceID string
- luidIndex uint32
- ifType uint32
- pool Pool
-}
-
-var deviceClassNetGUID = windows.GUID{Data1: 0x4d36e972, Data2: 0xe325, Data3: 0x11ce, Data4: [8]byte{0xbf, 0xc1, 0x08, 0x00, 0x2b, 0xe1, 0x03, 0x18}}
-var deviceInterfaceNetGUID = windows.GUID{Data1: 0xcac88484, Data2: 0x7515, Data3: 0x4c03, Data4: [8]byte{0x82, 0xe6, 0x71, 0xa8, 0x7a, 0xba, 0xc3, 0x61}}
-
-const (
- hardwareID = "Wintun"
- waitForRegistryTimeout = time.Second * 10
-)
-
-// makeWintun creates a Wintun interface handle and populates it from the device's registry key.
-func makeWintun(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData, pool Pool) (*Interface, error) {
- // Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key.
- key, err := devInfo.OpenDevRegKey(devInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.QUERY_VALUE)
- if err != nil {
- return nil, fmt.Errorf("Device-specific registry key open failed: %v", err)
- }
- defer key.Close()
-
- // Read the NetCfgInstanceId value.
- valueStr, err := registryEx.GetStringValue(key, "NetCfgInstanceId")
- if err != nil {
- return nil, fmt.Errorf("RegQueryStringValue(\"NetCfgInstanceId\") failed: %v", err)
- }
-
- // Convert to GUID.
- ifid, err := windows.GUIDFromString(valueStr)
- if err != nil {
- return nil, fmt.Errorf("NetCfgInstanceId registry value is not a GUID (expected: \"{...}\", provided: %q)", valueStr)
- }
-
- // Read the NetLuidIndex value.
- luidIdx, _, err := key.GetIntegerValue("NetLuidIndex")
- if err != nil {
- return nil, fmt.Errorf("RegQueryValue(\"NetLuidIndex\") failed: %v", err)
- }
-
- // Read the NetLuidIndex value.
- ifType, _, err := key.GetIntegerValue("*IfType")
- if err != nil {
- return nil, fmt.Errorf("RegQueryValue(\"*IfType\") failed: %v", err)
- }
-
- instanceID, err := devInfo.DeviceInstanceID(devInfoData)
- if err != nil {
- return nil, fmt.Errorf("DeviceInstanceID failed: %v", err)
- }
-
- return &Interface{
- cfgInstanceID: ifid,
- devInstanceID: instanceID,
- luidIndex: uint32(luidIdx),
- ifType: uint32(ifType),
- pool: pool,
- }, nil
-}
-
-func removeNumberedSuffix(ifname string) string {
- removed := strings.TrimRight(ifname, "0123456789")
- if removed != ifname && len(removed) > 1 && removed[len(removed)-1] == ' ' {
- return removed[:len(removed)-1]
- }
- return ifname
-}
-
-// GetInterface finds a Wintun interface by its name. This function returns
-// the interface if found, or windows.ERROR_OBJECT_NOT_FOUND otherwise. If
-// the interface is found but not a Wintun-class or a member of the pool,
-// this function returns windows.ERROR_ALREADY_EXISTS.
-func (pool Pool) GetInterface(ifname string) (*Interface, error) {
- mutex, err := pool.takeNameMutex()
- if err != nil {
- return nil, err
- }
- defer func() {
- windows.ReleaseMutex(mutex)
- windows.CloseHandle(mutex)
- }()
-
- // Create a list of network devices.
- devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
- if err != nil {
- return nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err)
- }
- defer devInfo.Close()
-
- // Windows requires each interface to have a different name. When
- // enforcing this, Windows treats interface names case-insensitive. If an
- // interface "FooBar" exists and this function reports there is no
- // interface "foobar", an attempt to create a new interface and name it
- // "foobar" would cause conflict with "FooBar".
- ifname = strings.ToLower(ifname)
-
- for index := 0; ; index++ {
- devInfoData, err := devInfo.EnumDeviceInfo(index)
- if err != nil {
- if err == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- // Check the Hardware ID to make sure it's a real Wintun device first. This avoids doing slow operations on non-Wintun devices.
- property, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_HARDWAREID)
- if err != nil {
- continue
- }
- if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID {
- continue
- }
-
- wintun, err := makeWintun(devInfo, devInfoData, pool)
- if err != nil {
- continue
- }
-
- // TODO: is there a better way than comparing ifnames?
- ifname2, err := wintun.Name()
- if err != nil {
- continue
- }
- ifname2 = strings.ToLower(ifname2)
- ifname3 := removeNumberedSuffix(ifname2)
-
- if ifname == ifname2 || ifname == ifname3 {
- err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
- if err != nil {
- return nil, fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
- }
- defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
-
- for index := 0; ; index++ {
- driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, index)
- if err != nil {
- if err == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- // Get driver info details.
- driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
- if err != nil {
- continue
- }
-
- if driverDetailData.IsCompatible(hardwareID) {
- isMember, err := pool.isMember(devInfo, devInfoData)
- if err != nil {
- return nil, err
- }
- if !isMember {
- return nil, windows.ERROR_ALREADY_EXISTS
- }
-
- return wintun, nil
- }
- }
-
- // This interface is not using Wintun driver.
- return nil, windows.ERROR_ALREADY_EXISTS
- }
- }
-
- return nil, windows.ERROR_OBJECT_NOT_FOUND
-}
-
-// CreateInterface creates a Wintun interface. ifname is the requested name of
-// the interface, while requestedGUID is the GUID of the created network
-// interface, which then influences NLA generation deterministically. If it is
-// set to nil, the GUID is chosen by the system at random, and hence a new NLA
-// entry is created for each new interface. It is called "requested" GUID
-// because the API it uses is completely undocumented, and so there could be minor
-// interesting complications with its usage. This function returns the network
-// interface ID and a flag if reboot is required.
-func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wintun *Interface, rebootRequired bool, err error) {
- mutex, err := pool.takeNameMutex()
- if err != nil {
- return
- }
- defer func() {
- windows.ReleaseMutex(mutex)
- windows.CloseHandle(mutex)
- }()
-
- // Create an empty device info set for network adapter device class.
- devInfo, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "")
- if err != nil {
- err = fmt.Errorf("SetupDiCreateDeviceInfoListEx(%v) failed: %v", deviceClassNetGUID, err)
- return
- }
- defer devInfo.Close()
-
- // Get the device class name from GUID.
- className, err := setupapi.SetupDiClassNameFromGuidEx(&deviceClassNetGUID, "")
- if err != nil {
- err = fmt.Errorf("SetupDiClassNameFromGuidEx(%v) failed: %v", deviceClassNetGUID, err)
- return
- }
-
- // Create a new device info element and add it to the device info set.
- deviceTypeName := pool.deviceTypeName()
- devInfoData, err := devInfo.CreateDeviceInfo(className, &deviceClassNetGUID, deviceTypeName, 0, setupapi.DICD_GENERATE_ID)
- if err != nil {
- err = fmt.Errorf("SetupDiCreateDeviceInfo failed: %v", err)
- return
- }
-
- err = setQuietInstall(devInfo, devInfoData)
- if err != nil {
- err = fmt.Errorf("Setting quiet installation failed: %v", err)
- return
- }
-
- // Set a device information element as the selected member of a device information set.
- err = devInfo.SetSelectedDevice(devInfoData)
- if err != nil {
- err = fmt.Errorf("SetupDiSetSelectedDevice failed: %v", err)
- return
- }
-
- // Set Plug&Play device hardware ID property.
- err = devInfo.SetDeviceRegistryPropertyString(devInfoData, setupapi.SPDRP_HARDWAREID, hardwareID)
- if err != nil {
- err = fmt.Errorf("SetupDiSetDeviceRegistryProperty(SPDRP_HARDWAREID) failed: %v", err)
- return
- }
-
- err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER) // TODO: This takes ~510ms
- if err != nil {
- err = fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
- return
- }
- defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
-
- driverDate := windows.Filetime{}
- driverVersion := uint64(0)
- for index := 0; ; index++ { // TODO: This loop takes ~600ms
- driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, index)
- if err != nil {
- if err == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- // Check the driver version first, since the check is trivial and will save us iterating over hardware IDs for any driver versioned prior our best match.
- if driverData.IsNewer(driverDate, driverVersion) {
- driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
- if err != nil {
- continue
- }
-
- if driverDetailData.IsCompatible(hardwareID) {
- err := devInfo.SetSelectedDriver(devInfoData, driverData)
- if err != nil {
- continue
- }
-
- driverDate = driverData.DriverDate
- driverVersion = driverData.DriverVersion
- }
- }
- }
-
- if driverVersion == 0 {
- err = fmt.Errorf("No driver for device %q installed", hardwareID)
- return
- }
-
- defer func() {
- if err != nil {
- // The interface failed to install, or the interface ID was unobtainable. Clean-up.
- removeDeviceParams := setupapi.RemoveDeviceParams{
- ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE),
- Scope: setupapi.DI_REMOVEDEVICE_GLOBAL,
- }
-
- // Set class installer parameters for DIF_REMOVE.
- if devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) == nil {
- // Call appropriate class installer.
- if devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData) == nil {
- rebootRequired = rebootRequired || checkReboot(devInfo, devInfoData)
- }
- }
-
- wintun = nil
- }
- }()
-
- // Call appropriate class installer.
- err = devInfo.CallClassInstaller(setupapi.DIF_REGISTERDEVICE, devInfoData)
- if err != nil {
- err = fmt.Errorf("SetupDiCallClassInstaller(DIF_REGISTERDEVICE) failed: %v", err)
- return
- }
-
- // Register device co-installers if any. (Ignore errors)
- devInfo.CallClassInstaller(setupapi.DIF_REGISTER_COINSTALLERS, devInfoData)
-
- var netDevRegKey registry.Key
- const pollTimeout = time.Millisecond * 50
- for i := 0; i < int(waitForRegistryTimeout/pollTimeout); i++ {
- if i != 0 {
- time.Sleep(pollTimeout)
- }
- netDevRegKey, err = devInfo.OpenDevRegKey(devInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.SET_VALUE|registry.QUERY_VALUE|registry.NOTIFY)
- if err == nil {
- break
- }
- }
- if err != nil {
- err = fmt.Errorf("SetupDiOpenDevRegKey failed: %v", err)
- return
- }
- defer netDevRegKey.Close()
- if requestedGUID != nil {
- err = netDevRegKey.SetStringValue("NetSetupAnticipatedInstanceId", requestedGUID.String())
- if err != nil {
- err = fmt.Errorf("SetStringValue(NetSetupAnticipatedInstanceId) failed: %v", err)
- return
- }
- }
-
- // Install interfaces if any. (Ignore errors)
- devInfo.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, devInfoData)
-
- // Install the device.
- err = devInfo.CallClassInstaller(setupapi.DIF_INSTALLDEVICE, devInfoData)
- if err != nil {
- err = fmt.Errorf("SetupDiCallClassInstaller(DIF_INSTALLDEVICE) failed: %v", err)
- return
- }
- rebootRequired = checkReboot(devInfo, devInfoData)
-
- err = devInfo.SetDeviceRegistryPropertyString(devInfoData, setupapi.SPDRP_DEVICEDESC, deviceTypeName)
- if err != nil {
- err = fmt.Errorf("SetDeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err)
- return
- }
-
- // DIF_INSTALLDEVICE returns almost immediately, while the device installation
- // continues in the background. It might take a while, before all registry
- // keys and values are populated.
- _, err = registryEx.GetStringValueWait(netDevRegKey, "NetCfgInstanceId", waitForRegistryTimeout)
- if err != nil {
- err = fmt.Errorf("GetStringValueWait(NetCfgInstanceId) failed: %v", err)
- return
- }
- _, err = registryEx.GetIntegerValueWait(netDevRegKey, "NetLuidIndex", waitForRegistryTimeout)
- if err != nil {
- err = fmt.Errorf("GetIntegerValueWait(NetLuidIndex) failed: %v", err)
- return
- }
- _, err = registryEx.GetIntegerValueWait(netDevRegKey, "*IfType", waitForRegistryTimeout)
- if err != nil {
- err = fmt.Errorf("GetIntegerValueWait(*IfType) failed: %v", err)
- return
- }
-
- // Get network interface.
- wintun, err = makeWintun(devInfo, devInfoData, pool)
- if err != nil {
- err = fmt.Errorf("makeWintun failed: %v", err)
- return
- }
-
- // Wait for TCP/IP adapter registry key to emerge and populate.
- tcpipAdapterRegKey, err := registryEx.OpenKeyWait(
- registry.LOCAL_MACHINE,
- wintun.tcpipAdapterRegKeyName(), registry.QUERY_VALUE|registry.NOTIFY,
- waitForRegistryTimeout)
- if err != nil {
- err = fmt.Errorf("OpenKeyWait(HKLM\\%s) failed: %v", wintun.tcpipAdapterRegKeyName(), err)
- return
- }
- defer tcpipAdapterRegKey.Close()
- _, err = registryEx.GetStringValueWait(tcpipAdapterRegKey, "IpConfig", waitForRegistryTimeout)
- if err != nil {
- err = fmt.Errorf("GetStringValueWait(IpConfig) failed: %v", err)
- return
- }
-
- tcpipInterfaceRegKeyName, err := wintun.tcpipInterfaceRegKeyName()
- if err != nil {
- err = fmt.Errorf("tcpipInterfaceRegKeyName failed: %v", err)
- return
- }
-
- // Wait for TCP/IP interface registry key to emerge.
- tcpipInterfaceRegKey, err := registryEx.OpenKeyWait(
- registry.LOCAL_MACHINE,
- tcpipInterfaceRegKeyName, registry.QUERY_VALUE|registry.SET_VALUE,
- waitForRegistryTimeout)
- if err != nil {
- err = fmt.Errorf("OpenKeyWait(HKLM\\%s) failed: %v", tcpipInterfaceRegKeyName, err)
- return
- }
- defer tcpipInterfaceRegKey.Close()
- // Disable dead gateway detection on our interface.
- tcpipInterfaceRegKey.SetDWordValue("EnableDeadGWDetect", 0)
-
- err = wintun.SetName(ifname)
- if err != nil {
- err = fmt.Errorf("Unable to set name of Wintun interface: %v", err)
- return
- }
-
- return
-}
-
-// DeleteInterface deletes a Wintun interface. This function succeeds
-// if the interface was not found. It returns a bool indicating whether
-// a reboot is required.
-func (wintun *Interface) DeleteInterface() (rebootRequired bool, err error) {
- devInfo, devInfoData, err := wintun.devInfoData()
- if err == windows.ERROR_OBJECT_NOT_FOUND {
- return false, nil
- }
- if err != nil {
- return false, err
- }
- defer devInfo.Close()
-
- // Remove the device.
- removeDeviceParams := setupapi.RemoveDeviceParams{
- ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE),
- Scope: setupapi.DI_REMOVEDEVICE_GLOBAL,
- }
-
- // Set class installer parameters for DIF_REMOVE.
- err = devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
- if err != nil {
- return false, fmt.Errorf("SetupDiSetClassInstallParams failed: %v", err)
- }
-
- // Call appropriate class installer.
- err = devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData)
- if err != nil {
- return false, fmt.Errorf("SetupDiCallClassInstaller failed: %v", err)
- }
-
- return checkReboot(devInfo, devInfoData), nil
-}
-
-// DeleteMatchingInterfaces deletes all Wintun interfaces, which match
-// given criteria, and returns which ones it deleted, whether a reboot
-// is required after, and which errors occurred during the process.
-func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool) (deviceInstancesDeleted []uint32, rebootRequired bool, errors []error) {
- mutex, err := pool.takeNameMutex()
- if err != nil {
- errors = append(errors, err)
- return
- }
- defer func() {
- windows.ReleaseMutex(mutex)
- windows.CloseHandle(mutex)
- }()
-
- devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
- if err != nil {
- return nil, false, []error{fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())}
- }
- defer devInfo.Close()
-
- for i := 0; ; i++ {
- devInfoData, err := devInfo.EnumDeviceInfo(i)
- if err != nil {
- if err == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- // Check the Hardware ID to make sure it's a real Wintun device first. This avoids doing slow operations on non-Wintun devices.
- property, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_HARDWAREID)
- if err != nil {
- continue
- }
- if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID {
- continue
- }
-
- err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
- if err != nil {
- continue
- }
- defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
-
- isWintun := false
- for j := 0; ; j++ {
- driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, j)
- if err != nil {
- if err == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
- driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
- if err != nil {
- continue
- }
- if driverDetailData.IsCompatible(hardwareID) {
- isWintun = true
- break
- }
- }
- if !isWintun {
- continue
- }
-
- isMember, err := pool.isMember(devInfo, devInfoData)
- if err != nil {
- errors = append(errors, err)
- continue
- }
- if !isMember {
- continue
- }
-
- wintun, err := makeWintun(devInfo, devInfoData, pool)
- if err != nil {
- errors = append(errors, fmt.Errorf("Unable to make Wintun interface object: %v", err))
- continue
- }
- if !matches(wintun) {
- continue
- }
-
- err = setQuietInstall(devInfo, devInfoData)
- if err != nil {
- errors = append(errors, err)
- continue
- }
-
- inst := devInfoData.DevInst
- removeDeviceParams := setupapi.RemoveDeviceParams{
- ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE),
- Scope: setupapi.DI_REMOVEDEVICE_GLOBAL,
- }
- err = devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
- if err != nil {
- errors = append(errors, err)
- continue
- }
- err = devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData)
- if err != nil {
- errors = append(errors, err)
- continue
- }
- rebootRequired = rebootRequired || checkReboot(devInfo, devInfoData)
- deviceInstancesDeleted = append(deviceInstancesDeleted, inst)
- }
- return
-}
-
-// isMember checks if SPDRP_DEVICEDESC or SPDRP_FRIENDLYNAME match device type name.
-func (pool Pool) isMember(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) (bool, error) {
- deviceDescVal, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_DEVICEDESC)
- if err != nil {
- return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err)
- }
- deviceDesc, _ := deviceDescVal.(string)
- friendlyNameVal, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_FRIENDLYNAME)
- if err != nil {
- return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_FRIENDLYNAME) failed: %v", err)
- }
- friendlyName, _ := friendlyNameVal.(string)
- deviceTypeName := pool.deviceTypeName()
- return friendlyName == deviceTypeName || deviceDesc == deviceTypeName ||
- removeNumberedSuffix(friendlyName) == deviceTypeName || removeNumberedSuffix(deviceDesc) == deviceTypeName, nil
-}
-
-// checkReboot checks device install parameters if a system reboot is required.
-func checkReboot(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) bool {
- devInstallParams, err := devInfo.DeviceInstallParams(devInfoData)
- if err != nil {
- return false
- }
-
- return (devInstallParams.Flags & (setupapi.DI_NEEDREBOOT | setupapi.DI_NEEDRESTART)) != 0
-}
-
-// setQuietInstall sets device install parameters for a quiet installation
-func setQuietInstall(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) error {
- devInstallParams, err := devInfo.DeviceInstallParams(devInfoData)
- if err != nil {
- return err
- }
-
- devInstallParams.Flags |= setupapi.DI_QUIETINSTALL
- return devInfo.SetDeviceInstallParams(devInfoData, devInstallParams)
-}
-
-// deviceTypeName returns pool-specific device type name.
-func (pool Pool) deviceTypeName() string {
- return fmt.Sprintf("%s Tunnel", pool)
-}
-
-// Name returns the name of the Wintun interface.
-func (wintun *Interface) Name() (string, error) {
- return nci.ConnectionName(&wintun.cfgInstanceID)
-}
-
-// SetName sets name of the Wintun interface.
-func (wintun *Interface) SetName(ifname string) error {
- const maxSuffix = 1000
- availableIfname := ifname
- for i := 0; ; i++ {
- err := nci.SetConnectionName(&wintun.cfgInstanceID, availableIfname)
- if err == windows.ERROR_DUP_NAME {
- duplicateGuid, err2 := iphlpapi.InterfaceGUIDFromAlias(availableIfname)
- if err2 == nil {
- for j := 0; j < maxSuffix; j++ {
- proposal := fmt.Sprintf("%s %d", ifname, j+1)
- if proposal == availableIfname {
- continue
- }
- err2 = nci.SetConnectionName(duplicateGuid, proposal)
- if err2 == windows.ERROR_DUP_NAME {
- continue
- }
- if err2 == nil {
- err = nci.SetConnectionName(&wintun.cfgInstanceID, availableIfname)
- if err == nil {
- break
- }
- }
- break
- }
- }
- }
- if err == nil {
- break
- }
-
- if i > maxSuffix || err != windows.ERROR_DUP_NAME {
- return fmt.Errorf("NciSetConnectionName failed: %v", err)
- }
- availableIfname = fmt.Sprintf("%s %d", ifname, i+1)
- }
-
- // TODO: This should use NetSetup2 so that it doesn't get unset.
- deviceRegKey, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.deviceRegKeyName(), registry.SET_VALUE)
- if err != nil {
- return fmt.Errorf("Device-level registry key open failed: %v", err)
- }
- defer deviceRegKey.Close()
- err = deviceRegKey.SetStringValue("FriendlyName", wintun.pool.deviceTypeName())
- if err != nil {
- return fmt.Errorf("SetStringValue(FriendlyName) failed: %v", err)
- }
- return nil
-}
-
-// tcpipAdapterRegKeyName returns the adapter-specific TCP/IP network registry key name.
-func (wintun *Interface) tcpipAdapterRegKeyName() string {
- return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", wintun.cfgInstanceID)
-}
-
-// deviceRegKeyName returns the device-level registry key name.
-func (wintun *Interface) deviceRegKeyName() string {
- return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Enum\\%v", wintun.devInstanceID)
-}
-
-// Version returns the version of the Wintun driver and NDIS system currently loaded.
-func (wintun *Interface) Version() (driverVersion string, ndisVersion string, err error) {
- key, err := registry.OpenKey(registry.LOCAL_MACHINE, "SYSTEM\\CurrentControlSet\\Services\\Wintun", registry.QUERY_VALUE)
- if err != nil {
- return
- }
- defer key.Close()
- driverMajor, _, err := key.GetIntegerValue("DriverMajorVersion")
- if err != nil {
- return
- }
- driverMinor, _, err := key.GetIntegerValue("DriverMinorVersion")
- if err != nil {
- return
- }
- ndisMajor, _, err := key.GetIntegerValue("NdisMajorVersion")
- if err != nil {
- return
- }
- ndisMinor, _, err := key.GetIntegerValue("NdisMinorVersion")
- if err != nil {
- return
- }
- driverVersion = fmt.Sprintf("%d.%d", driverMajor, driverMinor)
- ndisVersion = fmt.Sprintf("%d.%d", ndisMajor, ndisMinor)
- return
-}
-
-// tcpipInterfaceRegKeyName returns the interface-specific TCP/IP network registry key name.
-func (wintun *Interface) tcpipInterfaceRegKeyName() (path string, err error) {
- key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.tcpipAdapterRegKeyName(), registry.QUERY_VALUE)
- if err != nil {
- return "", fmt.Errorf("Error opening adapter-specific TCP/IP network registry key: %v", err)
- }
- paths, _, err := key.GetStringsValue("IpConfig")
- key.Close()
- if err != nil {
- return "", fmt.Errorf("Error reading IpConfig registry key: %v", err)
- }
- if len(paths) == 0 {
- return "", errors.New("No TCP/IP interfaces found on adapter")
- }
- return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), nil
-}
-
-// devInfoData returns TUN device info list handle and interface device info
-// data. The device info list handle must be closed after use. In case the
-// device is not found, windows.ERROR_OBJECT_NOT_FOUND is returned.
-func (wintun *Interface) devInfoData() (setupapi.DevInfo, *setupapi.DevInfoData, error) {
- // Create a list of network devices.
- devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
- if err != nil {
- return 0, nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())
- }
-
- for index := 0; ; index++ {
- devInfoData, err := devInfo.EnumDeviceInfo(index)
- if err != nil {
- if err == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- // Get interface ID.
- // TODO: Store some ID in the Wintun object such that this call isn't required.
- wintun2, err := makeWintun(devInfo, devInfoData, wintun.pool)
- if err != nil {
- continue
- }
-
- if wintun.cfgInstanceID == wintun2.cfgInstanceID {
- err = setQuietInstall(devInfo, devInfoData)
- if err != nil {
- devInfo.Close()
- return 0, nil, fmt.Errorf("Setting quiet installation failed: %v", err)
- }
- return devInfo, devInfoData, nil
- }
- }
-
- devInfo.Close()
- return 0, nil, windows.ERROR_OBJECT_NOT_FOUND
-}
-
-// handle returns a handle to the interface device object.
-func (wintun *Interface) handle() (windows.Handle, error) {
- interfaces, err := setupapi.CM_Get_Device_Interface_List(wintun.devInstanceID, &deviceInterfaceNetGUID, setupapi.CM_GET_DEVICE_INTERFACE_LIST_PRESENT)
- if err != nil {
- return windows.InvalidHandle, fmt.Errorf("Error listing NDIS interfaces: %v", err)
- }
- handle, err := windows.CreateFile(windows.StringToUTF16Ptr(interfaces[0]), windows.GENERIC_READ|windows.GENERIC_WRITE, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE, nil, windows.OPEN_EXISTING, 0, 0)
- if err != nil {
- return windows.InvalidHandle, fmt.Errorf("Error opening NDIS device: %v", err)
- }
- return handle, nil
-}
-
-// GUID returns the GUID of the interface.
-func (wintun *Interface) GUID() windows.GUID {
- return wintun.cfgInstanceID
-}
-
-// LUID returns the LUID of the interface.
-func (wintun *Interface) LUID() uint64 {
- return ((uint64(wintun.luidIndex) & ((1 << 24) - 1)) << 24) | ((uint64(wintun.ifType) & ((1 << 16) - 1)) << 48)
-}
diff --git a/version.go b/version.go
new file mode 100644
index 0000000..db75bb9
--- /dev/null
+++ b/version.go
@@ -0,0 +1,3 @@
+package main
+
+const Version = "0.0.20230223"