aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Makefile2
-rw-r--r--README.md4
-rw-r--r--conn/bind_linux.go585
-rw-r--r--conn/bind_std.go524
-rw-r--r--conn/bind_std_test.go250
-rw-r--r--conn/bind_windows.go152
-rw-r--r--conn/bindtest/bindtest.go61
-rw-r--r--conn/boundif_android.go10
-rw-r--r--conn/conn.go62
-rw-r--r--conn/conn_test.go24
-rw-r--r--conn/controlfns.go43
-rw-r--r--conn/controlfns_linux.go69
-rw-r--r--conn/controlfns_unix.go35
-rw-r--r--conn/controlfns_windows.go23
-rw-r--r--conn/default.go4
-rw-r--r--conn/errors_default.go12
-rw-r--r--conn/errors_linux.go26
-rw-r--r--conn/features_default.go15
-rw-r--r--conn/features_linux.go29
-rw-r--r--conn/gso_default.go21
-rw-r--r--conn/gso_linux.go65
-rw-r--r--conn/mark_default.go6
-rw-r--r--conn/mark_unix.go14
-rw-r--r--conn/sticky_default.go42
-rw-r--r--conn/sticky_linux.go112
-rw-r--r--conn/sticky_linux_test.go266
-rw-r--r--conn/winrio/rio_windows.go10
-rw-r--r--device/alignment_test.go65
-rw-r--r--device/allowedips.go89
-rw-r--r--device/allowedips_rand_test.go7
-rw-r--r--device/allowedips_test.go10
-rw-r--r--device/bind_test.go12
-rw-r--r--device/channels.go40
-rw-r--r--device/constants.go2
-rw-r--r--device/cookie.go7
-rw-r--r--device/cookie_test.go7
-rw-r--r--device/device.go90
-rw-r--r--device/device_test.go85
-rw-r--r--device/endpoint_test.go40
-rw-r--r--device/indextable.go2
-rw-r--r--device/ip.go2
-rw-r--r--device/kdf_test.go4
-rw-r--r--device/keypair.go15
-rw-r--r--device/logger.go10
-rw-r--r--device/misc.go41
-rw-r--r--device/mobilequirks.go11
-rw-r--r--device/noise-helpers.go12
-rw-r--r--device/noise-protocol.go86
-rw-r--r--device/noise-types.go2
-rw-r--r--device/noise_test.go12
-rw-r--r--device/peer.go134
-rw-r--r--device/pools.go58
-rw-r--r--device/pools_test.go77
-rw-r--r--device/queueconstants_android.go8
-rw-r--r--device/queueconstants_default.go8
-rw-r--r--device/queueconstants_ios.go4
-rw-r--r--device/queueconstants_windows.go2
-rw-r--r--device/race_disabled_test.go4
-rw-r--r--device/race_enabled_test.go4
-rw-r--r--device/receive.go385
-rw-r--r--device/send.go355
-rw-r--r--device/sticky_default.go2
-rw-r--r--device/sticky_linux.go45
-rw-r--r--device/timers.go55
-rw-r--r--device/tun.go5
-rw-r--r--device/uapi.go82
-rw-r--r--format_test.go2
-rw-r--r--go.mod15
-rw-r--r--go.sum36
-rw-r--r--ipc/namedpipe/file.go (renamed from ipc/winpipe/file.go)39
-rw-r--r--ipc/namedpipe/namedpipe.go (renamed from ipc/winpipe/winpipe.go)91
-rw-r--r--ipc/namedpipe/namedpipe_test.go (renamed from ipc/winpipe/winpipe_test.go)124
-rw-r--r--ipc/uapi_bsd.go7
-rw-r--r--ipc/uapi_linux.go7
-rw-r--r--ipc/uapi_unix.go8
-rw-r--r--ipc/uapi_wasm.go15
-rw-r--r--ipc/uapi_windows.go10
-rw-r--r--main.go19
-rw-r--r--main_windows.go7
-rw-r--r--ratelimiter/ratelimiter.go59
-rw-r--r--ratelimiter/ratelimiter_test.go34
-rw-r--r--replay/replay.go4
-rw-r--r--replay/replay_test.go2
-rw-r--r--rwcancel/rwcancel.go4
-rw-r--r--rwcancel/rwcancel_stub.go (renamed from rwcancel/rwcancel_windows.go)5
-rw-r--r--tai64n/tai64n.go10
-rw-r--r--tai64n/tai64n_test.go2
-rw-r--r--tun/alignment_windows_test.go2
-rw-r--r--tun/checksum.go118
-rw-r--r--tun/checksum_test.go35
-rw-r--r--tun/errors.go12
-rw-r--r--tun/netstack/examples/http_client.go18
-rw-r--r--tun/netstack/examples/http_server.go17
-rw-r--r--tun/netstack/examples/ping_client.go75
-rw-r--r--tun/netstack/go.mod11
-rw-r--r--tun/netstack/go.sum605
-rw-r--r--tun/netstack/tun.go535
-rw-r--r--tun/offload_linux.go993
-rw-r--r--tun/offload_linux_test.go752
-rw-r--r--tun/operateonfd.go4
-rw-r--r--tun/tun.go42
-rw-r--r--tun/tun_darwin.go102
-rw-r--r--tun/tun_freebsd.go73
-rw-r--r--tun/tun_linux.go315
-rw-r--r--tun/tun_openbsd.go79
-rw-r--r--tun/tun_windows.go141
-rw-r--r--tun/tuntest/tuntest.go47
-rw-r--r--tun/wintun/dll_fromfile_windows.go54
-rw-r--r--tun/wintun/dll_fromrsrc_windows.go61
-rw-r--r--tun/wintun/dll_windows.go59
-rw-r--r--tun/wintun/memmod/memmod_windows.go635
-rw-r--r--tun/wintun/memmod/memmod_windows_32.go16
-rw-r--r--tun/wintun/memmod/memmod_windows_386.go8
-rw-r--r--tun/wintun/memmod/memmod_windows_64.go36
-rw-r--r--tun/wintun/memmod/memmod_windows_amd64.go8
-rw-r--r--tun/wintun/memmod/memmod_windows_arm.go8
-rw-r--r--tun/wintun/memmod/memmod_windows_arm64.go8
-rw-r--r--tun/wintun/memmod/syscall_windows.go398
-rw-r--r--tun/wintun/memmod/syscall_windows_32.go96
-rw-r--r--tun/wintun/memmod/syscall_windows_64.go95
-rw-r--r--tun/wintun/session_windows.go108
-rw-r--r--tun/wintun/wintun_windows.go215
-rw-r--r--version.go2
123 files changed, 5808 insertions, 4820 deletions
diff --git a/Makefile b/Makefile
index c86ffa3..3f6e407 100644
--- a/Makefile
+++ b/Makefile
@@ -23,7 +23,7 @@ install: wireguard-go
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go"
test:
- go test -v ./...
+ go test ./...
clean:
rm -f wireguard-go
diff --git a/README.md b/README.md
index 9e01081..074f7ec 100644
--- a/README.md
+++ b/README.md
@@ -46,7 +46,7 @@ This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapp
## Building
-This requires an installation of [go](https://golang.org) ≥ 1.16.
+This requires an installation of the latest version of [Go](https://go.dev/).
```
$ git clone https://git.zx2c4.com/wireguard-go
@@ -56,7 +56,7 @@ $ make
## License
- Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
diff --git a/conn/bind_linux.go b/conn/bind_linux.go
deleted file mode 100644
index 7b970e6..0000000
--- a/conn/bind_linux.go
+++ /dev/null
@@ -1,585 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package conn
-
-import (
- "errors"
- "net"
- "strconv"
- "sync"
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/unix"
-)
-
-type ipv4Source struct {
- Src [4]byte
- Ifindex int32
-}
-
-type ipv6Source struct {
- src [16]byte
- // ifindex belongs in dst.ZoneId
-}
-
-type LinuxSocketEndpoint struct {
- mu sync.Mutex
- dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
- src [unsafe.Sizeof(ipv6Source{})]byte
- isV6 bool
-}
-
-func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source { return endpoint.src4() }
-func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
-func (endpoint *LinuxSocketEndpoint) IsV6() bool { return endpoint.isV6 }
-
-func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source {
- return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0]))
-}
-
-func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source {
- return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0]))
-}
-
-func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 {
- return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
-}
-
-func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
- return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
-}
-
-// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
-type LinuxSocketBind struct {
- // mu guards sock4 and sock6 and the associated fds.
- // As long as someone holds mu (read or write), the associated fds are valid.
- mu sync.RWMutex
- sock4 int
- sock6 int
-}
-
-func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
-func NewDefaultBind() Bind { return NewLinuxSocketBind() }
-
-var _ Endpoint = (*LinuxSocketEndpoint)(nil)
-var _ Bind = (*LinuxSocketBind)(nil)
-
-func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
- var end LinuxSocketEndpoint
- addr, err := parseEndpoint(s)
- if err != nil {
- return nil, err
- }
-
- ipv4 := addr.IP.To4()
- if ipv4 != nil {
- dst := end.dst4()
- end.isV6 = false
- dst.Port = addr.Port
- copy(dst.Addr[:], ipv4)
- end.ClearSrc()
- return &end, nil
- }
-
- ipv6 := addr.IP.To16()
- if ipv6 != nil {
- zone, err := zoneToUint32(addr.Zone)
- if err != nil {
- return nil, err
- }
- dst := end.dst6()
- end.isV6 = true
- dst.Port = addr.Port
- dst.ZoneId = zone
- copy(dst.Addr[:], ipv6[:])
- end.ClearSrc()
- return &end, nil
- }
-
- return nil, errors.New("invalid IP address")
-}
-
-func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
- bind.mu.Lock()
- defer bind.mu.Unlock()
-
- var err error
- var newPort uint16
- var tries int
-
- if bind.sock4 != -1 || bind.sock6 != -1 {
- return nil, 0, ErrBindAlreadyOpen
- }
-
- originalPort := port
-
-again:
- port = originalPort
- var sock4, sock6 int
- // Attempt ipv6 bind, update port if successful.
- sock6, newPort, err = create6(port)
- if err != nil {
- if !errors.Is(err, syscall.EAFNOSUPPORT) {
- return nil, 0, err
- }
- } else {
- port = newPort
- }
-
- // Attempt ipv4 bind, update port if successful.
- sock4, newPort, err = create4(port)
- if err != nil {
- if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
- unix.Close(sock6)
- tries++
- goto again
- }
- if !errors.Is(err, syscall.EAFNOSUPPORT) {
- unix.Close(sock6)
- return nil, 0, err
- }
- } else {
- port = newPort
- }
-
- var fns []ReceiveFunc
- if sock4 != -1 {
- bind.sock4 = sock4
- fns = append(fns, bind.receiveIPv4)
- }
- if sock6 != -1 {
- bind.sock6 = sock6
- fns = append(fns, bind.receiveIPv6)
- }
- if len(fns) == 0 {
- return nil, 0, syscall.EAFNOSUPPORT
- }
- return fns, port, nil
-}
-
-func (bind *LinuxSocketBind) SetMark(value uint32) error {
- bind.mu.RLock()
- defer bind.mu.RUnlock()
-
- if bind.sock6 != -1 {
- err := unix.SetsockoptInt(
- bind.sock6,
- unix.SOL_SOCKET,
- unix.SO_MARK,
- int(value),
- )
-
- if err != nil {
- return err
- }
- }
-
- if bind.sock4 != -1 {
- err := unix.SetsockoptInt(
- bind.sock4,
- unix.SOL_SOCKET,
- unix.SO_MARK,
- int(value),
- )
-
- if err != nil {
- return err
- }
- }
-
- return nil
-}
-
-func (bind *LinuxSocketBind) Close() error {
- // Take a readlock to shut down the sockets...
- bind.mu.RLock()
- if bind.sock6 != -1 {
- unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
- }
- if bind.sock4 != -1 {
- unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
- }
- bind.mu.RUnlock()
- // ...and a write lock to close the fd.
- // This ensures that no one else is using the fd.
- bind.mu.Lock()
- defer bind.mu.Unlock()
- var err1, err2 error
- if bind.sock6 != -1 {
- err1 = unix.Close(bind.sock6)
- bind.sock6 = -1
- }
- if bind.sock4 != -1 {
- err2 = unix.Close(bind.sock4)
- bind.sock4 = -1
- }
-
- if err1 != nil {
- return err1
- }
- return err2
-}
-
-func (bind *LinuxSocketBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
- bind.mu.RLock()
- defer bind.mu.RUnlock()
- if bind.sock4 == -1 {
- return 0, nil, net.ErrClosed
- }
- var end LinuxSocketEndpoint
- n, err := receive4(bind.sock4, buf, &end)
- return n, &end, err
-}
-
-func (bind *LinuxSocketBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
- bind.mu.RLock()
- defer bind.mu.RUnlock()
- if bind.sock6 == -1 {
- return 0, nil, net.ErrClosed
- }
- var end LinuxSocketEndpoint
- n, err := receive6(bind.sock6, buf, &end)
- return n, &end, err
-}
-
-func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
- nend, ok := end.(*LinuxSocketEndpoint)
- if !ok {
- return ErrWrongEndpointType
- }
- bind.mu.RLock()
- defer bind.mu.RUnlock()
- if !nend.isV6 {
- if bind.sock4 == -1 {
- return net.ErrClosed
- }
- return send4(bind.sock4, nend, buff)
- } else {
- if bind.sock6 == -1 {
- return net.ErrClosed
- }
- return send6(bind.sock6, nend, buff)
- }
-}
-
-func (end *LinuxSocketEndpoint) SrcIP() net.IP {
- if !end.isV6 {
- return net.IPv4(
- end.src4().Src[0],
- end.src4().Src[1],
- end.src4().Src[2],
- end.src4().Src[3],
- )
- } else {
- return end.src6().src[:]
- }
-}
-
-func (end *LinuxSocketEndpoint) DstIP() net.IP {
- if !end.isV6 {
- return net.IPv4(
- end.dst4().Addr[0],
- end.dst4().Addr[1],
- end.dst4().Addr[2],
- end.dst4().Addr[3],
- )
- } else {
- return end.dst6().Addr[:]
- }
-}
-
-func (end *LinuxSocketEndpoint) DstToBytes() []byte {
- if !end.isV6 {
- return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
- } else {
- return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
- }
-}
-
-func (end *LinuxSocketEndpoint) SrcToString() string {
- return end.SrcIP().String()
-}
-
-func (end *LinuxSocketEndpoint) DstToString() string {
- var udpAddr net.UDPAddr
- udpAddr.IP = end.DstIP()
- if !end.isV6 {
- udpAddr.Port = end.dst4().Port
- } else {
- udpAddr.Port = end.dst6().Port
- }
- return udpAddr.String()
-}
-
-func (end *LinuxSocketEndpoint) ClearDst() {
- for i := range end.dst {
- end.dst[i] = 0
- }
-}
-
-func (end *LinuxSocketEndpoint) ClearSrc() {
- for i := range end.src {
- end.src[i] = 0
- }
-}
-
-func zoneToUint32(zone string) (uint32, error) {
- if zone == "" {
- return 0, nil
- }
- if intr, err := net.InterfaceByName(zone); err == nil {
- return uint32(intr.Index), nil
- }
- n, err := strconv.ParseUint(zone, 10, 32)
- return uint32(n), err
-}
-
-func create4(port uint16) (int, uint16, error) {
-
- // create socket
-
- fd, err := unix.Socket(
- unix.AF_INET,
- unix.SOCK_DGRAM,
- 0,
- )
-
- if err != nil {
- return -1, 0, err
- }
-
- addr := unix.SockaddrInet4{
- Port: int(port),
- }
-
- // set sockopts and bind
-
- if err := func() error {
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IP,
- unix.IP_PKTINFO,
- 1,
- ); err != nil {
- return err
- }
-
- return unix.Bind(fd, &addr)
- }(); err != nil {
- unix.Close(fd)
- return -1, 0, err
- }
-
- sa, err := unix.Getsockname(fd)
- if err == nil {
- addr.Port = sa.(*unix.SockaddrInet4).Port
- }
-
- return fd, uint16(addr.Port), err
-}
-
-func create6(port uint16) (int, uint16, error) {
-
- // create socket
-
- fd, err := unix.Socket(
- unix.AF_INET6,
- unix.SOCK_DGRAM,
- 0,
- )
-
- if err != nil {
- return -1, 0, err
- }
-
- // set sockopts and bind
-
- addr := unix.SockaddrInet6{
- Port: int(port),
- }
-
- if err := func() error {
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IPV6,
- unix.IPV6_RECVPKTINFO,
- 1,
- ); err != nil {
- return err
- }
-
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IPV6,
- unix.IPV6_V6ONLY,
- 1,
- ); err != nil {
- return err
- }
-
- return unix.Bind(fd, &addr)
-
- }(); err != nil {
- unix.Close(fd)
- return -1, 0, err
- }
-
- sa, err := unix.Getsockname(fd)
- if err == nil {
- addr.Port = sa.(*unix.SockaddrInet6).Port
- }
-
- return fd, uint16(addr.Port), err
-}
-
-func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error {
-
- // construct message header
-
- cmsg := struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet4Pktinfo
- }{
- unix.Cmsghdr{
- Level: unix.IPPROTO_IP,
- Type: unix.IP_PKTINFO,
- Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
- },
- unix.Inet4Pktinfo{
- Spec_dst: end.src4().Src,
- Ifindex: end.src4().Ifindex,
- },
- }
-
- end.mu.Lock()
- _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
- end.mu.Unlock()
-
- if err == nil {
- return nil
- }
-
- // clear src and retry
-
- if err == unix.EINVAL {
- end.ClearSrc()
- cmsg.pktinfo = unix.Inet4Pktinfo{}
- end.mu.Lock()
- _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
- end.mu.Unlock()
- }
-
- return err
-}
-
-func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error {
-
- // construct message header
-
- cmsg := struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet6Pktinfo
- }{
- unix.Cmsghdr{
- Level: unix.IPPROTO_IPV6,
- Type: unix.IPV6_PKTINFO,
- Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
- },
- unix.Inet6Pktinfo{
- Addr: end.src6().src,
- Ifindex: end.dst6().ZoneId,
- },
- }
-
- if cmsg.pktinfo.Addr == [16]byte{} {
- cmsg.pktinfo.Ifindex = 0
- }
-
- end.mu.Lock()
- _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
- end.mu.Unlock()
-
- if err == nil {
- return nil
- }
-
- // clear src and retry
-
- if err == unix.EINVAL {
- end.ClearSrc()
- cmsg.pktinfo = unix.Inet6Pktinfo{}
- end.mu.Lock()
- _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
- end.mu.Unlock()
- }
-
- return err
-}
-
-func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
-
- // construct message header
-
- var cmsg struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet4Pktinfo
- }
-
- size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
-
- if err != nil {
- return 0, err
- }
- end.isV6 = false
-
- if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
- *end.dst4() = *newDst4
- }
-
- // update source cache
-
- if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
- cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
- cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
- end.src4().Src = cmsg.pktinfo.Spec_dst
- end.src4().Ifindex = cmsg.pktinfo.Ifindex
- }
-
- return size, nil
-}
-
-func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
-
- // construct message header
-
- var cmsg struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet6Pktinfo
- }
-
- size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
-
- if err != nil {
- return 0, err
- }
- end.isV6 = true
-
- if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
- *end.dst6() = *newDst6
- }
-
- // update source cache
-
- if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
- cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
- cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
- end.src6().src = cmsg.pktinfo.Addr
- end.dst6().ZoneId = cmsg.pktinfo.Ifindex
- }
-
- return size, nil
-}
diff --git a/conn/bind_std.go b/conn/bind_std.go
index cb85cfd..46df7fd 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -1,72 +1,126 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
+ "context"
"errors"
+ "fmt"
"net"
+ "net/netip"
+ "runtime"
+ "strconv"
"sync"
"syscall"
+
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+)
+
+var (
+ _ Bind = (*StdNetBind)(nil)
)
-// StdNetBind is meant to be a temporary solution on platforms for which
-// the sticky socket / source caching behavior has not yet been implemented.
-// It uses the Go's net package to implement networking.
-// See LinuxSocketBind for a proper implementation on the Linux platform.
+// StdNetBind implements Bind for all platforms. While Windows has its own Bind
+// (see bind_windows.go), it may fall back to StdNetBind.
+// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
+// methods for sending and receiving multiple datagrams per-syscall. See the
+// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
type StdNetBind struct {
- mu sync.Mutex // protects following fields
- ipv4 *net.UDPConn
- ipv6 *net.UDPConn
+ mu sync.Mutex // protects all fields except as specified
+ ipv4 *net.UDPConn
+ ipv6 *net.UDPConn
+ ipv4PC *ipv4.PacketConn // will be nil on non-Linux
+ ipv6PC *ipv6.PacketConn // will be nil on non-Linux
+ ipv4TxOffload bool
+ ipv4RxOffload bool
+ ipv6TxOffload bool
+ ipv6RxOffload bool
+
+ // these two fields are not guarded by mu
+ udpAddrPool sync.Pool
+ msgsPool sync.Pool
+
blackhole4 bool
blackhole6 bool
}
-func NewStdNetBind() Bind { return &StdNetBind{} }
+func NewStdNetBind() Bind {
+ return &StdNetBind{
+ udpAddrPool: sync.Pool{
+ New: func() any {
+ return &net.UDPAddr{
+ IP: make([]byte, 16),
+ }
+ },
+ },
+
+ msgsPool: sync.Pool{
+ New: func() any {
+ // ipv6.Message and ipv4.Message are interchangeable as they are
+ // both aliases for x/net/internal/socket.Message.
+ msgs := make([]ipv6.Message, IdealBatchSize)
+ for i := range msgs {
+ msgs[i].Buffers = make(net.Buffers, 1)
+ msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
+ }
+ return &msgs
+ },
+ },
+ }
+}
-type StdNetEndpoint net.UDPAddr
+type StdNetEndpoint struct {
+ // AddrPort is the endpoint destination.
+ netip.AddrPort
+ // src is the current sticky source address and interface index, if
+ // supported. Typically this is a PKTINFO structure from/for control
+ // messages, see unix.PKTINFO for an example.
+ src []byte
+}
-var _ Bind = (*StdNetBind)(nil)
-var _ Endpoint = (*StdNetEndpoint)(nil)
+var (
+ _ Bind = (*StdNetBind)(nil)
+ _ Endpoint = &StdNetEndpoint{}
+)
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
- addr, err := parseEndpoint(s)
- return (*StdNetEndpoint)(addr), err
+ e, err := netip.ParseAddrPort(s)
+ if err != nil {
+ return nil, err
+ }
+ return &StdNetEndpoint{
+ AddrPort: e,
+ }, nil
}
-func (*StdNetEndpoint) ClearSrc() {}
-
-func (e *StdNetEndpoint) DstIP() net.IP {
- return (*net.UDPAddr)(e).IP
+func (e *StdNetEndpoint) ClearSrc() {
+ if e.src != nil {
+ // Truncate src, no need to reallocate.
+ e.src = e.src[:0]
+ }
}
-func (e *StdNetEndpoint) SrcIP() net.IP {
- return nil // not supported
+func (e *StdNetEndpoint) DstIP() netip.Addr {
+ return e.AddrPort.Addr()
}
+// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
+
func (e *StdNetEndpoint) DstToBytes() []byte {
- addr := (*net.UDPAddr)(e)
- out := addr.IP.To4()
- if out == nil {
- out = addr.IP
- }
- out = append(out, byte(addr.Port&0xff))
- out = append(out, byte((addr.Port>>8)&0xff))
- return out
+ b, _ := e.AddrPort.MarshalBinary()
+ return b
}
func (e *StdNetEndpoint) DstToString() string {
- return (*net.UDPAddr)(e).String()
-}
-
-func (e *StdNetEndpoint) SrcToString() string {
- return ""
+ return e.AddrPort.String()
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
- conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
+ conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
if err != nil {
return nil, 0, err
}
@@ -80,17 +134,17 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
if err != nil {
return nil, 0, err
}
- return conn, uaddr.Port, nil
+ return conn.(*net.UDPConn), uaddr.Port, nil
}
-func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
- bind.mu.Lock()
- defer bind.mu.Unlock()
+func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
var err error
var tries int
- if bind.ipv4 != nil || bind.ipv6 != nil {
+ if s.ipv4 != nil || s.ipv6 != nil {
return nil, 0, ErrBindAlreadyOpen
}
@@ -98,92 +152,207 @@ func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
// If uport is 0, we can retry on failure.
again:
port := int(uport)
- var ipv4, ipv6 *net.UDPConn
+ var v4conn, v6conn *net.UDPConn
+ var v4pc *ipv4.PacketConn
+ var v6pc *ipv6.PacketConn
- ipv4, port, err = listenNet("udp4", port)
+ v4conn, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
// Listen on the same port as we're using for ipv4.
- ipv6, port, err = listenNet("udp6", port)
+ v6conn, port, err = listenNet("udp6", port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
- ipv4.Close()
+ v4conn.Close()
tries++
goto again
}
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
- ipv4.Close()
+ v4conn.Close()
return nil, 0, err
}
var fns []ReceiveFunc
- if ipv4 != nil {
- fns = append(fns, bind.makeReceiveIPv4(ipv4))
- bind.ipv4 = ipv4
+ if v4conn != nil {
+ s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ v4pc = ipv4.NewPacketConn(v4conn)
+ s.ipv4PC = v4pc
+ }
+ fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
+ s.ipv4 = v4conn
}
- if ipv6 != nil {
- fns = append(fns, bind.makeReceiveIPv6(ipv6))
- bind.ipv6 = ipv6
+ if v6conn != nil {
+ s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ v6pc = ipv6.NewPacketConn(v6conn)
+ s.ipv6PC = v6pc
+ }
+ fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
+ s.ipv6 = v6conn
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
+
return fns, uint16(port), nil
}
-func (bind *StdNetBind) Close() error {
- bind.mu.Lock()
- defer bind.mu.Unlock()
+func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
+ for i := range *msgs {
+ (*msgs)[i].OOB = (*msgs)[i].OOB[:0]
+ (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
+ }
+ s.msgsPool.Put(msgs)
+}
+
+func (s *StdNetBind) getMessages() *[]ipv6.Message {
+ return s.msgsPool.Get().(*[]ipv6.Message)
+}
+
+var (
+ // If compilation fails here these are no longer the same underlying type.
+ _ ipv6.Message = ipv4.Message{}
+)
+
+type batchReader interface {
+ ReadBatch([]ipv6.Message, int) (int, error)
+}
+
+type batchWriter interface {
+ WriteBatch([]ipv6.Message, int) (int, error)
+}
+
+func (s *StdNetBind) receiveIP(
+ br batchReader,
+ conn *net.UDPConn,
+ rxOffload bool,
+ bufs [][]byte,
+ sizes []int,
+ eps []Endpoint,
+) (n int, err error) {
+ msgs := s.getMessages()
+ for i := range bufs {
+ (*msgs)[i].Buffers[0] = bufs[i]
+ (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
+ }
+ defer s.putMessages(msgs)
+ var numMsgs int
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ if rxOffload {
+ readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
+ numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
+ if err != nil {
+ return 0, err
+ }
+ numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
+ if err != nil {
+ return 0, err
+ }
+ } else {
+ numMsgs, err = br.ReadBatch(*msgs, 0)
+ if err != nil {
+ return 0, err
+ }
+ }
+ } else {
+ msg := &(*msgs)[0]
+ msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
+ if err != nil {
+ return 0, err
+ }
+ numMsgs = 1
+ }
+ for i := 0; i < numMsgs; i++ {
+ msg := &(*msgs)[i]
+ sizes[i] = msg.N
+ if sizes[i] == 0 {
+ continue
+ }
+ addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+ ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
+ getSrcFromControl(msg.OOB[:msg.NN], ep)
+ eps[i] = ep
+ }
+ return numMsgs, nil
+}
+
+func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
+ return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
+ }
+}
+
+func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
+ return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
+ }
+}
+
+// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
+// rename the IdealBatchSize constant to BatchSize.
+func (s *StdNetBind) BatchSize() int {
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ return IdealBatchSize
+ }
+ return 1
+}
+
+func (s *StdNetBind) Close() error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
var err1, err2 error
- if bind.ipv4 != nil {
- err1 = bind.ipv4.Close()
- bind.ipv4 = nil
+ if s.ipv4 != nil {
+ err1 = s.ipv4.Close()
+ s.ipv4 = nil
+ s.ipv4PC = nil
}
- if bind.ipv6 != nil {
- err2 = bind.ipv6.Close()
- bind.ipv6 = nil
+ if s.ipv6 != nil {
+ err2 = s.ipv6.Close()
+ s.ipv6 = nil
+ s.ipv6PC = nil
}
- bind.blackhole4 = false
- bind.blackhole6 = false
+ s.blackhole4 = false
+ s.blackhole6 = false
+ s.ipv4TxOffload = false
+ s.ipv4RxOffload = false
+ s.ipv6TxOffload = false
+ s.ipv6RxOffload = false
if err1 != nil {
return err1
}
return err2
}
-func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
- return func(buff []byte) (int, Endpoint, error) {
- n, endpoint, err := conn.ReadFromUDP(buff)
- if endpoint != nil {
- endpoint.IP = endpoint.IP.To4()
- }
- return n, (*StdNetEndpoint)(endpoint), err
- }
+type ErrUDPGSODisabled struct {
+ onLaddr string
+ RetryErr error
}
-func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
- return func(buff []byte) (int, Endpoint, error) {
- n, endpoint, err := conn.ReadFromUDP(buff)
- return n, (*StdNetEndpoint)(endpoint), err
- }
+func (e ErrUDPGSODisabled) Error() string {
+ return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
}
-func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
- var err error
- nend, ok := endpoint.(*StdNetEndpoint)
- if !ok {
- return ErrWrongEndpointType
- }
+func (e ErrUDPGSODisabled) Unwrap() error {
+ return e.RetryErr
+}
- bind.mu.Lock()
- blackhole := bind.blackhole4
- conn := bind.ipv4
- if nend.IP.To4() == nil {
- blackhole = bind.blackhole6
- conn = bind.ipv6
+func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
+ s.mu.Lock()
+ blackhole := s.blackhole4
+ conn := s.ipv4
+ offload := s.ipv4TxOffload
+ br := batchWriter(s.ipv4PC)
+ is6 := false
+ if endpoint.DstIP().Is6() {
+ blackhole = s.blackhole6
+ conn = s.ipv6
+ br = s.ipv6PC
+ is6 = true
+ offload = s.ipv6TxOffload
}
- bind.mu.Unlock()
+ s.mu.Unlock()
if blackhole {
return nil
@@ -191,6 +360,185 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
if conn == nil {
return syscall.EAFNOSUPPORT
}
- _, err = conn.WriteToUDP(buff, (*net.UDPAddr)(nend))
+
+ msgs := s.getMessages()
+ defer s.putMessages(msgs)
+ ua := s.udpAddrPool.Get().(*net.UDPAddr)
+ defer s.udpAddrPool.Put(ua)
+ if is6 {
+ as16 := endpoint.DstIP().As16()
+ copy(ua.IP, as16[:])
+ ua.IP = ua.IP[:16]
+ } else {
+ as4 := endpoint.DstIP().As4()
+ copy(ua.IP, as4[:])
+ ua.IP = ua.IP[:4]
+ }
+ ua.Port = int(endpoint.(*StdNetEndpoint).Port())
+ var (
+ retried bool
+ err error
+ )
+retry:
+ if offload {
+ n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
+ err = s.send(conn, br, (*msgs)[:n])
+ if err != nil && offload && errShouldDisableUDPGSO(err) {
+ offload = false
+ s.mu.Lock()
+ if is6 {
+ s.ipv6TxOffload = false
+ } else {
+ s.ipv4TxOffload = false
+ }
+ s.mu.Unlock()
+ retried = true
+ goto retry
+ }
+ } else {
+ for i := range bufs {
+ (*msgs)[i].Addr = ua
+ (*msgs)[i].Buffers[0] = bufs[i]
+ setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
+ }
+ err = s.send(conn, br, (*msgs)[:len(bufs)])
+ }
+ if retried {
+ return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
+ }
+ return err
+}
+
+func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
+ var (
+ n int
+ err error
+ start int
+ )
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ for {
+ n, err = pc.WriteBatch(msgs[start:], 0)
+ if err != nil || n == len(msgs[start:]) {
+ break
+ }
+ start += n
+ }
+ } else {
+ for _, msg := range msgs {
+ _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
+ if err != nil {
+ break
+ }
+ }
+ }
return err
}
+
+const (
+ // Exceeding these values results in EMSGSIZE. They account for layer3 and
+ // layer4 headers. IPv6 does not need to account for itself as the payload
+ // length field is self excluding.
+ maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
+ maxIPv6PayloadLen = 1<<16 - 1 - 8
+
+ // This is a hard limit imposed by the kernel.
+ udpSegmentMaxDatagrams = 64
+)
+
+type setGSOFunc func(control *[]byte, gsoSize uint16)
+
+func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
+ var (
+ base = -1 // index of msg we are currently coalescing into
+ gsoSize int // segmentation size of msgs[base]
+ dgramCnt int // number of dgrams coalesced into msgs[base]
+ endBatch bool // tracking flag to start a new batch on next iteration of bufs
+ )
+ maxPayloadLen := maxIPv4PayloadLen
+ if ep.DstIP().Is6() {
+ maxPayloadLen = maxIPv6PayloadLen
+ }
+ for i, buf := range bufs {
+ if i > 0 {
+ msgLen := len(buf)
+ baseLenBefore := len(msgs[base].Buffers[0])
+ freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
+ if msgLen+baseLenBefore <= maxPayloadLen &&
+ msgLen <= gsoSize &&
+ msgLen <= freeBaseCap &&
+ dgramCnt < udpSegmentMaxDatagrams &&
+ !endBatch {
+ msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
+ if i == len(bufs)-1 {
+ setGSO(&msgs[base].OOB, uint16(gsoSize))
+ }
+ dgramCnt++
+ if msgLen < gsoSize {
+ // A smaller than gsoSize packet on the tail is legal, but
+ // it must end the batch.
+ endBatch = true
+ }
+ continue
+ }
+ }
+ if dgramCnt > 1 {
+ setGSO(&msgs[base].OOB, uint16(gsoSize))
+ }
+ // Reset prior to incrementing base since we are preparing to start a
+ // new potential batch.
+ endBatch = false
+ base++
+ gsoSize = len(buf)
+ setSrcControl(&msgs[base].OOB, ep)
+ msgs[base].Buffers[0] = buf
+ msgs[base].Addr = addr
+ dgramCnt = 1
+ }
+ return base + 1
+}
+
+type getGSOFunc func(control []byte) (int, error)
+
+func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
+ for i := firstMsgAt; i < len(msgs); i++ {
+ msg := &msgs[i]
+ if msg.N == 0 {
+ return n, err
+ }
+ var (
+ gsoSize int
+ start int
+ end = msg.N
+ numToSplit = 1
+ )
+ gsoSize, err = getGSO(msg.OOB[:msg.NN])
+ if err != nil {
+ return n, err
+ }
+ if gsoSize > 0 {
+ numToSplit = (msg.N + gsoSize - 1) / gsoSize
+ end = gsoSize
+ }
+ for j := 0; j < numToSplit; j++ {
+ if n > i {
+ return n, errors.New("splitting coalesced packet resulted in overflow")
+ }
+ copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
+ msgs[n].N = copied
+ msgs[n].Addr = msg.Addr
+ start = end
+ end += gsoSize
+ if end > msg.N {
+ end = msg.N
+ }
+ n++
+ }
+ if i != n-1 {
+ // It is legal for bytes to move within msg.Buffers[0] as a result
+ // of splitting, so we only zero the source msg len when it is not
+ // the destination of the last split operation above.
+ msg.N = 0
+ }
+ }
+ return n, nil
+}
diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go
new file mode 100644
index 0000000..34a3c9a
--- /dev/null
+++ b/conn/bind_std_test.go
@@ -0,0 +1,250 @@
+package conn
+
+import (
+ "encoding/binary"
+ "net"
+ "testing"
+
+ "golang.org/x/net/ipv6"
+)
+
+func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
+ bind := NewStdNetBind().(*StdNetBind)
+ fns, _, err := bind.Open(0)
+ if err != nil {
+ t.Fatal(err)
+ }
+ bind.Close()
+ bufs := make([][]byte, 1)
+ bufs[0] = make([]byte, 1)
+ sizes := make([]int, 1)
+ eps := make([]Endpoint, 1)
+ for _, fn := range fns {
+ // The ReceiveFuncs must not access conn-related fields on StdNetBind
+ // unguarded. Close() nils the conn-related fields resulting in a panic
+ // if they violate the mutex.
+ fn(bufs, sizes, eps)
+ }
+}
+
+func mockSetGSOSize(control *[]byte, gsoSize uint16) {
+ *control = (*control)[:cap(*control)]
+ binary.LittleEndian.PutUint16(*control, gsoSize)
+}
+
+func Test_coalesceMessages(t *testing.T) {
+ cases := []struct {
+ name string
+ buffs [][]byte
+ wantLens []int
+ wantGSO []int
+ }{
+ {
+ name: "one message no coalesce",
+ buffs: [][]byte{
+ make([]byte, 1, 1),
+ },
+ wantLens: []int{1},
+ wantGSO: []int{0},
+ },
+ {
+ name: "two messages equal len coalesce",
+ buffs: [][]byte{
+ make([]byte, 1, 2),
+ make([]byte, 1, 1),
+ },
+ wantLens: []int{2},
+ wantGSO: []int{1},
+ },
+ {
+ name: "two messages unequal len coalesce",
+ buffs: [][]byte{
+ make([]byte, 2, 3),
+ make([]byte, 1, 1),
+ },
+ wantLens: []int{3},
+ wantGSO: []int{2},
+ },
+ {
+ name: "three messages second unequal len coalesce",
+ buffs: [][]byte{
+ make([]byte, 2, 3),
+ make([]byte, 1, 1),
+ make([]byte, 2, 2),
+ },
+ wantLens: []int{3, 2},
+ wantGSO: []int{2, 0},
+ },
+ {
+ name: "three messages limited cap coalesce",
+ buffs: [][]byte{
+ make([]byte, 2, 4),
+ make([]byte, 2, 2),
+ make([]byte, 2, 2),
+ },
+ wantLens: []int{4, 2},
+ wantGSO: []int{2, 0},
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ addr := &net.UDPAddr{
+ IP: net.ParseIP("127.0.0.1").To4(),
+ Port: 1,
+ }
+ msgs := make([]ipv6.Message, len(tt.buffs))
+ for i := range msgs {
+ msgs[i].Buffers = make([][]byte, 1)
+ msgs[i].OOB = make([]byte, 0, 2)
+ }
+ got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize)
+ if got != len(tt.wantLens) {
+ t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
+ }
+ for i := 0; i < got; i++ {
+ if msgs[i].Addr != addr {
+ t.Errorf("msgs[%d].Addr != passed addr", i)
+ }
+ gotLen := len(msgs[i].Buffers[0])
+ if gotLen != tt.wantLens[i] {
+ t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
+ }
+ gotGSO, err := mockGetGSOSize(msgs[i].OOB)
+ if err != nil {
+ t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
+ }
+ if gotGSO != tt.wantGSO[i] {
+ t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
+ }
+ }
+ })
+ }
+}
+
+func mockGetGSOSize(control []byte) (int, error) {
+ if len(control) < 2 {
+ return 0, nil
+ }
+ return int(binary.LittleEndian.Uint16(control)), nil
+}
+
+func Test_splitCoalescedMessages(t *testing.T) {
+ newMsg := func(n, gso int) ipv6.Message {
+ msg := ipv6.Message{
+ Buffers: [][]byte{make([]byte, 1<<16-1)},
+ N: n,
+ OOB: make([]byte, 2),
+ }
+ binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
+ if gso > 0 {
+ msg.NN = 2
+ }
+ return msg
+ }
+
+ cases := []struct {
+ name string
+ msgs []ipv6.Message
+ firstMsgAt int
+ wantNumEval int
+ wantMsgLens []int
+ wantErr bool
+ }{
+ {
+ name: "second last split last empty",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(3, 1),
+ newMsg(0, 0),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 3,
+ wantMsgLens: []int{1, 1, 1, 0},
+ wantErr: false,
+ },
+ {
+ name: "second last no split last empty",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(1, 0),
+ newMsg(0, 0),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 1,
+ wantMsgLens: []int{1, 0, 0, 0},
+ wantErr: false,
+ },
+ {
+ name: "second last no split last no split",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(1, 0),
+ newMsg(1, 0),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 2,
+ wantMsgLens: []int{1, 1, 0, 0},
+ wantErr: false,
+ },
+ {
+ name: "second last no split last split",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(1, 0),
+ newMsg(3, 1),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 4,
+ wantMsgLens: []int{1, 1, 1, 1},
+ wantErr: false,
+ },
+ {
+ name: "second last split last split",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(2, 1),
+ newMsg(2, 1),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 4,
+ wantMsgLens: []int{1, 1, 1, 1},
+ wantErr: false,
+ },
+ {
+ name: "second last no split last split overflow",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(1, 0),
+ newMsg(4, 1),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 4,
+ wantMsgLens: []int{1, 1, 1, 1},
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize)
+ if err != nil && !tt.wantErr {
+ t.Fatalf("err: %v", err)
+ }
+ if got != tt.wantNumEval {
+ t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
+ }
+ for i, msg := range tt.msgs {
+ if msg.N != tt.wantMsgLens[i] {
+ t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
+ }
+ }
+ })
+ }
+}
diff --git a/conn/bind_windows.go b/conn/bind_windows.go
index d744987..d5095e0 100644
--- a/conn/bind_windows.go
+++ b/conn/bind_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
@@ -9,6 +9,7 @@ import (
"encoding/binary"
"io"
"net"
+ "net/netip"
"strconv"
"sync"
"sync/atomic"
@@ -73,7 +74,7 @@ type afWinRingBind struct {
type WinRingBind struct {
v4, v6 afWinRingBind
mu sync.RWMutex
- isOpen uint32
+ isOpen atomic.Uint32 // 0, 1, or 2
}
func NewDefaultBind() Bind { return NewWinRingBind() }
@@ -90,8 +91,10 @@ type WinRingEndpoint struct {
data [30]byte
}
-var _ Bind = (*WinRingBind)(nil)
-var _ Endpoint = (*WinRingEndpoint)(nil)
+var (
+ _ Bind = (*WinRingBind)(nil)
+ _ Endpoint = (*WinRingEndpoint)(nil)
+)
func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
host, port, err := net.SplitHostPort(s)
@@ -121,27 +124,25 @@ func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
return nil, windows.ERROR_INVALID_ADDRESS
}
- var src []byte
var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
- unsafeSlice(unsafe.Pointer(&src), unsafe.Pointer(addrinfo.Addr), int(addrinfo.Addrlen))
- copy(dst[:], src)
+ copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen))
return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
}
func (*WinRingEndpoint) ClearSrc() {}
-func (e *WinRingEndpoint) DstIP() net.IP {
+func (e *WinRingEndpoint) DstIP() netip.Addr {
switch e.family {
case windows.AF_INET:
- return append([]byte{}, e.data[2:6]...)
+ return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
case windows.AF_INET6:
- return append([]byte{}, e.data[6:22]...)
+ return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
}
- return nil
+ return netip.Addr{}
}
-func (e *WinRingEndpoint) SrcIP() net.IP {
- return nil // not supported
+func (e *WinRingEndpoint) SrcIP() netip.Addr {
+ return netip.Addr{} // not supported
}
func (e *WinRingEndpoint) DstToBytes() []byte {
@@ -163,15 +164,13 @@ func (e *WinRingEndpoint) DstToBytes() []byte {
func (e *WinRingEndpoint) DstToString() string {
switch e.family {
case windows.AF_INET:
- addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
- return addr.String()
+ return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
case windows.AF_INET6:
var zone string
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
zone = strconv.FormatUint(uint64(scope), 10)
}
- addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
- return addr.String()
+ return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
}
return ""
}
@@ -213,7 +212,7 @@ func (bind *afWinRingBind) CloseAndZero() {
}
func (bind *WinRingBind) closeAndZero() {
- atomic.StoreUint32(&bind.isOpen, 0)
+ bind.isOpen.Store(0)
bind.v4.CloseAndZero()
bind.v6.CloseAndZero()
}
@@ -277,7 +276,7 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
bind.closeAndZero()
}
}()
- if atomic.LoadUint32(&bind.isOpen) != 0 {
+ if bind.isOpen.Load() != 0 {
return nil, 0, ErrBindAlreadyOpen
}
var sa windows.Sockaddr
@@ -300,17 +299,17 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
return nil, 0, err
}
}
- atomic.StoreUint32(&bind.isOpen, 1)
+ bind.isOpen.Store(1)
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
}
func (bind *WinRingBind) Close() error {
bind.mu.RLock()
- if atomic.LoadUint32(&bind.isOpen) != 1 {
+ if bind.isOpen.Load() != 1 {
bind.mu.RUnlock()
return nil
}
- atomic.StoreUint32(&bind.isOpen, 2)
+ bind.isOpen.Store(2)
windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
@@ -322,6 +321,13 @@ func (bind *WinRingBind) Close() error {
return nil
}
+// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
+// rename the IdealBatchSize constant to BatchSize.
+func (bind *WinRingBind) BatchSize() int {
+ // TODO: implement batching in and out of the ring
+ return 1
+}
+
func (bind *WinRingBind) SetMark(mark uint32) error {
return nil
}
@@ -346,8 +352,8 @@ func (bind *afWinRingBind) InsertReceiveRequest() error {
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
-func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) {
- if atomic.LoadUint32(isOpen) != 1 {
+func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
+ if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
bind.rx.mu.Lock()
@@ -360,7 +366,7 @@ retry:
count = 0
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
if tries > 0 {
- if atomic.LoadUint32(isOpen) != 1 {
+ if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
procyield(1)
@@ -379,13 +385,12 @@ retry:
if err != nil {
return 0, nil, err
}
- if atomic.LoadUint32(isOpen) != 1 {
+ if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
if count == 0 {
return 0, nil, io.ErrNoProgress
-
}
}
bind.rx.Return(1)
@@ -397,7 +402,7 @@ retry:
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
// attacker bandwidth, just like the rest of the receive path.
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
- if atomic.LoadUint32(isOpen) != 1 {
+ if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
goto retry
@@ -411,20 +416,26 @@ retry:
return n, &ep, nil
}
-func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
+func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
- return bind.v4.Receive(buf, &bind.isOpen)
+ n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen)
+ sizes[0] = n
+ eps[0] = ep
+ return 1, err
}
-func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
+func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
- return bind.v6.Receive(buf, &bind.isOpen)
+ n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen)
+ sizes[0] = n
+ eps[0] = ep
+ return 1, err
}
-func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error {
- if atomic.LoadUint32(isOpen) != 1 {
+func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
+ if isOpen.Load() != 1 {
return net.ErrClosed
}
if len(buf) > bytesPerPacket {
@@ -446,7 +457,7 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint3
if err != nil {
return err
}
- if atomic.LoadUint32(isOpen) != 1 {
+ if isOpen.Load() != 1 {
return net.ErrClosed
}
count = winrio.DequeueCompletion(bind.tx.cq, results[:])
@@ -475,32 +486,38 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint3
return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
}
-func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error {
+func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error {
nend, ok := endpoint.(*WinRingEndpoint)
if !ok {
return ErrWrongEndpointType
}
bind.mu.RLock()
defer bind.mu.RUnlock()
- switch nend.family {
- case windows.AF_INET:
- if bind.v4.blackhole {
- return nil
- }
- return bind.v4.Send(buf, nend, &bind.isOpen)
- case windows.AF_INET6:
- if bind.v6.blackhole {
- return nil
+ for _, buf := range bufs {
+ switch nend.family {
+ case windows.AF_INET:
+ if bind.v4.blackhole {
+ continue
+ }
+ if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
+ return err
+ }
+ case windows.AF_INET6:
+ if bind.v6.blackhole {
+ continue
+ }
+ if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
+ return err
+ }
}
- return bind.v6.Send(buf, nend, &bind.isOpen)
}
return nil
}
-func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
- bind.mu.Lock()
- defer bind.mu.Unlock()
- sysconn, err := bind.ipv4.SyscallConn()
+func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ sysconn, err := s.ipv4.SyscallConn()
if err != nil {
return err
}
@@ -513,14 +530,14 @@ func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
if err != nil {
return err
}
- bind.blackhole4 = blackhole
+ s.blackhole4 = blackhole
return nil
}
-func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
- bind.mu.Lock()
- defer bind.mu.Unlock()
- sysconn, err := bind.ipv6.SyscallConn()
+func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ sysconn, err := s.ipv6.SyscallConn()
if err != nil {
return err
}
@@ -533,13 +550,14 @@ func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole
if err != nil {
return err
}
- bind.blackhole6 = blackhole
+ s.blackhole6 = blackhole
return nil
}
+
func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
bind.mu.RLock()
defer bind.mu.RUnlock()
- if atomic.LoadUint32(&bind.isOpen) != 1 {
+ if bind.isOpen.Load() != 1 {
return net.ErrClosed
}
err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
@@ -553,7 +571,7 @@ func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
bind.mu.RLock()
defer bind.mu.RUnlock()
- if atomic.LoadUint32(&bind.isOpen) != 1 {
+ if bind.isOpen.Load() != 1 {
return net.ErrClosed
}
err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
@@ -581,21 +599,3 @@ func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error
const IPV6_UNICAST_IF = 31
return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
}
-
-// unsafeSlice updates the slice slicePtr to be a slice
-// referencing the provided data with its length & capacity set to
-// lenCap.
-//
-// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
-// update callers to use unsafe.Slice instead of this.
-func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
- type sliceHeader struct {
- Data unsafe.Pointer
- Len int
- Cap int
- }
- h := (*sliceHeader)(slicePtr)
- h.Data = data
- h.Len = lenCap
- h.Cap = lenCap
-}
diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go
index 7d43fb3..74e7add 100644
--- a/conn/bindtest/bindtest.go
+++ b/conn/bindtest/bindtest.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bindtest
@@ -9,8 +9,8 @@ import (
"fmt"
"math/rand"
"net"
+ "net/netip"
"os"
- "strconv"
"golang.zx2c4.com/wireguard/conn"
)
@@ -25,8 +25,10 @@ type ChannelBind struct {
type ChannelEndpoint uint16
-var _ conn.Bind = (*ChannelBind)(nil)
-var _ conn.Endpoint = (*ChannelEndpoint)(nil)
+var (
+ _ conn.Bind = (*ChannelBind)(nil)
+ _ conn.Endpoint = (*ChannelEndpoint)(nil)
+)
func NewChannelBinds() [2]conn.Bind {
arx4 := make(chan []byte, 8192)
@@ -61,9 +63,9 @@ func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d
func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
-func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
+func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
-func (c ChannelEndpoint) SrcIP() net.IP { return nil }
+func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
c.closeSignal = make(chan bool)
@@ -87,45 +89,48 @@ func (c *ChannelBind) Close() error {
return nil
}
+func (c *ChannelBind) BatchSize() int { return 1 }
+
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
- return func(b []byte) (n int, ep conn.Endpoint, err error) {
+ return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
select {
case <-c.closeSignal:
- return 0, nil, net.ErrClosed
+ return 0, net.ErrClosed
case rx := <-ch:
- return copy(b, rx), c.target6, nil
+ copied := copy(bufs[0], rx)
+ sizes[0] = copied
+ eps[0] = c.target6
+ return 1, nil
}
}
}
-func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
- select {
- case <-c.closeSignal:
- return net.ErrClosed
- default:
- bc := make([]byte, len(b))
- copy(bc, b)
- if ep.(ChannelEndpoint) == c.target4 {
- *c.tx4 <- bc
- } else if ep.(ChannelEndpoint) == c.target6 {
- *c.tx6 <- bc
- } else {
- return os.ErrInvalid
+func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error {
+ for _, b := range bufs {
+ select {
+ case <-c.closeSignal:
+ return net.ErrClosed
+ default:
+ bc := make([]byte, len(b))
+ copy(bc, b)
+ if ep.(ChannelEndpoint) == c.target4 {
+ *c.tx4 <- bc
+ } else if ep.(ChannelEndpoint) == c.target6 {
+ *c.tx6 <- bc
+ } else {
+ return os.ErrInvalid
+ }
}
}
return nil
}
func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
- _, port, err := net.SplitHostPort(s)
- if err != nil {
- return nil, err
- }
- i, err := strconv.ParseUint(port, 10, 16)
+ addr, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
- return ChannelEndpoint(i), nil
+ return ChannelEndpoint(addr.Port()), nil
}
diff --git a/conn/boundif_android.go b/conn/boundif_android.go
index 8b82bfc..dd3ca5b 100644
--- a/conn/boundif_android.go
+++ b/conn/boundif_android.go
@@ -1,12 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
-func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
- sysconn, err := bind.ipv4.SyscallConn()
+func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
+ sysconn, err := s.ipv4.SyscallConn()
if err != nil {
return -1, err
}
@@ -19,8 +19,8 @@ func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
return
}
-func (bind *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
- sysconn, err := bind.ipv6.SyscallConn()
+func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
+ sysconn, err := s.ipv6.SyscallConn()
if err != nil {
return -1, err
}
diff --git a/conn/conn.go b/conn/conn.go
index 9cce9ad..a1f57d2 100644
--- a/conn/conn.go
+++ b/conn/conn.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
// Package conn implements WireGuard's network connections.
@@ -9,16 +9,23 @@ package conn
import (
"errors"
"fmt"
- "net"
+ "net/netip"
"reflect"
"runtime"
"strings"
)
-// A ReceiveFunc receives a single inbound packet from the network.
-// It writes the data into b. n is the length of the packet.
-// ep is the remote endpoint.
-type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error)
+const (
+ IdealBatchSize = 128 // maximum number of packets handled per read and write
+)
+
+// A ReceiveFunc receives at least one packet from the network and writes them
+// into packets. On a successful read it returns the number of elements of
+// sizes, packets, and endpoints that should be evaluated. Some elements of
+// sizes may be zero, and callers should ignore them. Callers must pass a sizes
+// and eps slice with a length greater than or equal to the length of packets.
+// These lengths must not exceed the length of the associated Bind.BatchSize().
+type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
//
@@ -38,11 +45,16 @@ type Bind interface {
// This mark is passed to the kernel as the socket option SO_MARK.
SetMark(mark uint32) error
- // Send writes a packet b to address ep.
- Send(b []byte, ep Endpoint) error
+ // Send writes one or more packets in bufs to address ep. The length of
+ // bufs must not exceed BatchSize().
+ Send(bufs [][]byte, ep Endpoint) error
// ParseEndpoint creates a new endpoint from a string.
ParseEndpoint(s string) (Endpoint, error)
+
+ // BatchSize is the number of buffers expected to be passed to
+ // the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
+ BatchSize() int
}
// BindSocketToInterface is implemented by Bind objects that support being
@@ -68,8 +80,8 @@ type Endpoint interface {
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
- DstIP() net.IP
- SrcIP() net.IP
+ DstIP() netip.Addr
+ SrcIP() netip.Addr
}
var (
@@ -119,33 +131,3 @@ func (fn ReceiveFunc) PrettyName() string {
}
return name
}
-
-func parseEndpoint(s string) (*net.UDPAddr, error) {
- // ensure that the host is an IP address
-
- host, _, err := net.SplitHostPort(s)
- if err != nil {
- return nil, err
- }
- if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
- // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
- // trying to make sure with a small sanity test that this is a real IP address and
- // not something that's likely to incur DNS lookups.
- host = host[:i]
- }
- if ip := net.ParseIP(host); ip == nil {
- return nil, errors.New("Failed to parse IP address: " + host)
- }
-
- // parse address and port
-
- addr, err := net.ResolveUDPAddr("udp", s)
- if err != nil {
- return nil, err
- }
- ip4 := addr.IP.To4()
- if ip4 != nil {
- addr.IP = ip4
- }
- return addr, err
-}
diff --git a/conn/conn_test.go b/conn/conn_test.go
new file mode 100644
index 0000000..c6194ee
--- /dev/null
+++ b/conn/conn_test.go
@@ -0,0 +1,24 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "testing"
+)
+
+func TestPrettyName(t *testing.T) {
+ var (
+ recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return }
+ )
+
+ const want = "TestPrettyName"
+
+ t.Run("ReceiveFunc.PrettyName", func(t *testing.T) {
+ if got := recvFunc.PrettyName(); got != want {
+ t.Errorf("PrettyName() = %v, want %v", got, want)
+ }
+ })
+}
diff --git a/conn/controlfns.go b/conn/controlfns.go
new file mode 100644
index 0000000..4f7d90f
--- /dev/null
+++ b/conn/controlfns.go
@@ -0,0 +1,43 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "net"
+ "syscall"
+)
+
+// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
+// the max supported by a default configuration of macOS. Some platforms will
+// silently clamp the value to other maximums, such as linux clamping to
+// net.core.{r,w}mem_max (see _linux.go for additional implementation that works
+// around this limitation)
+const socketBufferSize = 7 << 20
+
+// controlFn is the callback function signature from net.ListenConfig.Control.
+// It is used to apply platform specific configuration to the socket prior to
+// bind.
+type controlFn func(network, address string, c syscall.RawConn) error
+
+// controlFns is a list of functions that are called from the listen config
+// that can apply socket options.
+var controlFns = []controlFn{}
+
+// listenConfig returns a net.ListenConfig that applies the controlFns to the
+// socket prior to bind. This is used to apply socket buffer sizing and packet
+// information OOB configuration for sticky sockets.
+func listenConfig() *net.ListenConfig {
+ return &net.ListenConfig{
+ Control: func(network, address string, c syscall.RawConn) error {
+ for _, fn := range controlFns {
+ if err := fn(network, address, c); err != nil {
+ return err
+ }
+ }
+ return nil
+ },
+ }
+}
diff --git a/conn/controlfns_linux.go b/conn/controlfns_linux.go
new file mode 100644
index 0000000..f6ab1d2
--- /dev/null
+++ b/conn/controlfns_linux.go
@@ -0,0 +1,69 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "fmt"
+ "runtime"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+func init() {
+ controlFns = append(controlFns,
+
+ // Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by
+ // using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to
+ // fail silently - the result of failure is lower performance on very fast
+ // links or high latency links.
+ func(network, address string, c syscall.RawConn) error {
+ return c.Control(func(fd uintptr) {
+ // Set up to *mem_max
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
+ // Set beyond *mem_max if CAP_NET_ADMIN
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
+ })
+ },
+
+ // Enable receiving of the packet information (IP_PKTINFO for IPv4,
+ // IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
+ func(network, address string, c syscall.RawConn) error {
+ var err error
+ switch network {
+ case "udp4":
+ if runtime.GOOS != "android" {
+ c.Control(func(fd uintptr) {
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
+ })
+ }
+ case "udp6":
+ c.Control(func(fd uintptr) {
+ if runtime.GOOS != "android" {
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
+ if err != nil {
+ return
+ }
+ }
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
+ })
+ default:
+ err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
+ }
+ return err
+ },
+
+ // Attempt to enable UDP_GRO
+ func(network, address string, c syscall.RawConn) error {
+ c.Control(func(fd uintptr) {
+ _ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
+ })
+ return nil
+ },
+ )
+}
diff --git a/conn/controlfns_unix.go b/conn/controlfns_unix.go
new file mode 100644
index 0000000..91692c0
--- /dev/null
+++ b/conn/controlfns_unix.go
@@ -0,0 +1,35 @@
+//go:build !windows && !linux && !wasm
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+func init() {
+ controlFns = append(controlFns,
+ func(network, address string, c syscall.RawConn) error {
+ return c.Control(func(fd uintptr) {
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
+ })
+ },
+
+ func(network, address string, c syscall.RawConn) error {
+ var err error
+ if network == "udp6" {
+ c.Control(func(fd uintptr) {
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
+ })
+ }
+ return err
+ },
+ )
+}
diff --git a/conn/controlfns_windows.go b/conn/controlfns_windows.go
new file mode 100644
index 0000000..c3bdf7d
--- /dev/null
+++ b/conn/controlfns_windows.go
@@ -0,0 +1,23 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/windows"
+)
+
+func init() {
+ controlFns = append(controlFns,
+ func(network, address string, c syscall.RawConn) error {
+ return c.Control(func(fd uintptr) {
+ _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF, socketBufferSize)
+ _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF, socketBufferSize)
+ })
+ },
+ )
+}
diff --git a/conn/default.go b/conn/default.go
index 161454a..b6f761b 100644
--- a/conn/default.go
+++ b/conn/default.go
@@ -1,8 +1,8 @@
-// +build !linux,!windows
+//go:build !windows
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
diff --git a/conn/errors_default.go b/conn/errors_default.go
new file mode 100644
index 0000000..f1e5b90
--- /dev/null
+++ b/conn/errors_default.go
@@ -0,0 +1,12 @@
+//go:build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+func errShouldDisableUDPGSO(err error) bool {
+ return false
+}
diff --git a/conn/errors_linux.go b/conn/errors_linux.go
new file mode 100644
index 0000000..8e61000
--- /dev/null
+++ b/conn/errors_linux.go
@@ -0,0 +1,26 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "errors"
+ "os"
+
+ "golang.org/x/sys/unix"
+)
+
+func errShouldDisableUDPGSO(err error) bool {
+ var serr *os.SyscallError
+ if errors.As(err, &serr) {
+ // EIO is returned by udp_send_skb() if the device driver does not have
+ // tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
+ // See:
+ // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
+ // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
+ return serr.Err == unix.EIO
+ }
+ return false
+}
diff --git a/conn/features_default.go b/conn/features_default.go
new file mode 100644
index 0000000..d53ff5f
--- /dev/null
+++ b/conn/features_default.go
@@ -0,0 +1,15 @@
+//go:build !linux
+// +build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import "net"
+
+func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
+ return
+}
diff --git a/conn/features_linux.go b/conn/features_linux.go
new file mode 100644
index 0000000..8959d93
--- /dev/null
+++ b/conn/features_linux.go
@@ -0,0 +1,29 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "net"
+
+ "golang.org/x/sys/unix"
+)
+
+func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
+ rc, err := conn.SyscallConn()
+ if err != nil {
+ return
+ }
+ err = rc.Control(func(fd uintptr) {
+ _, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
+ txOffload = errSyscall == nil
+ opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO)
+ rxOffload = errSyscall == nil && opt == 1
+ })
+ if err != nil {
+ return false, false
+ }
+ return txOffload, rxOffload
+}
diff --git a/conn/gso_default.go b/conn/gso_default.go
new file mode 100644
index 0000000..57780db
--- /dev/null
+++ b/conn/gso_default.go
@@ -0,0 +1,21 @@
+//go:build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
+func getGSOSize(control []byte) (int, error) {
+ return 0, nil
+}
+
+// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
+func setGSOSize(control *[]byte, gsoSize uint16) {
+}
+
+// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
+// offloading control data.
+const gsoControlSize = 0
diff --git a/conn/gso_linux.go b/conn/gso_linux.go
new file mode 100644
index 0000000..8596b29
--- /dev/null
+++ b/conn/gso_linux.go
@@ -0,0 +1,65 @@
+//go:build linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "fmt"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+const (
+ sizeOfGSOData = 2
+)
+
+// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
+func getGSOSize(control []byte) (int, error) {
+ var (
+ hdr unix.Cmsghdr
+ data []byte
+ rem = control
+ err error
+ )
+
+ for len(rem) > unix.SizeofCmsghdr {
+ hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
+ if err != nil {
+ return 0, fmt.Errorf("error parsing socket control message: %w", err)
+ }
+ if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
+ var gso uint16
+ copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
+ return int(gso), nil
+ }
+ }
+ return 0, nil
+}
+
+// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
+// data in control untouched.
+func setGSOSize(control *[]byte, gsoSize uint16) {
+ existingLen := len(*control)
+ avail := cap(*control) - existingLen
+ space := unix.CmsgSpace(sizeOfGSOData)
+ if avail < space {
+ return
+ }
+ *control = (*control)[:cap(*control)]
+ gsoControl := (*control)[existingLen:]
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
+ hdr.Level = unix.SOL_UDP
+ hdr.Type = unix.UDP_SEGMENT
+ hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
+ copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
+ *control = (*control)[:existingLen+space]
+}
+
+// gsoControlSize returns the recommended buffer size for pooling UDP
+// offloading control data.
+var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)
diff --git a/conn/mark_default.go b/conn/mark_default.go
index c315f4b..3102384 100644
--- a/conn/mark_default.go
+++ b/conn/mark_default.go
@@ -1,12 +1,12 @@
-// +build !linux,!openbsd,!freebsd
+//go:build !linux && !openbsd && !freebsd
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
-func (bind *StdNetBind) SetMark(mark uint32) error {
+func (s *StdNetBind) SetMark(mark uint32) error {
return nil
}
diff --git a/conn/mark_unix.go b/conn/mark_unix.go
index 18eb581..d9e46ee 100644
--- a/conn/mark_unix.go
+++ b/conn/mark_unix.go
@@ -1,8 +1,8 @@
-// +build linux openbsd freebsd
+//go:build linux || openbsd || freebsd
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
@@ -26,13 +26,13 @@ func init() {
}
}
-func (bind *StdNetBind) SetMark(mark uint32) error {
+func (s *StdNetBind) SetMark(mark uint32) error {
var operr error
if fwmarkIoctl == 0 {
return nil
}
- if bind.ipv4 != nil {
- fd, err := bind.ipv4.SyscallConn()
+ if s.ipv4 != nil {
+ fd, err := s.ipv4.SyscallConn()
if err != nil {
return err
}
@@ -46,8 +46,8 @@ func (bind *StdNetBind) SetMark(mark uint32) error {
return err
}
}
- if bind.ipv6 != nil {
- fd, err := bind.ipv6.SyscallConn()
+ if s.ipv6 != nil {
+ fd, err := s.ipv6.SyscallConn()
if err != nil {
return err
}
diff --git a/conn/sticky_default.go b/conn/sticky_default.go
new file mode 100644
index 0000000..0b21386
--- /dev/null
+++ b/conn/sticky_default.go
@@ -0,0 +1,42 @@
+//go:build !linux || android
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import "net/netip"
+
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+ return netip.Addr{}
+}
+
+func (e *StdNetEndpoint) SrcIfidx() int32 {
+ return 0
+}
+
+func (e *StdNetEndpoint) SrcToString() string {
+ return ""
+}
+
+// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
+// {get,set}srcControl feature set, but use alternatively named flags and need
+// ports and require testing.
+
+// getSrcFromControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
+}
+
+// setSrcControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
+}
+
+// stickyControlSize returns the recommended buffer size for pooling sticky
+// offloading control data.
+const stickyControlSize = 0
+
+const StdNetSupportsStickySockets = false
diff --git a/conn/sticky_linux.go b/conn/sticky_linux.go
new file mode 100644
index 0000000..8e206e9
--- /dev/null
+++ b/conn/sticky_linux.go
@@ -0,0 +1,112 @@
+//go:build linux && !android
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "net/netip"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+ switch len(e.src) {
+ case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+ return netip.AddrFrom4(info.Spec_dst)
+ case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+ // TODO: set zone. in order to do so we need to check if the address is
+ // link local, and if it is perform a syscall to turn the ifindex into a
+ // zone string because netip uses string zones.
+ return netip.AddrFrom16(info.Addr)
+ }
+ return netip.Addr{}
+}
+
+func (e *StdNetEndpoint) SrcIfidx() int32 {
+ switch len(e.src) {
+ case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+ return info.Ifindex
+ case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+ return int32(info.Ifindex)
+ }
+ return 0
+}
+
+func (e *StdNetEndpoint) SrcToString() string {
+ return e.SrcIP().String()
+}
+
+// getSrcFromControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
+ ep.ClearSrc()
+
+ var (
+ hdr unix.Cmsghdr
+ data []byte
+ rem []byte = control
+ err error
+ )
+
+ for len(rem) > unix.SizeofCmsghdr {
+ hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
+ if err != nil {
+ return
+ }
+
+ if hdr.Level == unix.IPPROTO_IP &&
+ hdr.Type == unix.IP_PKTINFO {
+
+ if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
+ ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+ }
+ ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
+
+ hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
+ copy(ep.src, hdrBuf)
+ copy(ep.src[unix.CmsgLen(0):], data)
+ return
+ }
+
+ if hdr.Level == unix.IPPROTO_IPV6 &&
+ hdr.Type == unix.IPV6_PKTINFO {
+
+ if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
+ ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
+ }
+
+ ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
+
+ hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
+ copy(ep.src, hdrBuf)
+ copy(ep.src[unix.CmsgLen(0):], data)
+ return
+ }
+ }
+}
+
+// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
+// and source ifindex found in ep. control's len will be set to 0 in the event
+// that ep is a default value.
+func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
+ if cap(*control) < len(ep.src) {
+ return
+ }
+ *control = (*control)[:0]
+ *control = append(*control, ep.src...)
+}
+
+// stickyControlSize returns the recommended buffer size for pooling sticky
+// offloading control data.
+var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
+
+const StdNetSupportsStickySockets = true
diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go
new file mode 100644
index 0000000..d2bd584
--- /dev/null
+++ b/conn/sticky_linux_test.go
@@ -0,0 +1,266 @@
+//go:build linux && !android
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "context"
+ "net"
+ "net/netip"
+ "runtime"
+ "testing"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) {
+ var buf []byte
+ if addr.Is4() {
+ buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+ hdr := unix.Cmsghdr{
+ Level: unix.IPPROTO_IP,
+ Type: unix.IP_PKTINFO,
+ }
+ hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
+ copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
+
+ info := unix.Inet4Pktinfo{
+ Ifindex: ifidx,
+ Spec_dst: addr.As4(),
+ }
+ copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo))
+ } else {
+ buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
+ hdr := unix.Cmsghdr{
+ Level: unix.IPPROTO_IPV6,
+ Type: unix.IPV6_PKTINFO,
+ }
+ hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
+ copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
+
+ info := unix.Inet6Pktinfo{
+ Ifindex: uint32(ifidx),
+ Addr: addr.As16(),
+ }
+ copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo))
+ }
+
+ ep.src = buf
+}
+
+func Test_setSrcControl(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ ep := &StdNetEndpoint{
+ AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
+ }
+ setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
+
+ control := make([]byte, stickyControlSize)
+
+ setSrcControl(&control, ep)
+
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ if hdr.Level != unix.IPPROTO_IP {
+ t.Errorf("unexpected level: %d", hdr.Level)
+ }
+ if hdr.Type != unix.IP_PKTINFO {
+ t.Errorf("unexpected type: %d", hdr.Type)
+ }
+ if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
+ t.Errorf("unexpected length: %d", hdr.Len)
+ }
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
+ t.Errorf("unexpected address: %v", info.Spec_dst)
+ }
+ if info.Ifindex != 5 {
+ t.Errorf("unexpected ifindex: %d", info.Ifindex)
+ }
+ })
+
+ t.Run("IPv6", func(t *testing.T) {
+ ep := &StdNetEndpoint{
+ AddrPort: netip.MustParseAddrPort("[::1]:1234"),
+ }
+ setSrc(ep, netip.MustParseAddr("::1"), 5)
+
+ control := make([]byte, stickyControlSize)
+
+ setSrcControl(&control, ep)
+
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ if hdr.Level != unix.IPPROTO_IPV6 {
+ t.Errorf("unexpected level: %d", hdr.Level)
+ }
+ if hdr.Type != unix.IPV6_PKTINFO {
+ t.Errorf("unexpected type: %d", hdr.Type)
+ }
+ if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
+ t.Errorf("unexpected length: %d", hdr.Len)
+ }
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ if info.Addr != ep.SrcIP().As16() {
+ t.Errorf("unexpected address: %v", info.Addr)
+ }
+ if info.Ifindex != 5 {
+ t.Errorf("unexpected ifindex: %d", info.Ifindex)
+ }
+ })
+
+ t.Run("ClearOnNoSrc", func(t *testing.T) {
+ control := make([]byte, stickyControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = 1
+ hdr.Type = 2
+ hdr.Len = 3
+
+ setSrcControl(&control, &StdNetEndpoint{})
+
+ if len(control) != 0 {
+ t.Errorf("unexpected control: %v", control)
+ }
+ })
+}
+
+func Test_getSrcFromControl(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ control := make([]byte, stickyControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = unix.IPPROTO_IP
+ hdr.Type = unix.IP_PKTINFO
+ hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ info.Spec_dst = [4]byte{127, 0, 0, 1}
+ info.Ifindex = 5
+
+ ep := &StdNetEndpoint{}
+ getSrcFromControl(control, ep)
+
+ if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.SrcIfidx() != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
+ }
+ })
+ t.Run("IPv6", func(t *testing.T) {
+ control := make([]byte, stickyControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = unix.IPPROTO_IPV6
+ hdr.Type = unix.IPV6_PKTINFO
+ hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
+ info.Ifindex = 5
+
+ ep := &StdNetEndpoint{}
+ getSrcFromControl(control, ep)
+
+ if ep.SrcIP() != netip.MustParseAddr("::1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.SrcIfidx() != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
+ }
+ })
+ t.Run("ClearOnEmpty", func(t *testing.T) {
+ var control []byte
+ ep := &StdNetEndpoint{}
+ setSrc(ep, netip.MustParseAddr("::1"), 5)
+
+ getSrcFromControl(control, ep)
+ if ep.SrcIP().IsValid() {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.SrcIfidx() != 0 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
+ }
+ })
+ t.Run("Multiple", func(t *testing.T) {
+ zeroControl := make([]byte, unix.CmsgSpace(0))
+ zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0]))
+ zeroHdr.SetLen(unix.CmsgLen(0))
+
+ control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = unix.IPPROTO_IP
+ hdr.Type = unix.IP_PKTINFO
+ hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ info.Spec_dst = [4]byte{127, 0, 0, 1}
+ info.Ifindex = 5
+
+ combined := make([]byte, 0)
+ combined = append(combined, zeroControl...)
+ combined = append(combined, control...)
+
+ ep := &StdNetEndpoint{}
+ getSrcFromControl(combined, ep)
+
+ if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.SrcIfidx() != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
+ }
+ })
+}
+
+func Test_listenConfig(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ sc, err := conn.(*net.UDPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if runtime.GOOS == "linux" {
+ var i int
+ sc.Control(func(fd uintptr) {
+ i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if i != 1 {
+ t.Error("IP_PKTINFO not set!")
+ }
+ } else {
+ t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
+ }
+ })
+ t.Run("IPv6", func(t *testing.T) {
+ conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ sc, err := conn.(*net.UDPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if runtime.GOOS == "linux" {
+ var i int
+ sc.Control(func(fd uintptr) {
+ i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if i != 1 {
+ t.Error("IPV6_PKTINFO not set!")
+ }
+ } else {
+ t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
+ }
+ })
+}
diff --git a/conn/winrio/rio_windows.go b/conn/winrio/rio_windows.go
index 0a07c65..d1037bb 100644
--- a/conn/winrio/rio_windows.go
+++ b/conn/winrio/rio_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package winrio
@@ -84,8 +84,10 @@ type iocpNotificationCompletion struct {
overlapped *windows.Overlapped
}
-var initialized sync.Once
-var available bool
+var (
+ initialized sync.Once
+ available bool
+)
func Initialize() bool {
initialized.Do(func() {
@@ -108,7 +110,7 @@ func Initialize() bool {
return
}
defer windows.CloseHandle(socket)
- var WSAID_MULTIPLE_RIO = &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}}
+ WSAID_MULTIPLE_RIO := &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}}
const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024
ob := uint32(0)
err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,
diff --git a/device/alignment_test.go b/device/alignment_test.go
deleted file mode 100644
index a918112..0000000
--- a/device/alignment_test.go
+++ /dev/null
@@ -1,65 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import (
- "reflect"
- "testing"
- "unsafe"
-)
-
-func checkAlignment(t *testing.T, name string, offset uintptr) {
- t.Helper()
- if offset%8 != 0 {
- t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8))
- }
-}
-
-// TestPeerAlignment checks that atomically-accessed fields are
-// aligned to 64-bit boundaries, as required by the atomic package.
-//
-// Unfortunately, violating this rule on 32-bit platforms results in a
-// hard segfault at runtime.
-func TestPeerAlignment(t *testing.T) {
- var p Peer
-
- typ := reflect.TypeOf(&p).Elem()
- t.Logf("Peer type size: %d, with fields:", typ.Size())
- for i := 0; i < typ.NumField(); i++ {
- field := typ.Field(i)
- t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
- field.Name,
- field.Offset,
- field.Type.Size(),
- field.Type.Align(),
- )
- }
-
- checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats))
- checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning))
-}
-
-// TestDeviceAlignment checks that atomically-accessed fields are
-// aligned to 64-bit boundaries, as required by the atomic package.
-//
-// Unfortunately, violating this rule on 32-bit platforms results in a
-// hard segfault at runtime.
-func TestDeviceAlignment(t *testing.T) {
- var d Device
-
- typ := reflect.TypeOf(&d).Elem()
- t.Logf("Device type size: %d, with fields:", typ.Size())
- for i := 0; i < typ.NumField(); i++ {
- field := typ.Field(i)
- t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
- field.Name,
- field.Offset,
- field.Type.Size(),
- field.Type.Align(),
- )
- }
- checkAlignment(t, "Device.rate.underLoadUntil", unsafe.Offsetof(d.rate)+unsafe.Offsetof(d.rate.underLoadUntil))
-}
diff --git a/device/allowedips.go b/device/allowedips.go
index c08399b..fa46f97 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -1,15 +1,17 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"container/list"
+ "encoding/binary"
"errors"
"math/bits"
"net"
+ "net/netip"
"sync"
"unsafe"
)
@@ -26,49 +28,28 @@ type trieEntry struct {
cidr uint8
bitAtByte uint8
bitAtShift uint8
- bits net.IP
+ bits []byte
perPeerElem *list.Element
}
-func isLittleEndian() bool {
- one := uint32(1)
- return *(*byte)(unsafe.Pointer(&one)) != 0
-}
-
-func swapU32(i uint32) uint32 {
- if !isLittleEndian() {
- return i
- }
-
- return bits.ReverseBytes32(i)
-}
-
-func swapU64(i uint64) uint64 {
- if !isLittleEndian() {
- return i
- }
-
- return bits.ReverseBytes64(i)
-}
-
-func commonBits(ip1 net.IP, ip2 net.IP) uint8 {
+func commonBits(ip1, ip2 []byte) uint8 {
size := len(ip1)
if size == net.IPv4len {
- a := (*uint32)(unsafe.Pointer(&ip1[0]))
- b := (*uint32)(unsafe.Pointer(&ip2[0]))
- x := *a ^ *b
- return uint8(bits.LeadingZeros32(swapU32(x)))
+ a := binary.BigEndian.Uint32(ip1)
+ b := binary.BigEndian.Uint32(ip2)
+ x := a ^ b
+ return uint8(bits.LeadingZeros32(x))
} else if size == net.IPv6len {
- a := (*uint64)(unsafe.Pointer(&ip1[0]))
- b := (*uint64)(unsafe.Pointer(&ip2[0]))
- x := *a ^ *b
+ a := binary.BigEndian.Uint64(ip1)
+ b := binary.BigEndian.Uint64(ip2)
+ x := a ^ b
if x != 0 {
- return uint8(bits.LeadingZeros64(swapU64(x)))
+ return uint8(bits.LeadingZeros64(x))
}
- a = (*uint64)(unsafe.Pointer(&ip1[8]))
- b = (*uint64)(unsafe.Pointer(&ip2[8]))
- x = *a ^ *b
- return 64 + uint8(bits.LeadingZeros64(swapU64(x)))
+ a = binary.BigEndian.Uint64(ip1[8:])
+ b = binary.BigEndian.Uint64(ip2[8:])
+ x = a ^ b
+ return 64 + uint8(bits.LeadingZeros64(x))
} else {
panic("Wrong size bit string")
}
@@ -85,7 +66,7 @@ func (node *trieEntry) removeFromPeerEntries() {
}
}
-func (node *trieEntry) choose(ip net.IP) byte {
+func (node *trieEntry) choose(ip []byte) byte {
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
}
@@ -104,7 +85,7 @@ func (node *trieEntry) zeroizePointers() {
node.parent.parentBit = nil
}
-func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
+func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
parent = node
if parent.cidr == cidr {
@@ -117,7 +98,7 @@ func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry,
return
}
-func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
+func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
if *trie.parentBit == nil {
node := &trieEntry{
peer: peer,
@@ -207,7 +188,7 @@ func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
}
}
-func (node *trieEntry) lookup(ip net.IP) *Peer {
+func (node *trieEntry) lookup(ip []byte) *Peer {
var found *Peer
size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
@@ -229,13 +210,14 @@ type AllowedIPs struct {
mutex sync.RWMutex
}
-func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) {
+func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
table.mutex.RLock()
defer table.mutex.RUnlock()
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
node := elem.Value.(*trieEntry)
- if !cb(node.bits, node.cidr) {
+ a, _ := netip.AddrFromSlice(node.bits)
+ if !cb(netip.PrefixFrom(a, int(node.cidr))) {
return
}
}
@@ -283,28 +265,29 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
}
}
-func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
+func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
- switch len(ip) {
- case net.IPv6len:
- parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer)
- case net.IPv4len:
- parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer)
- default:
+ if prefix.Addr().Is6() {
+ ip := prefix.Addr().As16()
+ parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+ } else if prefix.Addr().Is4() {
+ ip := prefix.Addr().As4()
+ parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+ } else {
panic(errors.New("inserting unknown address type"))
}
}
-func (table *AllowedIPs) Lookup(address []byte) *Peer {
+func (table *AllowedIPs) Lookup(ip []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
- switch len(address) {
+ switch len(ip) {
case net.IPv6len:
- return table.IPv6.lookup(address)
+ return table.IPv6.lookup(ip)
case net.IPv4len:
- return table.IPv4.lookup(address)
+ return table.IPv4.lookup(ip)
default:
panic(errors.New("looking up unknown address type"))
}
diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go
index 16de170..07065c3 100644
--- a/device/allowedips_rand_test.go
+++ b/device/allowedips_rand_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -8,6 +8,7 @@ package device
import (
"math/rand"
"net"
+ "net/netip"
"sort"
"testing"
)
@@ -93,14 +94,14 @@ func TestTrieRandom(t *testing.T) {
rand.Read(addr4[:])
cidr := uint8(rand.Intn(32) + 1)
index := rand.Intn(NumberOfPeers)
- allowedIPs.Insert(addr4[:], cidr, peers[index])
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
var addr6 [16]byte
rand.Read(addr6[:])
cidr = uint8(rand.Intn(128) + 1)
index = rand.Intn(NumberOfPeers)
- allowedIPs.Insert(addr6[:], cidr, peers[index])
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
}
diff --git a/device/allowedips_test.go b/device/allowedips_test.go
index 2059a88..cde068e 100644
--- a/device/allowedips_test.go
+++ b/device/allowedips_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -8,6 +8,7 @@ package device
import (
"math/rand"
"net"
+ "net/netip"
"testing"
)
@@ -18,7 +19,6 @@ type testPairCommonBits struct {
}
func TestCommonBits(t *testing.T) {
-
tests := []testPairCommonBits{
{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
@@ -39,7 +39,7 @@ func TestCommonBits(t *testing.T) {
}
}
-func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
+func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
var trie *trieEntry
var peers []*Peer
root := parentIndirection{&trie, 2}
@@ -98,7 +98,7 @@ func TestTrieIPv4(t *testing.T) {
var allowedIPs AllowedIPs
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
- allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer)
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d byte) {
@@ -208,7 +208,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
- allowedIPs.Insert(addr, cidr, peer)
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d uint32) {
diff --git a/device/bind_test.go b/device/bind_test.go
index 1c0d247..302a521 100644
--- a/device/bind_test.go
+++ b/device/bind_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -26,21 +26,21 @@ func (b *DummyBind) SetMark(v uint32) error {
return nil
}
-func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) {
+func (b *DummyBind) ReceiveIPv6(buf []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in6
if !ok {
return 0, nil, errors.New("closed")
}
- copy(buff, datagram.msg)
+ copy(buf, datagram.msg)
return len(datagram.msg), datagram.endpoint, nil
}
-func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) {
+func (b *DummyBind) ReceiveIPv4(buf []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in4
if !ok {
return 0, nil, errors.New("closed")
}
- copy(buff, datagram.msg)
+ copy(buf, datagram.msg)
return len(datagram.msg), datagram.endpoint, nil
}
@@ -51,6 +51,6 @@ func (b *DummyBind) Close() error {
return nil
}
-func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error {
+func (b *DummyBind) Send(buf []byte, end conn.Endpoint) error {
return nil
}
diff --git a/device/channels.go b/device/channels.go
index 6f4370a..e526f6b 100644
--- a/device/channels.go
+++ b/device/channels.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -19,13 +19,13 @@ import (
// call wg.Done to remove the initial reference.
// When the refcount hits 0, the queue's channel is closed.
type outboundQueue struct {
- c chan *QueueOutboundElement
+ c chan *QueueOutboundElementsContainer
wg sync.WaitGroup
}
func newOutboundQueue() *outboundQueue {
q := &outboundQueue{
- c: make(chan *QueueOutboundElement, QueueOutboundSize),
+ c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
}
q.wg.Add(1)
go func() {
@@ -37,13 +37,13 @@ func newOutboundQueue() *outboundQueue {
// A inboundQueue is similar to an outboundQueue; see those docs.
type inboundQueue struct {
- c chan *QueueInboundElement
+ c chan *QueueInboundElementsContainer
wg sync.WaitGroup
}
func newInboundQueue() *inboundQueue {
q := &inboundQueue{
- c: make(chan *QueueInboundElement, QueueInboundSize),
+ c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
}
q.wg.Add(1)
go func() {
@@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue {
}
type autodrainingInboundQueue struct {
- c chan *QueueInboundElement
+ c chan *QueueInboundElementsContainer
}
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
@@ -81,7 +81,7 @@ type autodrainingInboundQueue struct {
// some other means, such as sending a sentinel nil values.
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
q := &autodrainingInboundQueue{
- c: make(chan *QueueInboundElement, QueueInboundSize),
+ c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
}
runtime.SetFinalizer(q, device.flushInboundQueue)
return q
@@ -90,10 +90,13 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
for {
select {
- case elem := <-q.c:
- elem.Lock()
- device.PutMessageBuffer(elem.buffer)
- device.PutInboundElement(elem)
+ case elemsContainer := <-q.c:
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
+ }
+ device.PutInboundElementsContainer(elemsContainer)
default:
return
}
@@ -101,7 +104,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
}
type autodrainingOutboundQueue struct {
- c chan *QueueOutboundElement
+ c chan *QueueOutboundElementsContainer
}
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
@@ -111,7 +114,7 @@ type autodrainingOutboundQueue struct {
// All sends to the channel must be best-effort, because there may be no receivers.
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
q := &autodrainingOutboundQueue{
- c: make(chan *QueueOutboundElement, QueueOutboundSize),
+ c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
}
runtime.SetFinalizer(q, device.flushOutboundQueue)
return q
@@ -120,10 +123,13 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
for {
select {
- case elem := <-q.c:
- elem.Lock()
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ case elemsContainer := <-q.c:
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
+ device.PutOutboundElementsContainer(elemsContainer)
default:
return
}
diff --git a/device/constants.go b/device/constants.go
index 53a2526..59854a1 100644
--- a/device/constants.go
+++ b/device/constants.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/cookie.go b/device/cookie.go
index b02b769..876f05d 100644
--- a/device/cookie.go
+++ b/device/cookie.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -83,7 +83,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
return hmac.Equal(mac1[:], msg[smac1:smac2])
}
-func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
+func (st *CookieChecker) CheckMAC2(msg, src []byte) bool {
st.RLock()
defer st.RUnlock()
@@ -119,7 +119,6 @@ func (st *CookieChecker) CreateReply(
recv uint32,
src []byte,
) (*MessageCookieReply, error) {
-
st.RLock()
// refresh cookie secret
@@ -204,7 +203,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
_, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:])
-
if err != nil {
return false
}
@@ -215,7 +213,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
}
func (st *CookieGenerator) AddMacs(msg []byte) {
-
size := len(msg)
smac2 := size - blake2s.Size128
diff --git a/device/cookie_test.go b/device/cookie_test.go
index 02e01d1..4f1e50a 100644
--- a/device/cookie_test.go
+++ b/device/cookie_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -10,7 +10,6 @@ import (
)
func TestCookieMAC1(t *testing.T) {
-
// setup generator / checker
var (
@@ -132,12 +131,12 @@ func TestCookieMAC1(t *testing.T) {
msg[5] ^= 0x20
- srcBad1 := []byte{192, 168, 13, 37, 40, 01}
+ srcBad1 := []byte{192, 168, 13, 37, 40, 1}
if checker.CheckMAC2(msg, srcBad1) {
t.Fatal("MAC2 generation/verification failed")
}
- srcBad2 := []byte{192, 168, 13, 38, 40, 01}
+ srcBad2 := []byte{192, 168, 13, 38, 40, 1}
if checker.CheckMAC2(msg, srcBad2) {
t.Fatal("MAC2 generation/verification failed")
}
diff --git a/device/device.go b/device/device.go
index 5644c8a..83c33ee 100644
--- a/device/device.go
+++ b/device/device.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -30,7 +30,7 @@ type Device struct {
// will become the actual state; Up can fail.
// The device can also change state multiple times between time of check and time of use.
// Unsynchronized uses of state must therefore be advisory/best-effort only.
- state uint32 // actually a deviceState, but typed uint32 for convenience
+ state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
// stopping blocks until all inputs to Device have been closed.
stopping sync.WaitGroup
// mu protects state changes.
@@ -44,6 +44,7 @@ type Device struct {
netlinkCancel *rwcancel.RWCancel
port uint16 // listening port
fwmark uint32 // mark value (0 = disabled)
+ brokenRoaming bool
}
staticIdentity struct {
@@ -52,24 +53,26 @@ type Device struct {
publicKey NoisePublicKey
}
- rate struct {
- underLoadUntil int64
- limiter ratelimiter.Ratelimiter
- }
-
peers struct {
sync.RWMutex // protects keyMap
keyMap map[NoisePublicKey]*Peer
}
+ rate struct {
+ underLoadUntil atomic.Int64
+ limiter ratelimiter.Ratelimiter
+ }
+
allowedips AllowedIPs
indexTable IndexTable
cookieChecker CookieChecker
pool struct {
- messageBuffers *WaitPool
- inboundElements *WaitPool
- outboundElements *WaitPool
+ inboundElementsContainer *WaitPool
+ outboundElementsContainer *WaitPool
+ messageBuffers *WaitPool
+ inboundElements *WaitPool
+ outboundElements *WaitPool
}
queue struct {
@@ -80,7 +83,7 @@ type Device struct {
tun struct {
device tun.Device
- mtu int32
+ mtu atomic.Int32
}
ipcMutex sync.RWMutex
@@ -92,10 +95,9 @@ type Device struct {
// There are three states: down, up, closed.
// Transitions:
//
-// down -----+
-// ↑↓ ↓
-// up -> closed
-//
+// down -----+
+// ↑↓ ↓
+// up -> closed
type deviceState uint32
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
@@ -108,7 +110,7 @@ const (
// deviceState returns device.state.state as a deviceState
// See those docs for how to interpret this value.
func (device *Device) deviceState() deviceState {
- return deviceState(atomic.LoadUint32(&device.state.state))
+ return deviceState(device.state.state.Load())
}
// isClosed reports whether the device is closed (or is closing).
@@ -147,14 +149,14 @@ func (device *Device) changeState(want deviceState) (err error) {
case old:
return nil
case deviceStateUp:
- atomic.StoreUint32(&device.state.state, uint32(deviceStateUp))
+ device.state.state.Store(uint32(deviceStateUp))
err = device.upLocked()
if err == nil {
break
}
fallthrough // up failed; bring the device all the way back down
case deviceStateDown:
- atomic.StoreUint32(&device.state.state, uint32(deviceStateDown))
+ device.state.state.Store(uint32(deviceStateDown))
errDown := device.downLocked()
if err == nil {
err = errDown
@@ -172,10 +174,15 @@ func (device *Device) upLocked() error {
return err
}
+ // The IPC set operation waits for peers to be created before calling Start() on them,
+ // so if there's a concurrent IPC set request happening, we should wait for it to complete.
+ device.ipcMutex.Lock()
+ defer device.ipcMutex.Unlock()
+
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Start()
- if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
+ if peer.persistentKeepaliveInterval.Load() > 0 {
peer.SendKeepalive()
}
}
@@ -212,11 +219,11 @@ func (device *Device) IsUnderLoad() bool {
now := time.Now()
underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
if underLoad {
- atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano())
+ device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
return true
}
// check if recently under load
- return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano()
+ return device.rate.underLoadUntil.Load() > now.UnixNano()
}
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
@@ -260,7 +267,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap {
handshake := &peer.handshake
- handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
+ handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
expiredPeers = append(expiredPeers, peer)
}
@@ -276,7 +283,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device := new(Device)
- device.state.state = uint32(deviceStateDown)
+ device.state.state.Store(uint32(deviceStateDown))
device.closed = make(chan struct{})
device.log = logger
device.net.bind = bind
@@ -286,10 +293,11 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
mtu = DefaultMTU
}
- device.tun.mtu = int32(mtu)
+ device.tun.mtu.Store(int32(mtu))
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
device.rate.limiter.Init()
device.indexTable.Init()
+
device.PopulatePools()
// create queues
@@ -317,6 +325,19 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
return device
}
+// BatchSize returns the BatchSize for the device as a whole which is the max of
+// the bind batch size and the tun batch size. The batch size reported by device
+// is the size used to construct memory pools, and is the allowed batch size for
+// the lifetime of the device.
+func (device *Device) BatchSize() int {
+ size := device.net.bind.BatchSize()
+ dSize := device.tun.device.BatchSize()
+ if size < dSize {
+ size = dSize
+ }
+ return size
+}
+
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.peers.RLock()
defer device.peers.RUnlock()
@@ -349,10 +370,12 @@ func (device *Device) RemoveAllPeers() {
func (device *Device) Close() {
device.state.Lock()
defer device.state.Unlock()
+ device.ipcMutex.Lock()
+ defer device.ipcMutex.Unlock()
if device.isClosed() {
return
}
- atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed))
+ device.state.state.Store(uint32(deviceStateClosed))
device.log.Verbosef("Device closing")
device.tun.device.Close()
@@ -438,11 +461,7 @@ func (device *Device) BindSetMark(mark uint32) error {
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
- peer.Lock()
- defer peer.Unlock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
+ peer.markEndpointSrcForClearing()
}
device.peers.RUnlock()
@@ -467,11 +486,13 @@ func (device *Device) BindUpdate() error {
var err error
var recvFns []conn.ReceiveFunc
netc := &device.net
+
recvFns, netc.port, err = netc.bind.Open(netc.port)
if err != nil {
netc.port = 0
return err
}
+
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
if err != nil {
netc.bind.Close()
@@ -490,11 +511,7 @@ func (device *Device) BindUpdate() error {
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
- peer.Lock()
- defer peer.Unlock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
+ peer.markEndpointSrcForClearing()
}
device.peers.RUnlock()
@@ -502,8 +519,9 @@ func (device *Device) BindUpdate() error {
device.net.stopping.Add(len(recvFns))
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
+ batchSize := netc.bind.BatchSize()
for _, fn := range recvFns {
- go device.RoutineReceiveIncoming(fn)
+ go device.RoutineReceiveIncoming(batchSize, fn)
}
device.log.Verbosef("UDP bind has been updated")
diff --git a/device/device_test.go b/device/device_test.go
index 29daeb9..fff172b 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -11,7 +11,8 @@ import (
"fmt"
"io"
"math/rand"
- "net"
+ "net/netip"
+ "os"
"runtime"
"runtime/pprof"
"sync"
@@ -21,6 +22,7 @@ import (
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/conn/bindtest"
+ "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/tuntest"
)
@@ -48,7 +50,7 @@ func uapiCfg(cfg ...string) string {
// genConfigs generates a pair of configs that connect to each other.
// The configs use distinct, probably-usable ports.
-func genConfigs(tb testing.TB) (cfgs [2]string, endpointCfgs [2]string) {
+func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
var key1, key2 NoisePrivateKey
_, err := rand.Read(key1[:])
if err != nil {
@@ -96,7 +98,7 @@ type testPair [2]testPeer
type testPeer struct {
tun *tuntest.ChannelTUN
dev *Device
- ip net.IP
+ ip netip.Addr
}
type SendDirection bool
@@ -159,7 +161,7 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
for i := range pair {
p := &pair[i]
p.tun = tuntest.NewChannelTUN()
- p.ip = net.IPv4(1, 0, 0, byte(i+1))
+ p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
level := LogLevelVerbose
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
level = LogLevelError
@@ -307,6 +309,17 @@ func TestConcurrencySafety(t *testing.T) {
}
})
+ // Perform bind updates and keepalive sends concurrently with tunnel use.
+ t.Run("bindUpdate and keepalive", func(t *testing.T) {
+ const iters = 10
+ for i := 0; i < iters; i++ {
+ for _, peer := range pair {
+ peer.dev.BindUpdate()
+ peer.dev.SendKeepalivesToPeersWithCurrentKeypair()
+ }
+ }
+ })
+
close(done)
}
@@ -333,7 +346,7 @@ func BenchmarkThroughput(b *testing.B) {
// Measure how long it takes to receive b.N packets,
// starting when we receive the first packet.
- var recv uint64
+ var recv atomic.Uint64
var elapsed time.Duration
var wg sync.WaitGroup
wg.Add(1)
@@ -342,7 +355,7 @@ func BenchmarkThroughput(b *testing.B) {
var start time.Time
for {
<-pair[0].tun.Inbound
- new := atomic.AddUint64(&recv, 1)
+ new := recv.Add(1)
if new == 1 {
start = time.Now()
}
@@ -358,7 +371,7 @@ func BenchmarkThroughput(b *testing.B) {
ping := tuntest.Ping(pair[0].ip, pair[1].ip)
pingc := pair[1].tun.Outbound
var sent uint64
- for atomic.LoadUint64(&recv) != uint64(b.N) {
+ for recv.Load() != uint64(b.N) {
sent++
pingc <- ping
}
@@ -405,3 +418,59 @@ func goroutineLeakCheck(t *testing.T) {
t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
})
}
+
+type fakeBindSized struct {
+ size int
+}
+
+func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
+ return nil, 0, nil
+}
+func (b *fakeBindSized) Close() error { return nil }
+func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
+func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
+func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
+func (b *fakeBindSized) BatchSize() int { return b.size }
+
+type fakeTUNDeviceSized struct {
+ size int
+}
+
+func (t *fakeTUNDeviceSized) File() *os.File { return nil }
+func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
+ return 0, nil
+}
+func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
+func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
+func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
+func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
+func (t *fakeTUNDeviceSized) Close() error { return nil }
+func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
+
+func TestBatchSize(t *testing.T) {
+ d := Device{}
+
+ d.net.bind = &fakeBindSized{1}
+ d.tun.device = &fakeTUNDeviceSized{1}
+ if want, got := 1, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{1}
+ d.tun.device = &fakeTUNDeviceSized{128}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{128}
+ d.tun.device = &fakeTUNDeviceSized{1}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{128}
+ d.tun.device = &fakeTUNDeviceSized{128}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+}
diff --git a/device/endpoint_test.go b/device/endpoint_test.go
index 57c361c..93a4998 100644
--- a/device/endpoint_test.go
+++ b/device/endpoint_test.go
@@ -1,53 +1,49 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"math/rand"
- "net"
+ "net/netip"
)
type DummyEndpoint struct {
- src [16]byte
- dst [16]byte
+ src, dst netip.Addr
}
func CreateDummyEndpoint() (*DummyEndpoint, error) {
- var end DummyEndpoint
- if _, err := rand.Read(end.src[:]); err != nil {
+ var src, dst [16]byte
+ if _, err := rand.Read(src[:]); err != nil {
return nil, err
}
- _, err := rand.Read(end.dst[:])
- return &end, err
+ _, err := rand.Read(dst[:])
+ return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err
}
func (e *DummyEndpoint) ClearSrc() {}
func (e *DummyEndpoint) SrcToString() string {
- var addr net.UDPAddr
- addr.IP = e.SrcIP()
- addr.Port = 1000
- return addr.String()
+ return netip.AddrPortFrom(e.SrcIP(), 1000).String()
}
func (e *DummyEndpoint) DstToString() string {
- var addr net.UDPAddr
- addr.IP = e.DstIP()
- addr.Port = 1000
- return addr.String()
+ return netip.AddrPortFrom(e.DstIP(), 1000).String()
}
-func (e *DummyEndpoint) SrcToBytes() []byte {
- return e.src[:]
+func (e *DummyEndpoint) DstToBytes() []byte {
+ out := e.DstIP().AsSlice()
+ out = append(out, byte(1000&0xff))
+ out = append(out, byte((1000>>8)&0xff))
+ return out
}
-func (e *DummyEndpoint) DstIP() net.IP {
- return e.dst[:]
+func (e *DummyEndpoint) DstIP() netip.Addr {
+ return e.dst
}
-func (e *DummyEndpoint) SrcIP() net.IP {
- return e.src[:]
+func (e *DummyEndpoint) SrcIP() netip.Addr {
+ return e.src
}
diff --git a/device/indextable.go b/device/indextable.go
index 15b8b0d..00ade7d 100644
--- a/device/indextable.go
+++ b/device/indextable.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/ip.go b/device/ip.go
index 50c5040..eaf2363 100644
--- a/device/ip.go
+++ b/device/ip.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/kdf_test.go b/device/kdf_test.go
index b14aa6d..f9c76d6 100644
--- a/device/kdf_test.go
+++ b/device/kdf_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -20,7 +20,7 @@ type KDFTest struct {
t2 string
}
-func assertEquals(t *testing.T, a string, b string) {
+func assertEquals(t *testing.T, a, b string) {
if a != b {
t.Fatal("expected", a, "=", b)
}
diff --git a/device/keypair.go b/device/keypair.go
index 788c947..e3540d7 100644
--- a/device/keypair.go
+++ b/device/keypair.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -10,7 +10,6 @@ import (
"sync"
"sync/atomic"
"time"
- "unsafe"
"golang.zx2c4.com/wireguard/replay"
)
@@ -23,7 +22,7 @@ import (
*/
type Keypair struct {
- sendNonce uint64 // accessed atomically
+ sendNonce atomic.Uint64
send cipher.AEAD
receive cipher.AEAD
replayFilter replay.Filter
@@ -37,15 +36,7 @@ type Keypairs struct {
sync.RWMutex
current *Keypair
previous *Keypair
- next *Keypair
-}
-
-func (kp *Keypairs) storeNext(next *Keypair) {
- atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
-}
-
-func (kp *Keypairs) loadNext() *Keypair {
- return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
+ next atomic.Pointer[Keypair]
}
func (kp *Keypairs) Current() *Keypair {
diff --git a/device/logger.go b/device/logger.go
index c79c971..22b0df0 100644
--- a/device/logger.go
+++ b/device/logger.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -16,8 +16,8 @@ import (
// They do not require a trailing newline in the format.
// If nil, that level of logging will be silent.
type Logger struct {
- Verbosef func(format string, args ...interface{})
- Errorf func(format string, args ...interface{})
+ Verbosef func(format string, args ...any)
+ Errorf func(format string, args ...any)
}
// Log levels for use with NewLogger.
@@ -28,14 +28,14 @@ const (
)
// Function for use in Logger for discarding logged lines.
-func DiscardLogf(format string, args ...interface{}) {}
+func DiscardLogf(format string, args ...any) {}
// NewLogger constructs a Logger that writes to stdout.
// It logs at the specified log level and above.
// It decorates log lines with the log level, date, time, and prepend.
func NewLogger(level int, prepend string) *Logger {
logger := &Logger{DiscardLogf, DiscardLogf}
- logf := func(prefix string) func(string, ...interface{}) {
+ logf := func(prefix string) func(string, ...any) {
return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf
}
if level >= LogLevelVerbose {
diff --git a/device/misc.go b/device/misc.go
deleted file mode 100644
index 4126704..0000000
--- a/device/misc.go
+++ /dev/null
@@ -1,41 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import (
- "sync/atomic"
-)
-
-/* Atomic Boolean */
-
-const (
- AtomicFalse = int32(iota)
- AtomicTrue
-)
-
-type AtomicBool struct {
- int32
-}
-
-func (a *AtomicBool) Get() bool {
- return atomic.LoadInt32(&a.int32) == AtomicTrue
-}
-
-func (a *AtomicBool) Swap(val bool) bool {
- flag := AtomicFalse
- if val {
- flag = AtomicTrue
- }
- return atomic.SwapInt32(&a.int32, flag) == AtomicTrue
-}
-
-func (a *AtomicBool) Set(val bool) {
- flag := AtomicFalse
- if val {
- flag = AtomicTrue
- }
- atomic.StoreInt32(&a.int32, flag)
-}
diff --git a/device/mobilequirks.go b/device/mobilequirks.go
index f27d9d7..0a0080e 100644
--- a/device/mobilequirks.go
+++ b/device/mobilequirks.go
@@ -1,16 +1,19 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
+// DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created,
+// though it will try to deal with it, and race maybe, if called after.
func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
+ device.net.brokenRoaming = true
device.peers.RLock()
for _, peer := range device.peers.keyMap {
- peer.Lock()
- peer.disableRoaming = peer.endpoint != nil
- peer.Unlock()
+ peer.endpoint.Lock()
+ peer.endpoint.disableRoaming = peer.endpoint.val != nil
+ peer.endpoint.Unlock()
}
device.peers.RUnlock()
}
diff --git a/device/noise-helpers.go b/device/noise-helpers.go
index b6b51fd..c2f356b 100644
--- a/device/noise-helpers.go
+++ b/device/noise-helpers.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -9,6 +9,7 @@ import (
"crypto/hmac"
"crypto/rand"
"crypto/subtle"
+ "errors"
"hash"
"golang.org/x/crypto/blake2s"
@@ -94,9 +95,14 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
return
}
-func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
+var errInvalidPublicKey = errors.New("invalid public key")
+
+func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
apk := (*[NoisePublicKeySize]byte)(&pk)
ask := (*[NoisePrivateKeySize]byte)(sk)
curve25519.ScalarMult(&ss, ask, apk)
- return ss
+ if isZero(ss[:]) {
+ return ss, errInvalidPublicKey
+ }
+ return ss, nil
}
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index 0212b7d..e8f6145 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -138,11 +138,11 @@ var (
ZeroNonce [chacha20poly1305.NonceSize]byte
)
-func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) {
+func mixKey(dst, c *[blake2s.Size]byte, data []byte) {
KDF1(dst, c[:], data)
}
-func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) {
+func mixHash(dst, h *[blake2s.Size]byte, data []byte) {
hash, _ := blake2s.New256(nil)
hash.Write(h[:])
hash.Write(data)
@@ -175,8 +175,6 @@ func init() {
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
- var errZeroECDHResult = errors.New("ECDH returned all zeros")
-
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -204,9 +202,9 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(msg.Ephemeral[:])
// encrypt static key
- ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
- if isZero(ss[:]) {
- return nil, errZeroECDHResult
+ ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ if err != nil {
+ return nil, err
}
var key [chacha20poly1305.KeySize]byte
KDF2(
@@ -221,7 +219,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
// encrypt timestamp
if isZero(handshake.precomputedStaticStatic[:]) {
- return nil, errZeroECDHResult
+ return nil, errInvalidPublicKey
}
KDF2(
&handshake.chainKey,
@@ -264,11 +262,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key
- var err error
var peerPK NoisePublicKey
var key [chacha20poly1305.KeySize]byte
- ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
- if isZero(ss[:]) {
+ ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ if err != nil {
return nil
}
KDF2(&chainKey, &key, chainKey[:], ss[:])
@@ -282,7 +279,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
// lookup peer
peer := device.LookupPeer(peerPK)
- if peer == nil {
+ if peer == nil || !peer.isRunning.Load() {
return nil
}
@@ -384,12 +381,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(msg.Ephemeral[:])
handshake.mixKey(msg.Ephemeral[:])
- func() {
- ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
- handshake.mixKey(ss[:])
- ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
- handshake.mixKey(ss[:])
- }()
+ ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
+ if err != nil {
+ return nil, err
+ }
+ handshake.mixKey(ss[:])
+ ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ if err != nil {
+ return nil, err
+ }
+ handshake.mixKey(ss[:])
// add preshared key
@@ -406,11 +407,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(tau[:])
- func() {
- aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
- handshake.mixHash(msg.Empty[:])
- }()
+ aead, _ := chacha20poly1305.New(key[:])
+ aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
+ handshake.mixHash(msg.Empty[:])
handshake.state = handshakeResponseCreated
@@ -436,7 +435,6 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
)
ok := func() bool {
-
// lock handshake state
handshake.mutex.RLock()
@@ -456,17 +454,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
- func() {
- ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
- mixKey(&chainKey, &chainKey, ss[:])
- setZero(ss[:])
- }()
+ ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
+ if err != nil {
+ return false
+ }
+ mixKey(&chainKey, &chainKey, ss[:])
+ setZero(ss[:])
- func() {
- ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
- mixKey(&chainKey, &chainKey, ss[:])
- setZero(ss[:])
- }()
+ ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ if err != nil {
+ return false
+ }
+ mixKey(&chainKey, &chainKey, ss[:])
+ setZero(ss[:])
// add preshared key (psk)
@@ -484,7 +484,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
// authenticate transcript
aead, _ := chacha20poly1305.New(key[:])
- _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
+ _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil {
return false
}
@@ -582,12 +582,12 @@ func (peer *Peer) BeginSymmetricSession() error {
defer keypairs.Unlock()
previous := keypairs.previous
- next := keypairs.loadNext()
+ next := keypairs.next.Load()
current := keypairs.current
if isInitiator {
if next != nil {
- keypairs.storeNext(nil)
+ keypairs.next.Store(nil)
keypairs.previous = next
device.DeleteKeypair(current)
} else {
@@ -596,7 +596,7 @@ func (peer *Peer) BeginSymmetricSession() error {
device.DeleteKeypair(previous)
keypairs.current = keypair
} else {
- keypairs.storeNext(keypair)
+ keypairs.next.Store(keypair)
device.DeleteKeypair(next)
keypairs.previous = nil
device.DeleteKeypair(previous)
@@ -608,18 +608,18 @@ func (peer *Peer) BeginSymmetricSession() error {
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs
- if keypairs.loadNext() != receivedKeypair {
+ if keypairs.next.Load() != receivedKeypair {
return false
}
keypairs.Lock()
defer keypairs.Unlock()
- if keypairs.loadNext() != receivedKeypair {
+ if keypairs.next.Load() != receivedKeypair {
return false
}
old := keypairs.previous
keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old)
- keypairs.current = keypairs.loadNext()
- keypairs.storeNext(nil)
+ keypairs.current = keypairs.next.Load()
+ keypairs.next.Store(nil)
return true
}
diff --git a/device/noise-types.go b/device/noise-types.go
index 6e850e7..e850359 100644
--- a/device/noise-types.go
+++ b/device/noise-types.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/noise_test.go b/device/noise_test.go
index 807ca2d..2dd5324 100644
--- a/device/noise_test.go
+++ b/device/noise_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -24,10 +24,10 @@ func TestCurveWrappers(t *testing.T) {
pk1 := sk1.publicKey()
pk2 := sk2.publicKey()
- ss1 := sk1.sharedSecret(pk2)
- ss2 := sk2.sharedSecret(pk1)
+ ss1, err1 := sk1.sharedSecret(pk2)
+ ss2, err2 := sk2.sharedSecret(pk1)
- if ss1 != ss2 {
+ if ss1 != ss2 || err1 != nil || err2 != nil {
t.Fatal("Failed to compute shared secet")
}
}
@@ -71,6 +71,8 @@ func TestNoiseHandshake(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ peer1.Start()
+ peer2.Start()
assertEqual(
t,
@@ -146,7 +148,7 @@ func TestNoiseHandshake(t *testing.T) {
t.Fatal("failed to derive keypair for peer 2", err)
}
- key1 := peer1.keypairs.loadNext()
+ key1 := peer1.keypairs.next.Load()
key2 := peer2.keypairs.current
// encrypting / decryption test
diff --git a/device/peer.go b/device/peer.go
index c8b825d..47a2f14 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -16,36 +16,31 @@ import (
)
type Peer struct {
- isRunning AtomicBool
- sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
- keypairs Keypairs
- handshake Handshake
- device *Device
- endpoint conn.Endpoint
- stopping sync.WaitGroup // routines pending stop
-
- // These fields are accessed with atomic operations, which must be
- // 64-bit aligned even on 32-bit platforms. Go guarantees that an
- // allocated struct will be 64-bit aligned. So we place
- // atomically-accessed fields up front, so that they can share in
- // this alignment before smaller fields throw it off.
- stats struct {
- txBytes uint64 // bytes send to peer (endpoint)
- rxBytes uint64 // bytes received from peer
- lastHandshakeNano int64 // nano seconds since epoch
+ isRunning atomic.Bool
+ keypairs Keypairs
+ handshake Handshake
+ device *Device
+ stopping sync.WaitGroup // routines pending stop
+ txBytes atomic.Uint64 // bytes send to peer (endpoint)
+ rxBytes atomic.Uint64 // bytes received from peer
+ lastHandshakeNano atomic.Int64 // nano seconds since epoch
+
+ endpoint struct {
+ sync.Mutex
+ val conn.Endpoint
+ clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
+ disableRoaming bool
}
- disableRoaming bool
-
timers struct {
retransmitHandshake *Timer
sendKeepalive *Timer
newHandshake *Timer
zeroKeyMaterial *Timer
persistentKeepalive *Timer
- handshakeAttempts uint32
- needAnotherKeepalive AtomicBool
- sentLastMinuteHandshake AtomicBool
+ handshakeAttempts atomic.Uint32
+ needAnotherKeepalive atomic.Bool
+ sentLastMinuteHandshake atomic.Bool
}
state struct {
@@ -53,14 +48,14 @@ type Peer struct {
}
queue struct {
- staged chan *QueueOutboundElement // staged packets before a handshake is available
- outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
- inbound *autodrainingInboundQueue // sequential ordering of tun writing
+ staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
+ outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
+ inbound *autodrainingInboundQueue // sequential ordering of tun writing
}
cookieGenerator CookieGenerator
trieEntries list.List
- persistentKeepaliveInterval uint32 // accessed atomically
+ persistentKeepaliveInterval atomic.Uint32
}
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
@@ -82,14 +77,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// create peer
peer := new(Peer)
- peer.Lock()
- defer peer.Unlock()
peer.cookieGenerator.Init(pk)
peer.device = device
peer.queue.outbound = newAutodrainingOutboundQueue(device)
peer.queue.inbound = newAutodrainingInboundQueue(device)
- peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize)
+ peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
// map public key
_, ok := device.peers.keyMap[pk]
@@ -100,26 +93,27 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// pre-compute DH
handshake := &peer.handshake
handshake.mutex.Lock()
- handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
+ handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
handshake.remoteStatic = pk
handshake.mutex.Unlock()
// reset endpoint
- peer.endpoint = nil
+ peer.endpoint.Lock()
+ peer.endpoint.val = nil
+ peer.endpoint.disableRoaming = false
+ peer.endpoint.clearSrcOnTx = false
+ peer.endpoint.Unlock()
+
+ // init timers
+ peer.timersInit()
// add
device.peers.keyMap[pk] = peer
- // start peer
- peer.timersInit()
- if peer.device.isUp() {
- peer.Start()
- }
-
return peer, nil
}
-func (peer *Peer) SendBuffer(buffer []byte) error {
+func (peer *Peer) SendBuffers(buffers [][]byte) error {
peer.device.net.RLock()
defer peer.device.net.RUnlock()
@@ -127,16 +121,25 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
return nil
}
- peer.RLock()
- defer peer.RUnlock()
-
- if peer.endpoint == nil {
+ peer.endpoint.Lock()
+ endpoint := peer.endpoint.val
+ if endpoint == nil {
+ peer.endpoint.Unlock()
return errors.New("no known endpoint for peer")
}
+ if peer.endpoint.clearSrcOnTx {
+ endpoint.ClearSrc()
+ peer.endpoint.clearSrcOnTx = false
+ }
+ peer.endpoint.Unlock()
- err := peer.device.net.bind.Send(buffer, peer.endpoint)
+ err := peer.device.net.bind.Send(buffers, endpoint)
if err == nil {
- atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer)))
+ var totalLen uint64
+ for _, b := range buffers {
+ totalLen += uint64(len(b))
+ }
+ peer.txBytes.Add(totalLen)
}
return err
}
@@ -177,7 +180,7 @@ func (peer *Peer) Start() {
peer.state.Lock()
defer peer.state.Unlock()
- if peer.isRunning.Get() {
+ if peer.isRunning.Load() {
return
}
@@ -198,10 +201,14 @@ func (peer *Peer) Start() {
device.flushInboundQueue(peer.queue.inbound)
device.flushOutboundQueue(peer.queue.outbound)
- go peer.RoutineSequentialSender()
- go peer.RoutineSequentialReceiver()
- peer.isRunning.Set(true)
+ // Use the device batch size, not the bind batch size, as the device size is
+ // the size of the batch pools.
+ batchSize := peer.device.BatchSize()
+ go peer.RoutineSequentialSender(batchSize)
+ go peer.RoutineSequentialReceiver(batchSize)
+
+ peer.isRunning.Store(true)
}
func (peer *Peer) ZeroAndFlushAll() {
@@ -213,10 +220,10 @@ func (peer *Peer) ZeroAndFlushAll() {
keypairs.Lock()
device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current)
- device.DeleteKeypair(keypairs.loadNext())
+ device.DeleteKeypair(keypairs.next.Load())
keypairs.previous = nil
keypairs.current = nil
- keypairs.storeNext(nil)
+ keypairs.next.Store(nil)
keypairs.Unlock()
// clear handshake state
@@ -241,11 +248,10 @@ func (peer *Peer) ExpireCurrentKeypairs() {
keypairs := &peer.keypairs
keypairs.Lock()
if keypairs.current != nil {
- atomic.StoreUint64(&keypairs.current.sendNonce, RejectAfterMessages)
+ keypairs.current.sendNonce.Store(RejectAfterMessages)
}
- if keypairs.next != nil {
- next := keypairs.loadNext()
- atomic.StoreUint64(&next.sendNonce, RejectAfterMessages)
+ if next := keypairs.next.Load(); next != nil {
+ next.sendNonce.Store(RejectAfterMessages)
}
keypairs.Unlock()
}
@@ -271,10 +277,20 @@ func (peer *Peer) Stop() {
}
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
- if peer.disableRoaming {
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ if peer.endpoint.disableRoaming {
+ return
+ }
+ peer.endpoint.clearSrcOnTx = false
+ peer.endpoint.val = endpoint
+}
+
+func (peer *Peer) markEndpointSrcForClearing() {
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ if peer.endpoint.val == nil {
return
}
- peer.Lock()
- peer.endpoint = endpoint
- peer.Unlock()
+ peer.endpoint.clearSrcOnTx = true
}
diff --git a/device/pools.go b/device/pools.go
index f1d1fa0..94f3dc7 100644
--- a/device/pools.go
+++ b/device/pools.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -14,49 +14,85 @@ type WaitPool struct {
pool sync.Pool
cond sync.Cond
lock sync.Mutex
- count uint32
+ count atomic.Uint32
max uint32
}
-func NewWaitPool(max uint32, new func() interface{}) *WaitPool {
+func NewWaitPool(max uint32, new func() any) *WaitPool {
p := &WaitPool{pool: sync.Pool{New: new}, max: max}
p.cond = sync.Cond{L: &p.lock}
return p
}
-func (p *WaitPool) Get() interface{} {
+func (p *WaitPool) Get() any {
if p.max != 0 {
p.lock.Lock()
- for atomic.LoadUint32(&p.count) >= p.max {
+ for p.count.Load() >= p.max {
p.cond.Wait()
}
- atomic.AddUint32(&p.count, 1)
+ p.count.Add(1)
p.lock.Unlock()
}
return p.pool.Get()
}
-func (p *WaitPool) Put(x interface{}) {
+func (p *WaitPool) Put(x any) {
p.pool.Put(x)
if p.max == 0 {
return
}
- atomic.AddUint32(&p.count, ^uint32(0))
+ p.count.Add(^uint32(0))
p.cond.Signal()
}
func (device *Device) PopulatePools() {
- device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} {
+ device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ s := make([]*QueueInboundElement, 0, device.BatchSize())
+ return &QueueInboundElementsContainer{elems: s}
+ })
+ device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ s := make([]*QueueOutboundElement, 0, device.BatchSize())
+ return &QueueOutboundElementsContainer{elems: s}
+ })
+ device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new([MaxMessageSize]byte)
})
- device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} {
+ device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new(QueueInboundElement)
})
- device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} {
+ device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new(QueueOutboundElement)
})
}
+func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
+ c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
+ c.Mutex = sync.Mutex{}
+ return c
+}
+
+func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
+ for i := range c.elems {
+ c.elems[i] = nil
+ }
+ c.elems = c.elems[:0]
+ device.pool.inboundElementsContainer.Put(c)
+}
+
+func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
+ c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
+ c.Mutex = sync.Mutex{}
+ return c
+}
+
+func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
+ for i := range c.elems {
+ c.elems[i] = nil
+ }
+ c.elems = c.elems[:0]
+ device.pool.outboundElementsContainer.Put(c)
+}
+
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
}
diff --git a/device/pools_test.go b/device/pools_test.go
index b0840e1..82d7493 100644
--- a/device/pools_test.go
+++ b/device/pools_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -17,29 +17,31 @@ import (
func TestWaitPool(t *testing.T) {
t.Skip("Currently disabled")
var wg sync.WaitGroup
- trials := int32(100000)
+ var trials atomic.Int32
+ startTrials := int32(100000)
if raceEnabled {
// This test can be very slow with -race.
- trials /= 10
+ startTrials /= 10
}
+ trials.Store(startTrials)
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
t.Skip("Not enough cores")
}
- p := NewWaitPool(uint32(workers-4), func() interface{} { return make([]byte, 16) })
+ p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
wg.Add(workers)
- max := uint32(0)
+ var max atomic.Uint32
updateMax := func() {
- count := atomic.LoadUint32(&p.count)
+ count := p.count.Load()
if count > p.max {
t.Errorf("count (%d) > max (%d)", count, p.max)
}
for {
- old := atomic.LoadUint32(&max)
+ old := max.Load()
if count <= old {
break
}
- if atomic.CompareAndSwapUint32(&max, old, count) {
+ if max.CompareAndSwap(old, count) {
break
}
}
@@ -47,7 +49,7 @@ func TestWaitPool(t *testing.T) {
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
- for atomic.AddInt32(&trials, -1) > 0 {
+ for trials.Add(-1) > 0 {
updateMax()
x := p.Get()
updateMax()
@@ -59,25 +61,74 @@ func TestWaitPool(t *testing.T) {
}()
}
wg.Wait()
- if max != p.max {
+ if max.Load() != p.max {
t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
}
}
func BenchmarkWaitPool(b *testing.B) {
var wg sync.WaitGroup
- trials := int32(b.N)
+ var trials atomic.Int32
+ trials.Store(int32(b.N))
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
b.Skip("Not enough cores")
}
- p := NewWaitPool(uint32(workers-4), func() interface{} { return make([]byte, 16) })
+ p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
wg.Add(workers)
b.ResetTimer()
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
- for atomic.AddInt32(&trials, -1) > 0 {
+ for trials.Add(-1) > 0 {
+ x := p.Get()
+ time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
+ p.Put(x)
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+func BenchmarkWaitPoolEmpty(b *testing.B) {
+ var wg sync.WaitGroup
+ var trials atomic.Int32
+ trials.Store(int32(b.N))
+ workers := runtime.NumCPU() + 2
+ if workers-4 <= 0 {
+ b.Skip("Not enough cores")
+ }
+ p := NewWaitPool(0, func() any { return make([]byte, 16) })
+ wg.Add(workers)
+ b.ResetTimer()
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ for trials.Add(-1) > 0 {
+ x := p.Get()
+ time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
+ p.Put(x)
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+func BenchmarkSyncPool(b *testing.B) {
+ var wg sync.WaitGroup
+ var trials atomic.Int32
+ trials.Store(int32(b.N))
+ workers := runtime.NumCPU() + 2
+ if workers-4 <= 0 {
+ b.Skip("Not enough cores")
+ }
+ p := sync.Pool{New: func() any { return make([]byte, 16) }}
+ wg.Add(workers)
+ b.ResetTimer()
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ for trials.Add(-1) > 0 {
x := p.Get()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
p.Put(x)
diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go
index 7abc33a..25f700a 100644
--- a/device/queueconstants_android.go
+++ b/device/queueconstants_android.go
@@ -1,17 +1,19 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
+import "golang.zx2c4.com/wireguard/conn"
+
/* Reduce memory consumption for Android */
const (
- QueueStagedSize = 128
+ QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
- MaxSegmentSize = 2200
+ MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
PreallocatedBuffersPerPool = 4096
)
diff --git a/device/queueconstants_default.go b/device/queueconstants_default.go
index d5c6927..ea763d0 100644
--- a/device/queueconstants_default.go
+++ b/device/queueconstants_default.go
@@ -1,14 +1,16 @@
-// +build !android,!ios,!windows
+//go:build !android && !ios && !windows
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
+import "golang.zx2c4.com/wireguard/conn"
+
const (
- QueueStagedSize = 128
+ QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
diff --git a/device/queueconstants_ios.go b/device/queueconstants_ios.go
index 36c8704..acd3cec 100644
--- a/device/queueconstants_ios.go
+++ b/device/queueconstants_ios.go
@@ -1,8 +1,8 @@
-// +build ios
+//go:build ios
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/queueconstants_windows.go b/device/queueconstants_windows.go
index e330a6b..1eee32b 100644
--- a/device/queueconstants_windows.go
+++ b/device/queueconstants_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/race_disabled_test.go b/device/race_disabled_test.go
index 65fac7e..bb5c450 100644
--- a/device/race_disabled_test.go
+++ b/device/race_disabled_test.go
@@ -1,8 +1,8 @@
-//+build !race
+//go:build !race
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/race_enabled_test.go b/device/race_enabled_test.go
index f8ccac3..4e9daea 100644
--- a/device/race_enabled_test.go
+++ b/device/race_enabled_test.go
@@ -1,8 +1,8 @@
-//+build race
+//go:build race
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/receive.go b/device/receive.go
index 5857481..1ab3e29 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -11,13 +11,11 @@ import (
"errors"
"net"
"sync"
- "sync/atomic"
"time"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
-
"golang.zx2c4.com/wireguard/conn"
)
@@ -29,7 +27,6 @@ type QueueHandshakeElement struct {
}
type QueueInboundElement struct {
- sync.Mutex
buffer *[MaxMessageSize]byte
packet []byte
counter uint64
@@ -37,6 +34,11 @@ type QueueInboundElement struct {
endpoint conn.Endpoint
}
+type QueueInboundElementsContainer struct {
+ sync.Mutex
+ elems []*QueueInboundElement
+}
+
// clearPointers clears elem fields that contain pointers.
// This makes the garbage collector's life easier and
// avoids accidentally keeping other objects around unnecessarily.
@@ -53,12 +55,12 @@ func (elem *QueueInboundElement) clearPointers() {
* NOTE: Not thread safe, but called by sequential receiver!
*/
func (peer *Peer) keepKeyFreshReceiving() {
- if peer.timers.sentLastMinuteHandshake.Get() {
+ if peer.timers.sentLastMinuteHandshake.Load() {
return
}
keypair := peer.keypairs.Current()
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
- peer.timers.sentLastMinuteHandshake.Set(true)
+ peer.timers.sentLastMinuteHandshake.Store(true)
peer.SendHandshakeInitiation(false)
}
}
@@ -68,7 +70,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately)
*/
-func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
+func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
recvName := recv.PrettyName()
defer func() {
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
@@ -81,20 +83,33 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
// receive datagrams until conn is closed
- buffer := device.GetMessageBuffer()
-
var (
+ bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
+ bufs = make([][]byte, maxBatchSize)
err error
- size int
- endpoint conn.Endpoint
+ sizes = make([]int, maxBatchSize)
+ count int
+ endpoints = make([]conn.Endpoint, maxBatchSize)
deathSpiral int
+ elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
)
- for {
- size, endpoint, err = recv(buffer[:])
+ for i := range bufsArrs {
+ bufsArrs[i] = device.GetMessageBuffer()
+ bufs[i] = bufsArrs[i][:]
+ }
+
+ defer func() {
+ for i := 0; i < maxBatchSize; i++ {
+ if bufsArrs[i] != nil {
+ device.PutMessageBuffer(bufsArrs[i])
+ }
+ }
+ }()
+ for {
+ count, err = recv(bufs, sizes, endpoints)
if err != nil {
- device.PutMessageBuffer(buffer)
if errors.Is(err, net.ErrClosed) {
return
}
@@ -105,101 +120,119 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
if deathSpiral < 10 {
deathSpiral++
time.Sleep(time.Second / 3)
- buffer = device.GetMessageBuffer()
continue
}
return
}
deathSpiral = 0
- if size < MinMessageSize {
- continue
- }
+ // handle each packet in the batch
+ for i, size := range sizes[:count] {
+ if size < MinMessageSize {
+ continue
+ }
- // check size of packet
+ // check size of packet
- packet := buffer[:size]
- msgType := binary.LittleEndian.Uint32(packet[:4])
+ packet := bufsArrs[i][:size]
+ msgType := binary.LittleEndian.Uint32(packet[:4])
- var okay bool
+ switch msgType {
- switch msgType {
+ // check if transport
- // check if transport
+ case MessageTransportType:
- case MessageTransportType:
+ // check size
- // check size
+ if len(packet) < MessageTransportSize {
+ continue
+ }
- if len(packet) < MessageTransportSize {
- continue
- }
+ // lookup key pair
- // lookup key pair
+ receiver := binary.LittleEndian.Uint32(
+ packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
+ )
+ value := device.indexTable.Lookup(receiver)
+ keypair := value.keypair
+ if keypair == nil {
+ continue
+ }
- receiver := binary.LittleEndian.Uint32(
- packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
- )
- value := device.indexTable.Lookup(receiver)
- keypair := value.keypair
- if keypair == nil {
- continue
- }
+ // check keypair expiry
- // check keypair expiry
+ if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
+ continue
+ }
- if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
+ // create work element
+ peer := value.peer
+ elem := device.GetInboundElement()
+ elem.packet = packet
+ elem.buffer = bufsArrs[i]
+ elem.keypair = keypair
+ elem.endpoint = endpoints[i]
+ elem.counter = 0
+
+ elemsForPeer, ok := elemsByPeer[peer]
+ if !ok {
+ elemsForPeer = device.GetInboundElementsContainer()
+ elemsForPeer.Lock()
+ elemsByPeer[peer] = elemsForPeer
+ }
+ elemsForPeer.elems = append(elemsForPeer.elems, elem)
+ bufsArrs[i] = device.GetMessageBuffer()
+ bufs[i] = bufsArrs[i][:]
continue
- }
-
- // create work element
- peer := value.peer
- elem := device.GetInboundElement()
- elem.packet = packet
- elem.buffer = buffer
- elem.keypair = keypair
- elem.endpoint = endpoint
- elem.counter = 0
- elem.Mutex = sync.Mutex{}
- elem.Lock()
-
- // add to decryption queues
- if peer.isRunning.Get() {
- peer.queue.inbound.c <- elem
- device.queue.decryption.c <- elem
- buffer = device.GetMessageBuffer()
- } else {
- device.PutInboundElement(elem)
- }
- continue
- // otherwise it is a fixed size & handshake related packet
+ // otherwise it is a fixed size & handshake related packet
- case MessageInitiationType:
- okay = len(packet) == MessageInitiationSize
+ case MessageInitiationType:
+ if len(packet) != MessageInitiationSize {
+ continue
+ }
- case MessageResponseType:
- okay = len(packet) == MessageResponseSize
+ case MessageResponseType:
+ if len(packet) != MessageResponseSize {
+ continue
+ }
- case MessageCookieReplyType:
- okay = len(packet) == MessageCookieReplySize
+ case MessageCookieReplyType:
+ if len(packet) != MessageCookieReplySize {
+ continue
+ }
- default:
- device.log.Verbosef("Received message with unknown type")
- }
+ default:
+ device.log.Verbosef("Received message with unknown type")
+ continue
+ }
- if okay {
select {
case device.queue.handshake.c <- QueueHandshakeElement{
msgType: msgType,
- buffer: buffer,
+ buffer: bufsArrs[i],
packet: packet,
- endpoint: endpoint,
+ endpoint: endpoints[i],
}:
- buffer = device.GetMessageBuffer()
+ bufsArrs[i] = device.GetMessageBuffer()
+ bufs[i] = bufsArrs[i][:]
default:
}
}
+ for peer, elemsContainer := range elemsByPeer {
+ if peer.isRunning.Load() {
+ peer.queue.inbound.c <- elemsContainer
+ device.queue.decryption.c <- elemsContainer
+ } else {
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
+ }
+ device.PutInboundElementsContainer(elemsContainer)
+ }
+ delete(elemsByPeer, peer)
+ }
}
}
@@ -209,26 +242,28 @@ func (device *Device) RoutineDecryption(id int) {
defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
device.log.Verbosef("Routine: decryption worker %d - started", id)
- for elem := range device.queue.decryption.c {
- // split message into fields
- counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
- content := elem.packet[MessageTransportOffsetContent:]
-
- // decrypt and release to consumer
- var err error
- elem.counter = binary.LittleEndian.Uint64(counter)
- // copy counter to nonce
- binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
- elem.packet, err = elem.keypair.receive.Open(
- content[:0],
- nonce[:],
- content,
- nil,
- )
- if err != nil {
- elem.packet = nil
+ for elemsContainer := range device.queue.decryption.c {
+ for _, elem := range elemsContainer.elems {
+ // split message into fields
+ counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
+ content := elem.packet[MessageTransportOffsetContent:]
+
+ // decrypt and release to consumer
+ var err error
+ elem.counter = binary.LittleEndian.Uint64(counter)
+ // copy counter to nonce
+ binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
+ elem.packet, err = elem.keypair.receive.Open(
+ content[:0],
+ nonce[:],
+ content,
+ nil,
+ )
+ if err != nil {
+ elem.packet = nil
+ }
}
- elem.Unlock()
+ elemsContainer.Unlock()
}
}
@@ -269,7 +304,7 @@ func (device *Device) RoutineHandshake(id int) {
// consume reply
- if peer := entry.peer; peer.isRunning.Get() {
+ if peer := entry.peer; peer.isRunning.Load() {
device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
if !peer.cookieGenerator.ConsumeReply(&reply) {
device.log.Verbosef("Could not decrypt invalid cookie response")
@@ -342,7 +377,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SetEndpointFromPacket(elem.endpoint)
device.log.Verbosef("%v - Received handshake initiation", peer)
- atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
+ peer.rxBytes.Add(uint64(len(elem.packet)))
peer.SendHandshakeResponse()
@@ -370,7 +405,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SetEndpointFromPacket(elem.endpoint)
device.log.Verbosef("%v - Received handshake response", peer)
- atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
+ peer.rxBytes.Add(uint64(len(elem.packet)))
// update timers
@@ -395,7 +430,7 @@ func (device *Device) RoutineHandshake(id int) {
}
}
-func (peer *Peer) RoutineSequentialReceiver() {
+func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
device := peer.device
defer func() {
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
@@ -403,89 +438,103 @@ func (peer *Peer) RoutineSequentialReceiver() {
}()
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
- for elem := range peer.queue.inbound.c {
- if elem == nil {
+ bufs := make([][]byte, 0, maxBatchSize)
+
+ for elemsContainer := range peer.queue.inbound.c {
+ if elemsContainer == nil {
return
}
- var err error
- elem.Lock()
- if elem.packet == nil {
- // decryption failed
- goto skip
- }
+ elemsContainer.Lock()
+ validTailPacket := -1
+ dataPacketReceived := false
+ rxBytesLen := uint64(0)
+ for i, elem := range elemsContainer.elems {
+ if elem.packet == nil {
+ // decryption failed
+ continue
+ }
- if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
- goto skip
- }
+ if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
+ continue
+ }
- peer.SetEndpointFromPacket(elem.endpoint)
- if peer.ReceivedWithKeypair(elem.keypair) {
- peer.timersHandshakeComplete()
- peer.SendStagedPackets()
- }
+ validTailPacket = i
+ if peer.ReceivedWithKeypair(elem.keypair) {
+ peer.SetEndpointFromPacket(elem.endpoint)
+ peer.timersHandshakeComplete()
+ peer.SendStagedPackets()
+ }
+ rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
- peer.keepKeyFreshReceiving()
- peer.timersAnyAuthenticatedPacketTraversal()
- peer.timersAnyAuthenticatedPacketReceived()
- atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize))
+ if len(elem.packet) == 0 {
+ device.log.Verbosef("%v - Receiving keepalive packet", peer)
+ continue
+ }
+ dataPacketReceived = true
- if len(elem.packet) == 0 {
- device.log.Verbosef("%v - Receiving keepalive packet", peer)
- goto skip
- }
- peer.timersDataReceived()
+ switch elem.packet[0] >> 4 {
+ case 4:
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
+ field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
+ length := binary.BigEndian.Uint16(field)
+ if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
+ continue
+ }
+ elem.packet = elem.packet[:length]
+ src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
+ if device.allowedips.Lookup(src) != peer {
+ device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
+ continue
+ }
- switch elem.packet[0] >> 4 {
- case ipv4.Version:
- if len(elem.packet) < ipv4.HeaderLen {
- goto skip
- }
- field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
- length := binary.BigEndian.Uint16(field)
- if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
- goto skip
- }
- elem.packet = elem.packet[:length]
- src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
- if device.allowedips.Lookup(src) != peer {
- device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
- goto skip
- }
+ case 6:
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
+ field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
+ length := binary.BigEndian.Uint16(field)
+ length += ipv6.HeaderLen
+ if int(length) > len(elem.packet) {
+ continue
+ }
+ elem.packet = elem.packet[:length]
+ src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
+ if device.allowedips.Lookup(src) != peer {
+ device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
+ continue
+ }
- case ipv6.Version:
- if len(elem.packet) < ipv6.HeaderLen {
- goto skip
- }
- field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
- length := binary.BigEndian.Uint16(field)
- length += ipv6.HeaderLen
- if int(length) > len(elem.packet) {
- goto skip
- }
- elem.packet = elem.packet[:length]
- src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
- if device.allowedips.Lookup(src) != peer {
- device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
- goto skip
+ default:
+ device.log.Verbosef("Packet with invalid IP version from %v", peer)
+ continue
}
- default:
- device.log.Verbosef("Packet with invalid IP version from %v", peer)
- goto skip
+ bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
}
- _, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent)
- if err != nil && !device.isClosed() {
- device.log.Errorf("Failed to write packet to TUN device: %v", err)
+ peer.rxBytes.Add(rxBytesLen)
+ if validTailPacket >= 0 {
+ peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
+ peer.keepKeyFreshReceiving()
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersAnyAuthenticatedPacketReceived()
}
- if len(peer.queue.inbound.c) == 0 {
- err = device.tun.device.Flush()
- if err != nil {
- peer.device.log.Errorf("Unable to flush packets: %v", err)
+ if dataPacketReceived {
+ peer.timersDataReceived()
+ }
+ if len(bufs) > 0 {
+ _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
+ if err != nil && !device.isClosed() {
+ device.log.Errorf("Failed to write packets to TUN device: %v", err)
}
}
- skip:
- device.PutMessageBuffer(elem.buffer)
- device.PutInboundElement(elem)
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
+ }
+ bufs = bufs[:0]
+ device.PutInboundElementsContainer(elemsContainer)
}
}
diff --git a/device/send.go b/device/send.go
index b05c69e..769720a 100644
--- a/device/send.go
+++ b/device/send.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -12,12 +12,13 @@ import (
"net"
"os"
"sync"
- "sync/atomic"
"time"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/tun"
)
/* Outbound flow
@@ -45,7 +46,6 @@ import (
*/
type QueueOutboundElement struct {
- sync.Mutex
buffer *[MaxMessageSize]byte // slice holding the packet data
packet []byte // slice of "buffer" (always!)
nonce uint64 // nonce for encryption
@@ -53,10 +53,14 @@ type QueueOutboundElement struct {
peer *Peer // related peer
}
+type QueueOutboundElementsContainer struct {
+ sync.Mutex
+ elems []*QueueOutboundElement
+}
+
func (device *Device) NewOutboundElement() *QueueOutboundElement {
elem := device.GetOutboundElement()
elem.buffer = device.GetMessageBuffer()
- elem.Mutex = sync.Mutex{}
elem.nonce = 0
// keypair and peer were cleared (if necessary) by clearPointers.
return elem
@@ -76,14 +80,17 @@ func (elem *QueueOutboundElement) clearPointers() {
/* Queues a keepalive if no packets are queued for peer
*/
func (peer *Peer) SendKeepalive() {
- if len(peer.queue.staged) == 0 && peer.isRunning.Get() {
+ if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
elem := peer.device.NewOutboundElement()
+ elemsContainer := peer.device.GetOutboundElementsContainer()
+ elemsContainer.elems = append(elemsContainer.elems, elem)
select {
- case peer.queue.staged <- elem:
+ case peer.queue.staged <- elemsContainer:
peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
default:
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
+ peer.device.PutOutboundElementsContainer(elemsContainer)
}
}
peer.SendStagedPackets()
@@ -91,7 +98,7 @@ func (peer *Peer) SendKeepalive() {
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if !isRetry {
- atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
+ peer.timers.handshakeAttempts.Store(0)
}
peer.handshake.mutex.RLock()
@@ -117,8 +124,8 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
return err
}
- var buff [MessageInitiationSize]byte
- writer := bytes.NewBuffer(buff[:0])
+ var buf [MessageInitiationSize]byte
+ writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
@@ -126,7 +133,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
- err = peer.SendBuffer(packet)
+ err = peer.SendBuffers([][]byte{packet})
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
}
@@ -148,8 +155,8 @@ func (peer *Peer) SendHandshakeResponse() error {
return err
}
- var buff [MessageResponseSize]byte
- writer := bytes.NewBuffer(buff[:0])
+ var buf [MessageResponseSize]byte
+ writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, response)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
@@ -164,7 +171,8 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
- err = peer.SendBuffer(packet)
+ // TODO: allocation could be avoided
+ err = peer.SendBuffers([][]byte{packet})
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
}
@@ -181,10 +189,11 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement)
return err
}
- var buff [MessageCookieReplySize]byte
- writer := bytes.NewBuffer(buff[:0])
+ var buf [MessageCookieReplySize]byte
+ writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, reply)
- device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
+ // TODO: allocation could be avoided
+ device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
return nil
}
@@ -193,17 +202,12 @@ func (peer *Peer) keepKeyFreshSending() {
if keypair == nil {
return
}
- nonce := atomic.LoadUint64(&keypair.sendNonce)
+ nonce := keypair.sendNonce.Load()
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
peer.SendHandshakeInitiation(false)
}
}
-/* Reads packets from the TUN and inserts
- * into staged queue for peer
- *
- * Obs. Single instance per TUN device
- */
func (device *Device) RoutineReadFromTUN() {
defer func() {
device.log.Verbosef("Routine: TUN reader - stopped")
@@ -213,82 +217,123 @@ func (device *Device) RoutineReadFromTUN() {
device.log.Verbosef("Routine: TUN reader - started")
- var elem *QueueOutboundElement
+ var (
+ batchSize = device.BatchSize()
+ readErr error
+ elems = make([]*QueueOutboundElement, batchSize)
+ bufs = make([][]byte, batchSize)
+ elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
+ count = 0
+ sizes = make([]int, batchSize)
+ offset = MessageTransportHeaderSize
+ )
+
+ for i := range elems {
+ elems[i] = device.NewOutboundElement()
+ bufs[i] = elems[i].buffer[:]
+ }
- for {
- if elem != nil {
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ defer func() {
+ for _, elem := range elems {
+ if elem != nil {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
}
- elem = device.NewOutboundElement()
-
- // read packet
-
- offset := MessageTransportHeaderSize
- size, err := device.tun.device.Read(elem.buffer[:], offset)
+ }()
- if err != nil {
- if !device.isClosed() {
- if !errors.Is(err, os.ErrClosed) {
- device.log.Errorf("Failed to read packet from TUN device: %v", err)
- }
- go device.Close()
+ for {
+ // read packets
+ count, readErr = device.tun.device.Read(bufs, sizes, offset)
+ for i := 0; i < count; i++ {
+ if sizes[i] < 1 {
+ continue
}
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
- return
- }
- if size == 0 || size > MaxContentSize {
- continue
- }
+ elem := elems[i]
+ elem.packet = bufs[i][offset : offset+sizes[i]]
- elem.packet = elem.buffer[offset : offset+size]
+ // lookup peer
+ var peer *Peer
+ switch elem.packet[0] >> 4 {
+ case 4:
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
+ dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
+ peer = device.allowedips.Lookup(dst)
- // lookup peer
+ case 6:
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
+ dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
+ peer = device.allowedips.Lookup(dst)
- var peer *Peer
- switch elem.packet[0] >> 4 {
- case ipv4.Version:
- if len(elem.packet) < ipv4.HeaderLen {
- continue
+ default:
+ device.log.Verbosef("Received packet with unknown IP version")
}
- dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
- peer = device.allowedips.Lookup(dst)
- case ipv6.Version:
- if len(elem.packet) < ipv6.HeaderLen {
+ if peer == nil {
continue
}
- dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
- peer = device.allowedips.Lookup(dst)
-
- default:
- device.log.Verbosef("Received packet with unknown IP version")
+ elemsForPeer, ok := elemsByPeer[peer]
+ if !ok {
+ elemsForPeer = device.GetOutboundElementsContainer()
+ elemsByPeer[peer] = elemsForPeer
+ }
+ elemsForPeer.elems = append(elemsForPeer.elems, elem)
+ elems[i] = device.NewOutboundElement()
+ bufs[i] = elems[i].buffer[:]
}
- if peer == nil {
- continue
+ for peer, elemsForPeer := range elemsByPeer {
+ if peer.isRunning.Load() {
+ peer.StagePackets(elemsForPeer)
+ peer.SendStagedPackets()
+ } else {
+ for _, elem := range elemsForPeer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
+ device.PutOutboundElementsContainer(elemsForPeer)
+ }
+ delete(elemsByPeer, peer)
}
- if peer.isRunning.Get() {
- peer.StagePacket(elem)
- elem = nil
- peer.SendStagedPackets()
+
+ if readErr != nil {
+ if errors.Is(readErr, tun.ErrTooManySegments) {
+ // TODO: record stat for this
+ // This will happen if MSS is surprisingly small (< 576)
+ // coincident with reasonably high throughput.
+ device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
+ continue
+ }
+ if !device.isClosed() {
+ if !errors.Is(readErr, os.ErrClosed) {
+ device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
+ }
+ go device.Close()
+ }
+ return
}
}
}
-func (peer *Peer) StagePacket(elem *QueueOutboundElement) {
+func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
for {
select {
- case peer.queue.staged <- elem:
+ case peer.queue.staged <- elems:
return
default:
}
select {
case tooOld := <-peer.queue.staged:
- peer.device.PutMessageBuffer(tooOld.buffer)
- peer.device.PutOutboundElement(tooOld)
+ for _, elem := range tooOld.elems {
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
+ }
+ peer.device.PutOutboundElementsContainer(tooOld)
default:
}
}
@@ -301,32 +346,59 @@ top:
}
keypair := peer.keypairs.Current()
- if keypair == nil || atomic.LoadUint64(&keypair.sendNonce) >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
+ if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
peer.SendHandshakeInitiation(false)
return
}
for {
+ var elemsContainerOOO *QueueOutboundElementsContainer
select {
- case elem := <-peer.queue.staged:
- elem.peer = peer
- elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1
- if elem.nonce >= RejectAfterMessages {
- atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages)
- peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans
- goto top
+ case elemsContainer := <-peer.queue.staged:
+ i := 0
+ for _, elem := range elemsContainer.elems {
+ elem.peer = peer
+ elem.nonce = keypair.sendNonce.Add(1) - 1
+ if elem.nonce >= RejectAfterMessages {
+ keypair.sendNonce.Store(RejectAfterMessages)
+ if elemsContainerOOO == nil {
+ elemsContainerOOO = peer.device.GetOutboundElementsContainer()
+ }
+ elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
+ continue
+ } else {
+ elemsContainer.elems[i] = elem
+ i++
+ }
+
+ elem.keypair = keypair
}
+ elemsContainer.Lock()
+ elemsContainer.elems = elemsContainer.elems[:i]
- elem.keypair = keypair
- elem.Lock()
+ if elemsContainerOOO != nil {
+ peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
+ }
+
+ if len(elemsContainer.elems) == 0 {
+ peer.device.PutOutboundElementsContainer(elemsContainer)
+ goto top
+ }
// add to parallel and sequential queue
- if peer.isRunning.Get() {
- peer.queue.outbound.c <- elem
- peer.device.queue.encryption.c <- elem
+ if peer.isRunning.Load() {
+ peer.queue.outbound.c <- elemsContainer
+ peer.device.queue.encryption.c <- elemsContainer
} else {
- peer.device.PutMessageBuffer(elem.buffer)
- peer.device.PutOutboundElement(elem)
+ for _, elem := range elemsContainer.elems {
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
+ }
+ peer.device.PutOutboundElementsContainer(elemsContainer)
+ }
+
+ if elemsContainerOOO != nil {
+ goto top
}
default:
return
@@ -337,9 +409,12 @@ top:
func (peer *Peer) FlushStagedPackets() {
for {
select {
- case elem := <-peer.queue.staged:
- peer.device.PutMessageBuffer(elem.buffer)
- peer.device.PutOutboundElement(elem)
+ case elemsContainer := <-peer.queue.staged:
+ for _, elem := range elemsContainer.elems {
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
+ }
+ peer.device.PutOutboundElementsContainer(elemsContainer)
default:
return
}
@@ -373,41 +448,38 @@ func (device *Device) RoutineEncryption(id int) {
defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
device.log.Verbosef("Routine: encryption worker %d - started", id)
- for elem := range device.queue.encryption.c {
- // populate header fields
- header := elem.buffer[:MessageTransportHeaderSize]
+ for elemsContainer := range device.queue.encryption.c {
+ for _, elem := range elemsContainer.elems {
+ // populate header fields
+ header := elem.buffer[:MessageTransportHeaderSize]
- fieldType := header[0:4]
- fieldReceiver := header[4:8]
- fieldNonce := header[8:16]
+ fieldType := header[0:4]
+ fieldReceiver := header[4:8]
+ fieldNonce := header[8:16]
- binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
- binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
- binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
+ binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
+ binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
+ binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
- // pad content to multiple of 16
- paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu)))
- elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
+ // pad content to multiple of 16
+ paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
+ elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
- // encrypt content and release to consumer
+ // encrypt content and release to consumer
- binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
- elem.packet = elem.keypair.send.Seal(
- header,
- nonce[:],
- elem.packet,
- nil,
- )
- elem.Unlock()
+ binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
+ elem.packet = elem.keypair.send.Seal(
+ header,
+ nonce[:],
+ elem.packet,
+ nil,
+ )
+ }
+ elemsContainer.Unlock()
}
}
-/* Sequentially reads packets from queue and sends to endpoint
- *
- * Obs. Single instance per peer.
- * The routine terminates then the outbound queue is closed.
- */
-func (peer *Peer) RoutineSequentialSender() {
+func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
device := peer.device
defer func() {
defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
@@ -415,36 +487,57 @@ func (peer *Peer) RoutineSequentialSender() {
}()
device.log.Verbosef("%v - Routine: sequential sender - started", peer)
- for elem := range peer.queue.outbound.c {
- if elem == nil {
+ bufs := make([][]byte, 0, maxBatchSize)
+
+ for elemsContainer := range peer.queue.outbound.c {
+ bufs = bufs[:0]
+ if elemsContainer == nil {
return
}
- elem.Lock()
- if !peer.isRunning.Get() {
+ if !peer.isRunning.Load() {
// peer has been stopped; return re-usable elems to the shared pool.
// This is an optimization only. It is possible for the peer to be stopped
// immediately after this check, in which case, elem will get processed.
- // The timers and SendBuffer code are resilient to a few stragglers.
+ // The timers and SendBuffers code are resilient to a few stragglers.
// TODO: rework peer shutdown order to ensure
// that we never accidentally keep timers alive longer than necessary.
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
continue
}
+ dataSent := false
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
+ if len(elem.packet) != MessageKeepaliveSize {
+ dataSent = true
+ }
+ bufs = append(bufs, elem.packet)
+ }
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
- // send message and return buffer to pool
-
- err := peer.SendBuffer(elem.packet)
- if len(elem.packet) != MessageKeepaliveSize {
+ err := peer.SendBuffers(bufs)
+ if dataSent {
peer.timersDataSent()
}
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
+ device.PutOutboundElementsContainer(elemsContainer)
+ if err != nil {
+ var errGSO conn.ErrUDPGSODisabled
+ if errors.As(err, &errGSO) {
+ device.log.Verbosef(err.Error())
+ err = errGSO.RetryErr
+ }
+ }
if err != nil {
- device.log.Errorf("%v - Failed to send data packet: %v", peer, err)
+ device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
continue
}
diff --git a/device/sticky_default.go b/device/sticky_default.go
index 1cc52f6..1038256 100644
--- a/device/sticky_default.go
+++ b/device/sticky_default.go
@@ -1,4 +1,4 @@
-// +build !linux
+//go:build !linux
package device
diff --git a/device/sticky_linux.go b/device/sticky_linux.go
index 6193ea3..6057ff1 100644
--- a/device/sticky_linux.go
+++ b/device/sticky_linux.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*
* This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port
@@ -25,7 +25,10 @@ import (
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
- if _, ok := bind.(*conn.LinuxSocketBind); !ok {
+ if !conn.StdNetSupportsStickySockets {
+ return nil, nil
+ }
+ if _, ok := bind.(*conn.StdNetBind); !ok {
return nil, nil
}
@@ -107,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
if !ok {
break
}
- pePtr.peer.Lock()
- if &pePtr.peer.endpoint != pePtr.endpoint {
- pePtr.peer.Unlock()
+ pePtr.peer.endpoint.Lock()
+ if &pePtr.peer.endpoint.val != pePtr.endpoint {
+ pePtr.peer.endpoint.Unlock()
break
}
- if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx {
- pePtr.peer.Unlock()
+ if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
+ pePtr.peer.endpoint.Unlock()
break
}
- pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc()
- pePtr.peer.Unlock()
+ pePtr.peer.endpoint.clearSrcOnTx = true
+ pePtr.peer.endpoint.Unlock()
}
attr = attr[attrhdr.Len:]
}
@@ -131,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
device.peers.RLock()
i := uint32(1)
for _, peer := range device.peers.keyMap {
- peer.RLock()
- if peer.endpoint == nil {
- peer.RUnlock()
+ peer.endpoint.Lock()
+ if peer.endpoint.val == nil {
+ peer.endpoint.Unlock()
continue
}
- nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint)
+ nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
if nativeEP == nil {
- peer.RUnlock()
+ peer.endpoint.Unlock()
continue
}
- if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
- peer.RUnlock()
+ if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
+ peer.endpoint.Unlock()
break
}
nlmsg := struct {
@@ -169,12 +172,12 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
Len: 8,
Type: unix.RTA_DST,
},
- nativeEP.Dst4().Addr,
+ nativeEP.DstIP().As4(),
unix.RtAttr{
Len: 8,
Type: unix.RTA_SRC,
},
- nativeEP.Src4().Src,
+ nativeEP.SrcIP().As4(),
unix.RtAttr{
Len: 8,
Type: unix.RTA_MARK,
@@ -185,10 +188,10 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
reqPeerLock.Lock()
reqPeer[i] = peerEndpointPtr{
peer: peer,
- endpoint: &peer.endpoint,
+ endpoint: &peer.endpoint.val,
}
reqPeerLock.Unlock()
- peer.RUnlock()
+ peer.endpoint.Unlock()
i++
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
if err != nil {
@@ -204,7 +207,7 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
}
func createNetlinkRouteSocket() (int, error) {
- sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
+ sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
diff --git a/device/timers.go b/device/timers.go
index ee191e5..d4a4ed4 100644
--- a/device/timers.go
+++ b/device/timers.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*
* This is based heavily on timers.c from the kernel implementation.
*/
@@ -8,12 +8,14 @@
package device
import (
- "math/rand"
"sync"
- "sync/atomic"
"time"
+ _ "unsafe"
)
+//go:linkname fastrandn runtime.fastrandn
+func fastrandn(n uint32) uint32
+
// A Timer manages time-based aspects of the WireGuard protocol.
// Timer roughly copies the interface of the Linux kernel's struct timer_list.
type Timer struct {
@@ -71,11 +73,11 @@ func (timer *Timer) IsPending() bool {
}
func (peer *Peer) timersActive() bool {
- return peer.isRunning.Get() && peer.device != nil && peer.device.isUp()
+ return peer.isRunning.Load() && peer.device != nil && peer.device.isUp()
}
func expiredRetransmitHandshake(peer *Peer) {
- if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes {
+ if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
if peer.timersActive() {
@@ -94,15 +96,11 @@ func expiredRetransmitHandshake(peer *Peer) {
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
}
} else {
- atomic.AddUint32(&peer.timers.handshakeAttempts, 1)
- peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1)
+ peer.timers.handshakeAttempts.Add(1)
+ peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
/* We clear the endpoint address src address, in case this is the cause of trouble. */
- peer.Lock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
- peer.Unlock()
+ peer.markEndpointSrcForClearing()
peer.SendHandshakeInitiation(true)
}
@@ -110,8 +108,8 @@ func expiredRetransmitHandshake(peer *Peer) {
func expiredSendKeepalive(peer *Peer) {
peer.SendKeepalive()
- if peer.timers.needAnotherKeepalive.Get() {
- peer.timers.needAnotherKeepalive.Set(false)
+ if peer.timers.needAnotherKeepalive.Load() {
+ peer.timers.needAnotherKeepalive.Store(false)
if peer.timersActive() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
}
@@ -121,13 +119,8 @@ func expiredSendKeepalive(peer *Peer) {
func expiredNewHandshake(peer *Peer) {
peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
/* We clear the endpoint address src address, in case this is the cause of trouble. */
- peer.Lock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
- peer.Unlock()
+ peer.markEndpointSrcForClearing()
peer.SendHandshakeInitiation(false)
-
}
func expiredZeroKeyMaterial(peer *Peer) {
@@ -136,7 +129,7 @@ func expiredZeroKeyMaterial(peer *Peer) {
}
func expiredPersistentKeepalive(peer *Peer) {
- if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
+ if peer.persistentKeepaliveInterval.Load() > 0 {
peer.SendKeepalive()
}
}
@@ -144,7 +137,7 @@ func expiredPersistentKeepalive(peer *Peer) {
/* Should be called after an authenticated data packet is sent. */
func (peer *Peer) timersDataSent() {
if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
- peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs)))
+ peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
}
}
@@ -154,7 +147,7 @@ func (peer *Peer) timersDataReceived() {
if !peer.timers.sendKeepalive.IsPending() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
} else {
- peer.timers.needAnotherKeepalive.Set(true)
+ peer.timers.needAnotherKeepalive.Store(true)
}
}
}
@@ -176,7 +169,7 @@ func (peer *Peer) timersAnyAuthenticatedPacketReceived() {
/* Should be called after a handshake initiation message is sent. */
func (peer *Peer) timersHandshakeInitiated() {
if peer.timersActive() {
- peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs)))
+ peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
}
}
@@ -185,9 +178,9 @@ func (peer *Peer) timersHandshakeComplete() {
if peer.timersActive() {
peer.timers.retransmitHandshake.Del()
}
- atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
- peer.timers.sentLastMinuteHandshake.Set(false)
- atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano())
+ peer.timers.handshakeAttempts.Store(0)
+ peer.timers.sentLastMinuteHandshake.Store(false)
+ peer.lastHandshakeNano.Store(time.Now().UnixNano())
}
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
@@ -199,7 +192,7 @@ func (peer *Peer) timersSessionDerived() {
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
- keepalive := atomic.LoadUint32(&peer.persistentKeepaliveInterval)
+ keepalive := peer.persistentKeepaliveInterval.Load()
if keepalive > 0 && peer.timersActive() {
peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
}
@@ -214,9 +207,9 @@ func (peer *Peer) timersInit() {
}
func (peer *Peer) timersStart() {
- atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
- peer.timers.sentLastMinuteHandshake.Set(false)
- peer.timers.needAnotherKeepalive.Set(false)
+ peer.timers.handshakeAttempts.Store(0)
+ peer.timers.sentLastMinuteHandshake.Store(false)
+ peer.timers.needAnotherKeepalive.Store(false)
}
func (peer *Peer) timersStop() {
diff --git a/device/tun.go b/device/tun.go
index 4af9548..2a2ace9 100644
--- a/device/tun.go
+++ b/device/tun.go
@@ -1,13 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"fmt"
- "sync/atomic"
"golang.zx2c4.com/wireguard/tun"
)
@@ -33,7 +32,7 @@ func (device *Device) RoutineTUNEventReader() {
tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
mtu = MaxContentSize
}
- old := atomic.SwapInt32(&device.tun.mtu, int32(mtu))
+ old := device.tun.mtu.Swap(int32(mtu))
if int(old) != mtu {
device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
}
diff --git a/device/uapi.go b/device/uapi.go
index 66ecd48..d81dae3 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -12,10 +12,10 @@ import (
"fmt"
"io"
"net"
+ "net/netip"
"strconv"
"strings"
"sync"
- "sync/atomic"
"time"
"golang.zx2c4.com/wireguard/ipc"
@@ -38,12 +38,12 @@ func (s IPCError) ErrorCode() int64 {
return s.code
}
-func ipcErrorf(code int64, msg string, args ...interface{}) *IPCError {
+func ipcErrorf(code int64, msg string, args ...any) *IPCError {
return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
}
var byteBufferPool = &sync.Pool{
- New: func() interface{} { return new(bytes.Buffer) },
+ New: func() any { return new(bytes.Buffer) },
}
// IpcGetOperation implements the WireGuard configuration protocol "get" operation.
@@ -55,7 +55,7 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
buf := byteBufferPool.Get().(*bytes.Buffer)
buf.Reset()
defer byteBufferPool.Put(buf)
- sendf := func(format string, args ...interface{}) {
+ sendf := func(format string, args ...any) {
fmt.Fprintf(buf, format, args...)
buf.WriteByte('\n')
}
@@ -72,7 +72,6 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
}
func() {
-
// lock required resources
device.net.RLock()
@@ -98,31 +97,31 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("fwmark=%d", device.net.fwmark)
}
- // serialize each peer state
-
for _, peer := range device.peers.keyMap {
- peer.RLock()
- defer peer.RUnlock()
-
+ // Serialize peer state.
+ peer.handshake.mutex.RLock()
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
+ peer.handshake.mutex.RUnlock()
sendf("protocol_version=1")
- if peer.endpoint != nil {
- sendf("endpoint=%s", peer.endpoint.DstToString())
+ peer.endpoint.Lock()
+ if peer.endpoint.val != nil {
+ sendf("endpoint=%s", peer.endpoint.val.DstToString())
}
+ peer.endpoint.Unlock()
- nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
+ nano := peer.lastHandshakeNano.Load()
secs := nano / time.Second.Nanoseconds()
nano %= time.Second.Nanoseconds()
sendf("last_handshake_time_sec=%d", secs)
sendf("last_handshake_time_nsec=%d", nano)
- sendf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes))
- sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
- sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
+ sendf("tx_bytes=%d", peer.txBytes.Load())
+ sendf("rx_bytes=%d", peer.rxBytes.Load())
+ sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
- device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool {
- sendf("allowed_ip=%s/%d", ip.String(), cidr)
+ device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
+ sendf("allowed_ip=%s", prefix.String())
return true
})
}
@@ -156,14 +155,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
line := scanner.Text()
if line == "" {
// Blank line means terminate operation.
+ peer.handlePostConfig()
return nil
}
- parts := strings.Split(line, "=")
- if len(parts) != 2 {
- return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q, found %d =-separated parts, want 2", line, len(parts))
+ key, value, ok := strings.Cut(line, "=")
+ if !ok {
+ return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line)
}
- key := parts[0]
- value := parts[1]
if key == "public_key" {
if deviceConfig {
@@ -254,10 +252,21 @@ type ipcSetPeer struct {
*Peer // Peer is the current peer being operated on
dummy bool // dummy reports whether this peer is a temporary, placeholder peer
created bool // new reports whether this is a newly created peer
+ pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on
}
func (peer *ipcSetPeer) handlePostConfig() {
- if peer.Peer != nil && !peer.dummy && peer.Peer.device.isUp() {
+ if peer.Peer == nil || peer.dummy {
+ return
+ }
+ if peer.created {
+ peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
+ }
+ if peer.device.isUp() {
+ peer.Start()
+ if peer.pkaOn {
+ peer.SendKeepalive()
+ }
peer.SendStagedPackets()
}
}
@@ -334,9 +343,9 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
}
- peer.Lock()
- defer peer.Unlock()
- peer.endpoint = endpoint
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ peer.endpoint.val = endpoint
case "persistent_keepalive_interval":
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
@@ -346,17 +355,10 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
}
- old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs))
+ old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
// Send immediate keepalive if we're turning it on and before it wasn't on.
- if old == 0 && secs != 0 {
- if err != nil {
- return ipcErrorf(ipc.IpcErrorIO, "failed to get tun device status: %w", err)
- }
- if device.isUp() && !peer.dummy {
- peer.SendKeepalive()
- }
- }
+ peer.pkaOn = old == 0 && secs != 0
case "replace_allowed_ips":
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
@@ -370,16 +372,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
case "allowed_ip":
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
-
- _, network, err := net.ParseCIDR(value)
+ prefix, err := netip.ParsePrefix(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
}
if peer.dummy {
return nil
}
- ones, _ := network.Mask.Size()
- device.allowedips.Insert(network.IP, uint8(ones), peer.Peer)
+ device.allowedips.Insert(prefix, peer.Peer)
case "protocol_version":
if value != "1" {
diff --git a/format_test.go b/format_test.go
index 198cd49..6f6cab7 100644
--- a/format_test.go
+++ b/format_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main
diff --git a/go.mod b/go.mod
index 18769da..919dc49 100644
--- a/go.mod
+++ b/go.mod
@@ -1,9 +1,16 @@
module golang.zx2c4.com/wireguard
-go 1.16
+go 1.20
require (
- golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83
- golang.org/x/net v0.0.0-20210226172049-e18ecbb05110
- golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57
+ golang.org/x/crypto v0.13.0
+ golang.org/x/net v0.15.0
+ golang.org/x/sys v0.12.0
+ golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
+ gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259
+)
+
+require (
+ github.com/google/btree v1.0.1 // indirect
+ golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
)
diff --git a/go.sum b/go.sum
index 85c81a0..6bcecea 100644
--- a/go.sum
+++ b/go.sum
@@ -1,22 +1,14 @@
-golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
-golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 h1:/ZScEX8SfEmUGRHs0gxpqteO5nfNW6axyZbBdw9A12g=
-golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
-golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
-golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw=
-golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210309040221-94ec62e08169 h1:fpeMGRM6A+XFcw4RPCO8s8hH7ppgrGR22pSIjwM7YUI=
-golang.org/x/sys v0.0.0-20210309040221-94ec62e08169/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210331175145-43e1dd70ce54 h1:rF3Ohx8DRyl8h2zw9qojyLHLhrJpEMgyPOImREEryf0=
-golang.org/x/sys v0.0.0-20210331175145-43e1dd70ce54/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210402192133-700132347e07 h1:4k6HsQjxj6hVMsI2Vf0yKlzt5lXxZsMW1q0zaq2k8zY=
-golang.org/x/sys v0.0.0-20210402192133-700132347e07/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57 h1:F5Gozwx4I1xtr/sr/8CFbb57iKi3297KFs0QDbGN60A=
-golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
-golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
+github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
+golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck=
+golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
+golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8=
+golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
+golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
+golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44=
+golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
+golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
+gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
+gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
diff --git a/ipc/winpipe/file.go b/ipc/namedpipe/file.go
index 0c9abb1..ab1e13d 100644
--- a/ipc/winpipe/file.go
+++ b/ipc/namedpipe/file.go
@@ -1,12 +1,11 @@
-// +build windows
+// Copyright 2021 The Go Authors. All rights reserved.
+// Copyright 2015 Microsoft
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2005 Microsoft
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
+//go:build windows
-package winpipe
+package namedpipe
import (
"io"
@@ -22,8 +21,10 @@ import (
type timeoutChan chan struct{}
-var ioInitOnce sync.Once
-var ioCompletionPort windows.Handle
+var (
+ ioInitOnce sync.Once
+ ioCompletionPort windows.Handle
+)
// ioResult contains the result of an asynchronous IO operation
type ioResult struct {
@@ -52,7 +53,7 @@ type file struct {
handle windows.Handle
wg sync.WaitGroup
wgLock sync.RWMutex
- closing uint32 // used as atomic boolean
+ closing atomic.Bool
socket bool
readDeadline deadlineHandler
writeDeadline deadlineHandler
@@ -63,7 +64,7 @@ type deadlineHandler struct {
channel timeoutChan
channelLock sync.RWMutex
timer *time.Timer
- timedout uint32 // used as atomic boolean
+ timedout atomic.Bool
}
// makeFile makes a new file from an existing file handle
@@ -87,7 +88,7 @@ func makeFile(h windows.Handle) (*file, error) {
func (f *file) closeHandle() {
f.wgLock.Lock()
// Atomically set that we are closing, releasing the resources only once.
- if atomic.SwapUint32(&f.closing, 1) == 0 {
+ if f.closing.Swap(true) == false {
f.wgLock.Unlock()
// cancel all IO and wait for it to complete
windows.CancelIoEx(f.handle, nil)
@@ -110,7 +111,7 @@ func (f *file) Close() error {
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
func (f *file) prepareIo() (*ioOperation, error) {
f.wgLock.RLock()
- if atomic.LoadUint32(&f.closing) == 1 {
+ if f.closing.Load() {
f.wgLock.RUnlock()
return nil, os.ErrClosed
}
@@ -142,7 +143,7 @@ func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err err
return int(bytes), err
}
- if atomic.LoadUint32(&f.closing) == 1 {
+ if f.closing.Load() {
windows.CancelIoEx(f.handle, &c.o)
}
@@ -158,7 +159,7 @@ func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err err
case r = <-c.ch:
err = r.err
if err == windows.ERROR_OPERATION_ABORTED {
- if atomic.LoadUint32(&f.closing) == 1 {
+ if f.closing.Load() {
err = os.ErrClosed
}
} else if err != nil && f.socket {
@@ -190,7 +191,7 @@ func (f *file) Read(b []byte) (int, error) {
}
defer f.wg.Done()
- if atomic.LoadUint32(&f.readDeadline.timedout) == 1 {
+ if f.readDeadline.timedout.Load() {
return 0, os.ErrDeadlineExceeded
}
@@ -217,7 +218,7 @@ func (f *file) Write(b []byte) (int, error) {
}
defer f.wg.Done()
- if atomic.LoadUint32(&f.writeDeadline.timedout) == 1 {
+ if f.writeDeadline.timedout.Load() {
return 0, os.ErrDeadlineExceeded
}
@@ -254,7 +255,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
}
d.timer = nil
}
- atomic.StoreUint32(&d.timedout, 0)
+ d.timedout.Store(false)
select {
case <-d.channel:
@@ -269,7 +270,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
}
timeoutIO := func() {
- atomic.StoreUint32(&d.timedout, 1)
+ d.timedout.Store(true)
close(d.channel)
}
diff --git a/ipc/winpipe/winpipe.go b/ipc/namedpipe/namedpipe.go
index f02f3d8..ef3dea1 100644
--- a/ipc/winpipe/winpipe.go
+++ b/ipc/namedpipe/namedpipe.go
@@ -1,13 +1,12 @@
-// +build windows
+// Copyright 2021 The Go Authors. All rights reserved.
+// Copyright 2015 Microsoft
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2005 Microsoft
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
+//go:build windows
-// Package winpipe implements a net.Conn and net.Listener around Windows named pipes.
-package winpipe
+// Package namedpipe implements a net.Conn and net.Listener around Windows named pipes.
+package namedpipe
import (
"context"
@@ -15,6 +14,7 @@ import (
"net"
"os"
"runtime"
+ "sync/atomic"
"time"
"unsafe"
@@ -28,7 +28,7 @@ type pipe struct {
type messageBytePipe struct {
pipe
- writeClosed bool
+ writeClosed atomic.Bool
readEOF bool
}
@@ -50,25 +50,26 @@ func (f *pipe) SetDeadline(t time.Time) error {
// CloseWrite closes the write side of a message pipe in byte mode.
func (f *messageBytePipe) CloseWrite() error {
- if f.writeClosed {
+ if !f.writeClosed.CompareAndSwap(false, true) {
return io.ErrClosedPipe
}
err := f.file.Flush()
if err != nil {
+ f.writeClosed.Store(false)
return err
}
_, err = f.file.Write(nil)
if err != nil {
+ f.writeClosed.Store(false)
return err
}
- f.writeClosed = true
return nil
}
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
// they are used to implement CloseWrite.
func (f *messageBytePipe) Write(b []byte) (int, error) {
- if f.writeClosed {
+ if f.writeClosed.Load() {
return 0, io.ErrClosedPipe
}
if len(b) == 0 {
@@ -142,30 +143,24 @@ type DialConfig struct {
ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID.
}
-// Dial connects to the specified named pipe by path, timing out if the connection
-// takes longer than the specified duration. If timeout is nil, then we use
-// a default timeout of 2 seconds.
-func Dial(path string, timeout *time.Duration, config *DialConfig) (net.Conn, error) {
- var absTimeout time.Time
- if timeout != nil {
- absTimeout = time.Now().Add(*timeout)
- } else {
- absTimeout = time.Now().Add(2 * time.Second)
+// DialTimeout connects to the specified named pipe by path, timing out if the
+// connection takes longer than the specified duration. If timeout is zero, then
+// we use a default timeout of 2 seconds.
+func (config *DialConfig) DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
+ if timeout == 0 {
+ timeout = time.Second * 2
}
+ absTimeout := time.Now().Add(timeout)
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
- conn, err := DialContext(ctx, path, config)
+ conn, err := config.DialContext(ctx, path)
if err == context.DeadlineExceeded {
return nil, os.ErrDeadlineExceeded
}
return conn, err
}
-// DialContext attempts to connect to the specified named pipe by path
-// cancellation or timeout.
-func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn, error) {
- if config == nil {
- config = &DialConfig{}
- }
+// DialContext attempts to connect to the specified named pipe by path.
+func (config *DialConfig) DialContext(ctx context.Context, path string) (net.Conn, error) {
var err error
var h windows.Handle
h, err = tryDialPipe(ctx, &path)
@@ -213,6 +208,18 @@ func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn
return &pipe{file: f, path: path}, nil
}
+var defaultDialer DialConfig
+
+// DialTimeout calls DialConfig.DialTimeout using an empty configuration.
+func DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
+ return defaultDialer.DialTimeout(path, timeout)
+}
+
+// DialContext calls DialConfig.DialContext using an empty configuration.
+func DialContext(ctx context.Context, path string) (net.Conn, error) {
+ return defaultDialer.DialContext(ctx, path)
+}
+
type acceptResponse struct {
f *file
err error
@@ -222,12 +229,12 @@ type pipeListener struct {
firstHandle windows.Handle
path string
config ListenConfig
- acceptCh chan (chan acceptResponse)
+ acceptCh chan chan acceptResponse
closeCh chan int
doneCh chan int
}
-func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, first bool) (windows.Handle, error) {
+func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, isFirstPipe bool) (windows.Handle, error) {
path16, err := windows.UTF16PtrFromString(path)
if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err}
@@ -247,7 +254,7 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste
oa.ObjectName = &ntPath
// The security descriptor is only needed for the first pipe.
- if first {
+ if isFirstPipe {
if sd != nil {
oa.SecurityDescriptor = sd
} else {
@@ -257,7 +264,7 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste
return 0, err
}
defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl)))
- sd, err := windows.NewSecurityDescriptor()
+ sd, err = windows.NewSecurityDescriptor()
if err != nil {
return 0, err
}
@@ -275,11 +282,11 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste
disposition := uint32(windows.FILE_OPEN)
access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
- if first {
+ if isFirstPipe {
disposition = windows.FILE_CREATE
// By not asking for read or write access, the named pipe file system
// will put this pipe into an initially disconnected state, blocking
- // client connections until the next call with first == false.
+ // client connections until the next call with isFirstPipe == false.
access = windows.SYNCHRONIZE
}
@@ -395,10 +402,7 @@ type ListenConfig struct {
// Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe.
// The pipe must not already exist.
-func Listen(path string, c *ListenConfig) (net.Listener, error) {
- if c == nil {
- c = &ListenConfig{}
- }
+func (c *ListenConfig) Listen(path string) (net.Listener, error) {
h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
if err != nil {
return nil, err
@@ -407,12 +411,12 @@ func Listen(path string, c *ListenConfig) (net.Listener, error) {
firstHandle: h,
path: path,
config: *c,
- acceptCh: make(chan (chan acceptResponse)),
+ acceptCh: make(chan chan acceptResponse),
closeCh: make(chan int),
doneCh: make(chan int),
}
// The first connection is swallowed on Windows 7 & 8, so synthesize it.
- if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 {
+ if maj, min, _ := windows.RtlGetNtVersionNumbers(); maj < 6 || (maj == 6 && min < 4) {
path16, err := windows.UTF16PtrFromString(path)
if err == nil {
h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
@@ -425,6 +429,13 @@ func Listen(path string, c *ListenConfig) (net.Listener, error) {
return l, nil
}
+var defaultListener ListenConfig
+
+// Listen calls ListenConfig.Listen using an empty configuration.
+func Listen(path string) (net.Listener, error) {
+ return defaultListener.Listen(path)
+}
+
func connectPipe(p *file) error {
c, err := p.prepareIo()
if err != nil {
diff --git a/ipc/winpipe/winpipe_test.go b/ipc/namedpipe/namedpipe_test.go
index ee8dc8c..998453b 100644
--- a/ipc/winpipe/winpipe_test.go
+++ b/ipc/namedpipe/namedpipe_test.go
@@ -1,12 +1,11 @@
-// +build windows
+// Copyright 2021 The Go Authors. All rights reserved.
+// Copyright 2015 Microsoft
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2005 Microsoft
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
+//go:build windows
-package winpipe_test
+package namedpipe_test
import (
"bufio"
@@ -22,7 +21,7 @@ import (
"time"
"golang.org/x/sys/windows"
- "golang.zx2c4.com/wireguard/ipc/winpipe"
+ "golang.zx2c4.com/wireguard/ipc/namedpipe"
)
func randomPipePath() string {
@@ -30,7 +29,7 @@ func randomPipePath() string {
if err != nil {
panic(err)
}
- return `\\.\PIPE\go-winpipe-test-` + guid.String()
+ return `\\.\PIPE\go-namedpipe-test-` + guid.String()
}
func TestPingPong(t *testing.T) {
@@ -39,7 +38,7 @@ func TestPingPong(t *testing.T) {
pong = 24
)
pipePath := randomPipePath()
- listener, err := winpipe.Listen(pipePath, nil)
+ listener, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatalf("unable to listen on pipe: %v", err)
}
@@ -64,11 +63,12 @@ func TestPingPong(t *testing.T) {
t.Fatalf("unable to write pong to pipe: %v", err)
}
}()
- client, err := winpipe.Dial(pipePath, nil, nil)
+ client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatalf("unable to dial pipe: %v", err)
}
defer client.Close()
+ client.SetDeadline(time.Now().Add(time.Second * 5))
var data [1]byte
data[0] = ping
_, err = client.Write(data[:])
@@ -85,7 +85,7 @@ func TestPingPong(t *testing.T) {
}
func TestDialUnknownFailsImmediately(t *testing.T) {
- _, err := winpipe.Dial(randomPipePath(), nil, nil)
+ _, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0))
if !errors.Is(err, syscall.ENOENT) {
t.Fatalf("expected ENOENT got %v", err)
}
@@ -93,13 +93,15 @@ func TestDialUnknownFailsImmediately(t *testing.T) {
func TestDialListenerTimesOut(t *testing.T) {
pipePath := randomPipePath()
- l, err := winpipe.Listen(pipePath, nil)
+ l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
- d := 10 * time.Millisecond
- _, err = winpipe.Dial(pipePath, &d, nil)
+ pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond)
+ if err == nil {
+ pipe.Close()
+ }
if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
}
@@ -107,14 +109,17 @@ func TestDialListenerTimesOut(t *testing.T) {
func TestDialContextListenerTimesOut(t *testing.T) {
pipePath := randomPipePath()
- l, err := winpipe.Listen(pipePath, nil)
+ l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
d := 10 * time.Millisecond
ctx, _ := context.WithTimeout(context.Background(), d)
- _, err = winpipe.DialContext(ctx, pipePath, nil)
+ pipe, err := namedpipe.DialContext(ctx, pipePath)
+ if err == nil {
+ pipe.Close()
+ }
if err != context.DeadlineExceeded {
t.Fatalf("expected context.DeadlineExceeded, got %v", err)
}
@@ -123,14 +128,14 @@ func TestDialContextListenerTimesOut(t *testing.T) {
func TestDialListenerGetsCancelled(t *testing.T) {
pipePath := randomPipePath()
ctx, cancel := context.WithCancel(context.Background())
- l, err := winpipe.Listen(pipePath, nil)
+ l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
- ch := make(chan error)
defer l.Close()
+ ch := make(chan error)
go func(ctx context.Context, ch chan error) {
- _, err := winpipe.DialContext(ctx, pipePath, nil)
+ _, err := namedpipe.DialContext(ctx, pipePath)
ch <- err
}(ctx, ch)
time.Sleep(time.Millisecond * 30)
@@ -147,23 +152,28 @@ func TestDialAccessDeniedWithRestrictedSD(t *testing.T) {
}
pipePath := randomPipePath()
sd, _ := windows.SecurityDescriptorFromString("D:")
- c := winpipe.ListenConfig{
+ l, err := (&namedpipe.ListenConfig{
SecurityDescriptor: sd,
- }
- l, err := winpipe.Listen(pipePath, &c)
+ }).Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
- _, err = winpipe.Dial(pipePath, nil, nil)
+ pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
+ if err == nil {
+ pipe.Close()
+ }
if !errors.Is(err, windows.ERROR_ACCESS_DENIED) {
t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
}
}
-func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn, err error) {
+func getConnection(cfg *namedpipe.ListenConfig) (client, server net.Conn, err error) {
pipePath := randomPipePath()
- l, err := winpipe.Listen(pipePath, cfg)
+ if cfg == nil {
+ cfg = &namedpipe.ListenConfig{}
+ }
+ l, err := cfg.Listen(pipePath)
if err != nil {
return
}
@@ -179,7 +189,7 @@ func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn,
ch <- response{c, err}
}()
- c, err := winpipe.Dial(pipePath, nil, nil)
+ c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
return
}
@@ -236,7 +246,7 @@ func server(l net.Listener, ch chan int) {
func TestFullListenDialReadWrite(t *testing.T) {
pipePath := randomPipePath()
- l, err := winpipe.Listen(pipePath, nil)
+ l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
@@ -245,7 +255,7 @@ func TestFullListenDialReadWrite(t *testing.T) {
ch := make(chan int)
go server(l, ch)
- c, err := winpipe.Dial(pipePath, nil, nil)
+ c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
@@ -275,7 +285,7 @@ func TestFullListenDialReadWrite(t *testing.T) {
func TestCloseAbortsListen(t *testing.T) {
pipePath := randomPipePath()
- l, err := winpipe.Listen(pipePath, nil)
+ l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
@@ -328,7 +338,7 @@ func TestCloseServerEOFClient(t *testing.T) {
}
func TestCloseWriteEOF(t *testing.T) {
- cfg := &winpipe.ListenConfig{
+ cfg := &namedpipe.ListenConfig{
MessageMode: true,
}
c, s, err := getConnection(cfg)
@@ -356,7 +366,7 @@ func TestCloseWriteEOF(t *testing.T) {
func TestAcceptAfterCloseFails(t *testing.T) {
pipePath := randomPipePath()
- l, err := winpipe.Listen(pipePath, nil)
+ l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
@@ -369,12 +379,15 @@ func TestAcceptAfterCloseFails(t *testing.T) {
func TestDialTimesOutByDefault(t *testing.T) {
pipePath := randomPipePath()
- l, err := winpipe.Listen(pipePath, nil)
+ l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
- _, err = winpipe.Dial(pipePath, nil, nil)
+ pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds.
+ if err == nil {
+ pipe.Close()
+ }
if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
}
@@ -382,7 +395,7 @@ func TestDialTimesOutByDefault(t *testing.T) {
func TestTimeoutPendingRead(t *testing.T) {
pipePath := randomPipePath()
- l, err := winpipe.Listen(pipePath, nil)
+ l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
@@ -400,7 +413,7 @@ func TestTimeoutPendingRead(t *testing.T) {
close(serverDone)
}()
- client, err := winpipe.Dial(pipePath, nil, nil)
+ client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
@@ -430,7 +443,7 @@ func TestTimeoutPendingRead(t *testing.T) {
func TestTimeoutPendingWrite(t *testing.T) {
pipePath := randomPipePath()
- l, err := winpipe.Listen(pipePath, nil)
+ l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
@@ -448,7 +461,7 @@ func TestTimeoutPendingWrite(t *testing.T) {
close(serverDone)
}()
- client, err := winpipe.Dial(pipePath, nil, nil)
+ client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
@@ -480,13 +493,12 @@ type CloseWriter interface {
}
func TestEchoWithMessaging(t *testing.T) {
- c := winpipe.ListenConfig{
+ pipePath := randomPipePath()
+ l, err := (&namedpipe.ListenConfig{
MessageMode: true, // Use message mode so that CloseWrite() is supported
InputBufferSize: 65536, // Use 64KB buffers to improve performance
OutputBufferSize: 65536,
- }
- pipePath := randomPipePath()
- l, err := winpipe.Listen(pipePath, &c)
+ }).Listen(pipePath)
if err != nil {
t.Fatal(err)
}
@@ -496,19 +508,21 @@ func TestEchoWithMessaging(t *testing.T) {
clientDone := make(chan bool)
go func() {
// server echo
- conn, e := l.Accept()
- if e != nil {
- t.Fatal(e)
+ conn, err := l.Accept()
+ if err != nil {
+ t.Fatal(err)
}
defer conn.Close()
time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent
- io.Copy(conn, conn)
+ _, err = io.Copy(conn, conn)
+ if err != nil {
+ t.Fatal(err)
+ }
conn.(CloseWriter).CloseWrite()
close(listenerDone)
}()
- timeout := 1 * time.Second
- client, err := winpipe.Dial(pipePath, &timeout, nil)
+ client, err := namedpipe.DialTimeout(pipePath, time.Second)
if err != nil {
t.Fatal(err)
}
@@ -521,7 +535,7 @@ func TestEchoWithMessaging(t *testing.T) {
if e != nil {
t.Fatal(e)
}
- if n != 2 {
+ if n != 2 || bytes[0] != 0 || bytes[1] != 1 {
t.Fatalf("expected 2 bytes, got %v", n)
}
close(clientDone)
@@ -545,7 +559,7 @@ func TestEchoWithMessaging(t *testing.T) {
func TestConnectRace(t *testing.T) {
pipePath := randomPipePath()
- l, err := winpipe.Listen(pipePath, nil)
+ l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
@@ -565,7 +579,7 @@ func TestConnectRace(t *testing.T) {
}()
for i := 0; i < 1000; i++ {
- c, err := winpipe.Dial(pipePath, nil, nil)
+ c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
@@ -580,7 +594,7 @@ func TestMessageReadMode(t *testing.T) {
var wg sync.WaitGroup
defer wg.Wait()
pipePath := randomPipePath()
- l, err := winpipe.Listen(pipePath, &winpipe.ListenConfig{MessageMode: true})
+ l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath)
if err != nil {
t.Fatal(err)
}
@@ -602,7 +616,7 @@ func TestMessageReadMode(t *testing.T) {
s.Close()
}()
- c, err := winpipe.Dial(pipePath, nil, nil)
+ c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
@@ -643,13 +657,13 @@ func TestListenConnectRace(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
- c, err := winpipe.Dial(pipePath, nil, nil)
+ c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err == nil {
c.Close()
}
wg.Done()
}()
- s, err := winpipe.Listen(pipePath, nil)
+ s, err := namedpipe.Listen(pipePath)
if err != nil {
t.Error(i, err)
} else {
diff --git a/ipc/uapi_bsd.go b/ipc/uapi_bsd.go
index 5beee9e..ddcaf27 100644
--- a/ipc/uapi_bsd.go
+++ b/ipc/uapi_bsd.go
@@ -1,8 +1,8 @@
-// +build darwin freebsd openbsd
+//go:build darwin || freebsd || openbsd
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc
@@ -54,7 +54,6 @@ func (l *UAPIListener) Addr() net.Addr {
}
func UAPIListen(name string, file *os.File) (net.Listener, error) {
-
// wrap file in listener
listener, err := net.FileListener(file)
@@ -104,7 +103,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
l.connErr <- err
return
}
- if kerr != nil || n != 1 {
+ if (kerr != nil || n != 1) && kerr != unix.EINTR {
if kerr != nil {
l.connErr <- kerr
} else {
diff --git a/ipc/uapi_linux.go b/ipc/uapi_linux.go
index e03a00b..1562a18 100644
--- a/ipc/uapi_linux.go
+++ b/ipc/uapi_linux.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc
@@ -51,7 +51,6 @@ func (l *UAPIListener) Addr() net.Addr {
}
func UAPIListen(name string, file *os.File) (net.Listener, error) {
-
// wrap file in listener
listener, err := net.FileListener(file)
@@ -97,7 +96,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
}
go func(l *UAPIListener) {
- var buff [0]byte
+ var buf [0]byte
for {
defer uapi.inotifyRWCancel.Close()
// start with lstat to avoid race condition
@@ -105,7 +104,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
l.connErr <- err
return
}
- _, err := uapi.inotifyRWCancel.Read(buff[:])
+ _, err := uapi.inotifyRWCancel.Read(buf[:])
if err != nil {
l.connErr <- err
return
diff --git a/ipc/uapi_unix.go b/ipc/uapi_unix.go
index 544651b..e67be26 100644
--- a/ipc/uapi_unix.go
+++ b/ipc/uapi_unix.go
@@ -1,8 +1,8 @@
-// +build linux darwin freebsd openbsd
+//go:build linux || darwin || freebsd || openbsd
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc
@@ -33,7 +33,7 @@ func sockPath(iface string) string {
}
func UAPIOpen(name string) (*os.File, error) {
- if err := os.MkdirAll(socketDirectory, 0755); err != nil {
+ if err := os.MkdirAll(socketDirectory, 0o755); err != nil {
return nil, err
}
@@ -43,7 +43,7 @@ func UAPIOpen(name string) (*os.File, error) {
return nil, err
}
- oldUmask := unix.Umask(0077)
+ oldUmask := unix.Umask(0o077)
defer unix.Umask(oldUmask)
listener, err := net.ListenUnix("unix", addr)
diff --git a/ipc/uapi_wasm.go b/ipc/uapi_wasm.go
new file mode 100644
index 0000000..fa84684
--- /dev/null
+++ b/ipc/uapi_wasm.go
@@ -0,0 +1,15 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package ipc
+
+// Made up sentinel error codes for {js,wasip1}/wasm.
+const (
+ IpcErrorIO = 1
+ IpcErrorInvalid = 2
+ IpcErrorPortInUse = 3
+ IpcErrorUnknown = 4
+ IpcErrorProtocol = 5
+)
diff --git a/ipc/uapi_windows.go b/ipc/uapi_windows.go
index a4d68da..aa023c9 100644
--- a/ipc/uapi_windows.go
+++ b/ipc/uapi_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc
@@ -9,8 +9,7 @@ import (
"net"
"golang.org/x/sys/windows"
-
- "golang.zx2c4.com/wireguard/ipc/winpipe"
+ "golang.zx2c4.com/wireguard/ipc/namedpipe"
)
// TODO: replace these with actual standard windows error numbers from the win package
@@ -61,10 +60,9 @@ func init() {
}
func UAPIListen(name string) (net.Listener, error) {
- config := winpipe.ListenConfig{
+ listener, err := (&namedpipe.ListenConfig{
SecurityDescriptor: UAPISecurityDescriptor,
- }
- listener, err := winpipe.Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config)
+ }).Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\` + name)
if err != nil {
return nil, err
}
diff --git a/main.go b/main.go
index 639d644..e016116 100644
--- a/main.go
+++ b/main.go
@@ -1,8 +1,8 @@
-// +build !windows
+//go:build !windows
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main
@@ -13,8 +13,8 @@ import (
"os/signal"
"runtime"
"strconv"
- "syscall"
+ "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
@@ -111,7 +111,7 @@ func main() {
// open TUN device (or use supplied fd)
- tun, err := func() (tun.Device, error) {
+ tdev, err := func() (tun.Device, error) {
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
if tunFdStr == "" {
return tun.CreateTUN(interfaceName, device.DefaultMTU)
@@ -124,7 +124,7 @@ func main() {
return nil, err
}
- err = syscall.SetNonblock(int(fd), true)
+ err = unix.SetNonblock(int(fd), true)
if err != nil {
return nil, err
}
@@ -134,7 +134,7 @@ func main() {
}()
if err == nil {
- realInterfaceName, err2 := tun.Name()
+ realInterfaceName, err2 := tdev.Name()
if err2 == nil {
interfaceName = realInterfaceName
}
@@ -169,7 +169,6 @@ func main() {
return os.NewFile(uintptr(fd), ""), nil
}()
-
if err != nil {
logger.Errorf("UAPI listen error: %v", err)
os.Exit(ExitSetupFailed)
@@ -197,7 +196,7 @@ func main() {
files[0], // stdin
files[1], // stdout
files[2], // stderr
- tun.File(),
+ tdev.File(),
fileUAPI,
},
Dir: ".",
@@ -223,7 +222,7 @@ func main() {
return
}
- device := device.NewDevice(tun, conn.NewDefaultBind(), logger)
+ device := device.NewDevice(tdev, conn.NewDefaultBind(), logger)
logger.Verbosef("Device started")
@@ -251,7 +250,7 @@ func main() {
// wait for program to terminate
- signal.Notify(term, syscall.SIGTERM)
+ signal.Notify(term, unix.SIGTERM)
signal.Notify(term, os.Interrupt)
select {
diff --git a/main_windows.go b/main_windows.go
index 5a7b136..a4dc46f 100644
--- a/main_windows.go
+++ b/main_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main
@@ -9,7 +9,8 @@ import (
"fmt"
"os"
"os/signal"
- "syscall"
+
+ "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
@@ -81,7 +82,7 @@ func main() {
signal.Notify(term, os.Interrupt)
signal.Notify(term, os.Kill)
- signal.Notify(term, syscall.SIGTERM)
+ signal.Notify(term, windows.SIGTERM)
select {
case <-term:
diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go
index 2f7aa2a..f7d05ef 100644
--- a/ratelimiter/ratelimiter.go
+++ b/ratelimiter/ratelimiter.go
@@ -1,12 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ratelimiter
import (
- "net"
+ "net/netip"
"sync"
"time"
)
@@ -30,8 +30,7 @@ type Ratelimiter struct {
timeNow func() time.Time
stopReset chan struct{} // send to reset, close to stop
- tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
- tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
+ table map[netip.Addr]*RatelimiterEntry
}
func (rate *Ratelimiter) Close() {
@@ -57,8 +56,7 @@ func (rate *Ratelimiter) Init() {
}
rate.stopReset = make(chan struct{})
- rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
- rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
+ rate.table = make(map[netip.Addr]*RatelimiterEntry)
stopReset := rate.stopReset // store in case Init is called again.
@@ -87,71 +85,39 @@ func (rate *Ratelimiter) cleanup() (empty bool) {
rate.mu.Lock()
defer rate.mu.Unlock()
- for key, entry := range rate.tableIPv4 {
+ for key, entry := range rate.table {
entry.mu.Lock()
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
- delete(rate.tableIPv4, key)
+ delete(rate.table, key)
}
entry.mu.Unlock()
}
- for key, entry := range rate.tableIPv6 {
- entry.mu.Lock()
- if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
- delete(rate.tableIPv6, key)
- }
- entry.mu.Unlock()
- }
-
- return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
+ return len(rate.table) == 0
}
-func (rate *Ratelimiter) Allow(ip net.IP) bool {
+func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
var entry *RatelimiterEntry
- var keyIPv4 [net.IPv4len]byte
- var keyIPv6 [net.IPv6len]byte
-
// lookup entry
-
- IPv4 := ip.To4()
- IPv6 := ip.To16()
-
rate.mu.RLock()
-
- if IPv4 != nil {
- copy(keyIPv4[:], IPv4)
- entry = rate.tableIPv4[keyIPv4]
- } else {
- copy(keyIPv6[:], IPv6)
- entry = rate.tableIPv6[keyIPv6]
- }
-
+ entry = rate.table[ip]
rate.mu.RUnlock()
// make new entry if not found
-
if entry == nil {
entry = new(RatelimiterEntry)
entry.tokens = maxTokens - packetCost
entry.lastTime = rate.timeNow()
rate.mu.Lock()
- if IPv4 != nil {
- rate.tableIPv4[keyIPv4] = entry
- if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
- rate.stopReset <- struct{}{}
- }
- } else {
- rate.tableIPv6[keyIPv6] = entry
- if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
- rate.stopReset <- struct{}{}
- }
+ rate.table[ip] = entry
+ if len(rate.table) == 1 {
+ rate.stopReset <- struct{}{}
}
rate.mu.Unlock()
return true
}
// add tokens to entry
-
entry.mu.Lock()
now := rate.timeNow()
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
@@ -161,7 +127,6 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
}
// subtract cost of packet
-
if entry.tokens > packetCost {
entry.tokens -= packetCost
entry.mu.Unlock()
diff --git a/ratelimiter/ratelimiter_test.go b/ratelimiter/ratelimiter_test.go
index f231fe5..0bfa3af 100644
--- a/ratelimiter/ratelimiter_test.go
+++ b/ratelimiter/ratelimiter_test.go
@@ -1,12 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ratelimiter
import (
- "net"
+ "net/netip"
"testing"
"time"
)
@@ -71,21 +71,21 @@ func TestRatelimiter(t *testing.T) {
text: "packet following 2 packet burst",
})
- ips := []net.IP{
- net.ParseIP("127.0.0.1"),
- net.ParseIP("192.168.1.1"),
- net.ParseIP("172.167.2.3"),
- net.ParseIP("97.231.252.215"),
- net.ParseIP("248.97.91.167"),
- net.ParseIP("188.208.233.47"),
- net.ParseIP("104.2.183.179"),
- net.ParseIP("72.129.46.120"),
- net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
- net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
- net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
- net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
- net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
- net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
+ ips := []netip.Addr{
+ netip.MustParseAddr("127.0.0.1"),
+ netip.MustParseAddr("192.168.1.1"),
+ netip.MustParseAddr("172.167.2.3"),
+ netip.MustParseAddr("97.231.252.215"),
+ netip.MustParseAddr("248.97.91.167"),
+ netip.MustParseAddr("188.208.233.47"),
+ netip.MustParseAddr("104.2.183.179"),
+ netip.MustParseAddr("72.129.46.120"),
+ netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
+ netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
+ netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
+ netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
+ netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
+ netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
}
now := time.Now()
diff --git a/replay/replay.go b/replay/replay.go
index fac7ab2..8b99e23 100644
--- a/replay/replay.go
+++ b/replay/replay.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
@@ -34,7 +34,7 @@ func (f *Filter) Reset() {
// ValidateCounter checks if the counter should be accepted.
// Overlimit counters (>= limit) are always rejected.
-func (f *Filter) ValidateCounter(counter uint64, limit uint64) bool {
+func (f *Filter) ValidateCounter(counter, limit uint64) bool {
if counter >= limit {
return false
}
diff --git a/replay/replay_test.go b/replay/replay_test.go
index 28c3a0e..9a9e4a8 100644
--- a/replay/replay_test.go
+++ b/replay/replay_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package replay
diff --git a/rwcancel/rwcancel.go b/rwcancel/rwcancel.go
index 5dff446..e397c0e 100644
--- a/rwcancel/rwcancel.go
+++ b/rwcancel/rwcancel.go
@@ -1,8 +1,8 @@
-// +build !windows
+//go:build !windows && !wasm
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
// Package rwcancel implements cancelable read/write operations on
diff --git a/rwcancel/rwcancel_windows.go b/rwcancel/rwcancel_stub.go
index 0316911..2a98b2b 100644
--- a/rwcancel/rwcancel_windows.go
+++ b/rwcancel/rwcancel_stub.go
@@ -1,8 +1,9 @@
+//go:build windows || wasm
+
// SPDX-License-Identifier: MIT
package rwcancel
-type RWCancel struct {
-}
+type RWCancel struct{}
func (*RWCancel) Cancel() {}
diff --git a/tai64n/tai64n.go b/tai64n/tai64n.go
index e1fa26a..8f10b39 100644
--- a/tai64n/tai64n.go
+++ b/tai64n/tai64n.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tai64n
@@ -11,9 +11,11 @@ import (
"time"
)
-const TimestampSize = 12
-const base = uint64(0x400000000000000a)
-const whitenerMask = uint32(0x1000000 - 1)
+const (
+ TimestampSize = 12
+ base = uint64(0x400000000000000a)
+ whitenerMask = uint32(0x1000000 - 1)
+)
type Timestamp [TimestampSize]byte
diff --git a/tai64n/tai64n_test.go b/tai64n/tai64n_test.go
index 97eb26f..c70fc1a 100644
--- a/tai64n/tai64n_test.go
+++ b/tai64n/tai64n_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tai64n
diff --git a/tun/alignment_windows_test.go b/tun/alignment_windows_test.go
index b765cca..67a785e 100644
--- a/tun/alignment_windows_test.go
+++ b/tun/alignment_windows_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
diff --git a/tun/checksum.go b/tun/checksum.go
new file mode 100644
index 0000000..29a8fc8
--- /dev/null
+++ b/tun/checksum.go
@@ -0,0 +1,118 @@
+package tun
+
+import "encoding/binary"
+
+// TODO: Explore SIMD and/or other assembly optimizations.
+// TODO: Test native endian loads. See RFC 1071 section 2 part B.
+func checksumNoFold(b []byte, initial uint64) uint64 {
+ ac := initial
+
+ for len(b) >= 128 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+ ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+ ac += uint64(binary.BigEndian.Uint32(b[16:20]))
+ ac += uint64(binary.BigEndian.Uint32(b[20:24]))
+ ac += uint64(binary.BigEndian.Uint32(b[24:28]))
+ ac += uint64(binary.BigEndian.Uint32(b[28:32]))
+ ac += uint64(binary.BigEndian.Uint32(b[32:36]))
+ ac += uint64(binary.BigEndian.Uint32(b[36:40]))
+ ac += uint64(binary.BigEndian.Uint32(b[40:44]))
+ ac += uint64(binary.BigEndian.Uint32(b[44:48]))
+ ac += uint64(binary.BigEndian.Uint32(b[48:52]))
+ ac += uint64(binary.BigEndian.Uint32(b[52:56]))
+ ac += uint64(binary.BigEndian.Uint32(b[56:60]))
+ ac += uint64(binary.BigEndian.Uint32(b[60:64]))
+ ac += uint64(binary.BigEndian.Uint32(b[64:68]))
+ ac += uint64(binary.BigEndian.Uint32(b[68:72]))
+ ac += uint64(binary.BigEndian.Uint32(b[72:76]))
+ ac += uint64(binary.BigEndian.Uint32(b[76:80]))
+ ac += uint64(binary.BigEndian.Uint32(b[80:84]))
+ ac += uint64(binary.BigEndian.Uint32(b[84:88]))
+ ac += uint64(binary.BigEndian.Uint32(b[88:92]))
+ ac += uint64(binary.BigEndian.Uint32(b[92:96]))
+ ac += uint64(binary.BigEndian.Uint32(b[96:100]))
+ ac += uint64(binary.BigEndian.Uint32(b[100:104]))
+ ac += uint64(binary.BigEndian.Uint32(b[104:108]))
+ ac += uint64(binary.BigEndian.Uint32(b[108:112]))
+ ac += uint64(binary.BigEndian.Uint32(b[112:116]))
+ ac += uint64(binary.BigEndian.Uint32(b[116:120]))
+ ac += uint64(binary.BigEndian.Uint32(b[120:124]))
+ ac += uint64(binary.BigEndian.Uint32(b[124:128]))
+ b = b[128:]
+ }
+ if len(b) >= 64 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+ ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+ ac += uint64(binary.BigEndian.Uint32(b[16:20]))
+ ac += uint64(binary.BigEndian.Uint32(b[20:24]))
+ ac += uint64(binary.BigEndian.Uint32(b[24:28]))
+ ac += uint64(binary.BigEndian.Uint32(b[28:32]))
+ ac += uint64(binary.BigEndian.Uint32(b[32:36]))
+ ac += uint64(binary.BigEndian.Uint32(b[36:40]))
+ ac += uint64(binary.BigEndian.Uint32(b[40:44]))
+ ac += uint64(binary.BigEndian.Uint32(b[44:48]))
+ ac += uint64(binary.BigEndian.Uint32(b[48:52]))
+ ac += uint64(binary.BigEndian.Uint32(b[52:56]))
+ ac += uint64(binary.BigEndian.Uint32(b[56:60]))
+ ac += uint64(binary.BigEndian.Uint32(b[60:64]))
+ b = b[64:]
+ }
+ if len(b) >= 32 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+ ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+ ac += uint64(binary.BigEndian.Uint32(b[16:20]))
+ ac += uint64(binary.BigEndian.Uint32(b[20:24]))
+ ac += uint64(binary.BigEndian.Uint32(b[24:28]))
+ ac += uint64(binary.BigEndian.Uint32(b[28:32]))
+ b = b[32:]
+ }
+ if len(b) >= 16 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+ ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+ b = b[16:]
+ }
+ if len(b) >= 8 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ b = b[8:]
+ }
+ if len(b) >= 4 {
+ ac += uint64(binary.BigEndian.Uint32(b))
+ b = b[4:]
+ }
+ if len(b) >= 2 {
+ ac += uint64(binary.BigEndian.Uint16(b))
+ b = b[2:]
+ }
+ if len(b) == 1 {
+ ac += uint64(b[0]) << 8
+ }
+
+ return ac
+}
+
+func checksum(b []byte, initial uint64) uint16 {
+ ac := checksumNoFold(b, initial)
+ ac = (ac >> 16) + (ac & 0xffff)
+ ac = (ac >> 16) + (ac & 0xffff)
+ ac = (ac >> 16) + (ac & 0xffff)
+ ac = (ac >> 16) + (ac & 0xffff)
+ return uint16(ac)
+}
+
+func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
+ sum := checksumNoFold(srcAddr, 0)
+ sum = checksumNoFold(dstAddr, sum)
+ sum = checksumNoFold([]byte{0, protocol}, sum)
+ tmp := make([]byte, 2)
+ binary.BigEndian.PutUint16(tmp, totalLen)
+ return checksumNoFold(tmp, sum)
+}
diff --git a/tun/checksum_test.go b/tun/checksum_test.go
new file mode 100644
index 0000000..c1ccff5
--- /dev/null
+++ b/tun/checksum_test.go
@@ -0,0 +1,35 @@
+package tun
+
+import (
+ "fmt"
+ "math/rand"
+ "testing"
+)
+
+func BenchmarkChecksum(b *testing.B) {
+ lengths := []int{
+ 64,
+ 128,
+ 256,
+ 512,
+ 1024,
+ 1500,
+ 2048,
+ 4096,
+ 8192,
+ 9000,
+ 9001,
+ }
+
+ for _, length := range lengths {
+ b.Run(fmt.Sprintf("%d", length), func(b *testing.B) {
+ buf := make([]byte, length)
+ rng := rand.New(rand.NewSource(1))
+ rng.Read(buf)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ checksum(buf, 0)
+ }
+ })
+ }
+}
diff --git a/tun/errors.go b/tun/errors.go
new file mode 100644
index 0000000..75ae3a4
--- /dev/null
+++ b/tun/errors.go
@@ -0,0 +1,12 @@
+package tun
+
+import (
+ "errors"
+)
+
+var (
+ // ErrTooManySegments is returned by Device.Read() when segmentation
+ // overflows the length of supplied buffers. This error should not cause
+ // reads to cease.
+ ErrTooManySegments = errors.New("too many segments")
+)
diff --git a/tun/netstack/examples/http_client.go b/tun/netstack/examples/http_client.go
index 39b1c6d..ccd32ed 100644
--- a/tun/netstack/examples/http_client.go
+++ b/tun/netstack/examples/http_client.go
@@ -1,8 +1,8 @@
-// +build ignore
+//go:build ignore
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main
@@ -10,8 +10,8 @@ package main
import (
"io"
"log"
- "net"
"net/http"
+ "net/netip"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
@@ -20,17 +20,17 @@ import (
func main() {
tun, tnet, err := netstack.CreateNetTUN(
- []net.IP{net.ParseIP("192.168.4.29")},
- []net.IP{net.ParseIP("8.8.8.8")},
+ []netip.Addr{netip.MustParseAddr("192.168.4.28")},
+ []netip.Addr{netip.MustParseAddr("8.8.8.8")},
1420)
if err != nil {
log.Panic(err)
}
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
- dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f
-public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b
-endpoint=163.172.161.0:12912
+ err = dev.IpcSet(`private_key=087ec6e14bbed210e7215cdc73468dfa23f080a1bfb8665b2fd809bd99d28379
+public_key=c4c8e984c5322c8184c72265b92b250fdb63688705f504ba003c88f03393cf28
allowed_ip=0.0.0.0/0
+endpoint=127.0.0.1:58120
`)
err = dev.Up()
if err != nil {
@@ -42,7 +42,7 @@ allowed_ip=0.0.0.0/0
DialContext: tnet.DialContext,
},
}
- resp, err := client.Get("https://www.zx2c4.com/ip")
+ resp, err := client.Get("http://192.168.4.29/")
if err != nil {
log.Panic(err)
}
diff --git a/tun/netstack/examples/http_server.go b/tun/netstack/examples/http_server.go
index c1fc753..f5b7a8f 100644
--- a/tun/netstack/examples/http_server.go
+++ b/tun/netstack/examples/http_server.go
@@ -1,8 +1,8 @@
-// +build ignore
+//go:build ignore
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main
@@ -12,6 +12,7 @@ import (
"log"
"net"
"net/http"
+ "net/netip"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
@@ -20,18 +21,18 @@ import (
func main() {
tun, tnet, err := netstack.CreateNetTUN(
- []net.IP{net.ParseIP("192.168.4.29")},
- []net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")},
+ []netip.Addr{netip.MustParseAddr("192.168.4.29")},
+ []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")},
1420,
)
if err != nil {
log.Panic(err)
}
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
- dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f
-public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b
-endpoint=163.172.161.0:12912
-allowed_ip=0.0.0.0/0
+ dev.IpcSet(`private_key=003ed5d73b55806c30de3f8a7bdab38af13539220533055e635690b8b87ad641
+listen_port=58120
+public_key=f928d4f6c1b86c12f2562c10b07c555c5c57fd00f59e90c8d8d88767271cbf7c
+allowed_ip=192.168.4.28/32
persistent_keepalive_interval=25
`)
dev.Up()
diff --git a/tun/netstack/examples/ping_client.go b/tun/netstack/examples/ping_client.go
new file mode 100644
index 0000000..2eef0fb
--- /dev/null
+++ b/tun/netstack/examples/ping_client.go
@@ -0,0 +1,75 @@
+//go:build ignore
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package main
+
+import (
+ "bytes"
+ "log"
+ "math/rand"
+ "net/netip"
+ "time"
+
+ "golang.org/x/net/icmp"
+ "golang.org/x/net/ipv4"
+
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/device"
+ "golang.zx2c4.com/wireguard/tun/netstack"
+)
+
+func main() {
+ tun, tnet, err := netstack.CreateNetTUN(
+ []netip.Addr{netip.MustParseAddr("192.168.4.29")},
+ []netip.Addr{netip.MustParseAddr("8.8.8.8")},
+ 1420)
+ if err != nil {
+ log.Panic(err)
+ }
+ dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
+ dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f
+public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b
+endpoint=163.172.161.0:12912
+allowed_ip=0.0.0.0/0
+`)
+ err = dev.Up()
+ if err != nil {
+ log.Panic(err)
+ }
+
+ socket, err := tnet.Dial("ping4", "zx2c4.com")
+ if err != nil {
+ log.Panic(err)
+ }
+ requestPing := icmp.Echo{
+ Seq: rand.Intn(1 << 16),
+ Data: []byte("gopher burrow"),
+ }
+ icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
+ socket.SetReadDeadline(time.Now().Add(time.Second * 10))
+ start := time.Now()
+ _, err = socket.Write(icmpBytes)
+ if err != nil {
+ log.Panic(err)
+ }
+ n, err := socket.Read(icmpBytes[:])
+ if err != nil {
+ log.Panic(err)
+ }
+ replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
+ if err != nil {
+ log.Panic(err)
+ }
+ replyPing, ok := replyPacket.Body.(*icmp.Echo)
+ if !ok {
+ log.Panicf("invalid reply type: %v", replyPacket)
+ }
+ if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
+ log.Panicf("invalid ping reply: %v", replyPing)
+ }
+ log.Printf("Ping latency: %v", time.Since(start))
+}
diff --git a/tun/netstack/go.mod b/tun/netstack/go.mod
deleted file mode 100644
index 3123eec..0000000
--- a/tun/netstack/go.mod
+++ /dev/null
@@ -1,11 +0,0 @@
-module golang.zx2c4.com/wireguard/tun/netstack
-
-go 1.16
-
-require (
- golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6
- golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 // indirect
- golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect
- golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22
- gvisor.dev/gvisor v0.0.0-20210506004418-fbfeba3024f0
-)
diff --git a/tun/netstack/go.sum b/tun/netstack/go.sum
deleted file mode 100644
index 8227b4c..0000000
--- a/tun/netstack/go.sum
+++ /dev/null
@@ -1,605 +0,0 @@
-bazil.org/fuse v0.0.0-20160811212531-371fbbdaa898/go.mod h1:Xbm+BRKSBEpa4q4hTSxohYNQpsxXPbPry4JJWOB3LB8=
-cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
-cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
-cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU=
-cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU=
-cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY=
-cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc=
-cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0=
-cloud.google.com/go v0.50.0/go.mod h1:r9sluTvynVuxRIOHXQEHMFffphuXHOMZMycpNR5e6To=
-cloud.google.com/go v0.52.0/go.mod h1:pXajvRH/6o3+F9jDHZWQ5PbGhn+o8w9qiu/CffaVdO4=
-cloud.google.com/go v0.53.0/go.mod h1:fp/UouUEsRkN6ryDKNW/Upv/JBKnv6WDthjR6+vze6M=
-cloud.google.com/go v0.54.0/go.mod h1:1rq2OEkV3YMf6n/9ZvGWI3GWw0VoqH/1x2nd8Is/bPc=
-cloud.google.com/go v0.56.0/go.mod h1:jr7tqZxxKOVYizybht9+26Z/gUq7tiRzu+ACVAMbKVk=
-cloud.google.com/go v0.57.0/go.mod h1:oXiQ6Rzq3RAkkY7N6t3TcE6jE+CIBBbA36lwQ1JyzZs=
-cloud.google.com/go v0.62.0/go.mod h1:jmCYTdRCQuc1PHIIJ/maLInMho30T/Y0M4hTdTShOYc=
-cloud.google.com/go v0.65.0/go.mod h1:O5N8zS7uWy9vkA9vayVHs65eM1ubvY4h553ofrNHObY=
-cloud.google.com/go v0.72.0/go.mod h1:M+5Vjvlc2wnp6tjzE102Dw08nGShTscUx2nZMufOKPI=
-cloud.google.com/go v0.75.0/go.mod h1:VGuuCn7PG0dwsd5XPVm2Mm3wlh3EL55/79EKB6hlPTY=
-cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o=
-cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE=
-cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc=
-cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg=
-cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc=
-cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ=
-cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE=
-cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk=
-cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I=
-cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw=
-cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA=
-cloud.google.com/go/pubsub v1.3.1/go.mod h1:i+ucay31+CNRpDW4Lu78I4xXG+O1r/MAHgjpRVR+TSU=
-cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw=
-cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0ZeosJ0Rtdos=
-cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk=
-cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs=
-cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0=
-dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
-github.com/Azure/go-autorest/autorest v0.9.0/go.mod h1:xyHB1BMZT0cuDHU7I0+g046+BFDTQ8rEZB0s4Yfa6bI=
-github.com/Azure/go-autorest/autorest/adal v0.5.0/go.mod h1:8Z9fGy2MpX0PvDjB1pEgQTmVqjGhiHBW7RJJEciWzS0=
-github.com/Azure/go-autorest/autorest/date v0.1.0/go.mod h1:plvfp3oPSKwf2DNjlBjWF/7vwR+cUD/ELuzDCXwHUVA=
-github.com/Azure/go-autorest/autorest/mocks v0.1.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0=
-github.com/Azure/go-autorest/autorest/mocks v0.2.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0=
-github.com/Azure/go-autorest/logger v0.1.0/go.mod h1:oExouG+K6PryycPJfVSxi/koC6LSNgds39diKLz7Vrc=
-github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbtp2fGCgRFtBroKn4Dk=
-github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
-github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
-github.com/Microsoft/go-winio v0.4.16-0.20201130162521-d1ffc52c7331/go.mod h1:XB6nPKklQyQ7GC9LdcBEcBl8PF76WugXOPRXwdLnMv0=
-github.com/Microsoft/go-winio v0.4.16/go.mod h1:XB6nPKklQyQ7GC9LdcBEcBl8PF76WugXOPRXwdLnMv0=
-github.com/Microsoft/hcsshim v0.8.14/go.mod h1:NtVKoYxQuTLx6gEq0L96c9Ju4JbRJ4nY2ow3VK6a9Lg=
-github.com/NYTimes/gziphandler v0.0.0-20170623195520-56545f4a5d46/go.mod h1:3wb06e3pkSAbeQ52E9H9iFoQsEEwGN64994WTCIhntQ=
-github.com/PuerkitoBio/purell v1.0.0/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0=
-github.com/PuerkitoBio/urlesc v0.0.0-20160726150825-5bd2802263f2/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE=
-github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422/go.mod h1:b6Nc7NRH5C4aCISLry0tLnTjcuTEvoiqcWDdsU0sOGM=
-github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
-github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
-github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
-github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
-github.com/cilium/ebpf v0.0.0-20200110133405-4032b1d8aae3/go.mod h1:MA5e5Lr8slmEg9bt0VpxxWqJlO4iwu3FBdHUzV7wQVg=
-github.com/cilium/ebpf v0.2.0/go.mod h1:To2CFviqOWL/M0gIMsvSMlqe7em/l1ALkX1PyjrX2Qs=
-github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
-github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
-github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
-github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
-github.com/containerd/cgroups v0.0.0-20200531161412-0dbf7f05ba59/go.mod h1:pA0z1pT8KYB3TCXK/ocprsh7MAkoW8bZVzPdih9snmM=
-github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327/go.mod h1:ZJeTFisyysqgcCdecO57Dj79RfL0LNeGiFUqLYQRYLE=
-github.com/containerd/console v0.0.0-20180822173158-c12b1e7919c1/go.mod h1:Tj/on1eG8kiEhd0+fhSDzsPAFESxzBBvdyEgyryXffw=
-github.com/containerd/console v0.0.0-20191206165004-02ecf6a7291e/go.mod h1:8Pf4gM6VEbTNRIT26AyyU7hxdQU3MvAvxVI0sc00XBE=
-github.com/containerd/console v1.0.1/go.mod h1:XUsP6YE/mKtz6bxc+I8UiKKTP04qjQL4qcS3XoQ5xkw=
-github.com/containerd/containerd v1.3.2/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA=
-github.com/containerd/containerd v1.3.9/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA=
-github.com/containerd/continuity v0.0.0-20190426062206-aaeac12a7ffc/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y=
-github.com/containerd/continuity v0.0.0-20210208174643-50096c924a4e/go.mod h1:EXlVlkqNba9rJe3j7w3Xa924itAMLgZH4UD/Q4PExuQ=
-github.com/containerd/fifo v0.0.0-20190226154929-a9fb20d87448/go.mod h1:ODA38xgv3Kuk8dQz2ZQXpnv/UZZUHUCL7pnLehbXgQI=
-github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00/go.mod h1:jPQ2IAeZRCYxpS/Cm1495vGFww6ecHmMk1YJH2Q5ln0=
-github.com/containerd/go-runc v0.0.0-20180907222934-5a6d9f37cfa3/go.mod h1:IV7qH3hrUgRmyYrtgEeGWJfWbgcHL9CSRruz2Vqcph0=
-github.com/containerd/go-runc v0.0.0-20200220073739-7016d3ce2328/go.mod h1:PpyHrqVs8FTi9vpyHwPwiNEGaACDxT/N/pLcvMSRA9g=
-github.com/containerd/ttrpc v0.0.0-20190828154514-0e0f228740de/go.mod h1:PvCDdDGpgqzQIzDW1TphrGLssLDZp2GuS+X5DkEJB8o=
-github.com/containerd/ttrpc v1.0.2/go.mod h1:UAxOpgT9ziI0gJrmKvgcZivgxOp8iFPSk8httJEt98Y=
-github.com/containerd/typeurl v0.0.0-20180627222232-a93fcdb778cd/go.mod h1:Cm3kwCdlkCfMSHURc+r6fwoGH6/F1hH3S4sg0rLFWPc=
-github.com/containerd/typeurl v0.0.0-20200205145503-b45ef1f1f737/go.mod h1:TB1hUtrpaiO88KEK56ijojHS1+NeF0izUACaJW2mdXg=
-github.com/coreos/go-systemd/v22 v22.0.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk=
-github.com/coreos/go-systemd/v22 v22.1.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk=
-github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
-github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
-github.com/davecgh/go-spew v0.0.0-20151105211317-5215b55f46b2/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
-github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w=
-github.com/docker/docker v1.4.2-0.20191028175130-9e7d5ac5ea55/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
-github.com/docker/go-connections v0.3.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec=
-github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c/go.mod h1:Uw6UezgYA44ePAFQYUehOuCzmy5zmg/+nl2ZfMWGkpA=
-github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
-github.com/docker/spdystream v0.0.0-20160310174837-449fdfce4d96/go.mod h1:Qh8CwZgvJUkLughtfhJv5dyTYa91l1fOUCrgjqmcifM=
-github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
-github.com/elazarl/goproxy v0.0.0-20170405201442-c4fc26588b6e/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc=
-github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs=
-github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
-github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
-github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
-github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5ynNVH9qI8YYLbd1fK2po=
-github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
-github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
-github.com/evanphx/json-patch v4.2.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk=
-github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
-github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
-github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
-github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
-github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
-github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas=
-github.com/go-openapi/jsonpointer v0.0.0-20160704185906-46af16f9f7b1/go.mod h1:+35s3my2LFTysnkMfxsJBAMHj/DoqoB9knIWoYG/Vk0=
-github.com/go-openapi/jsonreference v0.0.0-20160704190145-13c6e3589ad9/go.mod h1:W3Z9FmVs9qj+KR4zFKmDPGiLdk1D9Rlm7cyMvf57TTg=
-github.com/go-openapi/spec v0.0.0-20160808142527-6aced65f8501/go.mod h1:J8+jY1nAiCcj+friV/PDoE1/3eeccG9LYBs0tYvLOWc=
-github.com/go-openapi/swag v0.0.0-20160704191624-1d0bd113de87/go.mod h1:DXUve3Dpr1UfpPtxFw+EFuQ41HhCWZfha5jSVRG7C7I=
-github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
-github.com/gofrs/flock v0.8.0/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU=
-github.com/gogo/googleapis v1.4.0/go.mod h1:5YRNX2z1oM5gXdAkurHa942MDgEJyk02w4OecKY87+c=
-github.com/gogo/protobuf v1.2.2-0.20190723190241-65acae22fc9d/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
-github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
-github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
-github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
-github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
-github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
-github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
-github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
-github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
-github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y=
-github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
-github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
-github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
-github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4=
-github.com/golang/protobuf v0.0.0-20161109072736-4bd1920723d7/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
-github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
-github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
-github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
-github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
-github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
-github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk=
-github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
-github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
-github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
-github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
-github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
-github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
-github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
-github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
-github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
-github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
-github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
-github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
-github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
-github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
-github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-github/v32 v32.1.0/go.mod h1:rIEpZD9CTDQwDK9GDrtMTycQNA4JU3qBsCizh3q2WCI=
-github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
-github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI=
-github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
-github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
-github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
-github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
-github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
-github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
-github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
-github.com/google/pprof v0.0.0-20200212024743-f11f1df84d12/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
-github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
-github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
-github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
-github.com/google/pprof v0.0.0-20201023163331-3e6fc7fc9c4c/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
-github.com/google/pprof v0.0.0-20201218002935-b9804c9f04c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
-github.com/google/pprof v0.0.0-20210115211752-39141e76b647/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
-github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
-github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
-github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
-github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
-github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
-github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
-github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
-github.com/googleapis/gnostic v0.0.0-20170729233727-0c5108395e2d/go.mod h1:sJBsCZ4ayReDTBIg8b9dl28c5xFWyhBTVRp3pOg5EKY=
-github.com/gophercloud/gophercloud v0.1.0/go.mod h1:vxM41WHh5uqHVBMZHzuwNOHh8XEoIEcSTewFxm1c5g8=
-github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA=
-github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
-github.com/hashicorp/go-multierror v1.1.0/go.mod h1:spPvp8C1qA32ftKqdAHm4hHTbPw+vmowP0z+KUhOZdA=
-github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
-github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
-github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
-github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
-github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
-github.com/imdario/mergo v0.3.5/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA=
-github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
-github.com/json-iterator/go v0.0.0-20180612202835-f2b4162afba3/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
-github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
-github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
-github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
-github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00=
-github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
-github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
-github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
-github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
-github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
-github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
-github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
-github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
-github.com/mattbaird/jsonpatch v0.0.0-20171005235357-81af80346b1a/go.mod h1:M1qoD/MqPgTZIk0EWKB38wE28ACRfVcn+cU08jyArI0=
-github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
-github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
-github.com/modern-go/reflect2 v0.0.0-20180320133207-05fbef0ca5da/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
-github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
-github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
-github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8=
-github.com/munnerz/goautoneg v0.0.0-20120707110453-a547fc61f48d/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
-github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
-github.com/onsi/ginkgo v0.0.0-20170829012221-11459a886d9c/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
-github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
-github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
-github.com/onsi/gomega v0.0.0-20170829124025-dcabb60a477c/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA=
-github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
-github.com/opencontainers/go-digest v0.0.0-20180430190053-c9281466c8b2/go.mod h1:cMLVZDEM3+U2I4VmLI6N8jQYUd2OVphdqWwCJHrFt2s=
-github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
-github.com/opencontainers/image-spec v1.0.1/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zMzWCbyJoFRP3s7yZA0=
-github.com/opencontainers/runc v0.0.0-20190115041553-12f6a991201f/go.mod h1:qT5XzbpPznkRYVz/mWwUaVBUv2rmF59PVA73FjuZG0U=
-github.com/opencontainers/runc v0.1.1/go.mod h1:qT5XzbpPznkRYVz/mWwUaVBUv2rmF59PVA73FjuZG0U=
-github.com/opencontainers/runtime-spec v1.0.1/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
-github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
-github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
-github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU=
-github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
-github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
-github.com/pmezard/go-difflib v0.0.0-20151028094244-d8ed2627bdf0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
-github.com/prometheus/procfs v0.0.0-20180125133057-cb4147076ac7/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
-github.com/prometheus/procfs v0.0.0-20190522114515-bc1a522cf7b1/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
-github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
-github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
-github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
-github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
-github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
-github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
-github.com/spf13/afero v1.2.2/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk=
-github.com/spf13/cobra v0.0.2-0.20171109065643-2da4a54c5cee/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ=
-github.com/spf13/pflag v0.0.0-20170130214245-9ff6c6923cff/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
-github.com/spf13/pflag v1.0.1-0.20171106142849-4c012f6dcd95/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
-github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
-github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
-github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
-github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
-github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
-github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
-github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
-github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
-github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww=
-github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
-github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
-github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
-github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
-github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ=
-github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y=
-github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
-github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
-github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
-github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
-go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
-go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
-go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
-go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
-go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
-go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
-go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
-go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
-golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
-golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
-golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
-golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
-golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
-golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
-golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
-golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 h1:/ZScEX8SfEmUGRHs0gxpqteO5nfNW6axyZbBdw9A12g=
-golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
-golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
-golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
-golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
-golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek=
-golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY=
-golang.org/x/exp v0.0.0-20191129062945-2f5052295587/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
-golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
-golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
-golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
-golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
-golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
-golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
-golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
-golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
-golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
-golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
-golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
-golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
-golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
-golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs=
-golang.org/x/lint v0.0.0-20200130185559-910be7a94367/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
-golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
-golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
-golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE=
-golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=
-golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
-golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
-golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
-golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
-golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
-golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
-golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
-golang.org/x/net v0.0.0-20170114055629-f2499483f923/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
-golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
-golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
-golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
-golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
-golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20191004110552-13f9640d40b9/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20200222125558-5a598a2470a0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
-golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
-golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
-golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
-golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
-golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
-golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
-golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
-golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
-golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
-golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6 h1:0PC75Fz/kyMGhL0e1QnypqK2kQMqKt9csD1GnMJR+Zk=
-golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
-golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
-golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
-golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
-golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
-golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
-golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
-golang.org/x/oauth2 v0.0.0-20201109201403-9fd604954f58/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
-golang.org/x/oauth2 v0.0.0-20201208152858-08078c50e5b5/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
-golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20190209173611-3b5209105503/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20191022100944-742c48ecaeb7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20191210023423-ac6580df4449/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200120151820-655fe14d7479/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200124204421-9fbb57f87de9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200331124033-c3d80250170d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200916030750-2334cc1a136f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210309040221-94ec62e08169/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 h1:iGu644GcxtEcrInvDsQRCwJjtCIOlT2V7IRt6ah2Whw=
-golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
-golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
-golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
-golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
-golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
-golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba h1:O8mE0/t419eoIwhTFpKVkHiTs/Igowgfkj25AcZrtiE=
-golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
-golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
-golang.org/x/tools v0.0.0-20181011042414-1f849cf54d09/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
-golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
-golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
-golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
-golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
-golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
-golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
-golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
-golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
-golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
-golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
-golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
-golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
-golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.0.0-20191113191852-77e3bb0ad9e7/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.0.0-20191216173652-a0e659d51361/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.0.0-20191227053925-7b8e75db28f4/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.0.0-20200122220014-bf1340f18c4a/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.0.0-20200204074204-1cc6d1ef6c74/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.0.0-20200207183749-b753a1ba74fa/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.0.0-20200212150539-ea181f53ac56/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.0.0-20200224181240-023911ca70b2/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.0.0-20200227222343-706bc42d1f0d/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.0.0-20200304193943-95d2e580d8eb/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw=
-golang.org/x/tools v0.0.0-20200312045724-11d5b4c81c7d/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw=
-golang.org/x/tools v0.0.0-20200331025713-a30bf2db82d4/go.mod h1:Sl4aGygMT6LrqrWclx+PTx3U+LnKx/seiNR+3G19Ar8=
-golang.org/x/tools v0.0.0-20200501065659-ab2804fb9c9d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
-golang.org/x/tools v0.0.0-20200512131952-2bc93b1c0c88/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
-golang.org/x/tools v0.0.0-20200515010526-7d3b6ebf133d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
-golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
-golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
-golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
-golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
-golang.org/x/tools v0.0.0-20200904185747-39188db58858/go.mod h1:Cj7w3i3Rnn0Xh82ur9kSqwfTHTeVxaDqrfMjpcNT6bE=
-golang.org/x/tools v0.0.0-20201110124207-079ba7bd75cd/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
-golang.org/x/tools v0.0.0-20201201161351-ac6f37ff4c2a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
-golang.org/x/tools v0.0.0-20210108195828-e2f9c7f1fc8e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
-golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0=
-golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 h1:ytS28bw9HtZVDRMDxviC6ryCJuccw+zXhh04u2IRWJw=
-golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22/go.mod h1:a057zjmoc00UN7gVkaJt2sXVK523kMJcogDTEvPIasg=
-google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
-google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
-google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
-google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
-google.golang.org/api v0.13.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI=
-google.golang.org/api v0.14.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI=
-google.golang.org/api v0.15.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI=
-google.golang.org/api v0.17.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE=
-google.golang.org/api v0.18.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE=
-google.golang.org/api v0.19.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE=
-google.golang.org/api v0.20.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE=
-google.golang.org/api v0.22.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE=
-google.golang.org/api v0.24.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE=
-google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE=
-google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM=
-google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc=
-google.golang.org/api v0.35.0/go.mod h1:/XrVsuzM0rZmrsbjJutiuftIzeuTQcEeaYcSk/mQ1dg=
-google.golang.org/api v0.36.0/go.mod h1:+z5ficQTmoYpPn8LCUNVpK5I7hwkpjbcgqA7I34qYtE=
-google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
-google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
-google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
-google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0=
-google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
-google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
-google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
-google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
-google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
-google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
-google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
-google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
-google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
-google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
-google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8=
-google.golang.org/genproto v0.0.0-20191108220845-16a3f7862a1a/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
-google.golang.org/genproto v0.0.0-20191115194625-c23dd37a84c9/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
-google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
-google.golang.org/genproto v0.0.0-20191230161307-f3c370f40bfb/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
-google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
-google.golang.org/genproto v0.0.0-20200117163144-32f20d992d24/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
-google.golang.org/genproto v0.0.0-20200122232147-0452cf42e150/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
-google.golang.org/genproto v0.0.0-20200204135345-fa8e72b47b90/go.mod h1:GmwEX6Z4W5gMy59cAlVYjN9JhxgbQH6Gn+gFDQe2lzA=
-google.golang.org/genproto v0.0.0-20200212174721-66ed5ce911ce/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
-google.golang.org/genproto v0.0.0-20200224152610-e50cd9704f63/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
-google.golang.org/genproto v0.0.0-20200228133532-8c2c7df3a383/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
-google.golang.org/genproto v0.0.0-20200305110556-506484158171/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
-google.golang.org/genproto v0.0.0-20200312145019-da6875a35672/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
-google.golang.org/genproto v0.0.0-20200331122359-1ee6d9798940/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
-google.golang.org/genproto v0.0.0-20200430143042-b979b6f78d84/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
-google.golang.org/genproto v0.0.0-20200511104702-f5ebc3bea380/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
-google.golang.org/genproto v0.0.0-20200515170657-fc4c6c6a6587/go.mod h1:YsZOwe1myG/8QRHRsmBRE1LrgQY60beZKjly0O1fX9U=
-google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
-google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7FcilCzHH/e9qn6dsT145K34l5v+OpcnNgKAAA=
-google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
-google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
-google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
-google.golang.org/genproto v0.0.0-20200904004341-0bd0a958aa1d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
-google.golang.org/genproto v0.0.0-20201109203340-2640f1f9cdfb/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
-google.golang.org/genproto v0.0.0-20201201144952-b05cb90ed32e/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
-google.golang.org/genproto v0.0.0-20210108203827-ffc7fda8c3d7/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
-google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
-google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
-google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
-google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
-google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
-google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
-google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
-google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
-google.golang.org/grpc v1.27.1/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
-google.golang.org/grpc v1.28.0/go.mod h1:rpkK4SK4GF4Ach/+MFLZUBavHOvF2JJB5uozKKal+60=
-google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk=
-google.golang.org/grpc v1.30.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak=
-google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak=
-google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak=
-google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
-google.golang.org/grpc v1.34.0/go.mod h1:WotjhfgOW/POjDeRt8vscBtXq+2VjORFy659qA51WJ8=
-google.golang.org/grpc v1.36.0-dev.0.20210208035533-9280052d3665/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU=
-google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
-google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
-google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
-google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
-google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
-google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
-google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
-google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
-google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4=
-google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
-google.golang.org/protobuf v1.25.1-0.20201020201750-d3470999428b/go.mod h1:hFxJC2f0epmp1elRCiEGJTKAWbwxZ2nvqZdHl3FQXCY=
-gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
-gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
-gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
-gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
-gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
-gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
-gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
-gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw=
-gvisor.dev/gvisor v0.0.0-20210506004418-fbfeba3024f0 h1:Ny53d6x76f7EgvYe7pi8SQcIzdRM69y1ckQ8rRtT6ww=
-gvisor.dev/gvisor v0.0.0-20210506004418-fbfeba3024f0/go.mod h1:ucHEMlckp+S/YzKEpwwAyGBhAh807Wxq/8Erc6gFxCE=
-honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
-honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
-honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
-honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
-honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
-honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
-honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
-honnef.co/go/tools v0.1.1/go.mod h1:NgwopIslSNH47DimFoV78dnkksY2EFtX0ajyb3K/las=
-k8s.io/api v0.16.13/go.mod h1:QWu8UWSTiuQZMMeYjwLs6ILu5O74qKSJ0c+4vrchDxs=
-k8s.io/apimachinery v0.16.13/go.mod h1:4HMHS3mDHtVttspuuhrJ1GGr/0S9B6iWYWZ57KnnZqQ=
-k8s.io/apimachinery v0.16.14-rc.0/go.mod h1:4HMHS3mDHtVttspuuhrJ1GGr/0S9B6iWYWZ57KnnZqQ=
-k8s.io/client-go v0.16.13/go.mod h1:UKvVT4cajC2iN7DCjLgT0KVY/cbY6DGdUCyRiIfws5M=
-k8s.io/gengo v0.0.0-20190128074634-0689ccc1d7d6/go.mod h1:ezvh/TsK7cY6rbqRK0oQQ8IAqLxYwwyPxAX1Pzy0ii0=
-k8s.io/klog v0.0.0-20181102134211-b9b56d5dfc92/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk=
-k8s.io/klog v0.3.0/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk=
-k8s.io/klog v1.0.0/go.mod h1:4Bi6QPql/J/LkTDqv7R/cd3hPo4k2DG6Ptcz060Ez5I=
-k8s.io/kube-openapi v0.0.0-20200410163147-594e756bea31/go.mod h1:1TqjTSzOxsLGIKfj0lK8EeCP7K1iUG65v09OM0/WG5E=
-k8s.io/utils v0.0.0-20190801114015-581e00157fb1/go.mod h1:sZAwmy6armz5eXlNoLmJcl4F1QuKu7sr+mFQ0byX7Ew=
-rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
-rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
-rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
-sigs.k8s.io/structured-merge-diff v0.0.0-20190525122527-15d366b2352e/go.mod h1:wWxsB5ozmmv/SG7nM11ayaAW51xMvak/t1r0CSlcokI=
-sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o=
diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go
index 9d2e90a..2b73054 100644
--- a/tun/netstack/tun.go
+++ b/tun/netstack/tun.go
@@ -1,11 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package netstack
import (
+ "bytes"
"context"
"crypto/rand"
"encoding/binary"
@@ -13,112 +14,85 @@ import (
"fmt"
"io"
"net"
+ "net/netip"
"os"
+ "regexp"
"strconv"
"strings"
+ "syscall"
"time"
"golang.zx2c4.com/wireguard/tun"
"golang.org/x/net/dns/dnsmessage"
+ "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
)
type netTun struct {
+ ep *channel.Endpoint
stack *stack.Stack
- dispatcher stack.NetworkDispatcher
events chan tun.Event
- incomingPacket chan buffer.VectorisedView
+ incomingPacket chan *buffer.View
mtu int
- dnsServers []net.IP
+ dnsServers []netip.Addr
hasV4, hasV6 bool
}
-type endpoint netTun
-type Net netTun
-
-func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
- e.dispatcher = dispatcher
-}
-
-func (e *endpoint) IsAttached() bool {
- return e.dispatcher != nil
-}
-
-func (e *endpoint) MTU() uint32 {
- mtu, err := (*netTun)(e).MTU()
- if err != nil {
- panic(err)
- }
- return uint32(mtu)
-}
-
-func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return stack.CapabilityNone
-}
-
-func (*endpoint) MaxHeaderLength() uint16 {
- return 0
-}
-
-func (*endpoint) LinkAddress() tcpip.LinkAddress {
- return ""
-}
-
-func (*endpoint) Wait() {}
-
-func (e *endpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- e.incomingPacket <- buffer.NewVectorisedView(pkt.Size(), pkt.Views())
- return nil
-}
-
-func (e *endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- panic("not implemented")
-}
-func (*endpoint) ARPHardwareType() header.ARPHardwareType {
- return header.ARPHardwareNone
-}
-
-func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
-}
+type Net netTun
-func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Net, error) {
+func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
HandleLocal: true,
}
dev := &netTun{
+ ep: channel.New(1024, uint32(mtu), ""),
stack: stack.New(opts),
events: make(chan tun.Event, 10),
- incomingPacket: make(chan buffer.VectorisedView),
+ incomingPacket: make(chan *buffer.View),
dnsServers: dnsServers,
mtu: mtu,
}
- tcpipErr := dev.stack.CreateNIC(1, (*endpoint)(dev))
+ sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
+ tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
+ if tcpipErr != nil {
+ return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
+ }
+ dev.ep.AddNotify(dev)
+ tcpipErr = dev.stack.CreateNIC(1, dev.ep)
if tcpipErr != nil {
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
}
for _, ip := range localAddresses {
- if ip4 := ip.To4(); ip4 != nil {
- tcpipErr = dev.stack.AddAddress(1, ipv4.ProtocolNumber, tcpip.Address(ip4))
- if tcpipErr != nil {
- return nil, nil, fmt.Errorf("AddAddress(%v): %v", ip4, tcpipErr)
- }
+ var protoNumber tcpip.NetworkProtocolNumber
+ if ip.Is4() {
+ protoNumber = ipv4.ProtocolNumber
+ } else if ip.Is6() {
+ protoNumber = ipv6.ProtocolNumber
+ }
+ protoAddr := tcpip.ProtocolAddress{
+ Protocol: protoNumber,
+ AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
+ }
+ tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
+ if tcpipErr != nil {
+ return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
+ }
+ if ip.Is4() {
dev.hasV4 = true
- } else {
- tcpipErr = dev.stack.AddAddress(1, ipv6.ProtocolNumber, tcpip.Address(ip))
- if tcpipErr != nil {
- return nil, nil, fmt.Errorf("AddAddress(%v): %v", ip4, tcpipErr)
- }
+ } else if ip.Is6() {
dev.hasV6 = true
}
}
@@ -141,37 +115,54 @@ func (tun *netTun) File() *os.File {
return nil
}
-func (tun *netTun) Events() chan tun.Event {
+func (tun *netTun) Events() <-chan tun.Event {
return tun.events
}
-func (tun *netTun) Read(buf []byte, offset int) (int, error) {
+func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
view, ok := <-tun.incomingPacket
if !ok {
return 0, os.ErrClosed
}
- return view.Read(buf[offset:])
-}
-func (tun *netTun) Write(buf []byte, offset int) (int, error) {
- packet := buf[offset:]
- if len(packet) == 0 {
- return 0, nil
+ n, err := view.Read(buf[0][offset:])
+ if err != nil {
+ return 0, err
}
+ sizes[0] = n
+ return 1, nil
+}
- pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(len(packet), []buffer.View{buffer.NewViewFromBytes(packet)})})
- switch packet[0] >> 4 {
- case 4:
- tun.dispatcher.DeliverNetworkPacket("", "", ipv4.ProtocolNumber, pkb)
- case 6:
- tun.dispatcher.DeliverNetworkPacket("", "", ipv6.ProtocolNumber, pkb)
- }
+func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
+ for _, buf := range buf {
+ packet := buf[offset:]
+ if len(packet) == 0 {
+ continue
+ }
+ pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
+ switch packet[0] >> 4 {
+ case 4:
+ tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
+ case 6:
+ tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
+ default:
+ return 0, syscall.EAFNOSUPPORT
+ }
+ }
return len(buf), nil
}
-func (tun *netTun) Flush() error {
- return nil
+func (tun *netTun) WriteNotify() {
+ pkt := tun.ep.Read()
+ if pkt.IsNil() {
+ return
+ }
+
+ view := pkt.ToView()
+ pkt.DecRef()
+
+ tun.incomingPacket <- view
}
func (tun *netTun) Close() error {
@@ -180,9 +171,13 @@ func (tun *netTun) Close() error {
if tun.events != nil {
close(tun.events)
}
+
+ tun.ep.Close()
+
if tun.incomingPacket != nil {
close(tun.incomingPacket)
}
+
return nil
}
@@ -190,62 +185,294 @@ func (tun *netTun) MTU() (int, error) {
return tun.mtu, nil
}
-func convertToFullAddr(ip net.IP, port int) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
- if ip4 := ip.To4(); ip4 != nil {
- return tcpip.FullAddress{
- NIC: 1,
- Addr: tcpip.Address(ip4),
- Port: uint16(port),
- }, ipv4.ProtocolNumber
+func (tun *netTun) BatchSize() int {
+ return 1
+}
+
+func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
+ var protoNumber tcpip.NetworkProtocolNumber
+ if endpoint.Addr().Is4() {
+ protoNumber = ipv4.ProtocolNumber
} else {
- return tcpip.FullAddress{
- NIC: 1,
- Addr: tcpip.Address(ip),
- Port: uint16(port),
- }, ipv6.ProtocolNumber
+ protoNumber = ipv6.ProtocolNumber
}
+ return tcpip.FullAddress{
+ NIC: 1,
+ Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
+ Port: endpoint.Port(),
+ }, protoNumber
+}
+
+func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.DialContextTCP(ctx, net.stack, fa, pn)
}
func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil {
- panic("todo: deal with auto addr semantics for nil addr")
+ return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
}
- fa, pn := convertToFullAddr(addr.IP, addr.Port)
- return gonet.DialContextTCP(ctx, net.stack, fa, pn)
+ ip, _ := netip.AddrFromSlice(addr.IP)
+ return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port)))
+}
+
+func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.DialTCP(net.stack, fa, pn)
}
func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil {
- panic("todo: deal with auto addr semantics for nil addr")
+ return net.DialTCPAddrPort(netip.AddrPort{})
}
- fa, pn := convertToFullAddr(addr.IP, addr.Port)
- return gonet.DialTCP(net.stack, fa, pn)
+ ip, _ := netip.AddrFromSlice(addr.IP)
+ return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
+}
+
+func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.ListenTCP(net.stack, fa, pn)
}
func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
if addr == nil {
- panic("todo: deal with auto addr semantics for nil addr")
+ return net.ListenTCPAddrPort(netip.AddrPort{})
}
- fa, pn := convertToFullAddr(addr.IP, addr.Port)
- return gonet.ListenTCP(net.stack, fa, pn)
+ ip, _ := netip.AddrFromSlice(addr.IP)
+ return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
}
-func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
+func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
var lfa, rfa *tcpip.FullAddress
var pn tcpip.NetworkProtocolNumber
- if laddr != nil {
+ if laddr.IsValid() || laddr.Port() > 0 {
var addr tcpip.FullAddress
- addr, pn = convertToFullAddr(laddr.IP, laddr.Port)
+ addr, pn = convertToFullAddr(laddr)
lfa = &addr
}
- if raddr != nil {
+ if raddr.IsValid() || raddr.Port() > 0 {
var addr tcpip.FullAddress
- addr, pn = convertToFullAddr(raddr.IP, raddr.Port)
+ addr, pn = convertToFullAddr(raddr)
rfa = &addr
}
return gonet.DialUDP(net.stack, lfa, rfa, pn)
}
+func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) {
+ return net.DialUDPAddrPort(laddr, netip.AddrPort{})
+}
+
+func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
+ var la, ra netip.AddrPort
+ if laddr != nil {
+ ip, _ := netip.AddrFromSlice(laddr.IP)
+ la = netip.AddrPortFrom(ip, uint16(laddr.Port))
+ }
+ if raddr != nil {
+ ip, _ := netip.AddrFromSlice(raddr.IP)
+ ra = netip.AddrPortFrom(ip, uint16(raddr.Port))
+ }
+ return net.DialUDPAddrPort(la, ra)
+}
+
+func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
+ return net.DialUDP(laddr, nil)
+}
+
+type PingConn struct {
+ laddr PingAddr
+ raddr PingAddr
+ wq waiter.Queue
+ ep tcpip.Endpoint
+ deadline *time.Timer
+}
+
+type PingAddr struct{ addr netip.Addr }
+
+func (ia PingAddr) String() string {
+ return ia.addr.String()
+}
+
+func (ia PingAddr) Network() string {
+ if ia.addr.Is4() {
+ return "ping4"
+ } else if ia.addr.Is6() {
+ return "ping6"
+ }
+ return "ping"
+}
+
+func (ia PingAddr) Addr() netip.Addr {
+ return ia.addr
+}
+
+func PingAddrFromAddr(addr netip.Addr) *PingAddr {
+ return &PingAddr{addr}
+}
+
+func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
+ if !laddr.IsValid() && !raddr.IsValid() {
+ return nil, errors.New("ping dial: invalid address")
+ }
+ v6 := laddr.Is6() || raddr.Is6()
+ bind := laddr.IsValid()
+ if !bind {
+ if v6 {
+ laddr = netip.IPv6Unspecified()
+ } else {
+ laddr = netip.IPv4Unspecified()
+ }
+ }
+
+ tn := icmp.ProtocolNumber4
+ pn := ipv4.ProtocolNumber
+ if v6 {
+ tn = icmp.ProtocolNumber6
+ pn = ipv6.ProtocolNumber
+ }
+
+ pc := &PingConn{
+ laddr: PingAddr{laddr},
+ deadline: time.NewTimer(time.Hour << 10),
+ }
+ pc.deadline.Stop()
+
+ ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
+ if tcpipErr != nil {
+ return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr)
+ }
+ pc.ep = ep
+
+ if bind {
+ fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0))
+ if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil {
+ return nil, fmt.Errorf("ping bind: %s", tcpipErr)
+ }
+ }
+
+ if raddr.IsValid() {
+ pc.raddr = PingAddr{raddr}
+ fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0))
+ if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil {
+ return nil, fmt.Errorf("ping connect: %s", tcpipErr)
+ }
+ }
+
+ return pc, nil
+}
+
+func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) {
+ return net.DialPingAddr(laddr, netip.Addr{})
+}
+
+func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) {
+ var la, ra netip.Addr
+ if laddr != nil {
+ la = laddr.addr
+ }
+ if raddr != nil {
+ ra = raddr.addr
+ }
+ return net.DialPingAddr(la, ra)
+}
+
+func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) {
+ var la netip.Addr
+ if laddr != nil {
+ la = laddr.addr
+ }
+ return net.ListenPingAddr(la)
+}
+
+func (pc *PingConn) LocalAddr() net.Addr {
+ return pc.laddr
+}
+
+func (pc *PingConn) RemoteAddr() net.Addr {
+ return pc.raddr
+}
+
+func (pc *PingConn) Close() error {
+ pc.deadline.Reset(0)
+ pc.ep.Close()
+ return nil
+}
+
+func (pc *PingConn) SetWriteDeadline(t time.Time) error {
+ return errors.New("not implemented")
+}
+
+func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
+ var na netip.Addr
+ switch v := addr.(type) {
+ case *PingAddr:
+ na = v.addr
+ case *net.IPAddr:
+ na, _ = netip.AddrFromSlice(v.IP)
+ default:
+ return 0, fmt.Errorf("ping write: wrong net.Addr type")
+ }
+ if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) {
+ return 0, fmt.Errorf("ping write: mismatched protocols")
+ }
+
+ buf := bytes.NewReader(p)
+ rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0))
+ // won't block, no deadlines
+ n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{
+ To: &rfa,
+ })
+ if tcpipErr != nil {
+ return int(n64), fmt.Errorf("ping write: %s", tcpipErr)
+ }
+
+ return int(n64), nil
+}
+
+func (pc *PingConn) Write(p []byte) (n int, err error) {
+ return pc.WriteTo(p, &pc.raddr)
+}
+
+func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
+ e, notifyCh := waiter.NewChannelEntry(waiter.EventIn)
+ pc.wq.EventRegister(&e)
+ defer pc.wq.EventUnregister(&e)
+
+ select {
+ case <-pc.deadline.C:
+ return 0, nil, os.ErrDeadlineExceeded
+ case <-notifyCh:
+ }
+
+ w := tcpip.SliceWriter(p)
+
+ res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{
+ NeedRemoteAddr: true,
+ })
+ if tcpipErr != nil {
+ return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
+ }
+
+ remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice())
+ return res.Count, &PingAddr{remoteAddr}, nil
+}
+
+func (pc *PingConn) Read(p []byte) (n int, err error) {
+ n, _, err = pc.ReadFrom(p)
+ return
+}
+
+func (pc *PingConn) SetDeadline(t time.Time) error {
+ // pc.SetWriteDeadline is unimplemented
+
+ return pc.SetReadDeadline(t)
+}
+
+func (pc *PingConn) SetReadDeadline(t time.Time) error {
+ pc.deadline.Reset(time.Until(t))
+ return nil
+}
+
var (
errNoSuchHost = errors.New("no such host")
errLameReferral = errors.New("lame referral")
@@ -421,7 +648,7 @@ func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []by
return p, h, nil
}
-func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
+func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
q.Class = dnsmessage.ClassINET
id, udpReq, tcpReq, err := newRequest(q)
if err != nil {
@@ -435,16 +662,19 @@ func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Quest
var c net.Conn
var err error
if useUDP {
- c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: server, Port: 53})
+ c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
} else {
- c, err = tnet.DialContextTCP(ctx, &net.TCPAddr{IP: server, Port: 53})
+ c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
}
if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
if d, ok := ctx.Deadline(); ok && !d.IsZero() {
- c.SetDeadline(d)
+ err := c.SetDeadline(d)
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
}
var p dnsmessage.Parser
var h dnsmessage.Header
@@ -588,8 +818,8 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
zlen = zidx
}
}
- if ip := net.ParseIP(host[:zlen]); ip != nil {
- return []string{host[:zlen]}, nil
+ if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
+ return []string{ip.String()}, nil
}
if !isDomainName(host) {
@@ -600,7 +830,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
server string
error
}
- var addrsV4, addrsV6 []net.IP
+ var addrsV4, addrsV6 []netip.Addr
lanes := 0
if tnet.hasV4 {
lanes++
@@ -655,7 +885,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
break loop
}
- addrsV4 = append(addrsV4, net.IP(a.A[:]))
+ addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
case dnsmessage.TypeAAAA:
aaaa, err := result.p.AAAAResource()
@@ -667,7 +897,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
break loop
}
- addrsV6 = append(addrsV6, net.IP(aaaa.AAAA[:]))
+ addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
default:
if err := result.p.SkipAnswer(); err != nil {
@@ -682,8 +912,8 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
}
}
- // We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
- var addrs []net.IP
+ // We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled
+ var addrs []netip.Addr
if tnet.hasV6 {
addrs = append(addrsV6, addrsV4...)
} else {
@@ -720,44 +950,48 @@ func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, er
return now.Add(timeout), nil
}
+var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`)
+
func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
if ctx == nil {
panic("nil context")
}
- var acceptV4, acceptV6, useUDP bool
- if len(network) == 3 {
+ var acceptV4, acceptV6 bool
+ matches := protoSplitter.FindStringSubmatch(network)
+ if matches == nil {
+ return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
+ } else if len(matches[2]) == 0 {
acceptV4 = true
acceptV6 = true
- } else if len(network) == 4 {
- acceptV4 = network[3] == '4'
- acceptV6 = network[3] == '6'
- }
- if !acceptV4 && !acceptV6 {
- return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
- }
- if network[:3] == "udp" {
- useUDP = true
- } else if network[:3] != "tcp" {
- return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
- }
- host, sport, err := net.SplitHostPort(address)
- if err != nil {
- return nil, &net.OpError{Op: "dial", Err: err}
+ } else {
+ acceptV4 = matches[2][0] == '4'
+ acceptV6 = !acceptV4
}
- port, err := strconv.Atoi(sport)
- if err != nil || port < 0 || port > 65535 {
- return nil, &net.OpError{Op: "dial", Err: errNumericPort}
+ var host string
+ var port int
+ if matches[1] == "ping" {
+ host = address
+ } else {
+ var sport string
+ var err error
+ host, sport, err = net.SplitHostPort(address)
+ if err != nil {
+ return nil, &net.OpError{Op: "dial", Err: err}
+ }
+ port, err = strconv.Atoi(sport)
+ if err != nil || port < 0 || port > 65535 {
+ return nil, &net.OpError{Op: "dial", Err: errNumericPort}
+ }
}
allAddr, err := tnet.LookupContextHost(ctx, host)
if err != nil {
return nil, &net.OpError{Op: "dial", Err: err}
}
- var addrs []net.IP
+ var addrs []netip.AddrPort
for _, addr := range allAddr {
- if strings.IndexByte(addr, ':') != -1 && acceptV6 {
- addrs = append(addrs, net.ParseIP(addr))
- } else if strings.IndexByte(addr, '.') != -1 && acceptV4 {
- addrs = append(addrs, net.ParseIP(addr))
+ ip, err := netip.ParseAddr(addr)
+ if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
+ addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
}
}
if len(addrs) == 0 && len(allAddr) != 0 {
@@ -795,10 +1029,13 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
}
var c net.Conn
- if useUDP {
- c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: addr, Port: port})
- } else {
- c, err = tnet.DialContextTCP(dialCtx, &net.TCPAddr{IP: addr, Port: port})
+ switch matches[1] {
+ case "tcp":
+ c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
+ case "udp":
+ c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
+ case "ping":
+ c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr())
}
if err == nil {
return c, nil
diff --git a/tun/offload_linux.go b/tun/offload_linux.go
new file mode 100644
index 0000000..9ff7fea
--- /dev/null
+++ b/tun/offload_linux.go
@@ -0,0 +1,993 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "io"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/conn"
+)
+
+const tcpFlagsOffset = 13
+
+const (
+ tcpFlagFIN uint8 = 0x01
+ tcpFlagPSH uint8 = 0x08
+ tcpFlagACK uint8 = 0x10
+)
+
+// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
+// kernel symbol is virtio_net_hdr.
+type virtioNetHdr struct {
+ flags uint8
+ gsoType uint8
+ hdrLen uint16
+ gsoSize uint16
+ csumStart uint16
+ csumOffset uint16
+}
+
+func (v *virtioNetHdr) decode(b []byte) error {
+ if len(b) < virtioNetHdrLen {
+ return io.ErrShortBuffer
+ }
+ copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
+ return nil
+}
+
+func (v *virtioNetHdr) encode(b []byte) error {
+ if len(b) < virtioNetHdrLen {
+ return io.ErrShortBuffer
+ }
+ copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
+ return nil
+}
+
+const (
+ // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
+ // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
+ virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
+)
+
+// tcpFlowKey represents the key for a TCP flow.
+type tcpFlowKey struct {
+ srcAddr, dstAddr [16]byte
+ srcPort, dstPort uint16
+ rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
+ isV6 bool
+}
+
+// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO.
+type tcpGROTable struct {
+ itemsByFlow map[tcpFlowKey][]tcpGROItem
+ itemsPool [][]tcpGROItem
+}
+
+func newTCPGROTable() *tcpGROTable {
+ t := &tcpGROTable{
+ itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, conn.IdealBatchSize),
+ itemsPool: make([][]tcpGROItem, conn.IdealBatchSize),
+ }
+ for i := range t.itemsPool {
+ t.itemsPool[i] = make([]tcpGROItem, 0, conn.IdealBatchSize)
+ }
+ return t
+}
+
+func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey {
+ key := tcpFlowKey{}
+ addrSize := dstAddrOffset - srcAddrOffset
+ copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
+ copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
+ key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
+ key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
+ key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
+ key.isV6 = addrSize == 16
+ return key
+}
+
+// lookupOrInsert looks up a flow for the provided packet and metadata,
+// returning the packets found for the flow, or inserting a new one if none
+// is found.
+func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
+ key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
+ items, ok := t.itemsByFlow[key]
+ if ok {
+ return items, ok
+ }
+ // TODO: insert() performs another map lookup. This could be rearranged to avoid.
+ t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
+ return nil, false
+}
+
+// insert an item in the table for the provided packet and packet metadata.
+func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
+ key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
+ item := tcpGROItem{
+ key: key,
+ bufsIndex: uint16(bufsIndex),
+ gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
+ iphLen: uint8(tcphOffset),
+ tcphLen: uint8(tcphLen),
+ sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
+ pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
+ }
+ items, ok := t.itemsByFlow[key]
+ if !ok {
+ items = t.newItems()
+ }
+ items = append(items, item)
+ t.itemsByFlow[key] = items
+}
+
+func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
+ items, _ := t.itemsByFlow[item.key]
+ items[i] = item
+}
+
+func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) {
+ items, _ := t.itemsByFlow[key]
+ items = append(items[:i], items[i+1:]...)
+ t.itemsByFlow[key] = items
+}
+
+// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
+// of a GRO evaluation across a vector of packets.
+type tcpGROItem struct {
+ key tcpFlowKey
+ sentSeq uint32 // the sequence number
+ bufsIndex uint16 // the index into the original bufs slice
+ numMerged uint16 // the number of packets merged into this item
+ gsoSize uint16 // payload size
+ iphLen uint8 // ip header len
+ tcphLen uint8 // tcp header len
+ pshSet bool // psh flag is set
+}
+
+func (t *tcpGROTable) newItems() []tcpGROItem {
+ var items []tcpGROItem
+ items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
+ return items
+}
+
+func (t *tcpGROTable) reset() {
+ for k, items := range t.itemsByFlow {
+ items = items[:0]
+ t.itemsPool = append(t.itemsPool, items)
+ delete(t.itemsByFlow, k)
+ }
+}
+
+// udpFlowKey represents the key for a UDP flow.
+type udpFlowKey struct {
+ srcAddr, dstAddr [16]byte
+ srcPort, dstPort uint16
+ isV6 bool
+}
+
+// udpGROTable holds flow and coalescing information for the purposes of UDP GRO.
+type udpGROTable struct {
+ itemsByFlow map[udpFlowKey][]udpGROItem
+ itemsPool [][]udpGROItem
+}
+
+func newUDPGROTable() *udpGROTable {
+ u := &udpGROTable{
+ itemsByFlow: make(map[udpFlowKey][]udpGROItem, conn.IdealBatchSize),
+ itemsPool: make([][]udpGROItem, conn.IdealBatchSize),
+ }
+ for i := range u.itemsPool {
+ u.itemsPool[i] = make([]udpGROItem, 0, conn.IdealBatchSize)
+ }
+ return u
+}
+
+func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey {
+ key := udpFlowKey{}
+ addrSize := dstAddrOffset - srcAddrOffset
+ copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
+ copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
+ key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:])
+ key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:])
+ key.isV6 = addrSize == 16
+ return key
+}
+
+// lookupOrInsert looks up a flow for the provided packet and metadata,
+// returning the packets found for the flow, or inserting a new one if none
+// is found.
+func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) {
+ key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
+ items, ok := u.itemsByFlow[key]
+ if ok {
+ return items, ok
+ }
+ // TODO: insert() performs another map lookup. This could be rearranged to avoid.
+ u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false)
+ return nil, false
+}
+
+// insert an item in the table for the provided packet and packet metadata.
+func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) {
+ key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
+ item := udpGROItem{
+ key: key,
+ bufsIndex: uint16(bufsIndex),
+ gsoSize: uint16(len(pkt[udphOffset+udphLen:])),
+ iphLen: uint8(udphOffset),
+ cSumKnownInvalid: cSumKnownInvalid,
+ }
+ items, ok := u.itemsByFlow[key]
+ if !ok {
+ items = u.newItems()
+ }
+ items = append(items, item)
+ u.itemsByFlow[key] = items
+}
+
+func (u *udpGROTable) updateAt(item udpGROItem, i int) {
+ items, _ := u.itemsByFlow[item.key]
+ items[i] = item
+}
+
+// udpGROItem represents bookkeeping data for a UDP packet during the lifetime
+// of a GRO evaluation across a vector of packets.
+type udpGROItem struct {
+ key udpFlowKey
+ bufsIndex uint16 // the index into the original bufs slice
+ numMerged uint16 // the number of packets merged into this item
+ gsoSize uint16 // payload size
+ iphLen uint8 // ip header len
+ cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown.
+}
+
+func (u *udpGROTable) newItems() []udpGROItem {
+ var items []udpGROItem
+ items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1]
+ return items
+}
+
+func (u *udpGROTable) reset() {
+ for k, items := range u.itemsByFlow {
+ items = items[:0]
+ u.itemsPool = append(u.itemsPool, items)
+ delete(u.itemsByFlow, k)
+ }
+}
+
+// canCoalesce represents the outcome of checking if two TCP packets are
+// candidates for coalescing.
+type canCoalesce int
+
+const (
+ coalescePrepend canCoalesce = -1
+ coalesceUnavailable canCoalesce = 0
+ coalesceAppend canCoalesce = 1
+)
+
+// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB
+// meet all requirements to be merged as part of a GRO operation, otherwise it
+// returns false.
+func ipHeadersCanCoalesce(pktA, pktB []byte) bool {
+ if len(pktA) < 9 || len(pktB) < 9 {
+ return false
+ }
+ if pktA[0]>>4 == 6 {
+ if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 {
+ // cannot coalesce with unequal Traffic class values
+ return false
+ }
+ if pktA[7] != pktB[7] {
+ // cannot coalesce with unequal Hop limit values
+ return false
+ }
+ } else {
+ if pktA[1] != pktB[1] {
+ // cannot coalesce with unequal ToS values
+ return false
+ }
+ if pktA[6]>>5 != pktB[6]>>5 {
+ // cannot coalesce with unequal DF or reserved bits. MF is checked
+ // further up the stack.
+ return false
+ }
+ if pktA[8] != pktB[8] {
+ // cannot coalesce with unequal TTL values
+ return false
+ }
+ }
+ return true
+}
+
+// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
+// described by item. iphLen and gsoSize describe pkt. bufs is the vector of
+// packets involved in the current GRO evaluation. bufsOffset is the offset at
+// which packet data begins within bufs.
+func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
+ pktTarget := bufs[item.bufsIndex][bufsOffset:]
+ if !ipHeadersCanCoalesce(pkt, pktTarget) {
+ return coalesceUnavailable
+ }
+ if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 {
+ // A smaller than gsoSize packet has been appended previously.
+ // Nothing can come after a smaller packet on the end.
+ return coalesceUnavailable
+ }
+ if gsoSize > item.gsoSize {
+ // We cannot have a larger packet following a smaller one.
+ return coalesceUnavailable
+ }
+ return coalesceAppend
+}
+
+// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
+// described by item. This function makes considerations that match the kernel's
+// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
+func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
+ pktTarget := bufs[item.bufsIndex][bufsOffset:]
+ if tcphLen != item.tcphLen {
+ // cannot coalesce with unequal tcp options len
+ return coalesceUnavailable
+ }
+ if tcphLen > 20 {
+ if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
+ // cannot coalesce with unequal tcp options
+ return coalesceUnavailable
+ }
+ }
+ if !ipHeadersCanCoalesce(pkt, pktTarget) {
+ return coalesceUnavailable
+ }
+ // seq adjacency
+ lhsLen := item.gsoSize
+ lhsLen += item.numMerged * item.gsoSize
+ if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
+ if item.pshSet {
+ // We cannot append to a segment that has the PSH flag set, PSH
+ // can only be set on the final segment in a reassembled group.
+ return coalesceUnavailable
+ }
+ if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
+ // A smaller than gsoSize packet has been appended previously.
+ // Nothing can come after a smaller packet on the end.
+ return coalesceUnavailable
+ }
+ if gsoSize > item.gsoSize {
+ // We cannot have a larger packet following a smaller one.
+ return coalesceUnavailable
+ }
+ return coalesceAppend
+ } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
+ if pshSet {
+ // We cannot prepend with a segment that has the PSH flag set, PSH
+ // can only be set on the final segment in a reassembled group.
+ return coalesceUnavailable
+ }
+ if gsoSize < item.gsoSize {
+ // We cannot have a larger packet following a smaller one.
+ return coalesceUnavailable
+ }
+ if gsoSize > item.gsoSize && item.numMerged > 0 {
+ // There's at least one previous merge, and we're larger than all
+ // previous. This would put multiple smaller packets on the end.
+ return coalesceUnavailable
+ }
+ return coalescePrepend
+ }
+ return coalesceUnavailable
+}
+
+func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool {
+ srcAddrAt := ipv4SrcAddrOffset
+ addrSize := 4
+ if isV6 {
+ srcAddrAt = ipv6SrcAddrOffset
+ addrSize = 16
+ }
+ lenForPseudo := uint16(len(pkt) - int(iphLen))
+ cSum := pseudoHeaderChecksumNoFold(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo)
+ return ^checksum(pkt[iphLen:], cSum) == 0
+}
+
+// coalesceResult represents the result of attempting to coalesce two TCP
+// packets.
+type coalesceResult int
+
+const (
+ coalesceInsufficientCap coalesceResult = iota
+ coalescePSHEnding
+ coalesceItemInvalidCSum
+ coalescePktInvalidCSum
+ coalesceSuccess
+)
+
+// coalesceUDPPackets attempts to coalesce pkt with the packet described by
+// item, and returns the outcome.
+func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
+ pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front
+ headersLen := item.iphLen + udphLen
+ coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
+
+ if cap(pktHead)-bufsOffset < coalescedLen {
+ // We don't want to allocate a new underlying array if capacity is
+ // too small.
+ return coalesceInsufficientCap
+ }
+ if item.numMerged == 0 {
+ if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) {
+ return coalesceItemInvalidCSum
+ }
+ }
+ if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) {
+ return coalescePktInvalidCSum
+ }
+ extendBy := len(pkt) - int(headersLen)
+ bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
+ copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
+
+ item.numMerged++
+ return coalesceSuccess
+}
+
+// coalesceTCPPackets attempts to coalesce pkt with the packet described by
+// item, and returns the outcome. This function may swap bufs elements in the
+// event of a prepend as item's bufs index is already being tracked for writing
+// to a Device.
+func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
+ var pktHead []byte // the packet that will end up at the front
+ headersLen := item.iphLen + item.tcphLen
+ coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
+
+ // Copy data
+ if mode == coalescePrepend {
+ pktHead = pkt
+ if cap(pkt)-bufsOffset < coalescedLen {
+ // We don't want to allocate a new underlying array if capacity is
+ // too small.
+ return coalesceInsufficientCap
+ }
+ if pshSet {
+ return coalescePSHEnding
+ }
+ if item.numMerged == 0 {
+ if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
+ return coalesceItemInvalidCSum
+ }
+ }
+ if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
+ return coalescePktInvalidCSum
+ }
+ item.sentSeq = seq
+ extendBy := coalescedLen - len(pktHead)
+ bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
+ copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
+ // Flip the slice headers in bufs as part of prepend. The index of item
+ // is already being tracked for writing.
+ bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
+ } else {
+ pktHead = bufs[item.bufsIndex][bufsOffset:]
+ if cap(pktHead)-bufsOffset < coalescedLen {
+ // We don't want to allocate a new underlying array if capacity is
+ // too small.
+ return coalesceInsufficientCap
+ }
+ if item.numMerged == 0 {
+ if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
+ return coalesceItemInvalidCSum
+ }
+ }
+ if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
+ return coalescePktInvalidCSum
+ }
+ if pshSet {
+ // We are appending a segment with PSH set.
+ item.pshSet = pshSet
+ pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
+ }
+ extendBy := len(pkt) - int(headersLen)
+ bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
+ copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
+ }
+
+ if gsoSize > item.gsoSize {
+ item.gsoSize = gsoSize
+ }
+
+ item.numMerged++
+ return coalesceSuccess
+}
+
+const (
+ ipv4FlagMoreFragments uint8 = 0x20
+)
+
+const (
+ ipv4SrcAddrOffset = 12
+ ipv6SrcAddrOffset = 8
+ maxUint16 = 1<<16 - 1
+)
+
+type groResult int
+
+const (
+ groResultNoop groResult = iota
+ groResultTableInsert
+ groResultCoalesced
+)
+
+// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
+// existing packets tracked in table. It returns a groResultNoop when no
+// action was taken, groResultTableInsert when the evaluated packet was
+// inserted into table, and groResultCoalesced when the evaluated packet was
+// coalesced with another packet in table.
+func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult {
+ pkt := bufs[pktI][offset:]
+ if len(pkt) > maxUint16 {
+ // A valid IPv4 or IPv6 packet will never exceed this.
+ return groResultNoop
+ }
+ iphLen := int((pkt[0] & 0x0F) * 4)
+ if isV6 {
+ iphLen = 40
+ ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
+ if ipv6HPayloadLen != len(pkt)-iphLen {
+ return groResultNoop
+ }
+ } else {
+ totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
+ if totalLen != len(pkt) {
+ return groResultNoop
+ }
+ }
+ if len(pkt) < iphLen {
+ return groResultNoop
+ }
+ tcphLen := int((pkt[iphLen+12] >> 4) * 4)
+ if tcphLen < 20 || tcphLen > 60 {
+ return groResultNoop
+ }
+ if len(pkt) < iphLen+tcphLen {
+ return groResultNoop
+ }
+ if !isV6 {
+ if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
+ // no GRO support for fragmented segments for now
+ return groResultNoop
+ }
+ }
+ tcpFlags := pkt[iphLen+tcpFlagsOffset]
+ var pshSet bool
+ // not a candidate if any non-ACK flags (except PSH+ACK) are set
+ if tcpFlags != tcpFlagACK {
+ if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
+ return groResultNoop
+ }
+ pshSet = true
+ }
+ gsoSize := uint16(len(pkt) - tcphLen - iphLen)
+ // not a candidate if payload len is 0
+ if gsoSize < 1 {
+ return groResultNoop
+ }
+ seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
+ srcAddrOffset := ipv4SrcAddrOffset
+ addrLen := 4
+ if isV6 {
+ srcAddrOffset = ipv6SrcAddrOffset
+ addrLen = 16
+ }
+ items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
+ if !existing {
+ return groResultTableInsert
+ }
+ for i := len(items) - 1; i >= 0; i-- {
+ // In the best case of packets arriving in order iterating in reverse is
+ // more efficient if there are multiple items for a given flow. This
+ // also enables a natural table.deleteAt() in the
+ // coalesceItemInvalidCSum case without the need for index tracking.
+ // This algorithm makes a best effort to coalesce in the event of
+ // unordered packets, where pkt may land anywhere in items from a
+ // sequence number perspective, however once an item is inserted into
+ // the table it is never compared across other items later.
+ item := items[i]
+ can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
+ if can != coalesceUnavailable {
+ result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
+ switch result {
+ case coalesceSuccess:
+ table.updateAt(item, i)
+ return groResultCoalesced
+ case coalesceItemInvalidCSum:
+ // delete the item with an invalid csum
+ table.deleteAt(item.key, i)
+ case coalescePktInvalidCSum:
+ // no point in inserting an item that we can't coalesce
+ return groResultNoop
+ default:
+ }
+ }
+ }
+ // failed to coalesce with any other packets; store the item in the flow
+ table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
+ return groResultTableInsert
+}
+
+// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the
+// metadata found in table.
+func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error {
+ for _, items := range table.itemsByFlow {
+ for _, item := range items {
+ if item.numMerged > 0 {
+ hdr := virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
+ hdrLen: uint16(item.iphLen + item.tcphLen),
+ gsoSize: item.gsoSize,
+ csumStart: uint16(item.iphLen),
+ csumOffset: 16,
+ }
+ pkt := bufs[item.bufsIndex][offset:]
+
+ // Recalculate the total len (IPv4) or payload len (IPv6).
+ // Recalculate the (IPv4) header checksum.
+ if item.key.isV6 {
+ hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
+ binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
+ } else {
+ hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
+ pkt[10], pkt[11] = 0, 0
+ binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
+ iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
+ binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
+ }
+ err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+
+ // Calculate the pseudo header checksum and place it at the TCP
+ // checksum offset. Downstream checksum offloading will combine
+ // this with computation of the tcp header and payload checksum.
+ addrLen := 4
+ addrOffset := ipv4SrcAddrOffset
+ if item.key.isV6 {
+ addrLen = 16
+ addrOffset = ipv6SrcAddrOffset
+ }
+ srcAddrAt := offset + addrOffset
+ srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
+ dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
+ psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
+ binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
+ } else {
+ hdr := virtioNetHdr{}
+ err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+ }
+ }
+ }
+ return nil
+}
+
+// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the
+// metadata found in table.
+func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error {
+ for _, items := range table.itemsByFlow {
+ for _, item := range items {
+ if item.numMerged > 0 {
+ hdr := virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
+ hdrLen: uint16(item.iphLen + udphLen),
+ gsoSize: item.gsoSize,
+ csumStart: uint16(item.iphLen),
+ csumOffset: 6,
+ }
+ pkt := bufs[item.bufsIndex][offset:]
+
+ // Recalculate the total len (IPv4) or payload len (IPv6).
+ // Recalculate the (IPv4) header checksum.
+ hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4
+ if item.key.isV6 {
+ binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
+ } else {
+ pkt[10], pkt[11] = 0, 0
+ binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
+ iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
+ binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
+ }
+ err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+
+ // Recalculate the UDP len field value
+ binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:])))
+
+ // Calculate the pseudo header checksum and place it at the UDP
+ // checksum offset. Downstream checksum offloading will combine
+ // this with computation of the udp header and payload checksum.
+ addrLen := 4
+ addrOffset := ipv4SrcAddrOffset
+ if item.key.isV6 {
+ addrLen = 16
+ addrOffset = ipv6SrcAddrOffset
+ }
+ srcAddrAt := offset + addrOffset
+ srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
+ dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
+ psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
+ binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
+ } else {
+ hdr := virtioNetHdr{}
+ err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+ }
+ }
+ }
+ return nil
+}
+
+type groCandidateType uint8
+
+const (
+ notGROCandidate groCandidateType = iota
+ tcp4GROCandidate
+ tcp6GROCandidate
+ udp4GROCandidate
+ udp6GROCandidate
+)
+
+func packetIsGROCandidate(b []byte, canUDPGRO bool) groCandidateType {
+ if len(b) < 28 {
+ return notGROCandidate
+ }
+ if b[0]>>4 == 4 {
+ if b[0]&0x0F != 5 {
+ // IPv4 packets w/IP options do not coalesce
+ return notGROCandidate
+ }
+ if b[9] == unix.IPPROTO_TCP && len(b) >= 40 {
+ return tcp4GROCandidate
+ }
+ if b[9] == unix.IPPROTO_UDP && canUDPGRO {
+ return udp4GROCandidate
+ }
+ } else if b[0]>>4 == 6 {
+ if b[6] == unix.IPPROTO_TCP && len(b) >= 60 {
+ return tcp6GROCandidate
+ }
+ if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && canUDPGRO {
+ return udp6GROCandidate
+ }
+ }
+ return notGROCandidate
+}
+
+const (
+ udphLen = 8
+)
+
+// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with
+// existing packets tracked in table. It returns a groResultNoop when no
+// action was taken, groResultTableInsert when the evaluated packet was
+// inserted into table, and groResultCoalesced when the evaluated packet was
+// coalesced with another packet in table.
+func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult {
+ pkt := bufs[pktI][offset:]
+ if len(pkt) > maxUint16 {
+ // A valid IPv4 or IPv6 packet will never exceed this.
+ return groResultNoop
+ }
+ iphLen := int((pkt[0] & 0x0F) * 4)
+ if isV6 {
+ iphLen = 40
+ ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
+ if ipv6HPayloadLen != len(pkt)-iphLen {
+ return groResultNoop
+ }
+ } else {
+ totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
+ if totalLen != len(pkt) {
+ return groResultNoop
+ }
+ }
+ if len(pkt) < iphLen {
+ return groResultNoop
+ }
+ if len(pkt) < iphLen+udphLen {
+ return groResultNoop
+ }
+ if !isV6 {
+ if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
+ // no GRO support for fragmented segments for now
+ return groResultNoop
+ }
+ }
+ gsoSize := uint16(len(pkt) - udphLen - iphLen)
+ // not a candidate if payload len is 0
+ if gsoSize < 1 {
+ return groResultNoop
+ }
+ srcAddrOffset := ipv4SrcAddrOffset
+ addrLen := 4
+ if isV6 {
+ srcAddrOffset = ipv6SrcAddrOffset
+ addrLen = 16
+ }
+ items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI)
+ if !existing {
+ return groResultTableInsert
+ }
+ // With UDP we only check the last item, otherwise we could reorder packets
+ // for a given flow. We must also always insert a new item, or successfully
+ // coalesce with an existing item, for the same reason.
+ item := items[len(items)-1]
+ can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset)
+ var pktCSumKnownInvalid bool
+ if can == coalesceAppend {
+ result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6)
+ switch result {
+ case coalesceSuccess:
+ table.updateAt(item, len(items)-1)
+ return groResultCoalesced
+ case coalesceItemInvalidCSum:
+ // If the existing item has an invalid csum we take no action. A new
+ // item will be stored after it, and the existing item will never be
+ // revisited as part of future coalescing candidacy checks.
+ case coalescePktInvalidCSum:
+ // We must insert a new item, but we also mark it as invalid csum
+ // to prevent a repeat checksum validation.
+ pktCSumKnownInvalid = true
+ default:
+ }
+ }
+ // failed to coalesce with any other packets; store the item in the flow
+ table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid)
+ return groResultTableInsert
+}
+
+// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
+// packets into toWrite. toWrite, tcpTable, and udpTable should initially be
+// empty (but non-nil), and are passed in to save allocs as the caller may reset
+// and recycle them across vectors of packets. canUDPGRO indicates if UDP GRO is
+// supported.
+func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, canUDPGRO bool, toWrite *[]int) error {
+ for i := range bufs {
+ if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
+ return errors.New("invalid offset")
+ }
+ var result groResult
+ switch packetIsGROCandidate(bufs[i][offset:], canUDPGRO) {
+ case tcp4GROCandidate:
+ result = tcpGRO(bufs, offset, i, tcpTable, false)
+ case tcp6GROCandidate:
+ result = tcpGRO(bufs, offset, i, tcpTable, true)
+ case udp4GROCandidate:
+ result = udpGRO(bufs, offset, i, udpTable, false)
+ case udp6GROCandidate:
+ result = udpGRO(bufs, offset, i, udpTable, true)
+ }
+ switch result {
+ case groResultNoop:
+ hdr := virtioNetHdr{}
+ err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+ fallthrough
+ case groResultTableInsert:
+ *toWrite = append(*toWrite, i)
+ }
+ }
+ errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable)
+ errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable)
+ return errors.Join(errTCP, errUDP)
+}
+
+// gsoSplit splits packets from in into outBuffs, writing the size of each
+// element into sizes. It returns the number of buffers populated, and/or an
+// error.
+func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int, isV6 bool) (int, error) {
+ iphLen := int(hdr.csumStart)
+ srcAddrOffset := ipv6SrcAddrOffset
+ addrLen := 16
+ if !isV6 {
+ in[10], in[11] = 0, 0 // clear ipv4 header checksum
+ srcAddrOffset = ipv4SrcAddrOffset
+ addrLen = 4
+ }
+ transportCsumAt := int(hdr.csumStart + hdr.csumOffset)
+ in[transportCsumAt], in[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum
+ var firstTCPSeqNum uint32
+ var protocol uint8
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV6 {
+ protocol = unix.IPPROTO_TCP
+ firstTCPSeqNum = binary.BigEndian.Uint32(in[hdr.csumStart+4:])
+ } else {
+ protocol = unix.IPPROTO_UDP
+ }
+ nextSegmentDataAt := int(hdr.hdrLen)
+ i := 0
+ for ; nextSegmentDataAt < len(in); i++ {
+ if i == len(outBuffs) {
+ return i - 1, ErrTooManySegments
+ }
+ nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
+ if nextSegmentEnd > len(in) {
+ nextSegmentEnd = len(in)
+ }
+ segmentDataLen := nextSegmentEnd - nextSegmentDataAt
+ totalLen := int(hdr.hdrLen) + segmentDataLen
+ sizes[i] = totalLen
+ out := outBuffs[i][outOffset:]
+
+ copy(out, in[:iphLen])
+ if !isV6 {
+ // For IPv4 we are responsible for incrementing the ID field,
+ // updating the total len field, and recalculating the header
+ // checksum.
+ if i > 0 {
+ id := binary.BigEndian.Uint16(out[4:])
+ id += uint16(i)
+ binary.BigEndian.PutUint16(out[4:], id)
+ }
+ binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
+ ipv4CSum := ^checksum(out[:iphLen], 0)
+ binary.BigEndian.PutUint16(out[10:], ipv4CSum)
+ } else {
+ // For IPv6 we are responsible for updating the payload length field.
+ binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
+ }
+
+ // copy transport header
+ copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
+
+ if protocol == unix.IPPROTO_TCP {
+ // set TCP seq and adjust TCP flags
+ tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
+ binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
+ if nextSegmentEnd != len(in) {
+ // FIN and PSH should only be set on last segment
+ clearFlags := tcpFlagFIN | tcpFlagPSH
+ out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
+ }
+ } else {
+ // set UDP header len
+ binary.BigEndian.PutUint16(out[hdr.csumStart+4:], uint16(segmentDataLen)+(hdr.hdrLen-hdr.csumStart))
+ }
+
+ // payload
+ copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
+
+ // transport checksum
+ transportHeaderLen := int(hdr.hdrLen - hdr.csumStart)
+ lenForPseudo := uint16(transportHeaderLen + segmentDataLen)
+ transportCSumNoFold := pseudoHeaderChecksumNoFold(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo)
+ transportCSum := ^checksum(out[hdr.csumStart:totalLen], transportCSumNoFold)
+ binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], transportCSum)
+
+ nextSegmentDataAt += int(hdr.gsoSize)
+ }
+ return i, nil
+}
+
+func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
+ cSumAt := cSumStart + cSumOffset
+ // The initial value at the checksum offset should be summed with the
+ // checksum we compute. This is typically the pseudo-header checksum.
+ initial := binary.BigEndian.Uint16(in[cSumAt:])
+ in[cSumAt], in[cSumAt+1] = 0, 0
+ binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial)))
+ return nil
+}
diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go
new file mode 100644
index 0000000..ae55c8c
--- /dev/null
+++ b/tun/offload_linux_test.go
@@ -0,0 +1,752 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+import (
+ "net/netip"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/conn"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+const (
+ offset = virtioNetHdrLen
+)
+
+var (
+ ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
+ ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
+ ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
+ ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
+ ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
+ ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
+)
+
+func udp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv4Fields)) []byte {
+ totalLen := 28 + payloadLen
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv4H := header.IPv4(b[offset:])
+ srcAs4 := srcIPPort.Addr().As4()
+ dstAs4 := dstIPPort.Addr().As4()
+ ipFields := &header.IPv4Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
+ Protocol: unix.IPPROTO_UDP,
+ TTL: 64,
+ TotalLength: uint16(totalLen),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv4H.Encode(ipFields)
+ udpH := header.UDP(b[offset+20:])
+ udpH.Encode(&header.UDPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ Length: uint16(payloadLen + udphLen),
+ })
+ ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udphLen+payloadLen))
+ udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func udp6Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
+ return udp6PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
+}
+
+func udp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv6Fields)) []byte {
+ totalLen := 48 + payloadLen
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv6H := header.IPv6(b[offset:])
+ srcAs16 := srcIPPort.Addr().As16()
+ dstAs16 := dstIPPort.Addr().As16()
+ ipFields := &header.IPv6Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
+ TransportProtocol: unix.IPPROTO_UDP,
+ HopLimit: 64,
+ PayloadLength: uint16(payloadLen + udphLen),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv6H.Encode(ipFields)
+ udpH := header.UDP(b[offset+40:])
+ udpH.Encode(&header.UDPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ Length: uint16(payloadLen + udphLen),
+ })
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udphLen+payloadLen))
+ udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func udp4Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
+ return udp4PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
+}
+
+func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte {
+ totalLen := 40 + segmentSize
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv4H := header.IPv4(b[offset:])
+ srcAs4 := srcIPPort.Addr().As4()
+ dstAs4 := dstIPPort.Addr().As4()
+ ipFields := &header.IPv4Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
+ Protocol: unix.IPPROTO_TCP,
+ TTL: 64,
+ TotalLength: uint16(totalLen),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv4H.Encode(ipFields)
+ tcpH := header.TCP(b[offset+20:])
+ tcpH.Encode(&header.TCPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ SeqNum: seq,
+ AckNum: 1,
+ DataOffset: 20,
+ Flags: flags,
+ WindowSize: 3000,
+ })
+ ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
+ tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+ return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
+}
+
+func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte {
+ totalLen := 60 + segmentSize
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv6H := header.IPv6(b[offset:])
+ srcAs16 := srcIPPort.Addr().As16()
+ dstAs16 := dstIPPort.Addr().As16()
+ ipFields := &header.IPv6Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
+ TransportProtocol: unix.IPPROTO_TCP,
+ HopLimit: 64,
+ PayloadLength: uint16(segmentSize + 20),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv6H.Encode(ipFields)
+ tcpH := header.TCP(b[offset+40:])
+ tcpH.Encode(&header.TCPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ SeqNum: seq,
+ AckNum: 1,
+ DataOffset: 20,
+ Flags: flags,
+ WindowSize: 3000,
+ })
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
+ tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+ return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
+}
+
+func Test_handleVirtioRead(t *testing.T) {
+ tests := []struct {
+ name string
+ hdr virtioNetHdr
+ pktIn []byte
+ wantLens []int
+ wantErr bool
+ }{
+ {
+ "tcp4",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
+ gsoSize: 100,
+ hdrLen: 40,
+ csumStart: 20,
+ csumOffset: 16,
+ },
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
+ []int{140, 140},
+ false,
+ },
+ {
+ "tcp6",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
+ gsoSize: 100,
+ hdrLen: 60,
+ csumStart: 40,
+ csumOffset: 16,
+ },
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
+ []int{160, 160},
+ false,
+ },
+ {
+ "udp4",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
+ gsoSize: 100,
+ hdrLen: 28,
+ csumStart: 20,
+ csumOffset: 6,
+ },
+ udp4Packet(ip4PortA, ip4PortB, 200),
+ []int{128, 128},
+ false,
+ },
+ {
+ "udp6",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
+ gsoSize: 100,
+ hdrLen: 48,
+ csumStart: 40,
+ csumOffset: 6,
+ },
+ udp6Packet(ip6PortA, ip6PortB, 200),
+ []int{148, 148},
+ false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ out := make([][]byte, conn.IdealBatchSize)
+ sizes := make([]int, conn.IdealBatchSize)
+ for i := range out {
+ out[i] = make([]byte, 65535)
+ }
+ tt.hdr.encode(tt.pktIn)
+ n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
+ if err != nil {
+ if tt.wantErr {
+ return
+ }
+ t.Fatalf("got err: %v", err)
+ }
+ if n != len(tt.wantLens) {
+ t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
+ }
+ for i := range tt.wantLens {
+ if tt.wantLens[i] != sizes[i] {
+ t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
+ }
+ }
+ })
+ }
+}
+
+func flipTCP4Checksum(b []byte) []byte {
+ at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
+ b[at] ^= 0xFF
+ b[at+1] ^= 0xFF
+ return b
+}
+
+func flipUDP4Checksum(b []byte) []byte {
+ at := virtioNetHdrLen + 20 + 6 // 20 byte ipv4 header; udp csum offset is 6
+ b[at] ^= 0xFF
+ b[at+1] ^= 0xFF
+ return b
+}
+
+func Fuzz_handleGRO(f *testing.F) {
+ pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
+ pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
+ pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
+ pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
+ pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
+ pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
+ pkt6 := udp4Packet(ip4PortA, ip4PortB, 100)
+ pkt7 := udp4Packet(ip4PortA, ip4PortB, 100)
+ pkt8 := udp4Packet(ip4PortA, ip4PortC, 100)
+ pkt9 := udp6Packet(ip6PortA, ip6PortB, 100)
+ pkt10 := udp6Packet(ip6PortA, ip6PortB, 100)
+ pkt11 := udp6Packet(ip6PortA, ip6PortC, 100)
+ f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, true, offset)
+ f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, canUDPGRO bool, offset int) {
+ pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11}
+ toWrite := make([]int, 0, len(pkts))
+ handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, &toWrite)
+ if len(toWrite) > len(pkts) {
+ t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
+ }
+ seenWriteI := make(map[int]bool)
+ for _, writeI := range toWrite {
+ if writeI < 0 || writeI > len(pkts)-1 {
+ t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
+ }
+ if seenWriteI[writeI] {
+ t.Errorf("duplicate toWrite value: %d", writeI)
+ }
+ seenWriteI[writeI] = true
+ }
+ })
+}
+
+func Test_handleGRO(t *testing.T) {
+ tests := []struct {
+ name string
+ pktsIn [][]byte
+ canUDPGRO bool
+ wantToWrite []int
+ wantLens []int
+ wantErr bool
+ }{
+ {
+ "multiple protocols and flows",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
+ tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ },
+ true,
+ []int{0, 1, 2, 4, 5, 7, 9},
+ []int{240, 228, 128, 140, 260, 160, 248},
+ false,
+ },
+ {
+ "multiple protocols and flows no UDP GRO",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
+ tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ },
+ false,
+ []int{0, 1, 2, 4, 5, 7, 8, 9, 10},
+ []int{240, 128, 128, 140, 260, 160, 128, 148, 148},
+ false,
+ },
+ {
+ "PSH interleaved",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1
+ },
+ true,
+ []int{0, 2, 4, 6},
+ []int{240, 240, 260, 260},
+ false,
+ },
+ {
+ "coalesceItemInvalidCSum",
+ [][]byte{
+ flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
+ flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ },
+ true,
+ []int{0, 1, 3, 4},
+ []int{140, 240, 128, 228},
+ false,
+ },
+ {
+ "out of order",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
+ },
+ true,
+ []int{0},
+ []int{340},
+ false,
+ },
+ {
+ "unequal TTL",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.TTL++
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.TTL++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "unequal ToS",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.TOS++
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.TOS++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "unequal flags more fragments set",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.Flags = 1
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.Flags = 1
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "unequal flags DF set",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.Flags = 2
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.Flags = 2
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "ipv6 unequal hop limit",
+ [][]byte{
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
+ tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
+ fields.HopLimit++
+ }),
+ udp6Packet(ip6PortA, ip6PortB, 100),
+ udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
+ fields.HopLimit++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{160, 160, 148, 148},
+ false,
+ },
+ {
+ "ipv6 unequal traffic class",
+ [][]byte{
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
+ tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
+ fields.TrafficClass++
+ }),
+ udp6Packet(ip6PortA, ip6PortB, 100),
+ udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
+ fields.TrafficClass++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{160, 160, 148, 148},
+ false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ toWrite := make([]int, 0, len(tt.pktsIn))
+ err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, &toWrite)
+ if err != nil {
+ if tt.wantErr {
+ return
+ }
+ t.Fatalf("got err: %v", err)
+ }
+ if len(toWrite) != len(tt.wantToWrite) {
+ t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
+ }
+ for i, pktI := range tt.wantToWrite {
+ if tt.wantToWrite[i] != toWrite[i] {
+ t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
+ }
+ if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
+ t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
+ }
+ }
+ })
+ }
+}
+
+func Test_packetIsGROCandidate(t *testing.T) {
+ tcp4 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
+ tcp4TooShort := tcp4[:39]
+ ip4InvalidHeaderLen := make([]byte, len(tcp4))
+ copy(ip4InvalidHeaderLen, tcp4)
+ ip4InvalidHeaderLen[0] = 0x46
+ ip4InvalidProtocol := make([]byte, len(tcp4))
+ copy(ip4InvalidProtocol, tcp4)
+ ip4InvalidProtocol[9] = unix.IPPROTO_GRE
+
+ tcp6 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
+ tcp6TooShort := tcp6[:59]
+ ip6InvalidProtocol := make([]byte, len(tcp6))
+ copy(ip6InvalidProtocol, tcp6)
+ ip6InvalidProtocol[6] = unix.IPPROTO_GRE
+
+ udp4 := udp4Packet(ip4PortA, ip4PortB, 100)[virtioNetHdrLen:]
+ udp4TooShort := udp4[:27]
+
+ udp6 := udp6Packet(ip6PortA, ip6PortB, 100)[virtioNetHdrLen:]
+ udp6TooShort := udp6[:47]
+
+ tests := []struct {
+ name string
+ b []byte
+ canUDPGRO bool
+ want groCandidateType
+ }{
+ {
+ "tcp4",
+ tcp4,
+ true,
+ tcp4GROCandidate,
+ },
+ {
+ "tcp6",
+ tcp6,
+ true,
+ tcp6GROCandidate,
+ },
+ {
+ "udp4",
+ udp4,
+ true,
+ udp4GROCandidate,
+ },
+ {
+ "udp4 no support",
+ udp4,
+ false,
+ notGROCandidate,
+ },
+ {
+ "udp6",
+ udp6,
+ true,
+ udp6GROCandidate,
+ },
+ {
+ "udp6 no support",
+ udp6,
+ false,
+ notGROCandidate,
+ },
+ {
+ "udp4 too short",
+ udp4TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "udp6 too short",
+ udp6TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "tcp4 too short",
+ tcp4TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "tcp6 too short",
+ tcp6TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "invalid IP version",
+ []byte{0x00},
+ true,
+ notGROCandidate,
+ },
+ {
+ "invalid IP header len",
+ ip4InvalidHeaderLen,
+ true,
+ notGROCandidate,
+ },
+ {
+ "ip4 invalid protocol",
+ ip4InvalidProtocol,
+ true,
+ notGROCandidate,
+ },
+ {
+ "ip6 invalid protocol",
+ ip6InvalidProtocol,
+ true,
+ notGROCandidate,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := packetIsGROCandidate(tt.b, tt.canUDPGRO); got != tt.want {
+ t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func Test_udpPacketsCanCoalesce(t *testing.T) {
+ udp4a := udp4Packet(ip4PortA, ip4PortB, 100)
+ udp4b := udp4Packet(ip4PortA, ip4PortB, 100)
+ udp4c := udp4Packet(ip4PortA, ip4PortB, 110)
+
+ type args struct {
+ pkt []byte
+ iphLen uint8
+ gsoSize uint16
+ item udpGROItem
+ bufs [][]byte
+ bufsOffset int
+ }
+ tests := []struct {
+ name string
+ args args
+ want canCoalesce
+ }{
+ {
+ "coalesceAppend equal gso",
+ args{
+ pkt: udp4a[offset:],
+ iphLen: 20,
+ gsoSize: 100,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4a,
+ udp4b,
+ },
+ bufsOffset: offset,
+ },
+ coalesceAppend,
+ },
+ {
+ "coalesceAppend smaller gso",
+ args{
+ pkt: udp4a[offset : len(udp4a)-90],
+ iphLen: 20,
+ gsoSize: 10,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4a,
+ udp4b,
+ },
+ bufsOffset: offset,
+ },
+ coalesceAppend,
+ },
+ {
+ "coalesceUnavailable smaller gso previously appended",
+ args{
+ pkt: udp4a[offset:],
+ iphLen: 20,
+ gsoSize: 100,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4c,
+ udp4b,
+ },
+ bufsOffset: offset,
+ },
+ coalesceUnavailable,
+ },
+ {
+ "coalesceUnavailable larger following smaller",
+ args{
+ pkt: udp4c[offset:],
+ iphLen: 20,
+ gsoSize: 110,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4a,
+ udp4c,
+ },
+ bufsOffset: offset,
+ },
+ coalesceUnavailable,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want {
+ t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/tun/operateonfd.go b/tun/operateonfd.go
index a9dcaef..f1beb6d 100644
--- a/tun/operateonfd.go
+++ b/tun/operateonfd.go
@@ -1,8 +1,8 @@
-// +build !windows
+//go:build darwin || freebsd
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
diff --git a/tun/tun.go b/tun/tun.go
index 5521cb7..0ae53d0 100644
--- a/tun/tun.go
+++ b/tun/tun.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -18,12 +18,36 @@ const (
)
type Device interface {
- File() *os.File // returns the file descriptor of the device
- Read([]byte, int) (int, error) // read a packet from the device (without any additional headers)
- Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers)
- Flush() error // flush all previous writes to the device
- MTU() (int, error) // returns the MTU of the device
- Name() (string, error) // fetches and returns the current name
- Events() chan Event // returns a constant channel of events related to the device
- Close() error // stops the device and closes the event channel
+ // File returns the file descriptor of the device.
+ File() *os.File
+
+ // Read one or more packets from the Device (without any additional headers).
+ // On a successful read it returns the number of packets read, and sets
+ // packet lengths within the sizes slice. len(sizes) must be >= len(bufs).
+ // A nonzero offset can be used to instruct the Device on where to begin
+ // reading into each element of the bufs slice.
+ Read(bufs [][]byte, sizes []int, offset int) (n int, err error)
+
+ // Write one or more packets to the device (without any additional headers).
+ // On a successful write it returns the number of packets written. A nonzero
+ // offset can be used to instruct the Device on where to begin writing from
+ // each packet contained within the bufs slice.
+ Write(bufs [][]byte, offset int) (int, error)
+
+ // MTU returns the MTU of the Device.
+ MTU() (int, error)
+
+ // Name returns the current name of the Device.
+ Name() (string, error)
+
+ // Events returns a channel of type Event, which is fed Device events.
+ Events() <-chan Event
+
+ // Close stops the Device and closes the Event channel.
+ Close() error
+
+ // BatchSize returns the preferred/max number of packets that can be read or
+ // written in a single read/write call. BatchSize must not change over the
+ // lifetime of a Device.
+ BatchSize() int
}
diff --git a/tun/tun_darwin.go b/tun/tun_darwin.go
index a703c8c..c9a6c0b 100644
--- a/tun/tun_darwin.go
+++ b/tun/tun_darwin.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -8,6 +8,7 @@ package tun
import (
"errors"
"fmt"
+ "io"
"net"
"os"
"sync"
@@ -15,7 +16,6 @@ import (
"time"
"unsafe"
- "golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
)
@@ -33,7 +33,7 @@ type NativeTun struct {
func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
for i := 0; i < 20; i++ {
iface, err = net.InterfaceByIndex(index)
- if err != nil && errors.Is(err, syscall.ENOMEM) {
+ if err != nil && errors.Is(err, unix.ENOMEM) {
time.Sleep(time.Duration(i) * time.Second / 3)
continue
}
@@ -55,7 +55,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
retry:
n, err := unix.Read(tun.routeSocket, data)
if err != nil {
- if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
+ if errno, ok := err.(unix.Errno); ok && errno == unix.EINTR {
goto retry
}
tun.errors <- err
@@ -107,8 +107,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
}
}
- fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2)
-
+ fd, err := socketCloexec(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2)
if err != nil {
return nil, err
}
@@ -117,6 +116,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
copy(ctlInfo.Name[:], []byte(utunControlName))
err = unix.IoctlCtlInfo(fd, ctlInfo)
if err != nil {
+ unix.Close(fd)
return nil, fmt.Errorf("IoctlGetCtlInfo: %w", err)
}
@@ -127,11 +127,13 @@ func CreateTUN(name string, mtu int) (Device, error) {
err = unix.Connect(fd, sc)
if err != nil {
+ unix.Close(fd)
return nil, err
}
- err = syscall.SetNonblock(fd, true)
+ err = unix.SetNonblock(fd, true)
if err != nil {
+ unix.Close(fd)
return nil, err
}
tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), ""), mtu)
@@ -139,7 +141,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
if err == nil && name == "utun" {
fname := os.Getenv("WG_TUN_NAME_FILE")
if fname != "" {
- os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0400)
+ os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400)
}
}
@@ -171,7 +173,7 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return nil, err
}
- tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+ tun.routeSocket, err = socketCloexec(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
tun.tunFile.Close()
return nil, err
@@ -211,50 +213,50 @@ func (tun *NativeTun) File() *os.File {
return tun.tunFile
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
+ // TODO: the BSDs look very similar in Read() and Write(). They should be
+ // collapsed, with platform-specific files containing the varying parts of
+ // their implementations.
select {
case err := <-tun.errors:
return 0, err
default:
- buff := buff[offset-4:]
- n, err := tun.tunFile.Read(buff[:])
+ buf := bufs[0][offset-4:]
+ n, err := tun.tunFile.Read(buf[:])
if n < 4 {
return 0, err
}
- return n - 4, err
+ sizes[0] = n - 4
+ return 1, err
}
}
-func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
-
- // reserve space for header
-
- buff = buff[offset-4:]
-
- // add packet information header
-
- buff[0] = 0x00
- buff[1] = 0x00
- buff[2] = 0x00
-
- if buff[4]>>4 == ipv6.Version {
- buff[3] = unix.AF_INET6
- } else {
- buff[3] = unix.AF_INET
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
+ if offset < 4 {
+ return 0, io.ErrShortBuffer
}
-
- // write
-
- return tun.tunFile.Write(buff)
-}
-
-func (tun *NativeTun) Flush() error {
- // TODO: can flushing be implemented by buffering and using sendmmsg?
- return nil
+ for i, buf := range bufs {
+ buf = buf[offset-4:]
+ buf[0] = 0x00
+ buf[1] = 0x00
+ buf[2] = 0x00
+ switch buf[4] >> 4 {
+ case 4:
+ buf[3] = unix.AF_INET
+ case 6:
+ buf[3] = unix.AF_INET6
+ default:
+ return i, unix.EAFNOSUPPORT
+ }
+ if _, err := tun.tunFile.Write(buf); err != nil {
+ return i, err
+ }
+ }
+ return len(bufs), nil
}
func (tun *NativeTun) Close() error {
@@ -275,12 +277,11 @@ func (tun *NativeTun) Close() error {
}
func (tun *NativeTun) setMTU(n int) error {
- fd, err := unix.Socket(
+ fd, err := socketCloexec(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
-
if err != nil {
return err
}
@@ -299,12 +300,11 @@ func (tun *NativeTun) setMTU(n int) error {
}
func (tun *NativeTun) MTU() (int, error) {
- fd, err := unix.Socket(
+ fd, err := socketCloexec(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
-
if err != nil {
return 0, err
}
@@ -318,3 +318,19 @@ func (tun *NativeTun) MTU() (int, error) {
return int(ifr.MTU), nil
}
+
+func (tun *NativeTun) BatchSize() int {
+ return 1
+}
+
+func socketCloexec(family, sotype, proto int) (fd int, err error) {
+ // See go/src/net/sys_cloexec.go for background.
+ syscall.ForkLock.RLock()
+ defer syscall.ForkLock.RUnlock()
+
+ fd, err = unix.Socket(family, sotype, proto)
+ if err == nil {
+ unix.CloseOnExec(fd)
+ }
+ return
+}
diff --git a/tun/tun_freebsd.go b/tun/tun_freebsd.go
index dc9eb3e..7c65fd9 100644
--- a/tun/tun_freebsd.go
+++ b/tun/tun_freebsd.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -143,7 +143,7 @@ func tunName(fd uintptr) (string, error) {
// Destroy a named system interface
func tunDestroy(name string) error {
- fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
+ fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
if err != nil {
return err
}
@@ -170,7 +170,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
return nil, fmt.Errorf("interface %s already exists", name)
}
- tunFile, err := os.OpenFile("/dev/tun", unix.O_RDWR, 0)
+ tunFile, err := os.OpenFile("/dev/tun", unix.O_RDWR|unix.O_CLOEXEC, 0)
if err != nil {
return nil, err
}
@@ -213,7 +213,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
// Disable link-local v6, not just because WireGuard doesn't do that anyway, but
// also because there are serious races with attaching and detaching LLv6 addresses
// in relation to interface lifetime within the FreeBSD kernel.
- confd6, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, 0)
+ confd6, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
if err != nil {
tunFile.Close()
tunDestroy(assignedName)
@@ -238,7 +238,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
}
if name != "" {
- confd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
+ confd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
if err != nil {
tunFile.Close()
tunDestroy(assignedName)
@@ -295,7 +295,7 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return nil, err
}
- tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+ tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC)
if err != nil {
tun.tunFile.Close()
return nil, err
@@ -329,49 +329,50 @@ func (tun *NativeTun) File() *os.File {
return tun.tunFile
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
select {
case err := <-tun.errors:
return 0, err
default:
- buff := buff[offset-4:]
- n, err := tun.tunFile.Read(buff[:])
+ buf := bufs[0][offset-4:]
+ n, err := tun.tunFile.Read(buf[:])
if n < 4 {
return 0, err
}
- return n - 4, err
+ sizes[0] = n - 4
+ return 1, err
}
}
-func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
if offset < 4 {
return 0, io.ErrShortBuffer
}
- buf = buf[offset-4:]
- if len(buf) < 5 {
- return 0, io.ErrShortBuffer
- }
- buf[0] = 0x00
- buf[1] = 0x00
- buf[2] = 0x00
- switch buf[4] >> 4 {
- case 4:
- buf[3] = unix.AF_INET
- case 6:
- buf[3] = unix.AF_INET6
- default:
- return 0, unix.EAFNOSUPPORT
+ for i, buf := range bufs {
+ buf = buf[offset-4:]
+ if len(buf) < 5 {
+ return i, io.ErrShortBuffer
+ }
+ buf[0] = 0x00
+ buf[1] = 0x00
+ buf[2] = 0x00
+ switch buf[4] >> 4 {
+ case 4:
+ buf[3] = unix.AF_INET
+ case 6:
+ buf[3] = unix.AF_INET6
+ default:
+ return i, unix.EAFNOSUPPORT
+ }
+ if _, err := tun.tunFile.Write(buf); err != nil {
+ return i, err
+ }
}
- return tun.tunFile.Write(buf)
-}
-
-func (tun *NativeTun) Flush() error {
- // TODO: can flushing be implemented by buffering and using sendmmsg?
- return nil
+ return len(bufs), nil
}
func (tun *NativeTun) Close() error {
@@ -397,7 +398,7 @@ func (tun *NativeTun) Close() error {
}
func (tun *NativeTun) setMTU(n int) error {
- fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
+ fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
if err != nil {
return err
}
@@ -414,7 +415,7 @@ func (tun *NativeTun) setMTU(n int) error {
}
func (tun *NativeTun) MTU() (int, error) {
- fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
+ fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
if err != nil {
return 0, err
}
@@ -428,3 +429,7 @@ func (tun *NativeTun) MTU() (int, error) {
}
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
}
+
+func (tun *NativeTun) BatchSize() int {
+ return 1
+}
diff --git a/tun/tun_linux.go b/tun/tun_linux.go
index 466a805..bd69cb5 100644
--- a/tun/tun_linux.go
+++ b/tun/tun_linux.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -9,7 +9,6 @@ package tun
*/
import (
- "bytes"
"errors"
"fmt"
"os"
@@ -18,9 +17,8 @@ import (
"time"
"unsafe"
- "golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
-
+ "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel"
)
@@ -34,17 +32,27 @@ type NativeTun struct {
index int32 // if index
errors chan error // async error handling
events chan Event // device related events
- nopi bool // the device was passed IFF_NO_PI
netlinkSock int
netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex
statusListenersShutdown chan struct{}
+ batchSize int
+ vnetHdr bool
+ udpGSO bool
closeOnce sync.Once
nameOnce sync.Once // guards calling initNameCache, which sets following fields
nameCache string // name of interface
nameErr error
+
+ readOpMu sync.Mutex // readOpMu guards readBuff
+ readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
+
+ writeOpMu sync.Mutex // writeOpMu guards toWrite, tcpGROTable
+ toWrite []int
+ tcpGROTable *tcpGROTable
+ udpGROTable *udpGROTable
}
func (tun *NativeTun) File() *os.File {
@@ -100,7 +108,7 @@ func (tun *NativeTun) routineHackListener() {
}
func createNetlinkSocket() (int, error) {
- sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
+ sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
@@ -195,7 +203,7 @@ func (tun *NativeTun) routineNetlinkListener() {
func getIFIndex(name string) (int32, error) {
fd, err := unix.Socket(
unix.AF_INET,
- unix.SOCK_DGRAM,
+ unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
if err != nil {
@@ -229,10 +237,9 @@ func (tun *NativeTun) setMTU(n int) error {
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
- unix.SOCK_DGRAM,
+ unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
-
if err != nil {
return err
}
@@ -266,10 +273,9 @@ func (tun *NativeTun) MTU() (int, error) {
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
- unix.SOCK_DGRAM,
+ unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
-
if err != nil {
return 0, err
}
@@ -323,67 +329,153 @@ func (tun *NativeTun) nameSlow() (string, error) {
if errno != 0 {
return "", fmt.Errorf("failed to get name of TUN device: %w", errno)
}
- name := ifr[:]
- if i := bytes.IndexByte(name, 0); i != -1 {
- name = name[:i]
- }
- return string(name), nil
+ return unix.ByteSliceToString(ifr[:]), nil
}
-func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
- if tun.nopi {
- buf = buf[offset:]
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
+ tun.writeOpMu.Lock()
+ defer func() {
+ tun.tcpGROTable.reset()
+ tun.udpGROTable.reset()
+ tun.writeOpMu.Unlock()
+ }()
+ var (
+ errs error
+ total int
+ )
+ tun.toWrite = tun.toWrite[:0]
+ if tun.vnetHdr {
+ err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, &tun.toWrite)
+ if err != nil {
+ return 0, err
+ }
+ offset -= virtioNetHdrLen
} else {
- // reserve space for header
- buf = buf[offset-4:]
-
- // add packet information header
- buf[0] = 0x00
- buf[1] = 0x00
- if buf[4]>>4 == ipv6.Version {
- buf[2] = 0x86
- buf[3] = 0xdd
+ for i := range bufs {
+ tun.toWrite = append(tun.toWrite, i)
+ }
+ }
+ for _, bufsI := range tun.toWrite {
+ n, err := tun.tunFile.Write(bufs[bufsI][offset:])
+ if errors.Is(err, syscall.EBADFD) {
+ return total, os.ErrClosed
+ }
+ if err != nil {
+ errs = errors.Join(errs, err)
} else {
- buf[2] = 0x08
- buf[3] = 0x00
+ total += n
}
}
+ return total, errs
+}
- n, err := tun.tunFile.Write(buf)
- if errors.Is(err, syscall.EBADFD) {
- err = os.ErrClosed
+// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
+// each buffer. It mutates sizes to reflect the size of each element of bufs,
+// and returns the number of packets read.
+func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
+ var hdr virtioNetHdr
+ err := hdr.decode(in)
+ if err != nil {
+ return 0, err
+ }
+ in = in[virtioNetHdrLen:]
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
+ if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
+ // This means CHECKSUM_PARTIAL in skb context. We are responsible
+ // for computing the checksum starting at hdr.csumStart and placing
+ // at hdr.csumOffset.
+ err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset)
+ if err != nil {
+ return 0, err
+ }
+ }
+ if len(in) > len(bufs[0][offset:]) {
+ return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:]))
+ }
+ n := copy(bufs[0][offset:], in)
+ sizes[0] = n
+ return 1, nil
+ }
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+ return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
}
- return n, err
-}
-func (tun *NativeTun) Flush() error {
- // TODO: can flushing be implemented by buffering and using sendmmsg?
- return nil
+ ipVersion := in[0] >> 4
+ switch ipVersion {
+ case 4:
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+ return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
+ }
+ case 6:
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+ return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
+ }
+ default:
+ return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
+ }
+
+ // Don't trust hdr.hdrLen from the kernel as it can be equal to the length
+ // of the entire first packet when the kernel is handling it as part of a
+ // FORWARD path. Instead, parse the transport header length and add it onto
+ // csumStart, which is synonymous for IP header length.
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+ hdr.hdrLen = hdr.csumStart + 8
+ } else {
+ if len(in) <= int(hdr.csumStart+12) {
+ return 0, errors.New("packet is too short")
+ }
+
+ tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
+ if tcpHLen < 20 || tcpHLen > 60 {
+ // A TCP header must be between 20 and 60 bytes in length.
+ return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
+ }
+ hdr.hdrLen = hdr.csumStart + tcpHLen
+ }
+
+ if len(in) < int(hdr.hdrLen) {
+ return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
+ }
+
+ if hdr.hdrLen < hdr.csumStart {
+ return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
+ }
+ cSumAt := int(hdr.csumStart + hdr.csumOffset)
+ if cSumAt+1 >= len(in) {
+ return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
+ }
+
+ return gsoSplit(in, hdr, bufs, sizes, offset, ipVersion == 6)
}
-func (tun *NativeTun) Read(buf []byte, offset int) (n int, err error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
+ tun.readOpMu.Lock()
+ defer tun.readOpMu.Unlock()
select {
- case err = <-tun.errors:
+ case err := <-tun.errors:
+ return 0, err
default:
- if tun.nopi {
- n, err = tun.tunFile.Read(buf[offset:])
+ readInto := bufs[0][offset:]
+ if tun.vnetHdr {
+ readInto = tun.readBuff[:]
+ }
+ n, err := tun.tunFile.Read(readInto)
+ if errors.Is(err, syscall.EBADFD) {
+ err = os.ErrClosed
+ }
+ if err != nil {
+ return 0, err
+ }
+ if tun.vnetHdr {
+ return handleVirtioRead(readInto[:n], bufs, sizes, offset)
} else {
- buff := buf[offset-4:]
- n, err = tun.tunFile.Read(buff[:])
- if errors.Is(err, syscall.EBADFD) {
- err = os.ErrClosed
- }
- if n < 4 {
- n = 0
- } else {
- n -= 4
- }
+ sizes[0] = n
+ return 1, nil
}
}
- return
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
@@ -406,8 +498,58 @@ func (tun *NativeTun) Close() error {
return err2
}
+func (tun *NativeTun) BatchSize() int {
+ return tun.batchSize
+}
+
+const (
+ // TODO: support TSO with ECN bits
+ tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
+ tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
+)
+
+func (tun *NativeTun) initFromFlags(name string) error {
+ sc, err := tun.tunFile.SyscallConn()
+ if err != nil {
+ return err
+ }
+ if e := sc.Control(func(fd uintptr) {
+ var (
+ ifr *unix.Ifreq
+ )
+ ifr, err = unix.NewIfreq(name)
+ if err != nil {
+ return
+ }
+ err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
+ if err != nil {
+ return
+ }
+ got := ifr.Uint16()
+ if got&unix.IFF_VNET_HDR != 0 {
+ // tunTCPOffloads were added in Linux v2.6. We require their support
+ // if IFF_VNET_HDR is set.
+ err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads)
+ if err != nil {
+ return
+ }
+ tun.vnetHdr = true
+ tun.batchSize = conn.IdealBatchSize
+ // tunUDPOffloads were added in Linux v6.2. We do not return an
+ // error if they are unsupported at runtime.
+ tun.udpGSO = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) == nil
+ } else {
+ tun.batchSize = 1
+ }
+ }); e != nil {
+ return e
+ }
+ return err
+}
+
+// CreateTUN creates a Device with the provided name and MTU.
func CreateTUN(name string, mtu int) (Device, error) {
- nfd, err := unix.Open(cloneDevicePath, os.O_RDWR, 0)
+ nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
if err != nil {
if os.IsNotExist(err) {
return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath)
@@ -415,43 +557,40 @@ func CreateTUN(name string, mtu int) (Device, error) {
return nil, err
}
- var ifr [ifReqSize]byte
- var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack)
- nameBytes := []byte(name)
- if len(nameBytes) >= unix.IFNAMSIZ {
- return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG)
+ ifr, err := unix.NewIfreq(name)
+ if err != nil {
+ return nil, err
}
- copy(ifr[:], nameBytes)
- *(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags
-
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(nfd),
- uintptr(unix.TUNSETIFF),
- uintptr(unsafe.Pointer(&ifr[0])),
- )
- if errno != 0 {
- return nil, errno
+ // IFF_VNET_HDR enables the "tun status hack" via routineHackListener()
+ // where a null write will return EINVAL indicating the TUN is up.
+ ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
+ err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr)
+ if err != nil {
+ return nil, err
}
- err = unix.SetNonblock(nfd, true)
- // Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line.
-
- fd := os.NewFile(uintptr(nfd), cloneDevicePath)
+ err = unix.SetNonblock(nfd, true)
if err != nil {
+ unix.Close(nfd)
return nil, err
}
+ // Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line.
+
+ fd := os.NewFile(uintptr(nfd), cloneDevicePath)
return CreateTUNFromFile(fd, mtu)
}
+// CreateTUNFromFile creates a Device from an os.File with the provided MTU.
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{
tunFile: file,
events: make(chan Event, 5),
errors: make(chan error, 5),
statusListenersShutdown: make(chan struct{}),
- nopi: false,
+ tcpGROTable: newTCPGROTable(),
+ udpGROTable: newUDPGROTable(),
+ toWrite: make([]int, 0, conn.IdealBatchSize),
}
name, err := tun.Name()
@@ -459,8 +598,12 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return nil, err
}
- // start event listener
+ err = tun.initFromFlags(name)
+ if err != nil {
+ return nil, err
+ }
+ // start event listener
tun.index, err = getIFIndex(name)
if err != nil {
return nil, err
@@ -489,6 +632,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return tun, nil
}
+// CreateUnmonitoredTUNFromFD creates a Device from the provided file
+// descriptor.
func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
err := unix.SetNonblock(fd, true)
if err != nil {
@@ -496,14 +641,20 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
}
file := os.NewFile(uintptr(fd), "/dev/tun")
tun := &NativeTun{
- tunFile: file,
- events: make(chan Event, 5),
- errors: make(chan error, 5),
- nopi: true,
+ tunFile: file,
+ events: make(chan Event, 5),
+ errors: make(chan error, 5),
+ tcpGROTable: newTCPGROTable(),
+ udpGROTable: newUDPGROTable(),
+ toWrite: make([]int, 0, conn.IdealBatchSize),
}
name, err := tun.Name()
if err != nil {
return nil, "", err
}
- return tun, name, nil
+ err = tun.initFromFlags(name)
+ if err != nil {
+ return nil, "", err
+ }
+ return tun, name, err
}
diff --git a/tun/tun_openbsd.go b/tun/tun_openbsd.go
index 7ef62f4..ae571b9 100644
--- a/tun/tun_openbsd.go
+++ b/tun/tun_openbsd.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -8,13 +8,13 @@ package tun
import (
"errors"
"fmt"
+ "io"
"net"
"os"
"sync"
"syscall"
"unsafe"
- "golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
)
@@ -114,10 +114,10 @@ func CreateTUN(name string, mtu int) (Device, error) {
var err error
if ifIndex != -1 {
- tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR, 0)
+ tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0)
} else {
for ifIndex = 0; ifIndex < 256; ifIndex++ {
- tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR, 0)
+ tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0)
if err == nil || !errors.Is(err, syscall.EBUSY) {
break
}
@@ -133,7 +133,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
if err == nil && name == "tun" {
fname := os.Getenv("WG_TUN_NAME_FILE")
if fname != "" {
- os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0400)
+ os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400)
}
}
@@ -165,7 +165,7 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return nil, err
}
- tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+ tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC)
if err != nil {
tun.tunFile.Close()
return nil, err
@@ -200,50 +200,47 @@ func (tun *NativeTun) File() *os.File {
return tun.tunFile
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
select {
case err := <-tun.errors:
return 0, err
default:
- buff := buff[offset-4:]
- n, err := tun.tunFile.Read(buff[:])
+ buf := bufs[0][offset-4:]
+ n, err := tun.tunFile.Read(buf[:])
if n < 4 {
return 0, err
}
- return n - 4, err
+ sizes[0] = n - 4
+ return 1, err
}
}
-func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
-
- // reserve space for header
-
- buff = buff[offset-4:]
-
- // add packet information header
-
- buff[0] = 0x00
- buff[1] = 0x00
- buff[2] = 0x00
-
- if buff[4]>>4 == ipv6.Version {
- buff[3] = unix.AF_INET6
- } else {
- buff[3] = unix.AF_INET
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
+ if offset < 4 {
+ return 0, io.ErrShortBuffer
}
-
- // write
-
- return tun.tunFile.Write(buff)
-}
-
-func (tun *NativeTun) Flush() error {
- // TODO: can flushing be implemented by buffering and using sendmmsg?
- return nil
+ for i, buf := range bufs {
+ buf = buf[offset-4:]
+ buf[0] = 0x00
+ buf[1] = 0x00
+ buf[2] = 0x00
+ switch buf[4] >> 4 {
+ case 4:
+ buf[3] = unix.AF_INET
+ case 6:
+ buf[3] = unix.AF_INET6
+ default:
+ return i, unix.EAFNOSUPPORT
+ }
+ if _, err := tun.tunFile.Write(buf); err != nil {
+ return i, err
+ }
+ }
+ return len(bufs), nil
}
func (tun *NativeTun) Close() error {
@@ -271,10 +268,9 @@ func (tun *NativeTun) setMTU(n int) error {
fd, err := unix.Socket(
unix.AF_INET,
- unix.SOCK_DGRAM,
+ unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
-
if err != nil {
return err
}
@@ -306,10 +302,9 @@ func (tun *NativeTun) MTU() (int, error) {
fd, err := unix.Socket(
unix.AF_INET,
- unix.SOCK_DGRAM,
+ unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
-
if err != nil {
return 0, err
}
@@ -332,3 +327,7 @@ func (tun *NativeTun) MTU() (int, error) {
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
}
+
+func (tun *NativeTun) BatchSize() int {
+ return 1
+}
diff --git a/tun/tun_windows.go b/tun/tun_windows.go
index ff16e2f..2af8e3e 100644
--- a/tun/tun_windows.go
+++ b/tun/tun_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2018-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -8,7 +8,6 @@ package tun
import (
"errors"
"fmt"
- "log"
"os"
"sync"
"sync/atomic"
@@ -16,8 +15,7 @@ import (
_ "unsafe"
"golang.org/x/sys/windows"
-
- "golang.zx2c4.com/wireguard/tun/wintun"
+ "golang.zx2c4.com/wintun"
)
const (
@@ -27,14 +25,15 @@ const (
)
type rateJuggler struct {
- current uint64
- nextByteCount uint64
- nextStartTime int64
- changing int32
+ current atomic.Uint64
+ nextByteCount atomic.Uint64
+ nextStartTime atomic.Int64
+ changing atomic.Bool
}
type NativeTun struct {
wt *wintun.Adapter
+ name string
handle windows.Handle
rate rateJuggler
session wintun.Session
@@ -42,12 +41,15 @@ type NativeTun struct {
events chan Event
running sync.WaitGroup
closeOnce sync.Once
- close int32
+ close atomic.Bool
forcedMTU int
+ outSizes []int
}
-var WintunPool, _ = wintun.MakePool("WireGuard")
-var WintunStaticRequestedGUID *windows.GUID
+var (
+ WintunTunnelType = "WireGuard"
+ WintunStaticRequestedGUID *windows.GUID
+)
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
@@ -55,38 +57,19 @@ func procyield(cycles uint32)
//go:linkname nanotime runtime.nanotime
func nanotime() int64
-//
// CreateTUN creates a Wintun interface with the given name. Should a Wintun
// interface with the same name exist, it is reused.
-//
func CreateTUN(ifname string, mtu int) (Device, error) {
return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
}
-//
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
// a requested GUID. Should a Wintun interface with the same name exist, it is reused.
-//
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
- var err error
- var wt *wintun.Adapter
-
- // Does an interface with this name already exist?
- wt, err = WintunPool.OpenAdapter(ifname)
- if err == nil {
- // If so, we delete it, in case it has weird residual configuration.
- _, err = wt.Delete(true)
- if err != nil {
- return nil, fmt.Errorf("Error deleting already existing interface: %w", err)
- }
- }
- wt, rebootRequired, err := WintunPool.CreateAdapter(ifname, requestedGUID)
+ wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
if err != nil {
return nil, fmt.Errorf("Error creating interface: %w", err)
}
- if rebootRequired {
- log.Println("Windows indicated a reboot is required.")
- }
forcedMTU := 1420
if mtu > 0 {
@@ -95,6 +78,7 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
tun := &NativeTun{
wt: wt,
+ name: ifname,
handle: windows.InvalidHandle,
events: make(chan Event, 10),
forcedMTU: forcedMTU,
@@ -102,7 +86,7 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
if err != nil {
- tun.wt.Delete(false)
+ tun.wt.Close()
close(tun.events)
return nil, fmt.Errorf("Error starting session: %w", err)
}
@@ -111,31 +95,26 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
}
func (tun *NativeTun) Name() (string, error) {
- tun.running.Add(1)
- defer tun.running.Done()
- if atomic.LoadInt32(&tun.close) == 1 {
- return "", os.ErrClosed
- }
- return tun.wt.Name()
+ return tun.name, nil
}
func (tun *NativeTun) File() *os.File {
return nil
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
func (tun *NativeTun) Close() error {
var err error
tun.closeOnce.Do(func() {
- atomic.StoreInt32(&tun.close, 1)
+ tun.close.Store(true)
windows.SetEvent(tun.readWait)
tun.running.Wait()
tun.session.End()
if tun.wt != nil {
- _, err = tun.wt.Delete(false)
+ tun.wt.Close()
}
close(tun.events)
})
@@ -148,6 +127,9 @@ func (tun *NativeTun) MTU() (int, error) {
// TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
func (tun *NativeTun) ForceMTU(mtu int) {
+ if tun.close.Load() {
+ return
+ }
update := tun.forcedMTU != mtu
tun.forcedMTU = mtu
if update {
@@ -155,29 +137,34 @@ func (tun *NativeTun) ForceMTU(mtu int) {
}
}
+func (tun *NativeTun) BatchSize() int {
+ // TODO: implement batching with wintun
+ return 1
+}
+
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
tun.running.Add(1)
defer tun.running.Done()
retry:
- if atomic.LoadInt32(&tun.close) == 1 {
+ if tun.close.Load() {
return 0, os.ErrClosed
}
start := nanotime()
- shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
+ shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
for {
- if atomic.LoadInt32(&tun.close) == 1 {
+ if tun.close.Load() {
return 0, os.ErrClosed
}
packet, err := tun.session.ReceivePacket()
switch err {
case nil:
- packetSize := len(packet)
- copy(buff[offset:], packet)
+ n := copy(bufs[0][offset:], packet)
+ sizes[0] = n
tun.session.ReleaseReceivePacket(packet)
- tun.rate.update(uint64(packetSize))
- return packetSize, nil
+ tun.rate.update(uint64(n))
+ return 1, nil
case windows.ERROR_NO_MORE_ITEMS:
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
@@ -194,40 +181,40 @@ retry:
}
}
-func (tun *NativeTun) Flush() error {
- return nil
-}
-
-func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
tun.running.Add(1)
defer tun.running.Done()
- if atomic.LoadInt32(&tun.close) == 1 {
+ if tun.close.Load() {
return 0, os.ErrClosed
}
- packetSize := len(buff) - offset
- tun.rate.update(uint64(packetSize))
+ for i, buf := range bufs {
+ packetSize := len(buf) - offset
+ tun.rate.update(uint64(packetSize))
- packet, err := tun.session.AllocateSendPacket(packetSize)
- if err == nil {
- copy(packet, buff[offset:])
- tun.session.SendPacket(packet)
- return packetSize, nil
- }
- switch err {
- case windows.ERROR_HANDLE_EOF:
- return 0, os.ErrClosed
- case windows.ERROR_BUFFER_OVERFLOW:
- return 0, nil // Dropping when ring is full.
+ packet, err := tun.session.AllocateSendPacket(packetSize)
+ switch err {
+ case nil:
+ // TODO: Explore options to eliminate this copy.
+ copy(packet, buf[offset:])
+ tun.session.SendPacket(packet)
+ continue
+ case windows.ERROR_HANDLE_EOF:
+ return i, os.ErrClosed
+ case windows.ERROR_BUFFER_OVERFLOW:
+ continue // Dropping when ring is full.
+ default:
+ return i, fmt.Errorf("Write failed: %w", err)
+ }
}
- return 0, fmt.Errorf("Write failed: %w", err)
+ return len(bufs), nil
}
// LUID returns Windows interface instance ID.
func (tun *NativeTun) LUID() uint64 {
tun.running.Add(1)
defer tun.running.Done()
- if atomic.LoadInt32(&tun.close) == 1 {
+ if tun.close.Load() {
return 0
}
return tun.wt.LUID()
@@ -240,15 +227,15 @@ func (tun *NativeTun) RunningVersion() (version uint32, err error) {
func (rate *rateJuggler) update(packetLen uint64) {
now := nanotime()
- total := atomic.AddUint64(&rate.nextByteCount, packetLen)
- period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
+ total := rate.nextByteCount.Add(packetLen)
+ period := uint64(now - rate.nextStartTime.Load())
if period >= rateMeasurementGranularity {
- if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
+ if !rate.changing.CompareAndSwap(false, true) {
return
}
- atomic.StoreInt64(&rate.nextStartTime, now)
- atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
- atomic.StoreUint64(&rate.nextByteCount, 0)
- atomic.StoreInt32(&rate.changing, 0)
+ rate.nextStartTime.Store(now)
+ rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period)
+ rate.nextByteCount.Store(0)
+ rate.changing.Store(false)
}
}
diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go
index d89db71..d07e860 100644
--- a/tun/tuntest/tuntest.go
+++ b/tun/tuntest/tuntest.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tuntest
@@ -8,13 +8,13 @@ package tuntest
import (
"encoding/binary"
"io"
- "net"
+ "net/netip"
"os"
"golang.zx2c4.com/wireguard/tun"
)
-func Ping(dst, src net.IP) []byte {
+func Ping(dst, src netip.Addr) []byte {
localPort := uint16(1337)
seq := uint16(0)
@@ -40,7 +40,7 @@ func checksum(buf []byte, initial uint16) uint16 {
return ^uint16(v)
}
-func genICMPv4(payload []byte, dst, src net.IP) []byte {
+func genICMPv4(payload []byte, dst, src netip.Addr) []byte {
const (
icmpv4ProtocolNumber = 1
icmpv4Echo = 8
@@ -70,8 +70,8 @@ func genICMPv4(payload []byte, dst, src net.IP) []byte {
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
ip[8] = ttl
ip[9] = icmpv4ProtocolNumber
- copy(ip[12:], src.To4())
- copy(ip[16:], dst.To4())
+ copy(ip[12:], src.AsSlice())
+ copy(ip[16:], dst.AsSlice())
chksum = ^checksum(ip[:], 0)
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
@@ -110,38 +110,45 @@ type chTun struct {
func (t *chTun) File() *os.File { return nil }
-func (t *chTun) Read(data []byte, offset int) (int, error) {
+func (t *chTun) Read(packets [][]byte, sizes []int, offset int) (int, error) {
select {
case <-t.c.closed:
return 0, os.ErrClosed
case msg := <-t.c.Outbound:
- return copy(data[offset:], msg), nil
+ n := copy(packets[0][offset:], msg)
+ sizes[0] = n
+ return 1, nil
}
}
// Write is called by the wireguard device to deliver a packet for routing.
-func (t *chTun) Write(data []byte, offset int) (int, error) {
+func (t *chTun) Write(packets [][]byte, offset int) (int, error) {
if offset == -1 {
close(t.c.closed)
close(t.c.events)
return 0, io.EOF
}
- msg := make([]byte, len(data)-offset)
- copy(msg, data[offset:])
- select {
- case <-t.c.closed:
- return 0, os.ErrClosed
- case t.c.Inbound <- msg:
- return len(data) - offset, nil
+ for i, data := range packets {
+ msg := make([]byte, len(data)-offset)
+ copy(msg, data[offset:])
+ select {
+ case <-t.c.closed:
+ return i, os.ErrClosed
+ case t.c.Inbound <- msg:
+ }
}
+ return len(packets), nil
+}
+
+func (t *chTun) BatchSize() int {
+ return 1
}
const DefaultMTU = 1420
-func (t *chTun) Flush() error { return nil }
-func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
-func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
-func (t *chTun) Events() chan tun.Event { return t.c.events }
+func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
+func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
+func (t *chTun) Events() <-chan tun.Event { return t.c.events }
func (t *chTun) Close() error {
t.Write(nil, -1)
return nil
diff --git a/tun/wintun/dll_fromfile_windows.go b/tun/wintun/dll_fromfile_windows.go
deleted file mode 100644
index f40f8b3..0000000
--- a/tun/wintun/dll_fromfile_windows.go
+++ /dev/null
@@ -1,54 +0,0 @@
-// +build !load_wintun_from_rsrc
-
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package wintun
-
-import (
- "fmt"
- "sync"
- "sync/atomic"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-type lazyDLL struct {
- Name string
- mu sync.Mutex
- module windows.Handle
- onLoad func(d *lazyDLL)
-}
-
-func (d *lazyDLL) Load() error {
- if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
- return nil
- }
- d.mu.Lock()
- defer d.mu.Unlock()
- if d.module != 0 {
- return nil
- }
-
- const (
- LOAD_LIBRARY_SEARCH_APPLICATION_DIR = 0x00000200
- LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
- )
- module, err := windows.LoadLibraryEx(d.Name, 0, LOAD_LIBRARY_SEARCH_APPLICATION_DIR|LOAD_LIBRARY_SEARCH_SYSTEM32)
- if err != nil {
- return fmt.Errorf("Unable to load library: %w", err)
- }
-
- atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
- if d.onLoad != nil {
- d.onLoad(d)
- }
- return nil
-}
-
-func (p *lazyProc) nameToAddr() (uintptr, error) {
- return windows.GetProcAddress(p.dll.module, p.Name)
-}
diff --git a/tun/wintun/dll_fromrsrc_windows.go b/tun/wintun/dll_fromrsrc_windows.go
deleted file mode 100644
index dc70486..0000000
--- a/tun/wintun/dll_fromrsrc_windows.go
+++ /dev/null
@@ -1,61 +0,0 @@
-// +build load_wintun_from_rsrc
-
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package wintun
-
-import (
- "fmt"
- "sync"
- "sync/atomic"
- "unsafe"
-
- "golang.org/x/sys/windows"
-
- "golang.zx2c4.com/wireguard/tun/wintun/memmod"
-)
-
-type lazyDLL struct {
- Name string
- mu sync.Mutex
- module *memmod.Module
- onLoad func(d *lazyDLL)
-}
-
-func (d *lazyDLL) Load() error {
- if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
- return nil
- }
- d.mu.Lock()
- defer d.mu.Unlock()
- if d.module != nil {
- return nil
- }
-
- const ourModule windows.Handle = 0
- resInfo, err := windows.FindResource(ourModule, d.Name, windows.RT_RCDATA)
- if err != nil {
- return fmt.Errorf("Unable to find \"%v\" RCDATA resource: %w", d.Name, err)
- }
- data, err := windows.LoadResourceData(ourModule, resInfo)
- if err != nil {
- return fmt.Errorf("Unable to load resource: %w", err)
- }
- module, err := memmod.LoadLibrary(data)
- if err != nil {
- return fmt.Errorf("Unable to load library: %w", err)
- }
-
- atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
- if d.onLoad != nil {
- d.onLoad(d)
- }
- return nil
-}
-
-func (p *lazyProc) nameToAddr() (uintptr, error) {
- return p.dll.module.ProcAddressByName(p.Name)
-}
diff --git a/tun/wintun/dll_windows.go b/tun/wintun/dll_windows.go
deleted file mode 100644
index 1ecd6d0..0000000
--- a/tun/wintun/dll_windows.go
+++ /dev/null
@@ -1,59 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package wintun
-
-import (
- "fmt"
- "sync"
- "sync/atomic"
- "unsafe"
-)
-
-func newLazyDLL(name string, onLoad func(d *lazyDLL)) *lazyDLL {
- return &lazyDLL{Name: name, onLoad: onLoad}
-}
-
-func (d *lazyDLL) NewProc(name string) *lazyProc {
- return &lazyProc{dll: d, Name: name}
-}
-
-type lazyProc struct {
- Name string
- mu sync.Mutex
- dll *lazyDLL
- addr uintptr
-}
-
-func (p *lazyProc) Find() error {
- if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr))) != nil {
- return nil
- }
- p.mu.Lock()
- defer p.mu.Unlock()
- if p.addr != 0 {
- return nil
- }
-
- err := p.dll.Load()
- if err != nil {
- return fmt.Errorf("Error loading %v DLL: %w", p.dll.Name, err)
- }
- addr, err := p.nameToAddr()
- if err != nil {
- return fmt.Errorf("Error getting %v address: %w", p.Name, err)
- }
-
- atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr)), unsafe.Pointer(addr))
- return nil
-}
-
-func (p *lazyProc) Addr() uintptr {
- err := p.Find()
- if err != nil {
- panic(err)
- }
- return p.addr
-}
diff --git a/tun/wintun/memmod/memmod_windows.go b/tun/wintun/memmod/memmod_windows.go
deleted file mode 100644
index 075c03a..0000000
--- a/tun/wintun/memmod/memmod_windows.go
+++ /dev/null
@@ -1,635 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package memmod
-
-import (
- "errors"
- "fmt"
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-type addressList struct {
- next *addressList
- address uintptr
-}
-
-func (head *addressList) free() {
- for node := head; node != nil; node = node.next {
- windows.VirtualFree(node.address, 0, windows.MEM_RELEASE)
- }
-}
-
-type Module struct {
- headers *IMAGE_NT_HEADERS
- codeBase uintptr
- modules []windows.Handle
- initialized bool
- isDLL bool
- isRelocated bool
- nameExports map[string]uint16
- entry uintptr
- blockedMemory *addressList
-}
-
-func (module *Module) headerDirectory(idx int) *IMAGE_DATA_DIRECTORY {
- return &module.headers.OptionalHeader.DataDirectory[idx]
-}
-
-func (module *Module) copySections(address uintptr, size uintptr, oldHeaders *IMAGE_NT_HEADERS) error {
- sections := module.headers.Sections()
- for i := range sections {
- if sections[i].SizeOfRawData == 0 {
- // Section doesn't contain data in the dll itself, but may define uninitialized data.
- sectionSize := oldHeaders.OptionalHeader.SectionAlignment
- if sectionSize == 0 {
- continue
- }
- dest, err := windows.VirtualAlloc(module.codeBase+uintptr(sections[i].VirtualAddress),
- uintptr(sectionSize),
- windows.MEM_COMMIT,
- windows.PAGE_READWRITE)
- if err != nil {
- return fmt.Errorf("Error allocating section: %w", err)
- }
-
- // Always use position from file to support alignments smaller than page size (allocation above will align to page size).
- dest = module.codeBase + uintptr(sections[i].VirtualAddress)
- // NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used.
- sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff))
- var dst []byte
- unsafeSlice(unsafe.Pointer(&dst), a2p(dest), int(sectionSize))
- for j := range dst {
- dst[j] = 0
- }
- continue
- }
-
- if size < uintptr(sections[i].PointerToRawData+sections[i].SizeOfRawData) {
- return errors.New("Incomplete section")
- }
-
- // Commit memory block and copy data from dll.
- dest, err := windows.VirtualAlloc(module.codeBase+uintptr(sections[i].VirtualAddress),
- uintptr(sections[i].SizeOfRawData),
- windows.MEM_COMMIT,
- windows.PAGE_READWRITE)
- if err != nil {
- return fmt.Errorf("Error allocating memory block: %w", err)
- }
-
- // Always use position from file to support alignments smaller than page size (allocation above will align to page size).
- memcpy(
- module.codeBase+uintptr(sections[i].VirtualAddress),
- address+uintptr(sections[i].PointerToRawData),
- uintptr(sections[i].SizeOfRawData))
- // NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used.
- sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff))
- }
-
- return nil
-}
-
-func (module *Module) realSectionSize(section *IMAGE_SECTION_HEADER) uintptr {
- size := section.SizeOfRawData
- if size != 0 {
- return uintptr(size)
- }
- if (section.Characteristics & IMAGE_SCN_CNT_INITIALIZED_DATA) != 0 {
- return uintptr(module.headers.OptionalHeader.SizeOfInitializedData)
- }
- if (section.Characteristics & IMAGE_SCN_CNT_UNINITIALIZED_DATA) != 0 {
- return uintptr(module.headers.OptionalHeader.SizeOfUninitializedData)
- }
- return 0
-}
-
-type sectionFinalizeData struct {
- address uintptr
- alignedAddress uintptr
- size uintptr
- characteristics uint32
- last bool
-}
-
-func (module *Module) finalizeSection(sectionData *sectionFinalizeData) error {
- if sectionData.size == 0 {
- return nil
- }
-
- if (sectionData.characteristics & IMAGE_SCN_MEM_DISCARDABLE) != 0 {
- // Section is not needed any more and can safely be freed.
- if sectionData.address == sectionData.alignedAddress &&
- (sectionData.last ||
- (sectionData.size%uintptr(module.headers.OptionalHeader.SectionAlignment)) == 0) {
- // Only allowed to decommit whole pages.
- windows.VirtualFree(sectionData.address, sectionData.size, windows.MEM_DECOMMIT)
- }
- return nil
- }
-
- // determine protection flags based on characteristics
- var ProtectionFlags = [8]uint32{
- windows.PAGE_NOACCESS, // not writeable, not readable, not executable
- windows.PAGE_EXECUTE, // not writeable, not readable, executable
- windows.PAGE_READONLY, // not writeable, readable, not executable
- windows.PAGE_EXECUTE_READ, // not writeable, readable, executable
- windows.PAGE_WRITECOPY, // writeable, not readable, not executable
- windows.PAGE_EXECUTE_WRITECOPY, // writeable, not readable, executable
- windows.PAGE_READWRITE, // writeable, readable, not executable
- windows.PAGE_EXECUTE_READWRITE, // writeable, readable, executable
- }
- protect := ProtectionFlags[sectionData.characteristics>>29]
- if (sectionData.characteristics & IMAGE_SCN_MEM_NOT_CACHED) != 0 {
- protect |= windows.PAGE_NOCACHE
- }
-
- // Change memory access flags.
- var oldProtect uint32
- err := windows.VirtualProtect(sectionData.address, sectionData.size, protect, &oldProtect)
- if err != nil {
- return fmt.Errorf("Error protecting memory page: %w", err)
- }
-
- return nil
-}
-
-var rtlAddFunctionTable = windows.NewLazySystemDLL("ntdll.dll").NewProc("RtlAddFunctionTable")
-
-func (module *Module) registerExceptionHandlers() {
- directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXCEPTION)
- if directory.Size == 0 || directory.VirtualAddress == 0 {
- return
- }
- rtlAddFunctionTable.Call(module.codeBase+uintptr(directory.VirtualAddress), uintptr(directory.Size)/unsafe.Sizeof(IMAGE_RUNTIME_FUNCTION_ENTRY{}), module.codeBase)
-}
-
-func (module *Module) finalizeSections() error {
- sections := module.headers.Sections()
- imageOffset := module.headers.OptionalHeader.imageOffset()
- sectionData := sectionFinalizeData{}
- sectionData.address = uintptr(sections[0].PhysicalAddress()) | imageOffset
- sectionData.alignedAddress = alignDown(sectionData.address, uintptr(module.headers.OptionalHeader.SectionAlignment))
- sectionData.size = module.realSectionSize(&sections[0])
- sections[0].SetVirtualSize(uint32(sectionData.size))
- sectionData.characteristics = sections[0].Characteristics
-
- // Loop through all sections and change access flags.
- for i := uint16(1); i < module.headers.FileHeader.NumberOfSections; i++ {
- sectionAddress := uintptr(sections[i].PhysicalAddress()) | imageOffset
- alignedAddress := alignDown(sectionAddress, uintptr(module.headers.OptionalHeader.SectionAlignment))
- sectionSize := module.realSectionSize(&sections[i])
- sections[i].SetVirtualSize(uint32(sectionSize))
- // Combine access flags of all sections that share a page.
- // TODO: We currently share flags of a trailing large section with the page of a first small section. This should be optimized.
- if sectionData.alignedAddress == alignedAddress || sectionData.address+sectionData.size > alignedAddress {
- // Section shares page with previous.
- if (sections[i].Characteristics&IMAGE_SCN_MEM_DISCARDABLE) == 0 || (sectionData.characteristics&IMAGE_SCN_MEM_DISCARDABLE) == 0 {
- sectionData.characteristics = (sectionData.characteristics | sections[i].Characteristics) &^ IMAGE_SCN_MEM_DISCARDABLE
- } else {
- sectionData.characteristics |= sections[i].Characteristics
- }
- sectionData.size = sectionAddress + sectionSize - sectionData.address
- continue
- }
-
- err := module.finalizeSection(&sectionData)
- if err != nil {
- return fmt.Errorf("Error finalizing section: %w", err)
- }
- sectionData.address = sectionAddress
- sectionData.alignedAddress = alignedAddress
- sectionData.size = sectionSize
- sectionData.characteristics = sections[i].Characteristics
- }
- sectionData.last = true
- err := module.finalizeSection(&sectionData)
- if err != nil {
- return fmt.Errorf("Error finalizing section: %w", err)
- }
- return nil
-}
-
-func (module *Module) executeTLS() {
- directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_TLS)
- if directory.VirtualAddress == 0 {
- return
- }
-
- tls := (*IMAGE_TLS_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
- callback := tls.AddressOfCallbacks
- if callback != 0 {
- for {
- f := *(*uintptr)(a2p(callback))
- if f == 0 {
- break
- }
- syscall.Syscall(f, 3, module.codeBase, uintptr(DLL_PROCESS_ATTACH), uintptr(0))
- callback += unsafe.Sizeof(f)
- }
- }
-}
-
-func (module *Module) performBaseRelocation(delta uintptr) (relocated bool, err error) {
- directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_BASERELOC)
- if directory.Size == 0 {
- return delta == 0, nil
- }
-
- relocationHdr := (*IMAGE_BASE_RELOCATION)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
- for relocationHdr.VirtualAddress > 0 {
- dest := module.codeBase + uintptr(relocationHdr.VirtualAddress)
-
- var relInfos []uint16
- unsafeSlice(
- unsafe.Pointer(&relInfos),
- a2p(uintptr(unsafe.Pointer(relocationHdr))+unsafe.Sizeof(*relocationHdr)),
- int((uintptr(relocationHdr.SizeOfBlock)-unsafe.Sizeof(*relocationHdr))/unsafe.Sizeof(relInfos[0])))
- for _, relInfo := range relInfos {
- // The upper 4 bits define the type of relocation.
- relType := relInfo >> 12
- // The lower 12 bits define the offset.
- relOffset := uintptr(relInfo & 0xfff)
-
- switch relType {
- case IMAGE_REL_BASED_ABSOLUTE:
- // Skip relocation.
-
- case IMAGE_REL_BASED_LOW:
- *(*uint16)(a2p(dest + relOffset)) += uint16(delta & 0xffff)
- break
-
- case IMAGE_REL_BASED_HIGH:
- *(*uint16)(a2p(dest + relOffset)) += uint16(uint32(delta) >> 16)
- break
-
- case IMAGE_REL_BASED_HIGHLOW:
- *(*uint32)(a2p(dest + relOffset)) += uint32(delta)
-
- case IMAGE_REL_BASED_DIR64:
- *(*uint64)(a2p(dest + relOffset)) += uint64(delta)
-
- case IMAGE_REL_BASED_THUMB_MOV32:
- inst := *(*uint32)(a2p(dest + relOffset))
- imm16 := ((inst << 1) & 0x0800) + ((inst << 12) & 0xf000) +
- ((inst >> 20) & 0x0700) + ((inst >> 16) & 0x00ff)
- if (inst & 0x8000fbf0) != 0x0000f240 {
- return false, fmt.Errorf("Wrong Thumb2 instruction %08x, expected MOVW", inst)
- }
- imm16 += uint32(delta) & 0xffff
- hiDelta := (uint32(delta&0xffff0000) >> 16) + ((imm16 & 0xffff0000) >> 16)
- *(*uint32)(a2p(dest + relOffset)) = (inst & 0x8f00fbf0) + ((imm16 >> 1) & 0x0400) +
- ((imm16 >> 12) & 0x000f) +
- ((imm16 << 20) & 0x70000000) +
- ((imm16 << 16) & 0xff0000)
- if hiDelta != 0 {
- inst = *(*uint32)(a2p(dest + relOffset + 4))
- imm16 = ((inst << 1) & 0x0800) + ((inst << 12) & 0xf000) +
- ((inst >> 20) & 0x0700) + ((inst >> 16) & 0x00ff)
- if (inst & 0x8000fbf0) != 0x0000f2c0 {
- return false, fmt.Errorf("Wrong Thumb2 instruction %08x, expected MOVT", inst)
- }
- imm16 += hiDelta
- if imm16 > 0xffff {
- return false, fmt.Errorf("Resulting immediate value won't fit: %08x", imm16)
- }
- *(*uint32)(a2p(dest + relOffset + 4)) = (inst & 0x8f00fbf0) +
- ((imm16 >> 1) & 0x0400) +
- ((imm16 >> 12) & 0x000f) +
- ((imm16 << 20) & 0x70000000) +
- ((imm16 << 16) & 0xff0000)
- }
-
- default:
- return false, fmt.Errorf("Unsupported relocation: %v", relType)
- }
- }
-
- // Advance to next relocation block.
- relocationHdr = (*IMAGE_BASE_RELOCATION)(a2p(uintptr(unsafe.Pointer(relocationHdr)) + uintptr(relocationHdr.SizeOfBlock)))
- }
- return true, nil
-}
-
-func (module *Module) buildImportTable() error {
- directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_IMPORT)
- if directory.Size == 0 {
- return nil
- }
-
- module.modules = make([]windows.Handle, 0, 16)
- importDesc := (*IMAGE_IMPORT_DESCRIPTOR)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
- for importDesc.Name != 0 {
- handle, err := windows.LoadLibraryEx(windows.BytePtrToString((*byte)(a2p(module.codeBase+uintptr(importDesc.Name)))), 0, windows.LOAD_LIBRARY_SEARCH_SYSTEM32)
- if err != nil {
- return fmt.Errorf("Error loading module: %w", err)
- }
- var thunkRef, funcRef *uintptr
- if importDesc.OriginalFirstThunk() != 0 {
- thunkRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.OriginalFirstThunk())))
- funcRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
- } else {
- // No hint table.
- thunkRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
- funcRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
- }
- for *thunkRef != 0 {
- if IMAGE_SNAP_BY_ORDINAL(*thunkRef) {
- *funcRef, err = windows.GetProcAddressByOrdinal(handle, IMAGE_ORDINAL(*thunkRef))
- } else {
- thunkData := (*IMAGE_IMPORT_BY_NAME)(a2p(module.codeBase + *thunkRef))
- *funcRef, err = windows.GetProcAddress(handle, windows.BytePtrToString(&thunkData.Name[0]))
- }
- if err != nil {
- windows.FreeLibrary(handle)
- return fmt.Errorf("Error getting function address: %w", err)
- }
- thunkRef = (*uintptr)(a2p(uintptr(unsafe.Pointer(thunkRef)) + unsafe.Sizeof(*thunkRef)))
- funcRef = (*uintptr)(a2p(uintptr(unsafe.Pointer(funcRef)) + unsafe.Sizeof(*funcRef)))
- }
- module.modules = append(module.modules, handle)
- importDesc = (*IMAGE_IMPORT_DESCRIPTOR)(a2p(uintptr(unsafe.Pointer(importDesc)) + unsafe.Sizeof(*importDesc)))
- }
- return nil
-}
-
-func (module *Module) buildNameExports() error {
- directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
- if directory.Size == 0 {
- return errors.New("No export table found")
- }
- exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
- if exports.NumberOfNames == 0 || exports.NumberOfFunctions == 0 {
- return errors.New("No functions exported")
- }
- if exports.NumberOfNames == 0 {
- return errors.New("No functions exported by name")
- }
- var nameRefs []uint32
- unsafeSlice(unsafe.Pointer(&nameRefs), a2p(module.codeBase+uintptr(exports.AddressOfNames)), int(exports.NumberOfNames))
- var ordinals []uint16
- unsafeSlice(unsafe.Pointer(&ordinals), a2p(module.codeBase+uintptr(exports.AddressOfNameOrdinals)), int(exports.NumberOfNames))
- module.nameExports = make(map[string]uint16)
- for i := range nameRefs {
- nameArray := windows.BytePtrToString((*byte)(a2p(module.codeBase + uintptr(nameRefs[i]))))
- module.nameExports[nameArray] = ordinals[i]
- }
- return nil
-}
-
-// LoadLibrary loads module image to memory.
-func LoadLibrary(data []byte) (module *Module, err error) {
- addr := uintptr(unsafe.Pointer(&data[0]))
- size := uintptr(len(data))
- if size < unsafe.Sizeof(IMAGE_DOS_HEADER{}) {
- return nil, errors.New("Incomplete IMAGE_DOS_HEADER")
- }
- dosHeader := (*IMAGE_DOS_HEADER)(a2p(addr))
- if dosHeader.E_magic != IMAGE_DOS_SIGNATURE {
- return nil, fmt.Errorf("Not an MS-DOS binary (provided: %x, expected: %x)", dosHeader.E_magic, IMAGE_DOS_SIGNATURE)
- }
- if (size < uintptr(dosHeader.E_lfanew)+unsafe.Sizeof(IMAGE_NT_HEADERS{})) {
- return nil, errors.New("Incomplete IMAGE_NT_HEADERS")
- }
- oldHeader := (*IMAGE_NT_HEADERS)(a2p(addr + uintptr(dosHeader.E_lfanew)))
- if oldHeader.Signature != IMAGE_NT_SIGNATURE {
- return nil, fmt.Errorf("Not an NT binary (provided: %x, expected: %x)", oldHeader.Signature, IMAGE_NT_SIGNATURE)
- }
- if oldHeader.FileHeader.Machine != imageFileProcess {
- return nil, fmt.Errorf("Foreign platform (provided: %x, expected: %x)", oldHeader.FileHeader.Machine, imageFileProcess)
- }
- if (oldHeader.OptionalHeader.SectionAlignment & 1) != 0 {
- return nil, errors.New("Unaligned section")
- }
- lastSectionEnd := uintptr(0)
- sections := oldHeader.Sections()
- optionalSectionSize := oldHeader.OptionalHeader.SectionAlignment
- for i := range sections {
- var endOfSection uintptr
- if sections[i].SizeOfRawData == 0 {
- // Section without data in the DLL
- endOfSection = uintptr(sections[i].VirtualAddress) + uintptr(optionalSectionSize)
- } else {
- endOfSection = uintptr(sections[i].VirtualAddress) + uintptr(sections[i].SizeOfRawData)
- }
- if endOfSection > lastSectionEnd {
- lastSectionEnd = endOfSection
- }
- }
- alignedImageSize := alignUp(uintptr(oldHeader.OptionalHeader.SizeOfImage), uintptr(oldHeader.OptionalHeader.SectionAlignment))
- if alignedImageSize != alignUp(lastSectionEnd, uintptr(oldHeader.OptionalHeader.SectionAlignment)) {
- return nil, errors.New("Section is not page-aligned")
- }
-
- module = &Module{isDLL: (oldHeader.FileHeader.Characteristics & IMAGE_FILE_DLL) != 0}
- defer func() {
- if err != nil {
- module.Free()
- module = nil
- }
- }()
-
- // Reserve memory for image of library.
- // TODO: Is it correct to commit the complete memory region at once? Calling DllEntry raises an exception if we don't.
- module.codeBase, err = windows.VirtualAlloc(oldHeader.OptionalHeader.ImageBase,
- alignedImageSize,
- windows.MEM_RESERVE|windows.MEM_COMMIT,
- windows.PAGE_READWRITE)
- if err != nil {
- // Try to allocate memory at arbitrary position.
- module.codeBase, err = windows.VirtualAlloc(0,
- alignedImageSize,
- windows.MEM_RESERVE|windows.MEM_COMMIT,
- windows.PAGE_READWRITE)
- if err != nil {
- err = fmt.Errorf("Error allocating code: %w", err)
- return
- }
- }
- err = module.check4GBBoundaries(alignedImageSize)
- if err != nil {
- err = fmt.Errorf("Error reallocating code: %w", err)
- return
- }
-
- if size < uintptr(oldHeader.OptionalHeader.SizeOfHeaders) {
- err = errors.New("Incomplete headers")
- return
- }
- // Commit memory for headers.
- headers, err := windows.VirtualAlloc(module.codeBase,
- uintptr(oldHeader.OptionalHeader.SizeOfHeaders),
- windows.MEM_COMMIT,
- windows.PAGE_READWRITE)
- if err != nil {
- err = fmt.Errorf("Error allocating headers: %w", err)
- return
- }
- // Copy PE header to code.
- memcpy(headers, addr, uintptr(oldHeader.OptionalHeader.SizeOfHeaders))
- module.headers = (*IMAGE_NT_HEADERS)(a2p(headers + uintptr(dosHeader.E_lfanew)))
-
- // Update position.
- module.headers.OptionalHeader.ImageBase = module.codeBase
-
- // Copy sections from DLL file block to new memory location.
- err = module.copySections(addr, size, oldHeader)
- if err != nil {
- err = fmt.Errorf("Error copying sections: %w", err)
- return
- }
-
- // Adjust base address of imported data.
- locationDelta := module.headers.OptionalHeader.ImageBase - oldHeader.OptionalHeader.ImageBase
- if locationDelta != 0 {
- module.isRelocated, err = module.performBaseRelocation(locationDelta)
- if err != nil {
- err = fmt.Errorf("Error relocating module: %w", err)
- return
- }
- } else {
- module.isRelocated = true
- }
-
- // Load required dlls and adjust function table of imports.
- err = module.buildImportTable()
- if err != nil {
- err = fmt.Errorf("Error building import table: %w", err)
- return
- }
-
- // Mark memory pages depending on section headers and release sections that are marked as "discardable".
- err = module.finalizeSections()
- if err != nil {
- err = fmt.Errorf("Error finalizing sections: %w", err)
- return
- }
-
- // Register exception tables, if they exist.
- module.registerExceptionHandlers()
-
- // TLS callbacks are executed BEFORE the main loading.
- module.executeTLS()
-
- // Get entry point of loaded module.
- if module.headers.OptionalHeader.AddressOfEntryPoint != 0 {
- module.entry = module.codeBase + uintptr(module.headers.OptionalHeader.AddressOfEntryPoint)
- if module.isDLL {
- // Notify library about attaching to process.
- r0, _, _ := syscall.Syscall(module.entry, 3, module.codeBase, uintptr(DLL_PROCESS_ATTACH), 0)
- successful := r0 != 0
- if !successful {
- err = windows.ERROR_DLL_INIT_FAILED
- return
- }
- module.initialized = true
- }
- }
-
- module.buildNameExports()
- return
-}
-
-// Free releases module resources and unloads it.
-func (module *Module) Free() {
- if module.initialized {
- // Notify library about detaching from process.
- syscall.Syscall(module.entry, 3, module.codeBase, uintptr(DLL_PROCESS_DETACH), 0)
- module.initialized = false
- }
- if module.modules != nil {
- // Free previously opened libraries.
- for _, handle := range module.modules {
- windows.FreeLibrary(handle)
- }
- module.modules = nil
- }
- if module.codeBase != 0 {
- windows.VirtualFree(module.codeBase, 0, windows.MEM_RELEASE)
- module.codeBase = 0
- }
- if module.blockedMemory != nil {
- module.blockedMemory.free()
- module.blockedMemory = nil
- }
-}
-
-// ProcAddressByName returns function address by exported name.
-func (module *Module) ProcAddressByName(name string) (uintptr, error) {
- directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
- if directory.Size == 0 {
- return 0, errors.New("No export table found")
- }
- exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
- if module.nameExports == nil {
- return 0, errors.New("No functions exported by name")
- }
- if idx, ok := module.nameExports[name]; ok {
- if uint32(idx) > exports.NumberOfFunctions {
- return 0, errors.New("Ordinal number too high")
- }
- // AddressOfFunctions contains the RVAs to the "real" functions.
- return module.codeBase + uintptr(*(*uint32)(a2p(module.codeBase + uintptr(exports.AddressOfFunctions) + uintptr(idx)*4))), nil
- }
- return 0, errors.New("Function not found by name")
-}
-
-// ProcAddressByOrdinal returns function address by exported ordinal.
-func (module *Module) ProcAddressByOrdinal(ordinal uint16) (uintptr, error) {
- directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
- if directory.Size == 0 {
- return 0, errors.New("No export table found")
- }
- exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
- if uint32(ordinal) < exports.Base {
- return 0, errors.New("Ordinal number too low")
- }
- idx := ordinal - uint16(exports.Base)
- if uint32(idx) > exports.NumberOfFunctions {
- return 0, errors.New("Ordinal number too high")
- }
- // AddressOfFunctions contains the RVAs to the "real" functions.
- return module.codeBase + uintptr(*(*uint32)(a2p(module.codeBase + uintptr(exports.AddressOfFunctions) + uintptr(idx)*4))), nil
-}
-
-func alignDown(value, alignment uintptr) uintptr {
- return value & ^(alignment - 1)
-}
-
-func alignUp(value, alignment uintptr) uintptr {
- return (value + alignment - 1) & ^(alignment - 1)
-}
-
-func a2p(addr uintptr) unsafe.Pointer {
- return unsafe.Pointer(addr)
-}
-
-func memcpy(dst, src, size uintptr) {
- var d, s []byte
- unsafeSlice(unsafe.Pointer(&d), a2p(dst), int(size))
- unsafeSlice(unsafe.Pointer(&s), a2p(src), int(size))
- copy(d, s)
-}
-
-// unsafeSlice updates the slice slicePtr to be a slice
-// referencing the provided data with its length & capacity set to
-// lenCap.
-//
-// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
-// update callers to use unsafe.Slice instead of this.
-func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
- type sliceHeader struct {
- Data unsafe.Pointer
- Len int
- Cap int
- }
- h := (*sliceHeader)(slicePtr)
- h.Data = data
- h.Len = lenCap
- h.Cap = lenCap
-}
diff --git a/tun/wintun/memmod/memmod_windows_32.go b/tun/wintun/memmod/memmod_windows_32.go
deleted file mode 100644
index ac76bdc..0000000
--- a/tun/wintun/memmod/memmod_windows_32.go
+++ /dev/null
@@ -1,16 +0,0 @@
-// +build windows,386 windows,arm
-
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package memmod
-
-func (opthdr *IMAGE_OPTIONAL_HEADER) imageOffset() uintptr {
- return 0
-}
-
-func (module *Module) check4GBBoundaries(alignedImageSize uintptr) (err error) {
- return
-}
diff --git a/tun/wintun/memmod/memmod_windows_386.go b/tun/wintun/memmod/memmod_windows_386.go
deleted file mode 100644
index 475c5c5..0000000
--- a/tun/wintun/memmod/memmod_windows_386.go
+++ /dev/null
@@ -1,8 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package memmod
-
-const imageFileProcess = IMAGE_FILE_MACHINE_I386
diff --git a/tun/wintun/memmod/memmod_windows_64.go b/tun/wintun/memmod/memmod_windows_64.go
deleted file mode 100644
index a620368..0000000
--- a/tun/wintun/memmod/memmod_windows_64.go
+++ /dev/null
@@ -1,36 +0,0 @@
-// +build windows,amd64 windows,arm64
-
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package memmod
-
-import (
- "fmt"
-
- "golang.org/x/sys/windows"
-)
-
-func (opthdr *IMAGE_OPTIONAL_HEADER) imageOffset() uintptr {
- return uintptr(opthdr.ImageBase & 0xffffffff00000000)
-}
-
-func (module *Module) check4GBBoundaries(alignedImageSize uintptr) (err error) {
- for (module.codeBase >> 32) < ((module.codeBase + alignedImageSize) >> 32) {
- node := &addressList{
- next: module.blockedMemory,
- address: module.codeBase,
- }
- module.blockedMemory = node
- module.codeBase, err = windows.VirtualAlloc(0,
- alignedImageSize,
- windows.MEM_RESERVE|windows.MEM_COMMIT,
- windows.PAGE_READWRITE)
- if err != nil {
- return fmt.Errorf("Error allocating memory block: %w", err)
- }
- }
- return
-}
diff --git a/tun/wintun/memmod/memmod_windows_amd64.go b/tun/wintun/memmod/memmod_windows_amd64.go
deleted file mode 100644
index a021a63..0000000
--- a/tun/wintun/memmod/memmod_windows_amd64.go
+++ /dev/null
@@ -1,8 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package memmod
-
-const imageFileProcess = IMAGE_FILE_MACHINE_AMD64
diff --git a/tun/wintun/memmod/memmod_windows_arm.go b/tun/wintun/memmod/memmod_windows_arm.go
deleted file mode 100644
index 4637a01..0000000
--- a/tun/wintun/memmod/memmod_windows_arm.go
+++ /dev/null
@@ -1,8 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package memmod
-
-const imageFileProcess = IMAGE_FILE_MACHINE_ARMNT
diff --git a/tun/wintun/memmod/memmod_windows_arm64.go b/tun/wintun/memmod/memmod_windows_arm64.go
deleted file mode 100644
index b8f1259..0000000
--- a/tun/wintun/memmod/memmod_windows_arm64.go
+++ /dev/null
@@ -1,8 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package memmod
-
-const imageFileProcess = IMAGE_FILE_MACHINE_ARM64
diff --git a/tun/wintun/memmod/syscall_windows.go b/tun/wintun/memmod/syscall_windows.go
deleted file mode 100644
index a111f92..0000000
--- a/tun/wintun/memmod/syscall_windows.go
+++ /dev/null
@@ -1,398 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package memmod
-
-import "unsafe"
-
-const (
- IMAGE_DOS_SIGNATURE = 0x5A4D // MZ
- IMAGE_OS2_SIGNATURE = 0x454E // NE
- IMAGE_OS2_SIGNATURE_LE = 0x454C // LE
- IMAGE_VXD_SIGNATURE = 0x454C // LE
- IMAGE_NT_SIGNATURE = 0x00004550 // PE00
-)
-
-// DOS .EXE header
-type IMAGE_DOS_HEADER struct {
- E_magic uint16 // Magic number
- E_cblp uint16 // Bytes on last page of file
- E_cp uint16 // Pages in file
- E_crlc uint16 // Relocations
- E_cparhdr uint16 // Size of header in paragraphs
- E_minalloc uint16 // Minimum extra paragraphs needed
- E_maxalloc uint16 // Maximum extra paragraphs needed
- E_ss uint16 // Initial (relative) SS value
- E_sp uint16 // Initial SP value
- E_csum uint16 // Checksum
- E_ip uint16 // Initial IP value
- E_cs uint16 // Initial (relative) CS value
- E_lfarlc uint16 // File address of relocation table
- E_ovno uint16 // Overlay number
- E_res [4]uint16 // Reserved words
- E_oemid uint16 // OEM identifier (for e_oeminfo)
- E_oeminfo uint16 // OEM information; e_oemid specific
- E_res2 [10]uint16 // Reserved words
- E_lfanew int32 // File address of new exe header
-}
-
-// File header format
-type IMAGE_FILE_HEADER struct {
- Machine uint16
- NumberOfSections uint16
- TimeDateStamp uint32
- PointerToSymbolTable uint32
- NumberOfSymbols uint32
- SizeOfOptionalHeader uint16
- Characteristics uint16
-}
-
-const (
- IMAGE_SIZEOF_FILE_HEADER = 20
-
- IMAGE_FILE_RELOCS_STRIPPED = 0x0001 // Relocation info stripped from file.
- IMAGE_FILE_EXECUTABLE_IMAGE = 0x0002 // File is executable (i.e. no unresolved external references).
- IMAGE_FILE_LINE_NUMS_STRIPPED = 0x0004 // Line nunbers stripped from file.
- IMAGE_FILE_LOCAL_SYMS_STRIPPED = 0x0008 // Local symbols stripped from file.
- IMAGE_FILE_AGGRESIVE_WS_TRIM = 0x0010 // Aggressively trim working set
- IMAGE_FILE_LARGE_ADDRESS_AWARE = 0x0020 // App can handle >2gb addresses
- IMAGE_FILE_BYTES_REVERSED_LO = 0x0080 // Bytes of machine word are reversed.
- IMAGE_FILE_32BIT_MACHINE = 0x0100 // 32 bit word machine.
- IMAGE_FILE_DEBUG_STRIPPED = 0x0200 // Debugging info stripped from file in .DBG file
- IMAGE_FILE_REMOVABLE_RUN_FROM_SWAP = 0x0400 // If Image is on removable media, copy and run from the swap file.
- IMAGE_FILE_NET_RUN_FROM_SWAP = 0x0800 // If Image is on Net, copy and run from the swap file.
- IMAGE_FILE_SYSTEM = 0x1000 // System File.
- IMAGE_FILE_DLL = 0x2000 // File is a DLL.
- IMAGE_FILE_UP_SYSTEM_ONLY = 0x4000 // File should only be run on a UP machine
- IMAGE_FILE_BYTES_REVERSED_HI = 0x8000 // Bytes of machine word are reversed.
-
- IMAGE_FILE_MACHINE_UNKNOWN = 0
- IMAGE_FILE_MACHINE_TARGET_HOST = 0x0001 // Useful for indicating we want to interact with the host and not a WoW guest.
- IMAGE_FILE_MACHINE_I386 = 0x014c // Intel 386.
- IMAGE_FILE_MACHINE_R3000 = 0x0162 // MIPS little-endian, 0x160 big-endian
- IMAGE_FILE_MACHINE_R4000 = 0x0166 // MIPS little-endian
- IMAGE_FILE_MACHINE_R10000 = 0x0168 // MIPS little-endian
- IMAGE_FILE_MACHINE_WCEMIPSV2 = 0x0169 // MIPS little-endian WCE v2
- IMAGE_FILE_MACHINE_ALPHA = 0x0184 // Alpha_AXP
- IMAGE_FILE_MACHINE_SH3 = 0x01a2 // SH3 little-endian
- IMAGE_FILE_MACHINE_SH3DSP = 0x01a3
- IMAGE_FILE_MACHINE_SH3E = 0x01a4 // SH3E little-endian
- IMAGE_FILE_MACHINE_SH4 = 0x01a6 // SH4 little-endian
- IMAGE_FILE_MACHINE_SH5 = 0x01a8 // SH5
- IMAGE_FILE_MACHINE_ARM = 0x01c0 // ARM Little-Endian
- IMAGE_FILE_MACHINE_THUMB = 0x01c2 // ARM Thumb/Thumb-2 Little-Endian
- IMAGE_FILE_MACHINE_ARMNT = 0x01c4 // ARM Thumb-2 Little-Endian
- IMAGE_FILE_MACHINE_AM33 = 0x01d3
- IMAGE_FILE_MACHINE_POWERPC = 0x01F0 // IBM PowerPC Little-Endian
- IMAGE_FILE_MACHINE_POWERPCFP = 0x01f1
- IMAGE_FILE_MACHINE_IA64 = 0x0200 // Intel 64
- IMAGE_FILE_MACHINE_MIPS16 = 0x0266 // MIPS
- IMAGE_FILE_MACHINE_ALPHA64 = 0x0284 // ALPHA64
- IMAGE_FILE_MACHINE_MIPSFPU = 0x0366 // MIPS
- IMAGE_FILE_MACHINE_MIPSFPU16 = 0x0466 // MIPS
- IMAGE_FILE_MACHINE_AXP64 = IMAGE_FILE_MACHINE_ALPHA64
- IMAGE_FILE_MACHINE_TRICORE = 0x0520 // Infineon
- IMAGE_FILE_MACHINE_CEF = 0x0CEF
- IMAGE_FILE_MACHINE_EBC = 0x0EBC // EFI Byte Code
- IMAGE_FILE_MACHINE_AMD64 = 0x8664 // AMD64 (K8)
- IMAGE_FILE_MACHINE_M32R = 0x9041 // M32R little-endian
- IMAGE_FILE_MACHINE_ARM64 = 0xAA64 // ARM64 Little-Endian
- IMAGE_FILE_MACHINE_CEE = 0xC0EE
-)
-
-// Directory format
-type IMAGE_DATA_DIRECTORY struct {
- VirtualAddress uint32
- Size uint32
-}
-
-const IMAGE_NUMBEROF_DIRECTORY_ENTRIES = 16
-
-type IMAGE_NT_HEADERS struct {
- Signature uint32
- FileHeader IMAGE_FILE_HEADER
- OptionalHeader IMAGE_OPTIONAL_HEADER
-}
-
-func (ntheader *IMAGE_NT_HEADERS) Sections() []IMAGE_SECTION_HEADER {
- return (*[0xffff]IMAGE_SECTION_HEADER)(unsafe.Pointer(
- (uintptr)(unsafe.Pointer(ntheader)) +
- unsafe.Offsetof(ntheader.OptionalHeader) +
- uintptr(ntheader.FileHeader.SizeOfOptionalHeader)))[:ntheader.FileHeader.NumberOfSections]
-}
-
-const (
- IMAGE_DIRECTORY_ENTRY_EXPORT = 0 // Export Directory
- IMAGE_DIRECTORY_ENTRY_IMPORT = 1 // Import Directory
- IMAGE_DIRECTORY_ENTRY_RESOURCE = 2 // Resource Directory
- IMAGE_DIRECTORY_ENTRY_EXCEPTION = 3 // Exception Directory
- IMAGE_DIRECTORY_ENTRY_SECURITY = 4 // Security Directory
- IMAGE_DIRECTORY_ENTRY_BASERELOC = 5 // Base Relocation Table
- IMAGE_DIRECTORY_ENTRY_DEBUG = 6 // Debug Directory
- IMAGE_DIRECTORY_ENTRY_COPYRIGHT = 7 // (X86 usage)
- IMAGE_DIRECTORY_ENTRY_ARCHITECTURE = 7 // Architecture Specific Data
- IMAGE_DIRECTORY_ENTRY_GLOBALPTR = 8 // RVA of GP
- IMAGE_DIRECTORY_ENTRY_TLS = 9 // TLS Directory
- IMAGE_DIRECTORY_ENTRY_LOAD_CONFIG = 10 // Load Configuration Directory
- IMAGE_DIRECTORY_ENTRY_BOUND_IMPORT = 11 // Bound Import Directory in headers
- IMAGE_DIRECTORY_ENTRY_IAT = 12 // Import Address Table
- IMAGE_DIRECTORY_ENTRY_DELAY_IMPORT = 13 // Delay Load Import Descriptors
- IMAGE_DIRECTORY_ENTRY_COM_DESCRIPTOR = 14 // COM Runtime descriptor
-)
-
-const IMAGE_SIZEOF_SHORT_NAME = 8
-
-// Section header format
-type IMAGE_SECTION_HEADER struct {
- Name [IMAGE_SIZEOF_SHORT_NAME]byte
- physicalAddressOrVirtualSize uint32
- VirtualAddress uint32
- SizeOfRawData uint32
- PointerToRawData uint32
- PointerToRelocations uint32
- PointerToLinenumbers uint32
- NumberOfRelocations uint16
- NumberOfLinenumbers uint16
- Characteristics uint32
-}
-
-func (ishdr *IMAGE_SECTION_HEADER) PhysicalAddress() uint32 {
- return ishdr.physicalAddressOrVirtualSize
-}
-
-func (ishdr *IMAGE_SECTION_HEADER) SetPhysicalAddress(addr uint32) {
- ishdr.physicalAddressOrVirtualSize = addr
-}
-
-func (ishdr *IMAGE_SECTION_HEADER) VirtualSize() uint32 {
- return ishdr.physicalAddressOrVirtualSize
-}
-
-func (ishdr *IMAGE_SECTION_HEADER) SetVirtualSize(addr uint32) {
- ishdr.physicalAddressOrVirtualSize = addr
-}
-
-const (
- // Dll characteristics.
- IMAGE_DLL_CHARACTERISTICS_HIGH_ENTROPY_VA = 0x0020
- IMAGE_DLL_CHARACTERISTICS_DYNAMIC_BASE = 0x0040
- IMAGE_DLL_CHARACTERISTICS_FORCE_INTEGRITY = 0x0080
- IMAGE_DLL_CHARACTERISTICS_NX_COMPAT = 0x0100
- IMAGE_DLL_CHARACTERISTICS_NO_ISOLATION = 0x0200
- IMAGE_DLL_CHARACTERISTICS_NO_SEH = 0x0400
- IMAGE_DLL_CHARACTERISTICS_NO_BIND = 0x0800
- IMAGE_DLL_CHARACTERISTICS_APPCONTAINER = 0x1000
- IMAGE_DLL_CHARACTERISTICS_WDM_DRIVER = 0x2000
- IMAGE_DLL_CHARACTERISTICS_GUARD_CF = 0x4000
- IMAGE_DLL_CHARACTERISTICS_TERMINAL_SERVER_AWARE = 0x8000
-)
-
-const (
- // Section characteristics.
- IMAGE_SCN_TYPE_REG = 0x00000000 // Reserved.
- IMAGE_SCN_TYPE_DSECT = 0x00000001 // Reserved.
- IMAGE_SCN_TYPE_NOLOAD = 0x00000002 // Reserved.
- IMAGE_SCN_TYPE_GROUP = 0x00000004 // Reserved.
- IMAGE_SCN_TYPE_NO_PAD = 0x00000008 // Reserved.
- IMAGE_SCN_TYPE_COPY = 0x00000010 // Reserved.
-
- IMAGE_SCN_CNT_CODE = 0x00000020 // Section contains code.
- IMAGE_SCN_CNT_INITIALIZED_DATA = 0x00000040 // Section contains initialized data.
- IMAGE_SCN_CNT_UNINITIALIZED_DATA = 0x00000080 // Section contains uninitialized data.
-
- IMAGE_SCN_LNK_OTHER = 0x00000100 // Reserved.
- IMAGE_SCN_LNK_INFO = 0x00000200 // Section contains comments or some other type of information.
- IMAGE_SCN_TYPE_OVER = 0x00000400 // Reserved.
- IMAGE_SCN_LNK_REMOVE = 0x00000800 // Section contents will not become part of image.
- IMAGE_SCN_LNK_COMDAT = 0x00001000 // Section contents comdat.
- IMAGE_SCN_MEM_PROTECTED = 0x00004000 // Obsolete.
- IMAGE_SCN_NO_DEFER_SPEC_EXC = 0x00004000 // Reset speculative exceptions handling bits in the TLB entries for this section.
- IMAGE_SCN_GPREL = 0x00008000 // Section content can be accessed relative to GP
- IMAGE_SCN_MEM_FARDATA = 0x00008000
- IMAGE_SCN_MEM_SYSHEAP = 0x00010000 // Obsolete.
- IMAGE_SCN_MEM_PURGEABLE = 0x00020000
- IMAGE_SCN_MEM_16BIT = 0x00020000
- IMAGE_SCN_MEM_LOCKED = 0x00040000
- IMAGE_SCN_MEM_PRELOAD = 0x00080000
-
- IMAGE_SCN_ALIGN_1BYTES = 0x00100000 //
- IMAGE_SCN_ALIGN_2BYTES = 0x00200000 //
- IMAGE_SCN_ALIGN_4BYTES = 0x00300000 //
- IMAGE_SCN_ALIGN_8BYTES = 0x00400000 //
- IMAGE_SCN_ALIGN_16BYTES = 0x00500000 // Default alignment if no others are specified.
- IMAGE_SCN_ALIGN_32BYTES = 0x00600000 //
- IMAGE_SCN_ALIGN_64BYTES = 0x00700000 //
- IMAGE_SCN_ALIGN_128BYTES = 0x00800000 //
- IMAGE_SCN_ALIGN_256BYTES = 0x00900000 //
- IMAGE_SCN_ALIGN_512BYTES = 0x00A00000 //
- IMAGE_SCN_ALIGN_1024BYTES = 0x00B00000 //
- IMAGE_SCN_ALIGN_2048BYTES = 0x00C00000 //
- IMAGE_SCN_ALIGN_4096BYTES = 0x00D00000 //
- IMAGE_SCN_ALIGN_8192BYTES = 0x00E00000 //
- IMAGE_SCN_ALIGN_MASK = 0x00F00000
-
- IMAGE_SCN_LNK_NRELOC_OVFL = 0x01000000 // Section contains extended relocations.
- IMAGE_SCN_MEM_DISCARDABLE = 0x02000000 // Section can be discarded.
- IMAGE_SCN_MEM_NOT_CACHED = 0x04000000 // Section is not cachable.
- IMAGE_SCN_MEM_NOT_PAGED = 0x08000000 // Section is not pageable.
- IMAGE_SCN_MEM_SHARED = 0x10000000 // Section is shareable.
- IMAGE_SCN_MEM_EXECUTE = 0x20000000 // Section is executable.
- IMAGE_SCN_MEM_READ = 0x40000000 // Section is readable.
- IMAGE_SCN_MEM_WRITE = 0x80000000 // Section is writeable.
-
- // TLS Characteristic Flags
- IMAGE_SCN_SCALE_INDEX = 0x00000001 // Tls index is scaled.
-)
-
-// Based relocation format
-type IMAGE_BASE_RELOCATION struct {
- VirtualAddress uint32
- SizeOfBlock uint32
-}
-
-const (
- IMAGE_REL_BASED_ABSOLUTE = 0
- IMAGE_REL_BASED_HIGH = 1
- IMAGE_REL_BASED_LOW = 2
- IMAGE_REL_BASED_HIGHLOW = 3
- IMAGE_REL_BASED_HIGHADJ = 4
- IMAGE_REL_BASED_MACHINE_SPECIFIC_5 = 5
- IMAGE_REL_BASED_RESERVED = 6
- IMAGE_REL_BASED_MACHINE_SPECIFIC_7 = 7
- IMAGE_REL_BASED_MACHINE_SPECIFIC_8 = 8
- IMAGE_REL_BASED_MACHINE_SPECIFIC_9 = 9
- IMAGE_REL_BASED_DIR64 = 10
-
- IMAGE_REL_BASED_IA64_IMM64 = 9
-
- IMAGE_REL_BASED_MIPS_JMPADDR = 5
- IMAGE_REL_BASED_MIPS_JMPADDR16 = 9
-
- IMAGE_REL_BASED_ARM_MOV32 = 5
- IMAGE_REL_BASED_THUMB_MOV32 = 7
-)
-
-// Export Format
-type IMAGE_EXPORT_DIRECTORY struct {
- Characteristics uint32
- TimeDateStamp uint32
- MajorVersion uint16
- MinorVersion uint16
- Name uint32
- Base uint32
- NumberOfFunctions uint32
- NumberOfNames uint32
- AddressOfFunctions uint32 // RVA from base of image
- AddressOfNames uint32 // RVA from base of image
- AddressOfNameOrdinals uint32 // RVA from base of image
-}
-
-type IMAGE_IMPORT_BY_NAME struct {
- Hint uint16
- Name [1]byte
-}
-
-func IMAGE_ORDINAL(ordinal uintptr) uintptr {
- return ordinal & 0xffff
-}
-
-func IMAGE_SNAP_BY_ORDINAL(ordinal uintptr) bool {
- return (ordinal & IMAGE_ORDINAL_FLAG) != 0
-}
-
-// Thread Local Storage
-type IMAGE_TLS_DIRECTORY struct {
- StartAddressOfRawData uintptr
- EndAddressOfRawData uintptr
- AddressOfIndex uintptr // PDWORD
- AddressOfCallbacks uintptr // PIMAGE_TLS_CALLBACK *;
- SizeOfZeroFill uint32
- Characteristics uint32
-}
-
-type IMAGE_IMPORT_DESCRIPTOR struct {
- characteristicsOrOriginalFirstThunk uint32 // 0 for terminating null import descriptor
- // RVA to original unbound IAT (PIMAGE_THUNK_DATA)
- TimeDateStamp uint32 // 0 if not bound,
- // -1 if bound, and real date\time stamp
- // in IMAGE_DIRECTORY_ENTRY_BOUND_IMPORT (new BIND)
- // O.W. date/time stamp of DLL bound to (Old BIND)
- ForwarderChain uint32 // -1 if no forwarders
- Name uint32
- FirstThunk uint32 // RVA to IAT (if bound this IAT has actual addresses)
-}
-
-func (imgimpdesc *IMAGE_IMPORT_DESCRIPTOR) Characteristics() uint32 {
- return imgimpdesc.characteristicsOrOriginalFirstThunk
-}
-
-func (imgimpdesc *IMAGE_IMPORT_DESCRIPTOR) OriginalFirstThunk() uint32 {
- return imgimpdesc.characteristicsOrOriginalFirstThunk
-}
-
-type IMAGE_DELAYLOAD_DESCRIPTOR struct {
- Attributes uint32
- DllNameRVA uint32
- ModuleHandleRVA uint32
- ImportAddressTableRVA uint32
- ImportNameTableRVA uint32
- BoundImportAddressTableRVA uint32
- UnloadInformationTableRVA uint32
- TimeDateStamp uint32
-}
-
-type IMAGE_LOAD_CONFIG_CODE_INTEGRITY struct {
- Flags uint16
- Catalog uint16
- CatalogOffset uint32
- Reserved uint32
-}
-
-const (
- IMAGE_GUARD_CF_INSTRUMENTED = 0x00000100
- IMAGE_GUARD_CFW_INSTRUMENTED = 0x00000200
- IMAGE_GUARD_CF_FUNCTION_TABLE_PRESENT = 0x00000400
- IMAGE_GUARD_SECURITY_COOKIE_UNUSED = 0x00000800
- IMAGE_GUARD_PROTECT_DELAYLOAD_IAT = 0x00001000
- IMAGE_GUARD_DELAYLOAD_IAT_IN_ITS_OWN_SECTION = 0x00002000
- IMAGE_GUARD_CF_EXPORT_SUPPRESSION_INFO_PRESENT = 0x00004000
- IMAGE_GUARD_CF_ENABLE_EXPORT_SUPPRESSION = 0x00008000
- IMAGE_GUARD_CF_LONGJUMP_TABLE_PRESENT = 0x00010000
- IMAGE_GUARD_RF_INSTRUMENTED = 0x00020000
- IMAGE_GUARD_RF_ENABLE = 0x00040000
- IMAGE_GUARD_RF_STRICT = 0x00080000
- IMAGE_GUARD_RETPOLINE_PRESENT = 0x00100000
- IMAGE_GUARD_EH_CONTINUATION_TABLE_PRESENT = 0x00400000
- IMAGE_GUARD_XFG_ENABLED = 0x00800000
- IMAGE_GUARD_CF_FUNCTION_TABLE_SIZE_MASK = 0xF0000000
- IMAGE_GUARD_CF_FUNCTION_TABLE_SIZE_SHIFT = 28
-)
-
-type IMAGE_RUNTIME_FUNCTION_ENTRY struct {
- BeginAddress uint32
- EndAddress uint32
- UnwindInfoAddress uint32
-}
-
-const (
- DLL_PROCESS_ATTACH = 1
- DLL_THREAD_ATTACH = 2
- DLL_THREAD_DETACH = 3
- DLL_PROCESS_DETACH = 0
-)
-
-type SYSTEM_INFO struct {
- ProcessorArchitecture uint16
- Reserved uint16
- PageSize uint32
- MinimumApplicationAddress uintptr
- MaximumApplicationAddress uintptr
- ActiveProcessorMask uintptr
- NumberOfProcessors uint32
- ProcessorType uint32
- AllocationGranularity uint32
- ProcessorLevel uint16
- ProcessorRevision uint16
-}
diff --git a/tun/wintun/memmod/syscall_windows_32.go b/tun/wintun/memmod/syscall_windows_32.go
deleted file mode 100644
index 7abbac9..0000000
--- a/tun/wintun/memmod/syscall_windows_32.go
+++ /dev/null
@@ -1,96 +0,0 @@
-// +build windows,386 windows,arm
-
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package memmod
-
-// Optional header format
-type IMAGE_OPTIONAL_HEADER struct {
- Magic uint16
- MajorLinkerVersion uint8
- MinorLinkerVersion uint8
- SizeOfCode uint32
- SizeOfInitializedData uint32
- SizeOfUninitializedData uint32
- AddressOfEntryPoint uint32
- BaseOfCode uint32
- BaseOfData uint32
- ImageBase uintptr
- SectionAlignment uint32
- FileAlignment uint32
- MajorOperatingSystemVersion uint16
- MinorOperatingSystemVersion uint16
- MajorImageVersion uint16
- MinorImageVersion uint16
- MajorSubsystemVersion uint16
- MinorSubsystemVersion uint16
- Win32VersionValue uint32
- SizeOfImage uint32
- SizeOfHeaders uint32
- CheckSum uint32
- Subsystem uint16
- DllCharacteristics uint16
- SizeOfStackReserve uintptr
- SizeOfStackCommit uintptr
- SizeOfHeapReserve uintptr
- SizeOfHeapCommit uintptr
- LoaderFlags uint32
- NumberOfRvaAndSizes uint32
- DataDirectory [IMAGE_NUMBEROF_DIRECTORY_ENTRIES]IMAGE_DATA_DIRECTORY
-}
-
-const IMAGE_ORDINAL_FLAG uintptr = 0x80000000
-
-type IMAGE_LOAD_CONFIG_DIRECTORY struct {
- Size uint32
- TimeDateStamp uint32
- MajorVersion uint16
- MinorVersion uint16
- GlobalFlagsClear uint32
- GlobalFlagsSet uint32
- CriticalSectionDefaultTimeout uint32
- DeCommitFreeBlockThreshold uint32
- DeCommitTotalFreeThreshold uint32
- LockPrefixTable uint32
- MaximumAllocationSize uint32
- VirtualMemoryThreshold uint32
- ProcessHeapFlags uint32
- ProcessAffinityMask uint32
- CSDVersion uint16
- DependentLoadFlags uint16
- EditList uint32
- SecurityCookie uint32
- SEHandlerTable uint32
- SEHandlerCount uint32
- GuardCFCheckFunctionPointer uint32
- GuardCFDispatchFunctionPointer uint32
- GuardCFFunctionTable uint32
- GuardCFFunctionCount uint32
- GuardFlags uint32
- CodeIntegrity IMAGE_LOAD_CONFIG_CODE_INTEGRITY
- GuardAddressTakenIatEntryTable uint32
- GuardAddressTakenIatEntryCount uint32
- GuardLongJumpTargetTable uint32
- GuardLongJumpTargetCount uint32
- DynamicValueRelocTable uint32
- CHPEMetadataPointer uint32
- GuardRFFailureRoutine uint32
- GuardRFFailureRoutineFunctionPointer uint32
- DynamicValueRelocTableOffset uint32
- DynamicValueRelocTableSection uint16
- Reserved2 uint16
- GuardRFVerifyStackPointerFunctionPointer uint32
- HotPatchTableOffset uint32
- Reserved3 uint32
- EnclaveConfigurationPointer uint32
- VolatileMetadataPointer uint32
- GuardEHContinuationTable uint32
- GuardEHContinuationCount uint32
- GuardXFGCheckFunctionPointer uint32
- GuardXFGDispatchFunctionPointer uint32
- GuardXFGTableDispatchFunctionPointer uint32
- CastGuardOsDeterminedFailureMode uint32
-}
diff --git a/tun/wintun/memmod/syscall_windows_64.go b/tun/wintun/memmod/syscall_windows_64.go
deleted file mode 100644
index 10c6533..0000000
--- a/tun/wintun/memmod/syscall_windows_64.go
+++ /dev/null
@@ -1,95 +0,0 @@
-// +build windows,amd64 windows,arm64
-
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package memmod
-
-// Optional header format
-type IMAGE_OPTIONAL_HEADER struct {
- Magic uint16
- MajorLinkerVersion uint8
- MinorLinkerVersion uint8
- SizeOfCode uint32
- SizeOfInitializedData uint32
- SizeOfUninitializedData uint32
- AddressOfEntryPoint uint32
- BaseOfCode uint32
- ImageBase uintptr
- SectionAlignment uint32
- FileAlignment uint32
- MajorOperatingSystemVersion uint16
- MinorOperatingSystemVersion uint16
- MajorImageVersion uint16
- MinorImageVersion uint16
- MajorSubsystemVersion uint16
- MinorSubsystemVersion uint16
- Win32VersionValue uint32
- SizeOfImage uint32
- SizeOfHeaders uint32
- CheckSum uint32
- Subsystem uint16
- DllCharacteristics uint16
- SizeOfStackReserve uintptr
- SizeOfStackCommit uintptr
- SizeOfHeapReserve uintptr
- SizeOfHeapCommit uintptr
- LoaderFlags uint32
- NumberOfRvaAndSizes uint32
- DataDirectory [IMAGE_NUMBEROF_DIRECTORY_ENTRIES]IMAGE_DATA_DIRECTORY
-}
-
-const IMAGE_ORDINAL_FLAG uintptr = 0x8000000000000000
-
-type IMAGE_LOAD_CONFIG_DIRECTORY struct {
- Size uint32
- TimeDateStamp uint32
- MajorVersion uint16
- MinorVersion uint16
- GlobalFlagsClear uint32
- GlobalFlagsSet uint32
- CriticalSectionDefaultTimeout uint32
- DeCommitFreeBlockThreshold uint64
- DeCommitTotalFreeThreshold uint64
- LockPrefixTable uint64
- MaximumAllocationSize uint64
- VirtualMemoryThreshold uint64
- ProcessAffinityMask uint64
- ProcessHeapFlags uint32
- CSDVersion uint16
- DependentLoadFlags uint16
- EditList uint64
- SecurityCookie uint64
- SEHandlerTable uint64
- SEHandlerCount uint64
- GuardCFCheckFunctionPointer uint64
- GuardCFDispatchFunctionPointer uint64
- GuardCFFunctionTable uint64
- GuardCFFunctionCount uint64
- GuardFlags uint32
- CodeIntegrity IMAGE_LOAD_CONFIG_CODE_INTEGRITY
- GuardAddressTakenIatEntryTable uint64
- GuardAddressTakenIatEntryCount uint64
- GuardLongJumpTargetTable uint64
- GuardLongJumpTargetCount uint64
- DynamicValueRelocTable uint64
- CHPEMetadataPointer uint64
- GuardRFFailureRoutine uint64
- GuardRFFailureRoutineFunctionPointer uint64
- DynamicValueRelocTableOffset uint32
- DynamicValueRelocTableSection uint16
- Reserved2 uint16
- GuardRFVerifyStackPointerFunctionPointer uint64
- HotPatchTableOffset uint32
- Reserved3 uint32
- EnclaveConfigurationPointer uint64
- VolatileMetadataPointer uint64
- GuardEHContinuationTable uint64
- GuardEHContinuationCount uint64
- GuardXFGCheckFunctionPointer uint64
- GuardXFGDispatchFunctionPointer uint64
- GuardXFGTableDispatchFunctionPointer uint64
- CastGuardOsDeterminedFailureMode uint64
-}
diff --git a/tun/wintun/session_windows.go b/tun/wintun/session_windows.go
deleted file mode 100644
index ffa9380..0000000
--- a/tun/wintun/session_windows.go
+++ /dev/null
@@ -1,108 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package wintun
-
-import (
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-type Session struct {
- handle uintptr
-}
-
-const (
- PacketSizeMax = 0xffff // Maximum packet size
- RingCapacityMin = 0x20000 // Minimum ring capacity (128 kiB)
- RingCapacityMax = 0x4000000 // Maximum ring capacity (64 MiB)
-)
-
-// Packet with data
-type Packet struct {
- Next *Packet // Pointer to next packet in queue
- Size uint32 // Size of packet (max WINTUN_MAX_IP_PACKET_SIZE)
- Data *[PacketSizeMax]byte // Pointer to layer 3 IPv4 or IPv6 packet
-}
-
-var (
- procWintunAllocateSendPacket = modwintun.NewProc("WintunAllocateSendPacket")
- procWintunEndSession = modwintun.NewProc("WintunEndSession")
- procWintunGetReadWaitEvent = modwintun.NewProc("WintunGetReadWaitEvent")
- procWintunReceivePacket = modwintun.NewProc("WintunReceivePacket")
- procWintunReleaseReceivePacket = modwintun.NewProc("WintunReleaseReceivePacket")
- procWintunSendPacket = modwintun.NewProc("WintunSendPacket")
- procWintunStartSession = modwintun.NewProc("WintunStartSession")
-)
-
-func (wintun *Adapter) StartSession(capacity uint32) (session Session, err error) {
- r0, _, e1 := syscall.Syscall(procWintunStartSession.Addr(), 2, uintptr(wintun.handle), uintptr(capacity), 0)
- if r0 == 0 {
- err = e1
- } else {
- session = Session{r0}
- }
- return
-}
-
-func (session Session) End() {
- syscall.Syscall(procWintunEndSession.Addr(), 1, session.handle, 0, 0)
- session.handle = 0
-}
-
-func (session Session) ReadWaitEvent() (handle windows.Handle) {
- r0, _, _ := syscall.Syscall(procWintunGetReadWaitEvent.Addr(), 1, session.handle, 0, 0)
- handle = windows.Handle(r0)
- return
-}
-
-func (session Session) ReceivePacket() (packet []byte, err error) {
- var packetSize uint32
- r0, _, e1 := syscall.Syscall(procWintunReceivePacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packetSize)), 0)
- if r0 == 0 {
- err = e1
- return
- }
- unsafeSlice(unsafe.Pointer(&packet), unsafe.Pointer(r0), int(packetSize))
- return
-}
-
-func (session Session) ReleaseReceivePacket(packet []byte) {
- syscall.Syscall(procWintunReleaseReceivePacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0)
-}
-
-func (session Session) AllocateSendPacket(packetSize int) (packet []byte, err error) {
- r0, _, e1 := syscall.Syscall(procWintunAllocateSendPacket.Addr(), 2, session.handle, uintptr(packetSize), 0)
- if r0 == 0 {
- err = e1
- return
- }
- unsafeSlice(unsafe.Pointer(&packet), unsafe.Pointer(r0), int(packetSize))
- return
-}
-
-func (session Session) SendPacket(packet []byte) {
- syscall.Syscall(procWintunSendPacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0)
-}
-
-// unsafeSlice updates the slice slicePtr to be a slice
-// referencing the provided data with its length & capacity set to
-// lenCap.
-//
-// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
-// update callers to use unsafe.Slice instead of this.
-func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
- type sliceHeader struct {
- Data unsafe.Pointer
- Len int
- Cap int
- }
- h := (*sliceHeader)(slicePtr)
- h.Data = data
- h.Len = lenCap
- h.Cap = lenCap
-}
diff --git a/tun/wintun/wintun_windows.go b/tun/wintun/wintun_windows.go
deleted file mode 100644
index 6c5a00d..0000000
--- a/tun/wintun/wintun_windows.go
+++ /dev/null
@@ -1,215 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package wintun
-
-import (
- "errors"
- "log"
- "runtime"
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-type loggerLevel int
-
-const (
- logInfo loggerLevel = iota
- logWarn
- logErr
-)
-
-const (
- PoolNameMax = 256
- AdapterNameMax = 128
-)
-
-type Pool [PoolNameMax]uint16
-type Adapter struct {
- handle uintptr
-}
-
-var (
- modwintun = newLazyDLL("wintun.dll", setupLogger)
-
- procWintunCreateAdapter = modwintun.NewProc("WintunCreateAdapter")
- procWintunDeleteAdapter = modwintun.NewProc("WintunDeleteAdapter")
- procWintunDeletePoolDriver = modwintun.NewProc("WintunDeletePoolDriver")
- procWintunEnumAdapters = modwintun.NewProc("WintunEnumAdapters")
- procWintunFreeAdapter = modwintun.NewProc("WintunFreeAdapter")
- procWintunOpenAdapter = modwintun.NewProc("WintunOpenAdapter")
- procWintunGetAdapterLUID = modwintun.NewProc("WintunGetAdapterLUID")
- procWintunGetAdapterName = modwintun.NewProc("WintunGetAdapterName")
- procWintunGetRunningDriverVersion = modwintun.NewProc("WintunGetRunningDriverVersion")
- procWintunSetAdapterName = modwintun.NewProc("WintunSetAdapterName")
-)
-
-func setupLogger(dll *lazyDLL) {
- syscall.Syscall(dll.NewProc("WintunSetLogger").Addr(), 1, windows.NewCallback(func(level loggerLevel, msg *uint16) int {
- log.Println("[Wintun]", windows.UTF16PtrToString(msg))
- return 0
- }), 0, 0)
-}
-
-func MakePool(poolName string) (pool *Pool, err error) {
- poolName16, err := windows.UTF16FromString(poolName)
- if err != nil {
- return
- }
- if len(poolName16) > PoolNameMax {
- err = errors.New("Pool name too long")
- return
- }
- pool = &Pool{}
- copy(pool[:], poolName16)
- return
-}
-
-func (pool *Pool) String() string {
- return windows.UTF16ToString(pool[:])
-}
-
-func freeAdapter(wintun *Adapter) {
- syscall.Syscall(procWintunFreeAdapter.Addr(), 1, uintptr(wintun.handle), 0, 0)
-}
-
-// OpenAdapter finds a Wintun adapter by its name. This function returns the adapter if found, or
-// windows.ERROR_FILE_NOT_FOUND otherwise. If the adapter is found but not a Wintun-class or a
-// member of the pool, this function returns windows.ERROR_ALREADY_EXISTS. The adapter must be
-// released after use.
-func (pool *Pool) OpenAdapter(ifname string) (wintun *Adapter, err error) {
- ifname16, err := windows.UTF16PtrFromString(ifname)
- if err != nil {
- return nil, err
- }
- r0, _, e1 := syscall.Syscall(procWintunOpenAdapter.Addr(), 2, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(ifname16)), 0)
- if r0 == 0 {
- err = e1
- return
- }
- wintun = &Adapter{r0}
- runtime.SetFinalizer(wintun, freeAdapter)
- return
-}
-
-// CreateAdapter creates a Wintun adapter. ifname is the requested name of the adapter, while
-// requestedGUID is the GUID of the created network adapter, which then influences NLA generation
-// deterministically. If it is set to nil, the GUID is chosen by the system at random, and hence a
-// new NLA entry is created for each new adapter. It is called "requested" GUID because the API it
-// uses is completely undocumented, and so there could be minor interesting complications with its
-// usage. This function returns the network adapter ID and a flag if reboot is required.
-func (pool *Pool) CreateAdapter(ifname string, requestedGUID *windows.GUID) (wintun *Adapter, rebootRequired bool, err error) {
- var ifname16 *uint16
- ifname16, err = windows.UTF16PtrFromString(ifname)
- if err != nil {
- return
- }
- var _p0 uint32
- r0, _, e1 := syscall.Syscall6(procWintunCreateAdapter.Addr(), 4, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(ifname16)), uintptr(unsafe.Pointer(requestedGUID)), uintptr(unsafe.Pointer(&_p0)), 0, 0)
- rebootRequired = _p0 != 0
- if r0 == 0 {
- err = e1
- return
- }
- wintun = &Adapter{r0}
- runtime.SetFinalizer(wintun, freeAdapter)
- return
-}
-
-// Delete deletes a Wintun adapter. This function succeeds if the adapter was not found. It returns
-// a bool indicating whether a reboot is required.
-func (wintun *Adapter) Delete(forceCloseSessions bool) (rebootRequired bool, err error) {
- var _p0 uint32
- if forceCloseSessions {
- _p0 = 1
- }
- var _p1 uint32
- r1, _, e1 := syscall.Syscall(procWintunDeleteAdapter.Addr(), 3, uintptr(wintun.handle), uintptr(_p0), uintptr(unsafe.Pointer(&_p1)))
- rebootRequired = _p1 != 0
- if r1 == 0 {
- err = e1
- }
- return
-}
-
-// DeleteMatchingAdapters deletes all Wintun adapters, which match
-// given criteria, and returns which ones it deleted, whether a reboot
-// is required after, and which errors occurred during the process.
-func (pool *Pool) DeleteMatchingAdapters(matches func(adapter *Adapter) bool, forceCloseSessions bool) (rebootRequired bool, errors []error) {
- cb := func(handle uintptr, _ uintptr) int {
- adapter := &Adapter{handle}
- if !matches(adapter) {
- return 1
- }
- rebootRequired2, err := adapter.Delete(forceCloseSessions)
- if err != nil {
- errors = append(errors, err)
- return 1
- }
- rebootRequired = rebootRequired || rebootRequired2
- return 1
- }
- r1, _, e1 := syscall.Syscall(procWintunEnumAdapters.Addr(), 3, uintptr(unsafe.Pointer(pool)), uintptr(windows.NewCallback(cb)), 0)
- if r1 == 0 {
- errors = append(errors, e1)
- }
- return
-}
-
-// Name returns the name of the Wintun adapter.
-func (wintun *Adapter) Name() (ifname string, err error) {
- var ifname16 [AdapterNameMax]uint16
- r1, _, e1 := syscall.Syscall(procWintunGetAdapterName.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&ifname16[0])), 0)
- if r1 == 0 {
- err = e1
- return
- }
- ifname = windows.UTF16ToString(ifname16[:])
- return
-}
-
-// DeleteDriver deletes all Wintun adapters in a pool and if there are no more adapters in any other
-// pools, also removes Wintun from the driver store, usually called by uninstallers.
-func (pool *Pool) DeleteDriver() (rebootRequired bool, err error) {
- var _p0 uint32
- r1, _, e1 := syscall.Syscall(procWintunDeletePoolDriver.Addr(), 2, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(&_p0)), 0)
- rebootRequired = _p0 != 0
- if r1 == 0 {
- err = e1
- }
- return
-
-}
-
-// SetName sets name of the Wintun adapter.
-func (wintun *Adapter) SetName(ifname string) (err error) {
- ifname16, err := windows.UTF16FromString(ifname)
- if err != nil {
- return err
- }
- r1, _, e1 := syscall.Syscall(procWintunSetAdapterName.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&ifname16[0])), 0)
- if r1 == 0 {
- err = e1
- }
- return
-}
-
-// RunningVersion returns the version of the running Wintun driver.
-func RunningVersion() (version uint32, err error) {
- r0, _, e1 := syscall.Syscall(procWintunGetRunningDriverVersion.Addr(), 0, 0, 0, 0)
- version = uint32(r0)
- if version == 0 {
- err = e1
- }
- return
-}
-
-// LUID returns the LUID of the adapter.
-func (wintun *Adapter) LUID() (luid uint64) {
- syscall.Syscall(procWintunGetAdapterLUID.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&luid)), 0)
- return
-}
diff --git a/version.go b/version.go
index e1d4a03..db75bb9 100644
--- a/version.go
+++ b/version.go
@@ -1,3 +1,3 @@
package main
-const Version = "0.0.20210424"
+const Version = "0.0.20230223"