aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--README.md4
-rw-r--r--conn/bind_linux.go562
-rw-r--r--conn/bind_std.go522
-rw-r--r--conn/bind_std_test.go250
-rw-r--r--conn/bind_windows.go73
-rw-r--r--conn/bindtest/bindtest.go41
-rw-r--r--conn/boundif_android.go10
-rw-r--r--conn/conn.go26
-rw-r--r--conn/conn_test.go24
-rw-r--r--conn/controlfns.go43
-rw-r--r--conn/controlfns_linux.go69
-rw-r--r--conn/controlfns_unix.go35
-rw-r--r--conn/controlfns_windows.go23
-rw-r--r--conn/default.go4
-rw-r--r--conn/errors_default.go12
-rw-r--r--conn/errors_linux.go26
-rw-r--r--conn/features_default.go15
-rw-r--r--conn/features_linux.go29
-rw-r--r--conn/gso_default.go21
-rw-r--r--conn/gso_linux.go65
-rw-r--r--conn/mark_default.go4
-rw-r--r--conn/mark_unix.go12
-rw-r--r--conn/sticky_default.go42
-rw-r--r--conn/sticky_linux.go112
-rw-r--r--conn/sticky_linux_test.go266
-rw-r--r--conn/winrio/rio_windows.go2
-rw-r--r--device/allowedips.go2
-rw-r--r--device/allowedips_rand_test.go2
-rw-r--r--device/allowedips_test.go2
-rw-r--r--device/bind_test.go12
-rw-r--r--device/channels.go40
-rw-r--r--device/constants.go2
-rw-r--r--device/cookie.go2
-rw-r--r--device/cookie_test.go2
-rw-r--r--device/device.go45
-rw-r--r--device/device_test.go71
-rw-r--r--device/endpoint_test.go2
-rw-r--r--device/indextable.go2
-rw-r--r--device/ip.go2
-rw-r--r--device/kdf_test.go2
-rw-r--r--device/keypair.go2
-rw-r--r--device/logger.go2
-rw-r--r--device/mobilequirks.go8
-rw-r--r--device/noise-helpers.go12
-rw-r--r--device/noise-protocol.go65
-rw-r--r--device/noise-types.go2
-rw-r--r--device/noise_test.go8
-rw-r--r--device/peer.go78
-rw-r--r--device/pools.go38
-rw-r--r--device/pools_test.go50
-rw-r--r--device/queueconstants_android.go8
-rw-r--r--device/queueconstants_default.go6
-rw-r--r--device/queueconstants_ios.go2
-rw-r--r--device/queueconstants_windows.go2
-rw-r--r--device/race_disabled_test.go2
-rw-r--r--device/race_enabled_test.go2
-rw-r--r--device/receive.go373
-rw-r--r--device/send.go341
-rw-r--r--device/sticky_linux.go43
-rw-r--r--device/timers.go14
-rw-r--r--device/tun.go2
-rw-r--r--device/uapi.go62
-rw-r--r--format_test.go2
-rw-r--r--go.mod14
-rw-r--r--go.sum24
-rw-r--r--ipc/namedpipe/file.go1
-rw-r--r--ipc/namedpipe/namedpipe.go1
-rw-r--r--ipc/namedpipe/namedpipe_test.go1
-rw-r--r--ipc/uapi_bsd.go2
-rw-r--r--ipc/uapi_linux.go6
-rw-r--r--ipc/uapi_unix.go2
-rw-r--r--ipc/uapi_wasm.go (renamed from ipc/uapi_js.go)4
-rw-r--r--ipc/uapi_windows.go2
-rw-r--r--main.go16
-rw-r--r--main_windows.go7
-rw-r--r--ratelimiter/ratelimiter.go2
-rw-r--r--ratelimiter/ratelimiter_test.go2
-rw-r--r--replay/replay.go2
-rw-r--r--replay/replay_test.go2
-rw-r--r--rwcancel/rwcancel.go4
-rw-r--r--rwcancel/rwcancel_stub.go2
-rw-r--r--tai64n/tai64n.go2
-rw-r--r--tai64n/tai64n_test.go2
-rw-r--r--tun/alignment_windows_test.go2
-rw-r--r--tun/checksum.go118
-rw-r--r--tun/checksum_test.go35
-rw-r--r--tun/errors.go12
-rw-r--r--tun/netstack/examples/http_client.go13
-rw-r--r--tun/netstack/examples/http_server.go11
-rw-r--r--tun/netstack/examples/ping_client.go3
-rw-r--r--tun/netstack/tun.go72
-rw-r--r--tun/offload_linux.go993
-rw-r--r--tun/offload_linux_test.go752
-rw-r--r--tun/operateonfd.go2
-rw-r--r--tun/tun.go42
-rw-r--r--tun/tun_darwin.go71
-rw-r--r--tun/tun_freebsd.go59
-rw-r--r--tun/tun_linux.go289
-rw-r--r--tun/tun_openbsd.go64
-rw-r--r--tun/tun_windows.go62
-rw-r--r--tun/tuntest/tuntest.go37
-rw-r--r--version.go2
102 files changed, 4828 insertions, 1513 deletions
diff --git a/README.md b/README.md
index 30ea3ac..074f7ec 100644
--- a/README.md
+++ b/README.md
@@ -46,7 +46,7 @@ This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapp
## Building
-This requires an installation of [go](https://golang.org) ≥ 1.18.
+This requires an installation of the latest version of [Go](https://go.dev/).
```
$ git clone https://git.zx2c4.com/wireguard-go
@@ -56,7 +56,7 @@ $ make
## License
- Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
diff --git a/conn/bind_linux.go b/conn/bind_linux.go
deleted file mode 100644
index 03e8707..0000000
--- a/conn/bind_linux.go
+++ /dev/null
@@ -1,562 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package conn
-
-import (
- "errors"
- "net"
- "net/netip"
- "strconv"
- "sync"
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/unix"
-)
-
-type ipv4Source struct {
- Src [4]byte
- Ifindex int32
-}
-
-type ipv6Source struct {
- src [16]byte
- // ifindex belongs in dst.ZoneId
-}
-
-type LinuxSocketEndpoint struct {
- mu sync.Mutex
- dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
- src [unsafe.Sizeof(ipv6Source{})]byte
- isV6 bool
-}
-
-func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source { return endpoint.src4() }
-func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
-func (endpoint *LinuxSocketEndpoint) IsV6() bool { return endpoint.isV6 }
-
-func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source {
- return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0]))
-}
-
-func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source {
- return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0]))
-}
-
-func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 {
- return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
-}
-
-func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
- return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
-}
-
-// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
-type LinuxSocketBind struct {
- // mu guards sock4 and sock6 and the associated fds.
- // As long as someone holds mu (read or write), the associated fds are valid.
- mu sync.RWMutex
- sock4 int
- sock6 int
-}
-
-func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
-func NewDefaultBind() Bind { return NewLinuxSocketBind() }
-
-var (
- _ Endpoint = (*LinuxSocketEndpoint)(nil)
- _ Bind = (*LinuxSocketBind)(nil)
-)
-
-func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
- var end LinuxSocketEndpoint
- e, err := netip.ParseAddrPort(s)
- if err != nil {
- return nil, err
- }
-
- if e.Addr().Is4() {
- dst := end.dst4()
- end.isV6 = false
- dst.Port = int(e.Port())
- dst.Addr = e.Addr().As4()
- end.ClearSrc()
- return &end, nil
- }
-
- if e.Addr().Is6() {
- zone, err := zoneToUint32(e.Addr().Zone())
- if err != nil {
- return nil, err
- }
- dst := end.dst6()
- end.isV6 = true
- dst.Port = int(e.Port())
- dst.ZoneId = zone
- dst.Addr = e.Addr().As16()
- end.ClearSrc()
- return &end, nil
- }
-
- return nil, errors.New("invalid IP address")
-}
-
-func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
- bind.mu.Lock()
- defer bind.mu.Unlock()
-
- var err error
- var newPort uint16
- var tries int
-
- if bind.sock4 != -1 || bind.sock6 != -1 {
- return nil, 0, ErrBindAlreadyOpen
- }
-
- originalPort := port
-
-again:
- port = originalPort
- var sock4, sock6 int
- // Attempt ipv6 bind, update port if successful.
- sock6, newPort, err = create6(port)
- if err != nil {
- if !errors.Is(err, syscall.EAFNOSUPPORT) {
- return nil, 0, err
- }
- } else {
- port = newPort
- }
-
- // Attempt ipv4 bind, update port if successful.
- sock4, newPort, err = create4(port)
- if err != nil {
- if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
- unix.Close(sock6)
- tries++
- goto again
- }
- if !errors.Is(err, syscall.EAFNOSUPPORT) {
- unix.Close(sock6)
- return nil, 0, err
- }
- } else {
- port = newPort
- }
-
- var fns []ReceiveFunc
- if sock4 != -1 {
- bind.sock4 = sock4
- fns = append(fns, bind.receiveIPv4)
- }
- if sock6 != -1 {
- bind.sock6 = sock6
- fns = append(fns, bind.receiveIPv6)
- }
- if len(fns) == 0 {
- return nil, 0, syscall.EAFNOSUPPORT
- }
- return fns, port, nil
-}
-
-func (bind *LinuxSocketBind) SetMark(value uint32) error {
- bind.mu.RLock()
- defer bind.mu.RUnlock()
-
- if bind.sock6 != -1 {
- err := unix.SetsockoptInt(
- bind.sock6,
- unix.SOL_SOCKET,
- unix.SO_MARK,
- int(value),
- )
- if err != nil {
- return err
- }
- }
-
- if bind.sock4 != -1 {
- err := unix.SetsockoptInt(
- bind.sock4,
- unix.SOL_SOCKET,
- unix.SO_MARK,
- int(value),
- )
- if err != nil {
- return err
- }
- }
-
- return nil
-}
-
-func (bind *LinuxSocketBind) Close() error {
- // Take a readlock to shut down the sockets...
- bind.mu.RLock()
- if bind.sock6 != -1 {
- unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
- }
- if bind.sock4 != -1 {
- unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
- }
- bind.mu.RUnlock()
- // ...and a write lock to close the fd.
- // This ensures that no one else is using the fd.
- bind.mu.Lock()
- defer bind.mu.Unlock()
- var err1, err2 error
- if bind.sock6 != -1 {
- err1 = unix.Close(bind.sock6)
- bind.sock6 = -1
- }
- if bind.sock4 != -1 {
- err2 = unix.Close(bind.sock4)
- bind.sock4 = -1
- }
-
- if err1 != nil {
- return err1
- }
- return err2
-}
-
-func (bind *LinuxSocketBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
- bind.mu.RLock()
- defer bind.mu.RUnlock()
- if bind.sock4 == -1 {
- return 0, nil, net.ErrClosed
- }
- var end LinuxSocketEndpoint
- n, err := receive4(bind.sock4, buf, &end)
- return n, &end, err
-}
-
-func (bind *LinuxSocketBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
- bind.mu.RLock()
- defer bind.mu.RUnlock()
- if bind.sock6 == -1 {
- return 0, nil, net.ErrClosed
- }
- var end LinuxSocketEndpoint
- n, err := receive6(bind.sock6, buf, &end)
- return n, &end, err
-}
-
-func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
- nend, ok := end.(*LinuxSocketEndpoint)
- if !ok {
- return ErrWrongEndpointType
- }
- bind.mu.RLock()
- defer bind.mu.RUnlock()
- if !nend.isV6 {
- if bind.sock4 == -1 {
- return net.ErrClosed
- }
- return send4(bind.sock4, nend, buff)
- } else {
- if bind.sock6 == -1 {
- return net.ErrClosed
- }
- return send6(bind.sock6, nend, buff)
- }
-}
-
-func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
- if !end.isV6 {
- return netip.AddrFrom4(end.src4().Src)
- } else {
- return netip.AddrFrom16(end.src6().src)
- }
-}
-
-func (end *LinuxSocketEndpoint) DstIP() netip.Addr {
- if !end.isV6 {
- return netip.AddrFrom4(end.dst4().Addr)
- } else {
- return netip.AddrFrom16(end.dst6().Addr)
- }
-}
-
-func (end *LinuxSocketEndpoint) DstToBytes() []byte {
- if !end.isV6 {
- return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
- } else {
- return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
- }
-}
-
-func (end *LinuxSocketEndpoint) SrcToString() string {
- return end.SrcIP().String()
-}
-
-func (end *LinuxSocketEndpoint) DstToString() string {
- var port int
- if !end.isV6 {
- port = end.dst4().Port
- } else {
- port = end.dst6().Port
- }
- return netip.AddrPortFrom(end.DstIP(), uint16(port)).String()
-}
-
-func (end *LinuxSocketEndpoint) ClearDst() {
- for i := range end.dst {
- end.dst[i] = 0
- }
-}
-
-func (end *LinuxSocketEndpoint) ClearSrc() {
- for i := range end.src {
- end.src[i] = 0
- }
-}
-
-func zoneToUint32(zone string) (uint32, error) {
- if zone == "" {
- return 0, nil
- }
- if intr, err := net.InterfaceByName(zone); err == nil {
- return uint32(intr.Index), nil
- }
- n, err := strconv.ParseUint(zone, 10, 32)
- return uint32(n), err
-}
-
-func create4(port uint16) (int, uint16, error) {
- // create socket
-
- fd, err := unix.Socket(
- unix.AF_INET,
- unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
- 0,
- )
- if err != nil {
- return -1, 0, err
- }
-
- addr := unix.SockaddrInet4{
- Port: int(port),
- }
-
- // set sockopts and bind
-
- if err := func() error {
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IP,
- unix.IP_PKTINFO,
- 1,
- ); err != nil {
- return err
- }
-
- return unix.Bind(fd, &addr)
- }(); err != nil {
- unix.Close(fd)
- return -1, 0, err
- }
-
- sa, err := unix.Getsockname(fd)
- if err == nil {
- addr.Port = sa.(*unix.SockaddrInet4).Port
- }
-
- return fd, uint16(addr.Port), err
-}
-
-func create6(port uint16) (int, uint16, error) {
- // create socket
-
- fd, err := unix.Socket(
- unix.AF_INET6,
- unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
- 0,
- )
- if err != nil {
- return -1, 0, err
- }
-
- // set sockopts and bind
-
- addr := unix.SockaddrInet6{
- Port: int(port),
- }
-
- if err := func() error {
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IPV6,
- unix.IPV6_RECVPKTINFO,
- 1,
- ); err != nil {
- return err
- }
-
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IPV6,
- unix.IPV6_V6ONLY,
- 1,
- ); err != nil {
- return err
- }
-
- return unix.Bind(fd, &addr)
- }(); err != nil {
- unix.Close(fd)
- return -1, 0, err
- }
-
- sa, err := unix.Getsockname(fd)
- if err == nil {
- addr.Port = sa.(*unix.SockaddrInet6).Port
- }
-
- return fd, uint16(addr.Port), err
-}
-
-func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error {
- // construct message header
-
- cmsg := struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet4Pktinfo
- }{
- unix.Cmsghdr{
- Level: unix.IPPROTO_IP,
- Type: unix.IP_PKTINFO,
- Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
- },
- unix.Inet4Pktinfo{
- Spec_dst: end.src4().Src,
- Ifindex: end.src4().Ifindex,
- },
- }
-
- end.mu.Lock()
- _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
- end.mu.Unlock()
-
- if err == nil {
- return nil
- }
-
- // clear src and retry
-
- if err == unix.EINVAL {
- end.ClearSrc()
- cmsg.pktinfo = unix.Inet4Pktinfo{}
- end.mu.Lock()
- _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
- end.mu.Unlock()
- }
-
- return err
-}
-
-func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error {
- // construct message header
-
- cmsg := struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet6Pktinfo
- }{
- unix.Cmsghdr{
- Level: unix.IPPROTO_IPV6,
- Type: unix.IPV6_PKTINFO,
- Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
- },
- unix.Inet6Pktinfo{
- Addr: end.src6().src,
- Ifindex: end.dst6().ZoneId,
- },
- }
-
- if cmsg.pktinfo.Addr == [16]byte{} {
- cmsg.pktinfo.Ifindex = 0
- }
-
- end.mu.Lock()
- _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
- end.mu.Unlock()
-
- if err == nil {
- return nil
- }
-
- // clear src and retry
-
- if err == unix.EINVAL {
- end.ClearSrc()
- cmsg.pktinfo = unix.Inet6Pktinfo{}
- end.mu.Lock()
- _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
- end.mu.Unlock()
- }
-
- return err
-}
-
-func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
- // construct message header
-
- var cmsg struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet4Pktinfo
- }
-
- size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
- if err != nil {
- return 0, err
- }
- end.isV6 = false
-
- if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
- *end.dst4() = *newDst4
- }
-
- // update source cache
-
- if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
- cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
- cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
- end.src4().Src = cmsg.pktinfo.Spec_dst
- end.src4().Ifindex = cmsg.pktinfo.Ifindex
- }
-
- return size, nil
-}
-
-func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
- // construct message header
-
- var cmsg struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet6Pktinfo
- }
-
- size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
- if err != nil {
- return 0, err
- }
- end.isV6 = true
-
- if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
- *end.dst6() = *newDst6
- }
-
- // update source cache
-
- if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
- cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
- cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
- end.src6().src = cmsg.pktinfo.Addr
- end.dst6().ZoneId = cmsg.pktinfo.Ifindex
- }
-
- return size, nil
-}
diff --git a/conn/bind_std.go b/conn/bind_std.go
index b6a7ab3..46df7fd 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -1,69 +1,126 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
+ "context"
"errors"
+ "fmt"
"net"
"net/netip"
+ "runtime"
+ "strconv"
"sync"
"syscall"
+
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
)
-// StdNetBind is meant to be a temporary solution on platforms for which
-// the sticky socket / source caching behavior has not yet been implemented.
-// It uses the Go's net package to implement networking.
-// See LinuxSocketBind for a proper implementation on the Linux platform.
+var (
+ _ Bind = (*StdNetBind)(nil)
+)
+
+// StdNetBind implements Bind for all platforms. While Windows has its own Bind
+// (see bind_windows.go), it may fall back to StdNetBind.
+// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
+// methods for sending and receiving multiple datagrams per-syscall. See the
+// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
type StdNetBind struct {
- mu sync.Mutex // protects following fields
- ipv4 *net.UDPConn
- ipv6 *net.UDPConn
+ mu sync.Mutex // protects all fields except as specified
+ ipv4 *net.UDPConn
+ ipv6 *net.UDPConn
+ ipv4PC *ipv4.PacketConn // will be nil on non-Linux
+ ipv6PC *ipv6.PacketConn // will be nil on non-Linux
+ ipv4TxOffload bool
+ ipv4RxOffload bool
+ ipv6TxOffload bool
+ ipv6RxOffload bool
+
+ // these two fields are not guarded by mu
+ udpAddrPool sync.Pool
+ msgsPool sync.Pool
+
blackhole4 bool
blackhole6 bool
}
-func NewStdNetBind() Bind { return &StdNetBind{} }
+func NewStdNetBind() Bind {
+ return &StdNetBind{
+ udpAddrPool: sync.Pool{
+ New: func() any {
+ return &net.UDPAddr{
+ IP: make([]byte, 16),
+ }
+ },
+ },
-type StdNetEndpoint netip.AddrPort
+ msgsPool: sync.Pool{
+ New: func() any {
+ // ipv6.Message and ipv4.Message are interchangeable as they are
+ // both aliases for x/net/internal/socket.Message.
+ msgs := make([]ipv6.Message, IdealBatchSize)
+ for i := range msgs {
+ msgs[i].Buffers = make(net.Buffers, 1)
+ msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
+ }
+ return &msgs
+ },
+ },
+ }
+}
+
+type StdNetEndpoint struct {
+ // AddrPort is the endpoint destination.
+ netip.AddrPort
+ // src is the current sticky source address and interface index, if
+ // supported. Typically this is a PKTINFO structure from/for control
+ // messages, see unix.PKTINFO for an example.
+ src []byte
+}
var (
_ Bind = (*StdNetBind)(nil)
- _ Endpoint = StdNetEndpoint{}
+ _ Endpoint = &StdNetEndpoint{}
)
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
e, err := netip.ParseAddrPort(s)
- return asEndpoint(e), err
+ if err != nil {
+ return nil, err
+ }
+ return &StdNetEndpoint{
+ AddrPort: e,
+ }, nil
}
-func (StdNetEndpoint) ClearSrc() {}
-
-func (e StdNetEndpoint) DstIP() netip.Addr {
- return (netip.AddrPort)(e).Addr()
+func (e *StdNetEndpoint) ClearSrc() {
+ if e.src != nil {
+ // Truncate src, no need to reallocate.
+ e.src = e.src[:0]
+ }
}
-func (e StdNetEndpoint) SrcIP() netip.Addr {
- return netip.Addr{} // not supported
+func (e *StdNetEndpoint) DstIP() netip.Addr {
+ return e.AddrPort.Addr()
}
-func (e StdNetEndpoint) DstToBytes() []byte {
- b, _ := (netip.AddrPort)(e).MarshalBinary()
- return b
-}
+// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
-func (e StdNetEndpoint) DstToString() string {
- return (netip.AddrPort)(e).String()
+func (e *StdNetEndpoint) DstToBytes() []byte {
+ b, _ := e.AddrPort.MarshalBinary()
+ return b
}
-func (e StdNetEndpoint) SrcToString() string {
- return ""
+func (e *StdNetEndpoint) DstToString() string {
+ return e.AddrPort.String()
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
- conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
+ conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
if err != nil {
return nil, 0, err
}
@@ -77,17 +134,17 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
if err != nil {
return nil, 0, err
}
- return conn, uaddr.Port, nil
+ return conn.(*net.UDPConn), uaddr.Port, nil
}
-func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
- bind.mu.Lock()
- defer bind.mu.Unlock()
+func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
var err error
var tries int
- if bind.ipv4 != nil || bind.ipv6 != nil {
+ if s.ipv4 != nil || s.ipv6 != nil {
return nil, 0, ErrBindAlreadyOpen
}
@@ -95,90 +152,207 @@ func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
// If uport is 0, we can retry on failure.
again:
port := int(uport)
- var ipv4, ipv6 *net.UDPConn
+ var v4conn, v6conn *net.UDPConn
+ var v4pc *ipv4.PacketConn
+ var v6pc *ipv6.PacketConn
- ipv4, port, err = listenNet("udp4", port)
+ v4conn, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
// Listen on the same port as we're using for ipv4.
- ipv6, port, err = listenNet("udp6", port)
+ v6conn, port, err = listenNet("udp6", port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
- ipv4.Close()
+ v4conn.Close()
tries++
goto again
}
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
- ipv4.Close()
+ v4conn.Close()
return nil, 0, err
}
var fns []ReceiveFunc
- if ipv4 != nil {
- fns = append(fns, bind.makeReceiveIPv4(ipv4))
- bind.ipv4 = ipv4
+ if v4conn != nil {
+ s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ v4pc = ipv4.NewPacketConn(v4conn)
+ s.ipv4PC = v4pc
+ }
+ fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
+ s.ipv4 = v4conn
}
- if ipv6 != nil {
- fns = append(fns, bind.makeReceiveIPv6(ipv6))
- bind.ipv6 = ipv6
+ if v6conn != nil {
+ s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ v6pc = ipv6.NewPacketConn(v6conn)
+ s.ipv6PC = v6pc
+ }
+ fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
+ s.ipv6 = v6conn
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
+
return fns, uint16(port), nil
}
-func (bind *StdNetBind) Close() error {
- bind.mu.Lock()
- defer bind.mu.Unlock()
+func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
+ for i := range *msgs {
+ (*msgs)[i].OOB = (*msgs)[i].OOB[:0]
+ (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
+ }
+ s.msgsPool.Put(msgs)
+}
+
+func (s *StdNetBind) getMessages() *[]ipv6.Message {
+ return s.msgsPool.Get().(*[]ipv6.Message)
+}
+
+var (
+ // If compilation fails here these are no longer the same underlying type.
+ _ ipv6.Message = ipv4.Message{}
+)
+
+type batchReader interface {
+ ReadBatch([]ipv6.Message, int) (int, error)
+}
+
+type batchWriter interface {
+ WriteBatch([]ipv6.Message, int) (int, error)
+}
+
+func (s *StdNetBind) receiveIP(
+ br batchReader,
+ conn *net.UDPConn,
+ rxOffload bool,
+ bufs [][]byte,
+ sizes []int,
+ eps []Endpoint,
+) (n int, err error) {
+ msgs := s.getMessages()
+ for i := range bufs {
+ (*msgs)[i].Buffers[0] = bufs[i]
+ (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
+ }
+ defer s.putMessages(msgs)
+ var numMsgs int
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ if rxOffload {
+ readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
+ numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
+ if err != nil {
+ return 0, err
+ }
+ numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
+ if err != nil {
+ return 0, err
+ }
+ } else {
+ numMsgs, err = br.ReadBatch(*msgs, 0)
+ if err != nil {
+ return 0, err
+ }
+ }
+ } else {
+ msg := &(*msgs)[0]
+ msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
+ if err != nil {
+ return 0, err
+ }
+ numMsgs = 1
+ }
+ for i := 0; i < numMsgs; i++ {
+ msg := &(*msgs)[i]
+ sizes[i] = msg.N
+ if sizes[i] == 0 {
+ continue
+ }
+ addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+ ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
+ getSrcFromControl(msg.OOB[:msg.NN], ep)
+ eps[i] = ep
+ }
+ return numMsgs, nil
+}
+
+func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
+ return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
+ }
+}
+
+func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
+ return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
+ }
+}
+
+// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
+// rename the IdealBatchSize constant to BatchSize.
+func (s *StdNetBind) BatchSize() int {
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ return IdealBatchSize
+ }
+ return 1
+}
+
+func (s *StdNetBind) Close() error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
var err1, err2 error
- if bind.ipv4 != nil {
- err1 = bind.ipv4.Close()
- bind.ipv4 = nil
+ if s.ipv4 != nil {
+ err1 = s.ipv4.Close()
+ s.ipv4 = nil
+ s.ipv4PC = nil
}
- if bind.ipv6 != nil {
- err2 = bind.ipv6.Close()
- bind.ipv6 = nil
+ if s.ipv6 != nil {
+ err2 = s.ipv6.Close()
+ s.ipv6 = nil
+ s.ipv6PC = nil
}
- bind.blackhole4 = false
- bind.blackhole6 = false
+ s.blackhole4 = false
+ s.blackhole6 = false
+ s.ipv4TxOffload = false
+ s.ipv4RxOffload = false
+ s.ipv6TxOffload = false
+ s.ipv6RxOffload = false
if err1 != nil {
return err1
}
return err2
}
-func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
- return func(buff []byte) (int, Endpoint, error) {
- n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
- return n, asEndpoint(endpoint), err
- }
+type ErrUDPGSODisabled struct {
+ onLaddr string
+ RetryErr error
}
-func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
- return func(buff []byte) (int, Endpoint, error) {
- n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
- return n, asEndpoint(endpoint), err
- }
+func (e ErrUDPGSODisabled) Error() string {
+ return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
}
-func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
- var err error
- nend, ok := endpoint.(StdNetEndpoint)
- if !ok {
- return ErrWrongEndpointType
- }
- addrPort := netip.AddrPort(nend)
+func (e ErrUDPGSODisabled) Unwrap() error {
+ return e.RetryErr
+}
- bind.mu.Lock()
- blackhole := bind.blackhole4
- conn := bind.ipv4
- if addrPort.Addr().Is6() {
- blackhole = bind.blackhole6
- conn = bind.ipv6
+func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
+ s.mu.Lock()
+ blackhole := s.blackhole4
+ conn := s.ipv4
+ offload := s.ipv4TxOffload
+ br := batchWriter(s.ipv4PC)
+ is6 := false
+ if endpoint.DstIP().Is6() {
+ blackhole = s.blackhole6
+ conn = s.ipv6
+ br = s.ipv6PC
+ is6 = true
+ offload = s.ipv6TxOffload
}
- bind.mu.Unlock()
+ s.mu.Unlock()
if blackhole {
return nil
@@ -186,27 +360,185 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
if conn == nil {
return syscall.EAFNOSUPPORT
}
- _, err = conn.WriteToUDPAddrPort(buff, addrPort)
+
+ msgs := s.getMessages()
+ defer s.putMessages(msgs)
+ ua := s.udpAddrPool.Get().(*net.UDPAddr)
+ defer s.udpAddrPool.Put(ua)
+ if is6 {
+ as16 := endpoint.DstIP().As16()
+ copy(ua.IP, as16[:])
+ ua.IP = ua.IP[:16]
+ } else {
+ as4 := endpoint.DstIP().As4()
+ copy(ua.IP, as4[:])
+ ua.IP = ua.IP[:4]
+ }
+ ua.Port = int(endpoint.(*StdNetEndpoint).Port())
+ var (
+ retried bool
+ err error
+ )
+retry:
+ if offload {
+ n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
+ err = s.send(conn, br, (*msgs)[:n])
+ if err != nil && offload && errShouldDisableUDPGSO(err) {
+ offload = false
+ s.mu.Lock()
+ if is6 {
+ s.ipv6TxOffload = false
+ } else {
+ s.ipv4TxOffload = false
+ }
+ s.mu.Unlock()
+ retried = true
+ goto retry
+ }
+ } else {
+ for i := range bufs {
+ (*msgs)[i].Addr = ua
+ (*msgs)[i].Buffers[0] = bufs[i]
+ setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
+ }
+ err = s.send(conn, br, (*msgs)[:len(bufs)])
+ }
+ if retried {
+ return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
+ }
+ return err
+}
+
+func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
+ var (
+ n int
+ err error
+ start int
+ )
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ for {
+ n, err = pc.WriteBatch(msgs[start:], 0)
+ if err != nil || n == len(msgs[start:]) {
+ break
+ }
+ start += n
+ }
+ } else {
+ for _, msg := range msgs {
+ _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
+ if err != nil {
+ break
+ }
+ }
+ }
return err
}
-// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
-// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
-// but Endpoints are immutable, so we can re-use them.
-var endpointPool = sync.Pool{
- New: func() any {
- return make(map[netip.AddrPort]Endpoint)
- },
+const (
+ // Exceeding these values results in EMSGSIZE. They account for layer3 and
+ // layer4 headers. IPv6 does not need to account for itself as the payload
+ // length field is self excluding.
+ maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
+ maxIPv6PayloadLen = 1<<16 - 1 - 8
+
+ // This is a hard limit imposed by the kernel.
+ udpSegmentMaxDatagrams = 64
+)
+
+type setGSOFunc func(control *[]byte, gsoSize uint16)
+
+func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
+ var (
+ base = -1 // index of msg we are currently coalescing into
+ gsoSize int // segmentation size of msgs[base]
+ dgramCnt int // number of dgrams coalesced into msgs[base]
+ endBatch bool // tracking flag to start a new batch on next iteration of bufs
+ )
+ maxPayloadLen := maxIPv4PayloadLen
+ if ep.DstIP().Is6() {
+ maxPayloadLen = maxIPv6PayloadLen
+ }
+ for i, buf := range bufs {
+ if i > 0 {
+ msgLen := len(buf)
+ baseLenBefore := len(msgs[base].Buffers[0])
+ freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
+ if msgLen+baseLenBefore <= maxPayloadLen &&
+ msgLen <= gsoSize &&
+ msgLen <= freeBaseCap &&
+ dgramCnt < udpSegmentMaxDatagrams &&
+ !endBatch {
+ msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
+ if i == len(bufs)-1 {
+ setGSO(&msgs[base].OOB, uint16(gsoSize))
+ }
+ dgramCnt++
+ if msgLen < gsoSize {
+ // A smaller than gsoSize packet on the tail is legal, but
+ // it must end the batch.
+ endBatch = true
+ }
+ continue
+ }
+ }
+ if dgramCnt > 1 {
+ setGSO(&msgs[base].OOB, uint16(gsoSize))
+ }
+ // Reset prior to incrementing base since we are preparing to start a
+ // new potential batch.
+ endBatch = false
+ base++
+ gsoSize = len(buf)
+ setSrcControl(&msgs[base].OOB, ep)
+ msgs[base].Buffers[0] = buf
+ msgs[base].Addr = addr
+ dgramCnt = 1
+ }
+ return base + 1
}
-// asEndpoint returns an Endpoint containing ap.
-func asEndpoint(ap netip.AddrPort) Endpoint {
- m := endpointPool.Get().(map[netip.AddrPort]Endpoint)
- defer endpointPool.Put(m)
- e, ok := m[ap]
- if !ok {
- e = Endpoint(StdNetEndpoint(ap))
- m[ap] = e
+type getGSOFunc func(control []byte) (int, error)
+
+func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
+ for i := firstMsgAt; i < len(msgs); i++ {
+ msg := &msgs[i]
+ if msg.N == 0 {
+ return n, err
+ }
+ var (
+ gsoSize int
+ start int
+ end = msg.N
+ numToSplit = 1
+ )
+ gsoSize, err = getGSO(msg.OOB[:msg.NN])
+ if err != nil {
+ return n, err
+ }
+ if gsoSize > 0 {
+ numToSplit = (msg.N + gsoSize - 1) / gsoSize
+ end = gsoSize
+ }
+ for j := 0; j < numToSplit; j++ {
+ if n > i {
+ return n, errors.New("splitting coalesced packet resulted in overflow")
+ }
+ copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
+ msgs[n].N = copied
+ msgs[n].Addr = msg.Addr
+ start = end
+ end += gsoSize
+ if end > msg.N {
+ end = msg.N
+ }
+ n++
+ }
+ if i != n-1 {
+ // It is legal for bytes to move within msg.Buffers[0] as a result
+ // of splitting, so we only zero the source msg len when it is not
+ // the destination of the last split operation above.
+ msg.N = 0
+ }
}
- return e
+ return n, nil
}
diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go
new file mode 100644
index 0000000..34a3c9a
--- /dev/null
+++ b/conn/bind_std_test.go
@@ -0,0 +1,250 @@
+package conn
+
+import (
+ "encoding/binary"
+ "net"
+ "testing"
+
+ "golang.org/x/net/ipv6"
+)
+
+func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
+ bind := NewStdNetBind().(*StdNetBind)
+ fns, _, err := bind.Open(0)
+ if err != nil {
+ t.Fatal(err)
+ }
+ bind.Close()
+ bufs := make([][]byte, 1)
+ bufs[0] = make([]byte, 1)
+ sizes := make([]int, 1)
+ eps := make([]Endpoint, 1)
+ for _, fn := range fns {
+ // The ReceiveFuncs must not access conn-related fields on StdNetBind
+ // unguarded. Close() nils the conn-related fields resulting in a panic
+ // if they violate the mutex.
+ fn(bufs, sizes, eps)
+ }
+}
+
+func mockSetGSOSize(control *[]byte, gsoSize uint16) {
+ *control = (*control)[:cap(*control)]
+ binary.LittleEndian.PutUint16(*control, gsoSize)
+}
+
+func Test_coalesceMessages(t *testing.T) {
+ cases := []struct {
+ name string
+ buffs [][]byte
+ wantLens []int
+ wantGSO []int
+ }{
+ {
+ name: "one message no coalesce",
+ buffs: [][]byte{
+ make([]byte, 1, 1),
+ },
+ wantLens: []int{1},
+ wantGSO: []int{0},
+ },
+ {
+ name: "two messages equal len coalesce",
+ buffs: [][]byte{
+ make([]byte, 1, 2),
+ make([]byte, 1, 1),
+ },
+ wantLens: []int{2},
+ wantGSO: []int{1},
+ },
+ {
+ name: "two messages unequal len coalesce",
+ buffs: [][]byte{
+ make([]byte, 2, 3),
+ make([]byte, 1, 1),
+ },
+ wantLens: []int{3},
+ wantGSO: []int{2},
+ },
+ {
+ name: "three messages second unequal len coalesce",
+ buffs: [][]byte{
+ make([]byte, 2, 3),
+ make([]byte, 1, 1),
+ make([]byte, 2, 2),
+ },
+ wantLens: []int{3, 2},
+ wantGSO: []int{2, 0},
+ },
+ {
+ name: "three messages limited cap coalesce",
+ buffs: [][]byte{
+ make([]byte, 2, 4),
+ make([]byte, 2, 2),
+ make([]byte, 2, 2),
+ },
+ wantLens: []int{4, 2},
+ wantGSO: []int{2, 0},
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ addr := &net.UDPAddr{
+ IP: net.ParseIP("127.0.0.1").To4(),
+ Port: 1,
+ }
+ msgs := make([]ipv6.Message, len(tt.buffs))
+ for i := range msgs {
+ msgs[i].Buffers = make([][]byte, 1)
+ msgs[i].OOB = make([]byte, 0, 2)
+ }
+ got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize)
+ if got != len(tt.wantLens) {
+ t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
+ }
+ for i := 0; i < got; i++ {
+ if msgs[i].Addr != addr {
+ t.Errorf("msgs[%d].Addr != passed addr", i)
+ }
+ gotLen := len(msgs[i].Buffers[0])
+ if gotLen != tt.wantLens[i] {
+ t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
+ }
+ gotGSO, err := mockGetGSOSize(msgs[i].OOB)
+ if err != nil {
+ t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
+ }
+ if gotGSO != tt.wantGSO[i] {
+ t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
+ }
+ }
+ })
+ }
+}
+
+func mockGetGSOSize(control []byte) (int, error) {
+ if len(control) < 2 {
+ return 0, nil
+ }
+ return int(binary.LittleEndian.Uint16(control)), nil
+}
+
+func Test_splitCoalescedMessages(t *testing.T) {
+ newMsg := func(n, gso int) ipv6.Message {
+ msg := ipv6.Message{
+ Buffers: [][]byte{make([]byte, 1<<16-1)},
+ N: n,
+ OOB: make([]byte, 2),
+ }
+ binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
+ if gso > 0 {
+ msg.NN = 2
+ }
+ return msg
+ }
+
+ cases := []struct {
+ name string
+ msgs []ipv6.Message
+ firstMsgAt int
+ wantNumEval int
+ wantMsgLens []int
+ wantErr bool
+ }{
+ {
+ name: "second last split last empty",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(3, 1),
+ newMsg(0, 0),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 3,
+ wantMsgLens: []int{1, 1, 1, 0},
+ wantErr: false,
+ },
+ {
+ name: "second last no split last empty",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(1, 0),
+ newMsg(0, 0),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 1,
+ wantMsgLens: []int{1, 0, 0, 0},
+ wantErr: false,
+ },
+ {
+ name: "second last no split last no split",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(1, 0),
+ newMsg(1, 0),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 2,
+ wantMsgLens: []int{1, 1, 0, 0},
+ wantErr: false,
+ },
+ {
+ name: "second last no split last split",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(1, 0),
+ newMsg(3, 1),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 4,
+ wantMsgLens: []int{1, 1, 1, 1},
+ wantErr: false,
+ },
+ {
+ name: "second last split last split",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(2, 1),
+ newMsg(2, 1),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 4,
+ wantMsgLens: []int{1, 1, 1, 1},
+ wantErr: false,
+ },
+ {
+ name: "second last no split last split overflow",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(1, 0),
+ newMsg(4, 1),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 4,
+ wantMsgLens: []int{1, 1, 1, 1},
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize)
+ if err != nil && !tt.wantErr {
+ t.Fatalf("err: %v", err)
+ }
+ if got != tt.wantNumEval {
+ t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
+ }
+ for i, msg := range tt.msgs {
+ if msg.N != tt.wantMsgLens[i] {
+ t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
+ }
+ }
+ })
+ }
+}
diff --git a/conn/bind_windows.go b/conn/bind_windows.go
index c066efa..d5095e0 100644
--- a/conn/bind_windows.go
+++ b/conn/bind_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
@@ -164,7 +164,7 @@ func (e *WinRingEndpoint) DstToBytes() []byte {
func (e *WinRingEndpoint) DstToString() string {
switch e.family {
case windows.AF_INET:
- netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
+ return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
case windows.AF_INET6:
var zone string
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
@@ -321,6 +321,13 @@ func (bind *WinRingBind) Close() error {
return nil
}
+// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
+// rename the IdealBatchSize constant to BatchSize.
+func (bind *WinRingBind) BatchSize() int {
+ // TODO: implement batching in and out of the ring
+ return 1
+}
+
func (bind *WinRingBind) SetMark(mark uint32) error {
return nil
}
@@ -409,16 +416,22 @@ retry:
return n, &ep, nil
}
-func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
+func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
- return bind.v4.Receive(buf, &bind.isOpen)
+ n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen)
+ sizes[0] = n
+ eps[0] = ep
+ return 1, err
}
-func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
+func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
- return bind.v6.Receive(buf, &bind.isOpen)
+ n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen)
+ sizes[0] = n
+ eps[0] = ep
+ return 1, err
}
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
@@ -473,32 +486,38 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomi
return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
}
-func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error {
+func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error {
nend, ok := endpoint.(*WinRingEndpoint)
if !ok {
return ErrWrongEndpointType
}
bind.mu.RLock()
defer bind.mu.RUnlock()
- switch nend.family {
- case windows.AF_INET:
- if bind.v4.blackhole {
- return nil
- }
- return bind.v4.Send(buf, nend, &bind.isOpen)
- case windows.AF_INET6:
- if bind.v6.blackhole {
- return nil
+ for _, buf := range bufs {
+ switch nend.family {
+ case windows.AF_INET:
+ if bind.v4.blackhole {
+ continue
+ }
+ if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
+ return err
+ }
+ case windows.AF_INET6:
+ if bind.v6.blackhole {
+ continue
+ }
+ if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
+ return err
+ }
}
- return bind.v6.Send(buf, nend, &bind.isOpen)
}
return nil
}
-func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
- bind.mu.Lock()
- defer bind.mu.Unlock()
- sysconn, err := bind.ipv4.SyscallConn()
+func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ sysconn, err := s.ipv4.SyscallConn()
if err != nil {
return err
}
@@ -511,14 +530,14 @@ func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
if err != nil {
return err
}
- bind.blackhole4 = blackhole
+ s.blackhole4 = blackhole
return nil
}
-func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
- bind.mu.Lock()
- defer bind.mu.Unlock()
- sysconn, err := bind.ipv6.SyscallConn()
+func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ sysconn, err := s.ipv6.SyscallConn()
if err != nil {
return err
}
@@ -531,7 +550,7 @@ func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole
if err != nil {
return err
}
- bind.blackhole6 = blackhole
+ s.blackhole6 = blackhole
return nil
}
diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go
index b38cae6..74e7add 100644
--- a/conn/bindtest/bindtest.go
+++ b/conn/bindtest/bindtest.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bindtest
@@ -89,32 +89,39 @@ func (c *ChannelBind) Close() error {
return nil
}
+func (c *ChannelBind) BatchSize() int { return 1 }
+
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
- return func(b []byte) (n int, ep conn.Endpoint, err error) {
+ return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
select {
case <-c.closeSignal:
- return 0, nil, net.ErrClosed
+ return 0, net.ErrClosed
case rx := <-ch:
- return copy(b, rx), c.target6, nil
+ copied := copy(bufs[0], rx)
+ sizes[0] = copied
+ eps[0] = c.target6
+ return 1, nil
}
}
}
-func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
- select {
- case <-c.closeSignal:
- return net.ErrClosed
- default:
- bc := make([]byte, len(b))
- copy(bc, b)
- if ep.(ChannelEndpoint) == c.target4 {
- *c.tx4 <- bc
- } else if ep.(ChannelEndpoint) == c.target6 {
- *c.tx6 <- bc
- } else {
- return os.ErrInvalid
+func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error {
+ for _, b := range bufs {
+ select {
+ case <-c.closeSignal:
+ return net.ErrClosed
+ default:
+ bc := make([]byte, len(b))
+ copy(bc, b)
+ if ep.(ChannelEndpoint) == c.target4 {
+ *c.tx4 <- bc
+ } else if ep.(ChannelEndpoint) == c.target6 {
+ *c.tx6 <- bc
+ } else {
+ return os.ErrInvalid
+ }
}
}
return nil
diff --git a/conn/boundif_android.go b/conn/boundif_android.go
index 8b82bfc..dd3ca5b 100644
--- a/conn/boundif_android.go
+++ b/conn/boundif_android.go
@@ -1,12 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
-func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
- sysconn, err := bind.ipv4.SyscallConn()
+func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
+ sysconn, err := s.ipv4.SyscallConn()
if err != nil {
return -1, err
}
@@ -19,8 +19,8 @@ func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
return
}
-func (bind *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
- sysconn, err := bind.ipv6.SyscallConn()
+func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
+ sysconn, err := s.ipv6.SyscallConn()
if err != nil {
return -1, err
}
diff --git a/conn/conn.go b/conn/conn.go
index 5a93b2b..a1f57d2 100644
--- a/conn/conn.go
+++ b/conn/conn.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
// Package conn implements WireGuard's network connections.
@@ -15,10 +15,17 @@ import (
"strings"
)
-// A ReceiveFunc receives a single inbound packet from the network.
-// It writes the data into b. n is the length of the packet.
-// ep is the remote endpoint.
-type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error)
+const (
+ IdealBatchSize = 128 // maximum number of packets handled per read and write
+)
+
+// A ReceiveFunc receives at least one packet from the network and writes them
+// into packets. On a successful read it returns the number of elements of
+// sizes, packets, and endpoints that should be evaluated. Some elements of
+// sizes may be zero, and callers should ignore them. Callers must pass a sizes
+// and eps slice with a length greater than or equal to the length of packets.
+// These lengths must not exceed the length of the associated Bind.BatchSize().
+type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
//
@@ -38,11 +45,16 @@ type Bind interface {
// This mark is passed to the kernel as the socket option SO_MARK.
SetMark(mark uint32) error
- // Send writes a packet b to address ep.
- Send(b []byte, ep Endpoint) error
+ // Send writes one or more packets in bufs to address ep. The length of
+ // bufs must not exceed BatchSize().
+ Send(bufs [][]byte, ep Endpoint) error
// ParseEndpoint creates a new endpoint from a string.
ParseEndpoint(s string) (Endpoint, error)
+
+ // BatchSize is the number of buffers expected to be passed to
+ // the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
+ BatchSize() int
}
// BindSocketToInterface is implemented by Bind objects that support being
diff --git a/conn/conn_test.go b/conn/conn_test.go
new file mode 100644
index 0000000..c6194ee
--- /dev/null
+++ b/conn/conn_test.go
@@ -0,0 +1,24 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "testing"
+)
+
+func TestPrettyName(t *testing.T) {
+ var (
+ recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return }
+ )
+
+ const want = "TestPrettyName"
+
+ t.Run("ReceiveFunc.PrettyName", func(t *testing.T) {
+ if got := recvFunc.PrettyName(); got != want {
+ t.Errorf("PrettyName() = %v, want %v", got, want)
+ }
+ })
+}
diff --git a/conn/controlfns.go b/conn/controlfns.go
new file mode 100644
index 0000000..4f7d90f
--- /dev/null
+++ b/conn/controlfns.go
@@ -0,0 +1,43 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "net"
+ "syscall"
+)
+
+// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
+// the max supported by a default configuration of macOS. Some platforms will
+// silently clamp the value to other maximums, such as linux clamping to
+// net.core.{r,w}mem_max (see _linux.go for additional implementation that works
+// around this limitation)
+const socketBufferSize = 7 << 20
+
+// controlFn is the callback function signature from net.ListenConfig.Control.
+// It is used to apply platform specific configuration to the socket prior to
+// bind.
+type controlFn func(network, address string, c syscall.RawConn) error
+
+// controlFns is a list of functions that are called from the listen config
+// that can apply socket options.
+var controlFns = []controlFn{}
+
+// listenConfig returns a net.ListenConfig that applies the controlFns to the
+// socket prior to bind. This is used to apply socket buffer sizing and packet
+// information OOB configuration for sticky sockets.
+func listenConfig() *net.ListenConfig {
+ return &net.ListenConfig{
+ Control: func(network, address string, c syscall.RawConn) error {
+ for _, fn := range controlFns {
+ if err := fn(network, address, c); err != nil {
+ return err
+ }
+ }
+ return nil
+ },
+ }
+}
diff --git a/conn/controlfns_linux.go b/conn/controlfns_linux.go
new file mode 100644
index 0000000..f6ab1d2
--- /dev/null
+++ b/conn/controlfns_linux.go
@@ -0,0 +1,69 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "fmt"
+ "runtime"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+func init() {
+ controlFns = append(controlFns,
+
+ // Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by
+ // using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to
+ // fail silently - the result of failure is lower performance on very fast
+ // links or high latency links.
+ func(network, address string, c syscall.RawConn) error {
+ return c.Control(func(fd uintptr) {
+ // Set up to *mem_max
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
+ // Set beyond *mem_max if CAP_NET_ADMIN
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
+ })
+ },
+
+ // Enable receiving of the packet information (IP_PKTINFO for IPv4,
+ // IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
+ func(network, address string, c syscall.RawConn) error {
+ var err error
+ switch network {
+ case "udp4":
+ if runtime.GOOS != "android" {
+ c.Control(func(fd uintptr) {
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
+ })
+ }
+ case "udp6":
+ c.Control(func(fd uintptr) {
+ if runtime.GOOS != "android" {
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
+ if err != nil {
+ return
+ }
+ }
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
+ })
+ default:
+ err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
+ }
+ return err
+ },
+
+ // Attempt to enable UDP_GRO
+ func(network, address string, c syscall.RawConn) error {
+ c.Control(func(fd uintptr) {
+ _ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
+ })
+ return nil
+ },
+ )
+}
diff --git a/conn/controlfns_unix.go b/conn/controlfns_unix.go
new file mode 100644
index 0000000..91692c0
--- /dev/null
+++ b/conn/controlfns_unix.go
@@ -0,0 +1,35 @@
+//go:build !windows && !linux && !wasm
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+func init() {
+ controlFns = append(controlFns,
+ func(network, address string, c syscall.RawConn) error {
+ return c.Control(func(fd uintptr) {
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
+ })
+ },
+
+ func(network, address string, c syscall.RawConn) error {
+ var err error
+ if network == "udp6" {
+ c.Control(func(fd uintptr) {
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
+ })
+ }
+ return err
+ },
+ )
+}
diff --git a/conn/controlfns_windows.go b/conn/controlfns_windows.go
new file mode 100644
index 0000000..c3bdf7d
--- /dev/null
+++ b/conn/controlfns_windows.go
@@ -0,0 +1,23 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/windows"
+)
+
+func init() {
+ controlFns = append(controlFns,
+ func(network, address string, c syscall.RawConn) error {
+ return c.Control(func(fd uintptr) {
+ _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF, socketBufferSize)
+ _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF, socketBufferSize)
+ })
+ },
+ )
+}
diff --git a/conn/default.go b/conn/default.go
index e65bb74..b6f761b 100644
--- a/conn/default.go
+++ b/conn/default.go
@@ -1,8 +1,8 @@
-//go:build !linux && !windows
+//go:build !windows
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
diff --git a/conn/errors_default.go b/conn/errors_default.go
new file mode 100644
index 0000000..f1e5b90
--- /dev/null
+++ b/conn/errors_default.go
@@ -0,0 +1,12 @@
+//go:build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+func errShouldDisableUDPGSO(err error) bool {
+ return false
+}
diff --git a/conn/errors_linux.go b/conn/errors_linux.go
new file mode 100644
index 0000000..8e61000
--- /dev/null
+++ b/conn/errors_linux.go
@@ -0,0 +1,26 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "errors"
+ "os"
+
+ "golang.org/x/sys/unix"
+)
+
+func errShouldDisableUDPGSO(err error) bool {
+ var serr *os.SyscallError
+ if errors.As(err, &serr) {
+ // EIO is returned by udp_send_skb() if the device driver does not have
+ // tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
+ // See:
+ // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
+ // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
+ return serr.Err == unix.EIO
+ }
+ return false
+}
diff --git a/conn/features_default.go b/conn/features_default.go
new file mode 100644
index 0000000..d53ff5f
--- /dev/null
+++ b/conn/features_default.go
@@ -0,0 +1,15 @@
+//go:build !linux
+// +build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import "net"
+
+func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
+ return
+}
diff --git a/conn/features_linux.go b/conn/features_linux.go
new file mode 100644
index 0000000..8959d93
--- /dev/null
+++ b/conn/features_linux.go
@@ -0,0 +1,29 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "net"
+
+ "golang.org/x/sys/unix"
+)
+
+func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
+ rc, err := conn.SyscallConn()
+ if err != nil {
+ return
+ }
+ err = rc.Control(func(fd uintptr) {
+ _, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
+ txOffload = errSyscall == nil
+ opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO)
+ rxOffload = errSyscall == nil && opt == 1
+ })
+ if err != nil {
+ return false, false
+ }
+ return txOffload, rxOffload
+}
diff --git a/conn/gso_default.go b/conn/gso_default.go
new file mode 100644
index 0000000..57780db
--- /dev/null
+++ b/conn/gso_default.go
@@ -0,0 +1,21 @@
+//go:build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
+func getGSOSize(control []byte) (int, error) {
+ return 0, nil
+}
+
+// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
+func setGSOSize(control *[]byte, gsoSize uint16) {
+}
+
+// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
+// offloading control data.
+const gsoControlSize = 0
diff --git a/conn/gso_linux.go b/conn/gso_linux.go
new file mode 100644
index 0000000..8596b29
--- /dev/null
+++ b/conn/gso_linux.go
@@ -0,0 +1,65 @@
+//go:build linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "fmt"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+const (
+ sizeOfGSOData = 2
+)
+
+// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
+func getGSOSize(control []byte) (int, error) {
+ var (
+ hdr unix.Cmsghdr
+ data []byte
+ rem = control
+ err error
+ )
+
+ for len(rem) > unix.SizeofCmsghdr {
+ hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
+ if err != nil {
+ return 0, fmt.Errorf("error parsing socket control message: %w", err)
+ }
+ if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
+ var gso uint16
+ copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
+ return int(gso), nil
+ }
+ }
+ return 0, nil
+}
+
+// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
+// data in control untouched.
+func setGSOSize(control *[]byte, gsoSize uint16) {
+ existingLen := len(*control)
+ avail := cap(*control) - existingLen
+ space := unix.CmsgSpace(sizeOfGSOData)
+ if avail < space {
+ return
+ }
+ *control = (*control)[:cap(*control)]
+ gsoControl := (*control)[existingLen:]
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
+ hdr.Level = unix.SOL_UDP
+ hdr.Type = unix.UDP_SEGMENT
+ hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
+ copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
+ *control = (*control)[:existingLen+space]
+}
+
+// gsoControlSize returns the recommended buffer size for pooling UDP
+// offloading control data.
+var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)
diff --git a/conn/mark_default.go b/conn/mark_default.go
index 6e01b0d..3102384 100644
--- a/conn/mark_default.go
+++ b/conn/mark_default.go
@@ -2,11 +2,11 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
-func (bind *StdNetBind) SetMark(mark uint32) error {
+func (s *StdNetBind) SetMark(mark uint32) error {
return nil
}
diff --git a/conn/mark_unix.go b/conn/mark_unix.go
index fec154c..d9e46ee 100644
--- a/conn/mark_unix.go
+++ b/conn/mark_unix.go
@@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
@@ -26,13 +26,13 @@ func init() {
}
}
-func (bind *StdNetBind) SetMark(mark uint32) error {
+func (s *StdNetBind) SetMark(mark uint32) error {
var operr error
if fwmarkIoctl == 0 {
return nil
}
- if bind.ipv4 != nil {
- fd, err := bind.ipv4.SyscallConn()
+ if s.ipv4 != nil {
+ fd, err := s.ipv4.SyscallConn()
if err != nil {
return err
}
@@ -46,8 +46,8 @@ func (bind *StdNetBind) SetMark(mark uint32) error {
return err
}
}
- if bind.ipv6 != nil {
- fd, err := bind.ipv6.SyscallConn()
+ if s.ipv6 != nil {
+ fd, err := s.ipv6.SyscallConn()
if err != nil {
return err
}
diff --git a/conn/sticky_default.go b/conn/sticky_default.go
new file mode 100644
index 0000000..0b21386
--- /dev/null
+++ b/conn/sticky_default.go
@@ -0,0 +1,42 @@
+//go:build !linux || android
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import "net/netip"
+
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+ return netip.Addr{}
+}
+
+func (e *StdNetEndpoint) SrcIfidx() int32 {
+ return 0
+}
+
+func (e *StdNetEndpoint) SrcToString() string {
+ return ""
+}
+
+// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
+// {get,set}srcControl feature set, but use alternatively named flags and need
+// ports and require testing.
+
+// getSrcFromControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
+}
+
+// setSrcControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
+}
+
+// stickyControlSize returns the recommended buffer size for pooling sticky
+// offloading control data.
+const stickyControlSize = 0
+
+const StdNetSupportsStickySockets = false
diff --git a/conn/sticky_linux.go b/conn/sticky_linux.go
new file mode 100644
index 0000000..8e206e9
--- /dev/null
+++ b/conn/sticky_linux.go
@@ -0,0 +1,112 @@
+//go:build linux && !android
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "net/netip"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+ switch len(e.src) {
+ case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+ return netip.AddrFrom4(info.Spec_dst)
+ case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+ // TODO: set zone. in order to do so we need to check if the address is
+ // link local, and if it is perform a syscall to turn the ifindex into a
+ // zone string because netip uses string zones.
+ return netip.AddrFrom16(info.Addr)
+ }
+ return netip.Addr{}
+}
+
+func (e *StdNetEndpoint) SrcIfidx() int32 {
+ switch len(e.src) {
+ case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+ return info.Ifindex
+ case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+ return int32(info.Ifindex)
+ }
+ return 0
+}
+
+func (e *StdNetEndpoint) SrcToString() string {
+ return e.SrcIP().String()
+}
+
+// getSrcFromControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
+ ep.ClearSrc()
+
+ var (
+ hdr unix.Cmsghdr
+ data []byte
+ rem []byte = control
+ err error
+ )
+
+ for len(rem) > unix.SizeofCmsghdr {
+ hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
+ if err != nil {
+ return
+ }
+
+ if hdr.Level == unix.IPPROTO_IP &&
+ hdr.Type == unix.IP_PKTINFO {
+
+ if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
+ ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+ }
+ ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
+
+ hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
+ copy(ep.src, hdrBuf)
+ copy(ep.src[unix.CmsgLen(0):], data)
+ return
+ }
+
+ if hdr.Level == unix.IPPROTO_IPV6 &&
+ hdr.Type == unix.IPV6_PKTINFO {
+
+ if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
+ ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
+ }
+
+ ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
+
+ hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
+ copy(ep.src, hdrBuf)
+ copy(ep.src[unix.CmsgLen(0):], data)
+ return
+ }
+ }
+}
+
+// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
+// and source ifindex found in ep. control's len will be set to 0 in the event
+// that ep is a default value.
+func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
+ if cap(*control) < len(ep.src) {
+ return
+ }
+ *control = (*control)[:0]
+ *control = append(*control, ep.src...)
+}
+
+// stickyControlSize returns the recommended buffer size for pooling sticky
+// offloading control data.
+var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
+
+const StdNetSupportsStickySockets = true
diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go
new file mode 100644
index 0000000..d2bd584
--- /dev/null
+++ b/conn/sticky_linux_test.go
@@ -0,0 +1,266 @@
+//go:build linux && !android
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "context"
+ "net"
+ "net/netip"
+ "runtime"
+ "testing"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) {
+ var buf []byte
+ if addr.Is4() {
+ buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+ hdr := unix.Cmsghdr{
+ Level: unix.IPPROTO_IP,
+ Type: unix.IP_PKTINFO,
+ }
+ hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
+ copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
+
+ info := unix.Inet4Pktinfo{
+ Ifindex: ifidx,
+ Spec_dst: addr.As4(),
+ }
+ copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo))
+ } else {
+ buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
+ hdr := unix.Cmsghdr{
+ Level: unix.IPPROTO_IPV6,
+ Type: unix.IPV6_PKTINFO,
+ }
+ hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
+ copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
+
+ info := unix.Inet6Pktinfo{
+ Ifindex: uint32(ifidx),
+ Addr: addr.As16(),
+ }
+ copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo))
+ }
+
+ ep.src = buf
+}
+
+func Test_setSrcControl(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ ep := &StdNetEndpoint{
+ AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
+ }
+ setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
+
+ control := make([]byte, stickyControlSize)
+
+ setSrcControl(&control, ep)
+
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ if hdr.Level != unix.IPPROTO_IP {
+ t.Errorf("unexpected level: %d", hdr.Level)
+ }
+ if hdr.Type != unix.IP_PKTINFO {
+ t.Errorf("unexpected type: %d", hdr.Type)
+ }
+ if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
+ t.Errorf("unexpected length: %d", hdr.Len)
+ }
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
+ t.Errorf("unexpected address: %v", info.Spec_dst)
+ }
+ if info.Ifindex != 5 {
+ t.Errorf("unexpected ifindex: %d", info.Ifindex)
+ }
+ })
+
+ t.Run("IPv6", func(t *testing.T) {
+ ep := &StdNetEndpoint{
+ AddrPort: netip.MustParseAddrPort("[::1]:1234"),
+ }
+ setSrc(ep, netip.MustParseAddr("::1"), 5)
+
+ control := make([]byte, stickyControlSize)
+
+ setSrcControl(&control, ep)
+
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ if hdr.Level != unix.IPPROTO_IPV6 {
+ t.Errorf("unexpected level: %d", hdr.Level)
+ }
+ if hdr.Type != unix.IPV6_PKTINFO {
+ t.Errorf("unexpected type: %d", hdr.Type)
+ }
+ if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
+ t.Errorf("unexpected length: %d", hdr.Len)
+ }
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ if info.Addr != ep.SrcIP().As16() {
+ t.Errorf("unexpected address: %v", info.Addr)
+ }
+ if info.Ifindex != 5 {
+ t.Errorf("unexpected ifindex: %d", info.Ifindex)
+ }
+ })
+
+ t.Run("ClearOnNoSrc", func(t *testing.T) {
+ control := make([]byte, stickyControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = 1
+ hdr.Type = 2
+ hdr.Len = 3
+
+ setSrcControl(&control, &StdNetEndpoint{})
+
+ if len(control) != 0 {
+ t.Errorf("unexpected control: %v", control)
+ }
+ })
+}
+
+func Test_getSrcFromControl(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ control := make([]byte, stickyControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = unix.IPPROTO_IP
+ hdr.Type = unix.IP_PKTINFO
+ hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ info.Spec_dst = [4]byte{127, 0, 0, 1}
+ info.Ifindex = 5
+
+ ep := &StdNetEndpoint{}
+ getSrcFromControl(control, ep)
+
+ if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.SrcIfidx() != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
+ }
+ })
+ t.Run("IPv6", func(t *testing.T) {
+ control := make([]byte, stickyControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = unix.IPPROTO_IPV6
+ hdr.Type = unix.IPV6_PKTINFO
+ hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
+ info.Ifindex = 5
+
+ ep := &StdNetEndpoint{}
+ getSrcFromControl(control, ep)
+
+ if ep.SrcIP() != netip.MustParseAddr("::1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.SrcIfidx() != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
+ }
+ })
+ t.Run("ClearOnEmpty", func(t *testing.T) {
+ var control []byte
+ ep := &StdNetEndpoint{}
+ setSrc(ep, netip.MustParseAddr("::1"), 5)
+
+ getSrcFromControl(control, ep)
+ if ep.SrcIP().IsValid() {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.SrcIfidx() != 0 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
+ }
+ })
+ t.Run("Multiple", func(t *testing.T) {
+ zeroControl := make([]byte, unix.CmsgSpace(0))
+ zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0]))
+ zeroHdr.SetLen(unix.CmsgLen(0))
+
+ control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = unix.IPPROTO_IP
+ hdr.Type = unix.IP_PKTINFO
+ hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ info.Spec_dst = [4]byte{127, 0, 0, 1}
+ info.Ifindex = 5
+
+ combined := make([]byte, 0)
+ combined = append(combined, zeroControl...)
+ combined = append(combined, control...)
+
+ ep := &StdNetEndpoint{}
+ getSrcFromControl(combined, ep)
+
+ if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.SrcIfidx() != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
+ }
+ })
+}
+
+func Test_listenConfig(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ sc, err := conn.(*net.UDPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if runtime.GOOS == "linux" {
+ var i int
+ sc.Control(func(fd uintptr) {
+ i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if i != 1 {
+ t.Error("IP_PKTINFO not set!")
+ }
+ } else {
+ t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
+ }
+ })
+ t.Run("IPv6", func(t *testing.T) {
+ conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ sc, err := conn.(*net.UDPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if runtime.GOOS == "linux" {
+ var i int
+ sc.Control(func(fd uintptr) {
+ i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if i != 1 {
+ t.Error("IPV6_PKTINFO not set!")
+ }
+ } else {
+ t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
+ }
+ })
+}
diff --git a/conn/winrio/rio_windows.go b/conn/winrio/rio_windows.go
index 0911998..d1037bb 100644
--- a/conn/winrio/rio_windows.go
+++ b/conn/winrio/rio_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package winrio
diff --git a/device/allowedips.go b/device/allowedips.go
index 3cac694..fa46f97 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go
index 0d3eecb..07065c3 100644
--- a/device/allowedips_rand_test.go
+++ b/device/allowedips_rand_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/allowedips_test.go b/device/allowedips_test.go
index 225c788..cde068e 100644
--- a/device/allowedips_test.go
+++ b/device/allowedips_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/bind_test.go b/device/bind_test.go
index 1c0d247..302a521 100644
--- a/device/bind_test.go
+++ b/device/bind_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -26,21 +26,21 @@ func (b *DummyBind) SetMark(v uint32) error {
return nil
}
-func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) {
+func (b *DummyBind) ReceiveIPv6(buf []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in6
if !ok {
return 0, nil, errors.New("closed")
}
- copy(buff, datagram.msg)
+ copy(buf, datagram.msg)
return len(datagram.msg), datagram.endpoint, nil
}
-func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) {
+func (b *DummyBind) ReceiveIPv4(buf []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in4
if !ok {
return 0, nil, errors.New("closed")
}
- copy(buff, datagram.msg)
+ copy(buf, datagram.msg)
return len(datagram.msg), datagram.endpoint, nil
}
@@ -51,6 +51,6 @@ func (b *DummyBind) Close() error {
return nil
}
-func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error {
+func (b *DummyBind) Send(buf []byte, end conn.Endpoint) error {
return nil
}
diff --git a/device/channels.go b/device/channels.go
index 6f4370a..e526f6b 100644
--- a/device/channels.go
+++ b/device/channels.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -19,13 +19,13 @@ import (
// call wg.Done to remove the initial reference.
// When the refcount hits 0, the queue's channel is closed.
type outboundQueue struct {
- c chan *QueueOutboundElement
+ c chan *QueueOutboundElementsContainer
wg sync.WaitGroup
}
func newOutboundQueue() *outboundQueue {
q := &outboundQueue{
- c: make(chan *QueueOutboundElement, QueueOutboundSize),
+ c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
}
q.wg.Add(1)
go func() {
@@ -37,13 +37,13 @@ func newOutboundQueue() *outboundQueue {
// A inboundQueue is similar to an outboundQueue; see those docs.
type inboundQueue struct {
- c chan *QueueInboundElement
+ c chan *QueueInboundElementsContainer
wg sync.WaitGroup
}
func newInboundQueue() *inboundQueue {
q := &inboundQueue{
- c: make(chan *QueueInboundElement, QueueInboundSize),
+ c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
}
q.wg.Add(1)
go func() {
@@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue {
}
type autodrainingInboundQueue struct {
- c chan *QueueInboundElement
+ c chan *QueueInboundElementsContainer
}
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
@@ -81,7 +81,7 @@ type autodrainingInboundQueue struct {
// some other means, such as sending a sentinel nil values.
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
q := &autodrainingInboundQueue{
- c: make(chan *QueueInboundElement, QueueInboundSize),
+ c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
}
runtime.SetFinalizer(q, device.flushInboundQueue)
return q
@@ -90,10 +90,13 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
for {
select {
- case elem := <-q.c:
- elem.Lock()
- device.PutMessageBuffer(elem.buffer)
- device.PutInboundElement(elem)
+ case elemsContainer := <-q.c:
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
+ }
+ device.PutInboundElementsContainer(elemsContainer)
default:
return
}
@@ -101,7 +104,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
}
type autodrainingOutboundQueue struct {
- c chan *QueueOutboundElement
+ c chan *QueueOutboundElementsContainer
}
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
@@ -111,7 +114,7 @@ type autodrainingOutboundQueue struct {
// All sends to the channel must be best-effort, because there may be no receivers.
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
q := &autodrainingOutboundQueue{
- c: make(chan *QueueOutboundElement, QueueOutboundSize),
+ c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
}
runtime.SetFinalizer(q, device.flushOutboundQueue)
return q
@@ -120,10 +123,13 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
for {
select {
- case elem := <-q.c:
- elem.Lock()
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ case elemsContainer := <-q.c:
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
+ device.PutOutboundElementsContainer(elemsContainer)
default:
return
}
diff --git a/device/constants.go b/device/constants.go
index 53a2526..59854a1 100644
--- a/device/constants.go
+++ b/device/constants.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/cookie.go b/device/cookie.go
index 840f709..876f05d 100644
--- a/device/cookie.go
+++ b/device/cookie.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/cookie_test.go b/device/cookie_test.go
index 06f306f..4f1e50a 100644
--- a/device/cookie_test.go
+++ b/device/cookie_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/device.go b/device/device.go
index f96e277..83c33ee 100644
--- a/device/device.go
+++ b/device/device.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -68,9 +68,11 @@ type Device struct {
cookieChecker CookieChecker
pool struct {
- messageBuffers *WaitPool
- inboundElements *WaitPool
- outboundElements *WaitPool
+ inboundElementsContainer *WaitPool
+ outboundElementsContainer *WaitPool
+ messageBuffers *WaitPool
+ inboundElements *WaitPool
+ outboundElements *WaitPool
}
queue struct {
@@ -265,7 +267,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap {
handshake := &peer.handshake
- handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
+ handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
expiredPeers = append(expiredPeers, peer)
}
@@ -295,6 +297,7 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
device.rate.limiter.Init()
device.indexTable.Init()
+
device.PopulatePools()
// create queues
@@ -322,6 +325,19 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
return device
}
+// BatchSize returns the BatchSize for the device as a whole which is the max of
+// the bind batch size and the tun batch size. The batch size reported by device
+// is the size used to construct memory pools, and is the allowed batch size for
+// the lifetime of the device.
+func (device *Device) BatchSize() int {
+ size := device.net.bind.BatchSize()
+ dSize := device.tun.device.BatchSize()
+ if size < dSize {
+ size = dSize
+ }
+ return size
+}
+
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.peers.RLock()
defer device.peers.RUnlock()
@@ -354,6 +370,8 @@ func (device *Device) RemoveAllPeers() {
func (device *Device) Close() {
device.state.Lock()
defer device.state.Unlock()
+ device.ipcMutex.Lock()
+ defer device.ipcMutex.Unlock()
if device.isClosed() {
return
}
@@ -443,11 +461,7 @@ func (device *Device) BindSetMark(mark uint32) error {
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
- peer.Lock()
- defer peer.Unlock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
+ peer.markEndpointSrcForClearing()
}
device.peers.RUnlock()
@@ -472,11 +486,13 @@ func (device *Device) BindUpdate() error {
var err error
var recvFns []conn.ReceiveFunc
netc := &device.net
+
recvFns, netc.port, err = netc.bind.Open(netc.port)
if err != nil {
netc.port = 0
return err
}
+
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
if err != nil {
netc.bind.Close()
@@ -495,11 +511,7 @@ func (device *Device) BindUpdate() error {
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
- peer.Lock()
- defer peer.Unlock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
+ peer.markEndpointSrcForClearing()
}
device.peers.RUnlock()
@@ -507,8 +519,9 @@ func (device *Device) BindUpdate() error {
device.net.stopping.Add(len(recvFns))
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
+ batchSize := netc.bind.BatchSize()
for _, fn := range recvFns {
- go device.RoutineReceiveIncoming(fn)
+ go device.RoutineReceiveIncoming(batchSize, fn)
}
device.log.Verbosef("UDP bind has been updated")
diff --git a/device/device_test.go b/device/device_test.go
index 8cffe08..fff172b 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -12,6 +12,7 @@ import (
"io"
"math/rand"
"net/netip"
+ "os"
"runtime"
"runtime/pprof"
"sync"
@@ -21,6 +22,7 @@ import (
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/conn/bindtest"
+ "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/tuntest"
)
@@ -307,6 +309,17 @@ func TestConcurrencySafety(t *testing.T) {
}
})
+ // Perform bind updates and keepalive sends concurrently with tunnel use.
+ t.Run("bindUpdate and keepalive", func(t *testing.T) {
+ const iters = 10
+ for i := 0; i < iters; i++ {
+ for _, peer := range pair {
+ peer.dev.BindUpdate()
+ peer.dev.SendKeepalivesToPeersWithCurrentKeypair()
+ }
+ }
+ })
+
close(done)
}
@@ -405,3 +418,59 @@ func goroutineLeakCheck(t *testing.T) {
t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
})
}
+
+type fakeBindSized struct {
+ size int
+}
+
+func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
+ return nil, 0, nil
+}
+func (b *fakeBindSized) Close() error { return nil }
+func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
+func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
+func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
+func (b *fakeBindSized) BatchSize() int { return b.size }
+
+type fakeTUNDeviceSized struct {
+ size int
+}
+
+func (t *fakeTUNDeviceSized) File() *os.File { return nil }
+func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
+ return 0, nil
+}
+func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
+func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
+func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
+func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
+func (t *fakeTUNDeviceSized) Close() error { return nil }
+func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
+
+func TestBatchSize(t *testing.T) {
+ d := Device{}
+
+ d.net.bind = &fakeBindSized{1}
+ d.tun.device = &fakeTUNDeviceSized{1}
+ if want, got := 1, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{1}
+ d.tun.device = &fakeTUNDeviceSized{128}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{128}
+ d.tun.device = &fakeTUNDeviceSized{1}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{128}
+ d.tun.device = &fakeTUNDeviceSized{128}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+}
diff --git a/device/endpoint_test.go b/device/endpoint_test.go
index b265be6..93a4998 100644
--- a/device/endpoint_test.go
+++ b/device/endpoint_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/indextable.go b/device/indextable.go
index 15b8b0d..00ade7d 100644
--- a/device/indextable.go
+++ b/device/indextable.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/ip.go b/device/ip.go
index 50c5040..eaf2363 100644
--- a/device/ip.go
+++ b/device/ip.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/kdf_test.go b/device/kdf_test.go
index 872195c..f9c76d6 100644
--- a/device/kdf_test.go
+++ b/device/kdf_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/keypair.go b/device/keypair.go
index 206d7a9..e3540d7 100644
--- a/device/keypair.go
+++ b/device/keypair.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/logger.go b/device/logger.go
index 29073e9..22b0df0 100644
--- a/device/logger.go
+++ b/device/logger.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/mobilequirks.go b/device/mobilequirks.go
index 680f5c7..0a0080e 100644
--- a/device/mobilequirks.go
+++ b/device/mobilequirks.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -11,9 +11,9 @@ func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
device.net.brokenRoaming = true
device.peers.RLock()
for _, peer := range device.peers.keyMap {
- peer.Lock()
- peer.disableRoaming = peer.endpoint != nil
- peer.Unlock()
+ peer.endpoint.Lock()
+ peer.endpoint.disableRoaming = peer.endpoint.val != nil
+ peer.endpoint.Unlock()
}
device.peers.RUnlock()
}
diff --git a/device/noise-helpers.go b/device/noise-helpers.go
index b6b51fd..c2f356b 100644
--- a/device/noise-helpers.go
+++ b/device/noise-helpers.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -9,6 +9,7 @@ import (
"crypto/hmac"
"crypto/rand"
"crypto/subtle"
+ "errors"
"hash"
"golang.org/x/crypto/blake2s"
@@ -94,9 +95,14 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
return
}
-func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
+var errInvalidPublicKey = errors.New("invalid public key")
+
+func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
apk := (*[NoisePublicKeySize]byte)(&pk)
ask := (*[NoisePrivateKeySize]byte)(sk)
curve25519.ScalarMult(&ss, ask, apk)
- return ss
+ if isZero(ss[:]) {
+ return ss, errInvalidPublicKey
+ }
+ return ss, nil
}
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index 410926e..e8f6145 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -175,8 +175,6 @@ func init() {
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
- errZeroECDHResult := errors.New("ECDH returned all zeros")
-
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -204,9 +202,9 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(msg.Ephemeral[:])
// encrypt static key
- ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
- if isZero(ss[:]) {
- return nil, errZeroECDHResult
+ ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ if err != nil {
+ return nil, err
}
var key [chacha20poly1305.KeySize]byte
KDF2(
@@ -221,7 +219,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
// encrypt timestamp
if isZero(handshake.precomputedStaticStatic[:]) {
- return nil, errZeroECDHResult
+ return nil, errInvalidPublicKey
}
KDF2(
&handshake.chainKey,
@@ -264,11 +262,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key
- var err error
var peerPK NoisePublicKey
var key [chacha20poly1305.KeySize]byte
- ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
- if isZero(ss[:]) {
+ ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ if err != nil {
return nil
}
KDF2(&chainKey, &key, chainKey[:], ss[:])
@@ -384,12 +381,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(msg.Ephemeral[:])
handshake.mixKey(msg.Ephemeral[:])
- func() {
- ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
- handshake.mixKey(ss[:])
- ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
- handshake.mixKey(ss[:])
- }()
+ ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
+ if err != nil {
+ return nil, err
+ }
+ handshake.mixKey(ss[:])
+ ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ if err != nil {
+ return nil, err
+ }
+ handshake.mixKey(ss[:])
// add preshared key
@@ -406,11 +407,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(tau[:])
- func() {
- aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
- handshake.mixHash(msg.Empty[:])
- }()
+ aead, _ := chacha20poly1305.New(key[:])
+ aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
+ handshake.mixHash(msg.Empty[:])
handshake.state = handshakeResponseCreated
@@ -455,17 +454,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
- func() {
- ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
- mixKey(&chainKey, &chainKey, ss[:])
- setZero(ss[:])
- }()
+ ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
+ if err != nil {
+ return false
+ }
+ mixKey(&chainKey, &chainKey, ss[:])
+ setZero(ss[:])
- func() {
- ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
- mixKey(&chainKey, &chainKey, ss[:])
- setZero(ss[:])
- }()
+ ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ if err != nil {
+ return false
+ }
+ mixKey(&chainKey, &chainKey, ss[:])
+ setZero(ss[:])
// add preshared key (psk)
@@ -483,7 +484,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
// authenticate transcript
aead, _ := chacha20poly1305.New(key[:])
- _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
+ _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil {
return false
}
diff --git a/device/noise-types.go b/device/noise-types.go
index 6e850e7..e850359 100644
--- a/device/noise-types.go
+++ b/device/noise-types.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/noise_test.go b/device/noise_test.go
index 7c84efc..2dd5324 100644
--- a/device/noise_test.go
+++ b/device/noise_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -24,10 +24,10 @@ func TestCurveWrappers(t *testing.T) {
pk1 := sk1.publicKey()
pk2 := sk2.publicKey()
- ss1 := sk1.sharedSecret(pk2)
- ss2 := sk2.sharedSecret(pk1)
+ ss1, err1 := sk1.sharedSecret(pk2)
+ ss2, err2 := sk2.sharedSecret(pk1)
- if ss1 != ss2 {
+ if ss1 != ss2 || err1 != nil || err2 != nil {
t.Fatal("Failed to compute shared secet")
}
}
diff --git a/device/peer.go b/device/peer.go
index 79feae7..47a2f14 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -17,17 +17,20 @@ import (
type Peer struct {
isRunning atomic.Bool
- sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
keypairs Keypairs
handshake Handshake
device *Device
- endpoint conn.Endpoint
stopping sync.WaitGroup // routines pending stop
txBytes atomic.Uint64 // bytes send to peer (endpoint)
rxBytes atomic.Uint64 // bytes received from peer
lastHandshakeNano atomic.Int64 // nano seconds since epoch
- disableRoaming bool
+ endpoint struct {
+ sync.Mutex
+ val conn.Endpoint
+ clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
+ disableRoaming bool
+ }
timers struct {
retransmitHandshake *Timer
@@ -45,9 +48,9 @@ type Peer struct {
}
queue struct {
- staged chan *QueueOutboundElement // staged packets before a handshake is available
- outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
- inbound *autodrainingInboundQueue // sequential ordering of tun writing
+ staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
+ outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
+ inbound *autodrainingInboundQueue // sequential ordering of tun writing
}
cookieGenerator CookieGenerator
@@ -74,14 +77,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// create peer
peer := new(Peer)
- peer.Lock()
- defer peer.Unlock()
peer.cookieGenerator.Init(pk)
peer.device = device
peer.queue.outbound = newAutodrainingOutboundQueue(device)
peer.queue.inbound = newAutodrainingInboundQueue(device)
- peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize)
+ peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
// map public key
_, ok := device.peers.keyMap[pk]
@@ -92,12 +93,16 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// pre-compute DH
handshake := &peer.handshake
handshake.mutex.Lock()
- handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
+ handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
handshake.remoteStatic = pk
handshake.mutex.Unlock()
// reset endpoint
- peer.endpoint = nil
+ peer.endpoint.Lock()
+ peer.endpoint.val = nil
+ peer.endpoint.disableRoaming = false
+ peer.endpoint.clearSrcOnTx = false
+ peer.endpoint.Unlock()
// init timers
peer.timersInit()
@@ -108,7 +113,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
return peer, nil
}
-func (peer *Peer) SendBuffer(buffer []byte) error {
+func (peer *Peer) SendBuffers(buffers [][]byte) error {
peer.device.net.RLock()
defer peer.device.net.RUnlock()
@@ -116,16 +121,25 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
return nil
}
- peer.RLock()
- defer peer.RUnlock()
-
- if peer.endpoint == nil {
+ peer.endpoint.Lock()
+ endpoint := peer.endpoint.val
+ if endpoint == nil {
+ peer.endpoint.Unlock()
return errors.New("no known endpoint for peer")
}
+ if peer.endpoint.clearSrcOnTx {
+ endpoint.ClearSrc()
+ peer.endpoint.clearSrcOnTx = false
+ }
+ peer.endpoint.Unlock()
- err := peer.device.net.bind.Send(buffer, peer.endpoint)
+ err := peer.device.net.bind.Send(buffers, endpoint)
if err == nil {
- peer.txBytes.Add(uint64(len(buffer)))
+ var totalLen uint64
+ for _, b := range buffers {
+ totalLen += uint64(len(b))
+ }
+ peer.txBytes.Add(totalLen)
}
return err
}
@@ -187,8 +201,12 @@ func (peer *Peer) Start() {
device.flushInboundQueue(peer.queue.inbound)
device.flushOutboundQueue(peer.queue.outbound)
- go peer.RoutineSequentialSender()
- go peer.RoutineSequentialReceiver()
+
+ // Use the device batch size, not the bind batch size, as the device size is
+ // the size of the batch pools.
+ batchSize := peer.device.BatchSize()
+ go peer.RoutineSequentialSender(batchSize)
+ go peer.RoutineSequentialReceiver(batchSize)
peer.isRunning.Store(true)
}
@@ -259,10 +277,20 @@ func (peer *Peer) Stop() {
}
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
- if peer.disableRoaming {
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ if peer.endpoint.disableRoaming {
+ return
+ }
+ peer.endpoint.clearSrcOnTx = false
+ peer.endpoint.val = endpoint
+}
+
+func (peer *Peer) markEndpointSrcForClearing() {
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ if peer.endpoint.val == nil {
return
}
- peer.Lock()
- peer.endpoint = endpoint
- peer.Unlock()
+ peer.endpoint.clearSrcOnTx = true
}
diff --git a/device/pools.go b/device/pools.go
index 9da0f79..94f3dc7 100644
--- a/device/pools.go
+++ b/device/pools.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -46,6 +46,14 @@ func (p *WaitPool) Put(x any) {
}
func (device *Device) PopulatePools() {
+ device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ s := make([]*QueueInboundElement, 0, device.BatchSize())
+ return &QueueInboundElementsContainer{elems: s}
+ })
+ device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ s := make([]*QueueOutboundElement, 0, device.BatchSize())
+ return &QueueOutboundElementsContainer{elems: s}
+ })
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new([MaxMessageSize]byte)
})
@@ -57,6 +65,34 @@ func (device *Device) PopulatePools() {
})
}
+func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
+ c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
+ c.Mutex = sync.Mutex{}
+ return c
+}
+
+func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
+ for i := range c.elems {
+ c.elems[i] = nil
+ }
+ c.elems = c.elems[:0]
+ device.pool.inboundElementsContainer.Put(c)
+}
+
+func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
+ c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
+ c.Mutex = sync.Mutex{}
+ return c
+}
+
+func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
+ for i := range c.elems {
+ c.elems[i] = nil
+ }
+ c.elems = c.elems[:0]
+ device.pool.outboundElementsContainer.Put(c)
+}
+
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
}
diff --git a/device/pools_test.go b/device/pools_test.go
index 48a98b0..82d7493 100644
--- a/device/pools_test.go
+++ b/device/pools_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -89,3 +89,51 @@ func BenchmarkWaitPool(b *testing.B) {
}
wg.Wait()
}
+
+func BenchmarkWaitPoolEmpty(b *testing.B) {
+ var wg sync.WaitGroup
+ var trials atomic.Int32
+ trials.Store(int32(b.N))
+ workers := runtime.NumCPU() + 2
+ if workers-4 <= 0 {
+ b.Skip("Not enough cores")
+ }
+ p := NewWaitPool(0, func() any { return make([]byte, 16) })
+ wg.Add(workers)
+ b.ResetTimer()
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ for trials.Add(-1) > 0 {
+ x := p.Get()
+ time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
+ p.Put(x)
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+func BenchmarkSyncPool(b *testing.B) {
+ var wg sync.WaitGroup
+ var trials atomic.Int32
+ trials.Store(int32(b.N))
+ workers := runtime.NumCPU() + 2
+ if workers-4 <= 0 {
+ b.Skip("Not enough cores")
+ }
+ p := sync.Pool{New: func() any { return make([]byte, 16) }}
+ wg.Add(workers)
+ b.ResetTimer()
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ for trials.Add(-1) > 0 {
+ x := p.Get()
+ time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
+ p.Put(x)
+ }
+ }()
+ }
+ wg.Wait()
+}
diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go
index 7abc33a..25f700a 100644
--- a/device/queueconstants_android.go
+++ b/device/queueconstants_android.go
@@ -1,17 +1,19 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
+import "golang.zx2c4.com/wireguard/conn"
+
/* Reduce memory consumption for Android */
const (
- QueueStagedSize = 128
+ QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
- MaxSegmentSize = 2200
+ MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
PreallocatedBuffersPerPool = 4096
)
diff --git a/device/queueconstants_default.go b/device/queueconstants_default.go
index 66ef600..ea763d0 100644
--- a/device/queueconstants_default.go
+++ b/device/queueconstants_default.go
@@ -2,13 +2,15 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
+import "golang.zx2c4.com/wireguard/conn"
+
const (
- QueueStagedSize = 128
+ QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
diff --git a/device/queueconstants_ios.go b/device/queueconstants_ios.go
index 854e4c2..acd3cec 100644
--- a/device/queueconstants_ios.go
+++ b/device/queueconstants_ios.go
@@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/queueconstants_windows.go b/device/queueconstants_windows.go
index e330a6b..1eee32b 100644
--- a/device/queueconstants_windows.go
+++ b/device/queueconstants_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/race_disabled_test.go b/device/race_disabled_test.go
index b3db3a1..bb5c450 100644
--- a/device/race_disabled_test.go
+++ b/device/race_disabled_test.go
@@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/race_enabled_test.go b/device/race_enabled_test.go
index 1565100..4e9daea 100644
--- a/device/race_enabled_test.go
+++ b/device/race_enabled_test.go
@@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/receive.go b/device/receive.go
index 4dbf1e8..1ab3e29 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -27,7 +27,6 @@ type QueueHandshakeElement struct {
}
type QueueInboundElement struct {
- sync.Mutex
buffer *[MaxMessageSize]byte
packet []byte
counter uint64
@@ -35,6 +34,11 @@ type QueueInboundElement struct {
endpoint conn.Endpoint
}
+type QueueInboundElementsContainer struct {
+ sync.Mutex
+ elems []*QueueInboundElement
+}
+
// clearPointers clears elem fields that contain pointers.
// This makes the garbage collector's life easier and
// avoids accidentally keeping other objects around unnecessarily.
@@ -66,7 +70,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately)
*/
-func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
+func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
recvName := recv.PrettyName()
defer func() {
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
@@ -79,20 +83,33 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
// receive datagrams until conn is closed
- buffer := device.GetMessageBuffer()
-
var (
+ bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
+ bufs = make([][]byte, maxBatchSize)
err error
- size int
- endpoint conn.Endpoint
+ sizes = make([]int, maxBatchSize)
+ count int
+ endpoints = make([]conn.Endpoint, maxBatchSize)
deathSpiral int
+ elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
)
- for {
- size, endpoint, err = recv(buffer[:])
+ for i := range bufsArrs {
+ bufsArrs[i] = device.GetMessageBuffer()
+ bufs[i] = bufsArrs[i][:]
+ }
+
+ defer func() {
+ for i := 0; i < maxBatchSize; i++ {
+ if bufsArrs[i] != nil {
+ device.PutMessageBuffer(bufsArrs[i])
+ }
+ }
+ }()
+ for {
+ count, err = recv(bufs, sizes, endpoints)
if err != nil {
- device.PutMessageBuffer(buffer)
if errors.Is(err, net.ErrClosed) {
return
}
@@ -103,101 +120,119 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
if deathSpiral < 10 {
deathSpiral++
time.Sleep(time.Second / 3)
- buffer = device.GetMessageBuffer()
continue
}
return
}
deathSpiral = 0
- if size < MinMessageSize {
- continue
- }
+ // handle each packet in the batch
+ for i, size := range sizes[:count] {
+ if size < MinMessageSize {
+ continue
+ }
- // check size of packet
+ // check size of packet
- packet := buffer[:size]
- msgType := binary.LittleEndian.Uint32(packet[:4])
+ packet := bufsArrs[i][:size]
+ msgType := binary.LittleEndian.Uint32(packet[:4])
- var okay bool
+ switch msgType {
- switch msgType {
+ // check if transport
- // check if transport
+ case MessageTransportType:
- case MessageTransportType:
+ // check size
- // check size
+ if len(packet) < MessageTransportSize {
+ continue
+ }
- if len(packet) < MessageTransportSize {
- continue
- }
+ // lookup key pair
- // lookup key pair
+ receiver := binary.LittleEndian.Uint32(
+ packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
+ )
+ value := device.indexTable.Lookup(receiver)
+ keypair := value.keypair
+ if keypair == nil {
+ continue
+ }
- receiver := binary.LittleEndian.Uint32(
- packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
- )
- value := device.indexTable.Lookup(receiver)
- keypair := value.keypair
- if keypair == nil {
- continue
- }
+ // check keypair expiry
- // check keypair expiry
+ if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
+ continue
+ }
- if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
+ // create work element
+ peer := value.peer
+ elem := device.GetInboundElement()
+ elem.packet = packet
+ elem.buffer = bufsArrs[i]
+ elem.keypair = keypair
+ elem.endpoint = endpoints[i]
+ elem.counter = 0
+
+ elemsForPeer, ok := elemsByPeer[peer]
+ if !ok {
+ elemsForPeer = device.GetInboundElementsContainer()
+ elemsForPeer.Lock()
+ elemsByPeer[peer] = elemsForPeer
+ }
+ elemsForPeer.elems = append(elemsForPeer.elems, elem)
+ bufsArrs[i] = device.GetMessageBuffer()
+ bufs[i] = bufsArrs[i][:]
continue
- }
-
- // create work element
- peer := value.peer
- elem := device.GetInboundElement()
- elem.packet = packet
- elem.buffer = buffer
- elem.keypair = keypair
- elem.endpoint = endpoint
- elem.counter = 0
- elem.Mutex = sync.Mutex{}
- elem.Lock()
-
- // add to decryption queues
- if peer.isRunning.Load() {
- peer.queue.inbound.c <- elem
- device.queue.decryption.c <- elem
- buffer = device.GetMessageBuffer()
- } else {
- device.PutInboundElement(elem)
- }
- continue
- // otherwise it is a fixed size & handshake related packet
+ // otherwise it is a fixed size & handshake related packet
- case MessageInitiationType:
- okay = len(packet) == MessageInitiationSize
+ case MessageInitiationType:
+ if len(packet) != MessageInitiationSize {
+ continue
+ }
- case MessageResponseType:
- okay = len(packet) == MessageResponseSize
+ case MessageResponseType:
+ if len(packet) != MessageResponseSize {
+ continue
+ }
- case MessageCookieReplyType:
- okay = len(packet) == MessageCookieReplySize
+ case MessageCookieReplyType:
+ if len(packet) != MessageCookieReplySize {
+ continue
+ }
- default:
- device.log.Verbosef("Received message with unknown type")
- }
+ default:
+ device.log.Verbosef("Received message with unknown type")
+ continue
+ }
- if okay {
select {
case device.queue.handshake.c <- QueueHandshakeElement{
msgType: msgType,
- buffer: buffer,
+ buffer: bufsArrs[i],
packet: packet,
- endpoint: endpoint,
+ endpoint: endpoints[i],
}:
- buffer = device.GetMessageBuffer()
+ bufsArrs[i] = device.GetMessageBuffer()
+ bufs[i] = bufsArrs[i][:]
default:
}
}
+ for peer, elemsContainer := range elemsByPeer {
+ if peer.isRunning.Load() {
+ peer.queue.inbound.c <- elemsContainer
+ device.queue.decryption.c <- elemsContainer
+ } else {
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
+ }
+ device.PutInboundElementsContainer(elemsContainer)
+ }
+ delete(elemsByPeer, peer)
+ }
}
}
@@ -207,26 +242,28 @@ func (device *Device) RoutineDecryption(id int) {
defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
device.log.Verbosef("Routine: decryption worker %d - started", id)
- for elem := range device.queue.decryption.c {
- // split message into fields
- counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
- content := elem.packet[MessageTransportOffsetContent:]
-
- // decrypt and release to consumer
- var err error
- elem.counter = binary.LittleEndian.Uint64(counter)
- // copy counter to nonce
- binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
- elem.packet, err = elem.keypair.receive.Open(
- content[:0],
- nonce[:],
- content,
- nil,
- )
- if err != nil {
- elem.packet = nil
+ for elemsContainer := range device.queue.decryption.c {
+ for _, elem := range elemsContainer.elems {
+ // split message into fields
+ counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
+ content := elem.packet[MessageTransportOffsetContent:]
+
+ // decrypt and release to consumer
+ var err error
+ elem.counter = binary.LittleEndian.Uint64(counter)
+ // copy counter to nonce
+ binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
+ elem.packet, err = elem.keypair.receive.Open(
+ content[:0],
+ nonce[:],
+ content,
+ nil,
+ )
+ if err != nil {
+ elem.packet = nil
+ }
}
- elem.Unlock()
+ elemsContainer.Unlock()
}
}
@@ -393,7 +430,7 @@ func (device *Device) RoutineHandshake(id int) {
}
}
-func (peer *Peer) RoutineSequentialReceiver() {
+func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
device := peer.device
defer func() {
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
@@ -401,89 +438,103 @@ func (peer *Peer) RoutineSequentialReceiver() {
}()
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
- for elem := range peer.queue.inbound.c {
- if elem == nil {
+ bufs := make([][]byte, 0, maxBatchSize)
+
+ for elemsContainer := range peer.queue.inbound.c {
+ if elemsContainer == nil {
return
}
- var err error
- elem.Lock()
- if elem.packet == nil {
- // decryption failed
- goto skip
- }
+ elemsContainer.Lock()
+ validTailPacket := -1
+ dataPacketReceived := false
+ rxBytesLen := uint64(0)
+ for i, elem := range elemsContainer.elems {
+ if elem.packet == nil {
+ // decryption failed
+ continue
+ }
- if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
- goto skip
- }
+ if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
+ continue
+ }
- peer.SetEndpointFromPacket(elem.endpoint)
- if peer.ReceivedWithKeypair(elem.keypair) {
- peer.timersHandshakeComplete()
- peer.SendStagedPackets()
- }
+ validTailPacket = i
+ if peer.ReceivedWithKeypair(elem.keypair) {
+ peer.SetEndpointFromPacket(elem.endpoint)
+ peer.timersHandshakeComplete()
+ peer.SendStagedPackets()
+ }
+ rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
- peer.keepKeyFreshReceiving()
- peer.timersAnyAuthenticatedPacketTraversal()
- peer.timersAnyAuthenticatedPacketReceived()
- peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
+ if len(elem.packet) == 0 {
+ device.log.Verbosef("%v - Receiving keepalive packet", peer)
+ continue
+ }
+ dataPacketReceived = true
- if len(elem.packet) == 0 {
- device.log.Verbosef("%v - Receiving keepalive packet", peer)
- goto skip
- }
- peer.timersDataReceived()
+ switch elem.packet[0] >> 4 {
+ case 4:
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
+ field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
+ length := binary.BigEndian.Uint16(field)
+ if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
+ continue
+ }
+ elem.packet = elem.packet[:length]
+ src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
+ if device.allowedips.Lookup(src) != peer {
+ device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
+ continue
+ }
- switch elem.packet[0] >> 4 {
- case ipv4.Version:
- if len(elem.packet) < ipv4.HeaderLen {
- goto skip
- }
- field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
- length := binary.BigEndian.Uint16(field)
- if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
- goto skip
- }
- elem.packet = elem.packet[:length]
- src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
- if device.allowedips.Lookup(src) != peer {
- device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
- goto skip
- }
+ case 6:
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
+ field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
+ length := binary.BigEndian.Uint16(field)
+ length += ipv6.HeaderLen
+ if int(length) > len(elem.packet) {
+ continue
+ }
+ elem.packet = elem.packet[:length]
+ src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
+ if device.allowedips.Lookup(src) != peer {
+ device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
+ continue
+ }
- case ipv6.Version:
- if len(elem.packet) < ipv6.HeaderLen {
- goto skip
- }
- field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
- length := binary.BigEndian.Uint16(field)
- length += ipv6.HeaderLen
- if int(length) > len(elem.packet) {
- goto skip
- }
- elem.packet = elem.packet[:length]
- src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
- if device.allowedips.Lookup(src) != peer {
- device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
- goto skip
+ default:
+ device.log.Verbosef("Packet with invalid IP version from %v", peer)
+ continue
}
- default:
- device.log.Verbosef("Packet with invalid IP version from %v", peer)
- goto skip
+ bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
}
- _, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent)
- if err != nil && !device.isClosed() {
- device.log.Errorf("Failed to write packet to TUN device: %v", err)
+ peer.rxBytes.Add(rxBytesLen)
+ if validTailPacket >= 0 {
+ peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
+ peer.keepKeyFreshReceiving()
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersAnyAuthenticatedPacketReceived()
}
- if len(peer.queue.inbound.c) == 0 {
- err = device.tun.device.Flush()
- if err != nil {
- peer.device.log.Errorf("Unable to flush packets: %v", err)
+ if dataPacketReceived {
+ peer.timersDataReceived()
+ }
+ if len(bufs) > 0 {
+ _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
+ if err != nil && !device.isClosed() {
+ device.log.Errorf("Failed to write packets to TUN device: %v", err)
}
}
- skip:
- device.PutMessageBuffer(elem.buffer)
- device.PutInboundElement(elem)
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
+ }
+ bufs = bufs[:0]
+ device.PutInboundElementsContainer(elemsContainer)
}
}
diff --git a/device/send.go b/device/send.go
index 471c51c..769720a 100644
--- a/device/send.go
+++ b/device/send.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -17,6 +17,8 @@ import (
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/tun"
)
/* Outbound flow
@@ -44,7 +46,6 @@ import (
*/
type QueueOutboundElement struct {
- sync.Mutex
buffer *[MaxMessageSize]byte // slice holding the packet data
packet []byte // slice of "buffer" (always!)
nonce uint64 // nonce for encryption
@@ -52,10 +53,14 @@ type QueueOutboundElement struct {
peer *Peer // related peer
}
+type QueueOutboundElementsContainer struct {
+ sync.Mutex
+ elems []*QueueOutboundElement
+}
+
func (device *Device) NewOutboundElement() *QueueOutboundElement {
elem := device.GetOutboundElement()
elem.buffer = device.GetMessageBuffer()
- elem.Mutex = sync.Mutex{}
elem.nonce = 0
// keypair and peer were cleared (if necessary) by clearPointers.
return elem
@@ -77,12 +82,15 @@ func (elem *QueueOutboundElement) clearPointers() {
func (peer *Peer) SendKeepalive() {
if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
elem := peer.device.NewOutboundElement()
+ elemsContainer := peer.device.GetOutboundElementsContainer()
+ elemsContainer.elems = append(elemsContainer.elems, elem)
select {
- case peer.queue.staged <- elem:
+ case peer.queue.staged <- elemsContainer:
peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
default:
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
+ peer.device.PutOutboundElementsContainer(elemsContainer)
}
}
peer.SendStagedPackets()
@@ -116,8 +124,8 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
return err
}
- var buff [MessageInitiationSize]byte
- writer := bytes.NewBuffer(buff[:0])
+ var buf [MessageInitiationSize]byte
+ writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
@@ -125,7 +133,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
- err = peer.SendBuffer(packet)
+ err = peer.SendBuffers([][]byte{packet})
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
}
@@ -147,8 +155,8 @@ func (peer *Peer) SendHandshakeResponse() error {
return err
}
- var buff [MessageResponseSize]byte
- writer := bytes.NewBuffer(buff[:0])
+ var buf [MessageResponseSize]byte
+ writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, response)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
@@ -163,7 +171,8 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
- err = peer.SendBuffer(packet)
+ // TODO: allocation could be avoided
+ err = peer.SendBuffers([][]byte{packet})
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
}
@@ -180,10 +189,11 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement)
return err
}
- var buff [MessageCookieReplySize]byte
- writer := bytes.NewBuffer(buff[:0])
+ var buf [MessageCookieReplySize]byte
+ writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, reply)
- device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
+ // TODO: allocation could be avoided
+ device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
return nil
}
@@ -198,11 +208,6 @@ func (peer *Peer) keepKeyFreshSending() {
}
}
-/* Reads packets from the TUN and inserts
- * into staged queue for peer
- *
- * Obs. Single instance per TUN device
- */
func (device *Device) RoutineReadFromTUN() {
defer func() {
device.log.Verbosef("Routine: TUN reader - stopped")
@@ -212,81 +217,123 @@ func (device *Device) RoutineReadFromTUN() {
device.log.Verbosef("Routine: TUN reader - started")
- var elem *QueueOutboundElement
+ var (
+ batchSize = device.BatchSize()
+ readErr error
+ elems = make([]*QueueOutboundElement, batchSize)
+ bufs = make([][]byte, batchSize)
+ elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
+ count = 0
+ sizes = make([]int, batchSize)
+ offset = MessageTransportHeaderSize
+ )
+
+ for i := range elems {
+ elems[i] = device.NewOutboundElement()
+ bufs[i] = elems[i].buffer[:]
+ }
- for {
- if elem != nil {
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ defer func() {
+ for _, elem := range elems {
+ if elem != nil {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
}
- elem = device.NewOutboundElement()
-
- // read packet
+ }()
- offset := MessageTransportHeaderSize
- size, err := device.tun.device.Read(elem.buffer[:], offset)
- if err != nil {
- if !device.isClosed() {
- if !errors.Is(err, os.ErrClosed) {
- device.log.Errorf("Failed to read packet from TUN device: %v", err)
- }
- go device.Close()
+ for {
+ // read packets
+ count, readErr = device.tun.device.Read(bufs, sizes, offset)
+ for i := 0; i < count; i++ {
+ if sizes[i] < 1 {
+ continue
}
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
- return
- }
- if size == 0 || size > MaxContentSize {
- continue
- }
+ elem := elems[i]
+ elem.packet = bufs[i][offset : offset+sizes[i]]
- elem.packet = elem.buffer[offset : offset+size]
+ // lookup peer
+ var peer *Peer
+ switch elem.packet[0] >> 4 {
+ case 4:
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
+ dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
+ peer = device.allowedips.Lookup(dst)
- // lookup peer
+ case 6:
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
+ dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
+ peer = device.allowedips.Lookup(dst)
- var peer *Peer
- switch elem.packet[0] >> 4 {
- case ipv4.Version:
- if len(elem.packet) < ipv4.HeaderLen {
- continue
+ default:
+ device.log.Verbosef("Received packet with unknown IP version")
}
- dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
- peer = device.allowedips.Lookup(dst)
- case ipv6.Version:
- if len(elem.packet) < ipv6.HeaderLen {
+ if peer == nil {
continue
}
- dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
- peer = device.allowedips.Lookup(dst)
-
- default:
- device.log.Verbosef("Received packet with unknown IP version")
+ elemsForPeer, ok := elemsByPeer[peer]
+ if !ok {
+ elemsForPeer = device.GetOutboundElementsContainer()
+ elemsByPeer[peer] = elemsForPeer
+ }
+ elemsForPeer.elems = append(elemsForPeer.elems, elem)
+ elems[i] = device.NewOutboundElement()
+ bufs[i] = elems[i].buffer[:]
}
- if peer == nil {
- continue
+ for peer, elemsForPeer := range elemsByPeer {
+ if peer.isRunning.Load() {
+ peer.StagePackets(elemsForPeer)
+ peer.SendStagedPackets()
+ } else {
+ for _, elem := range elemsForPeer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
+ device.PutOutboundElementsContainer(elemsForPeer)
+ }
+ delete(elemsByPeer, peer)
}
- if peer.isRunning.Load() {
- peer.StagePacket(elem)
- elem = nil
- peer.SendStagedPackets()
+
+ if readErr != nil {
+ if errors.Is(readErr, tun.ErrTooManySegments) {
+ // TODO: record stat for this
+ // This will happen if MSS is surprisingly small (< 576)
+ // coincident with reasonably high throughput.
+ device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
+ continue
+ }
+ if !device.isClosed() {
+ if !errors.Is(readErr, os.ErrClosed) {
+ device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
+ }
+ go device.Close()
+ }
+ return
}
}
}
-func (peer *Peer) StagePacket(elem *QueueOutboundElement) {
+func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
for {
select {
- case peer.queue.staged <- elem:
+ case peer.queue.staged <- elems:
return
default:
}
select {
case tooOld := <-peer.queue.staged:
- peer.device.PutMessageBuffer(tooOld.buffer)
- peer.device.PutOutboundElement(tooOld)
+ for _, elem := range tooOld.elems {
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
+ }
+ peer.device.PutOutboundElementsContainer(tooOld)
default:
}
}
@@ -305,26 +352,53 @@ top:
}
for {
+ var elemsContainerOOO *QueueOutboundElementsContainer
select {
- case elem := <-peer.queue.staged:
- elem.peer = peer
- elem.nonce = keypair.sendNonce.Add(1) - 1
- if elem.nonce >= RejectAfterMessages {
- keypair.sendNonce.Store(RejectAfterMessages)
- peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans
- goto top
+ case elemsContainer := <-peer.queue.staged:
+ i := 0
+ for _, elem := range elemsContainer.elems {
+ elem.peer = peer
+ elem.nonce = keypair.sendNonce.Add(1) - 1
+ if elem.nonce >= RejectAfterMessages {
+ keypair.sendNonce.Store(RejectAfterMessages)
+ if elemsContainerOOO == nil {
+ elemsContainerOOO = peer.device.GetOutboundElementsContainer()
+ }
+ elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
+ continue
+ } else {
+ elemsContainer.elems[i] = elem
+ i++
+ }
+
+ elem.keypair = keypair
}
+ elemsContainer.Lock()
+ elemsContainer.elems = elemsContainer.elems[:i]
- elem.keypair = keypair
- elem.Lock()
+ if elemsContainerOOO != nil {
+ peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
+ }
+
+ if len(elemsContainer.elems) == 0 {
+ peer.device.PutOutboundElementsContainer(elemsContainer)
+ goto top
+ }
// add to parallel and sequential queue
if peer.isRunning.Load() {
- peer.queue.outbound.c <- elem
- peer.device.queue.encryption.c <- elem
+ peer.queue.outbound.c <- elemsContainer
+ peer.device.queue.encryption.c <- elemsContainer
} else {
- peer.device.PutMessageBuffer(elem.buffer)
- peer.device.PutOutboundElement(elem)
+ for _, elem := range elemsContainer.elems {
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
+ }
+ peer.device.PutOutboundElementsContainer(elemsContainer)
+ }
+
+ if elemsContainerOOO != nil {
+ goto top
}
default:
return
@@ -335,9 +409,12 @@ top:
func (peer *Peer) FlushStagedPackets() {
for {
select {
- case elem := <-peer.queue.staged:
- peer.device.PutMessageBuffer(elem.buffer)
- peer.device.PutOutboundElement(elem)
+ case elemsContainer := <-peer.queue.staged:
+ for _, elem := range elemsContainer.elems {
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
+ }
+ peer.device.PutOutboundElementsContainer(elemsContainer)
default:
return
}
@@ -371,41 +448,38 @@ func (device *Device) RoutineEncryption(id int) {
defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
device.log.Verbosef("Routine: encryption worker %d - started", id)
- for elem := range device.queue.encryption.c {
- // populate header fields
- header := elem.buffer[:MessageTransportHeaderSize]
+ for elemsContainer := range device.queue.encryption.c {
+ for _, elem := range elemsContainer.elems {
+ // populate header fields
+ header := elem.buffer[:MessageTransportHeaderSize]
- fieldType := header[0:4]
- fieldReceiver := header[4:8]
- fieldNonce := header[8:16]
+ fieldType := header[0:4]
+ fieldReceiver := header[4:8]
+ fieldNonce := header[8:16]
- binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
- binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
- binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
+ binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
+ binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
+ binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
- // pad content to multiple of 16
- paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
- elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
+ // pad content to multiple of 16
+ paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
+ elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
- // encrypt content and release to consumer
+ // encrypt content and release to consumer
- binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
- elem.packet = elem.keypair.send.Seal(
- header,
- nonce[:],
- elem.packet,
- nil,
- )
- elem.Unlock()
+ binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
+ elem.packet = elem.keypair.send.Seal(
+ header,
+ nonce[:],
+ elem.packet,
+ nil,
+ )
+ }
+ elemsContainer.Unlock()
}
}
-/* Sequentially reads packets from queue and sends to endpoint
- *
- * Obs. Single instance per peer.
- * The routine terminates then the outbound queue is closed.
- */
-func (peer *Peer) RoutineSequentialSender() {
+func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
device := peer.device
defer func() {
defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
@@ -413,36 +487,57 @@ func (peer *Peer) RoutineSequentialSender() {
}()
device.log.Verbosef("%v - Routine: sequential sender - started", peer)
- for elem := range peer.queue.outbound.c {
- if elem == nil {
+ bufs := make([][]byte, 0, maxBatchSize)
+
+ for elemsContainer := range peer.queue.outbound.c {
+ bufs = bufs[:0]
+ if elemsContainer == nil {
return
}
- elem.Lock()
if !peer.isRunning.Load() {
// peer has been stopped; return re-usable elems to the shared pool.
// This is an optimization only. It is possible for the peer to be stopped
// immediately after this check, in which case, elem will get processed.
- // The timers and SendBuffer code are resilient to a few stragglers.
+ // The timers and SendBuffers code are resilient to a few stragglers.
// TODO: rework peer shutdown order to ensure
// that we never accidentally keep timers alive longer than necessary.
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
continue
}
+ dataSent := false
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
+ if len(elem.packet) != MessageKeepaliveSize {
+ dataSent = true
+ }
+ bufs = append(bufs, elem.packet)
+ }
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
- // send message and return buffer to pool
-
- err := peer.SendBuffer(elem.packet)
- if len(elem.packet) != MessageKeepaliveSize {
+ err := peer.SendBuffers(bufs)
+ if dataSent {
peer.timersDataSent()
}
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
+ device.PutOutboundElementsContainer(elemsContainer)
+ if err != nil {
+ var errGSO conn.ErrUDPGSODisabled
+ if errors.As(err, &errGSO) {
+ device.log.Verbosef(err.Error())
+ err = errGSO.RetryErr
+ }
+ }
if err != nil {
- device.log.Errorf("%v - Failed to send data packet: %v", peer, err)
+ device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
continue
}
diff --git a/device/sticky_linux.go b/device/sticky_linux.go
index 97c14b5..6057ff1 100644
--- a/device/sticky_linux.go
+++ b/device/sticky_linux.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*
* This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port
@@ -25,7 +25,10 @@ import (
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
- if _, ok := bind.(*conn.LinuxSocketBind); !ok {
+ if !conn.StdNetSupportsStickySockets {
+ return nil, nil
+ }
+ if _, ok := bind.(*conn.StdNetBind); !ok {
return nil, nil
}
@@ -107,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
if !ok {
break
}
- pePtr.peer.Lock()
- if &pePtr.peer.endpoint != pePtr.endpoint {
- pePtr.peer.Unlock()
+ pePtr.peer.endpoint.Lock()
+ if &pePtr.peer.endpoint.val != pePtr.endpoint {
+ pePtr.peer.endpoint.Unlock()
break
}
- if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx {
- pePtr.peer.Unlock()
+ if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
+ pePtr.peer.endpoint.Unlock()
break
}
- pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc()
- pePtr.peer.Unlock()
+ pePtr.peer.endpoint.clearSrcOnTx = true
+ pePtr.peer.endpoint.Unlock()
}
attr = attr[attrhdr.Len:]
}
@@ -131,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
device.peers.RLock()
i := uint32(1)
for _, peer := range device.peers.keyMap {
- peer.RLock()
- if peer.endpoint == nil {
- peer.RUnlock()
+ peer.endpoint.Lock()
+ if peer.endpoint.val == nil {
+ peer.endpoint.Unlock()
continue
}
- nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint)
+ nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
if nativeEP == nil {
- peer.RUnlock()
+ peer.endpoint.Unlock()
continue
}
- if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
- peer.RUnlock()
+ if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
+ peer.endpoint.Unlock()
break
}
nlmsg := struct {
@@ -169,12 +172,12 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
Len: 8,
Type: unix.RTA_DST,
},
- nativeEP.Dst4().Addr,
+ nativeEP.DstIP().As4(),
unix.RtAttr{
Len: 8,
Type: unix.RTA_SRC,
},
- nativeEP.Src4().Src,
+ nativeEP.SrcIP().As4(),
unix.RtAttr{
Len: 8,
Type: unix.RTA_MARK,
@@ -185,10 +188,10 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
reqPeerLock.Lock()
reqPeer[i] = peerEndpointPtr{
peer: peer,
- endpoint: &peer.endpoint,
+ endpoint: &peer.endpoint.val,
}
reqPeerLock.Unlock()
- peer.RUnlock()
+ peer.endpoint.Unlock()
i++
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
if err != nil {
diff --git a/device/timers.go b/device/timers.go
index c8ef887..d4a4ed4 100644
--- a/device/timers.go
+++ b/device/timers.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*
* This is based heavily on timers.c from the kernel implementation.
*/
@@ -100,11 +100,7 @@ func expiredRetransmitHandshake(peer *Peer) {
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
/* We clear the endpoint address src address, in case this is the cause of trouble. */
- peer.Lock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
- peer.Unlock()
+ peer.markEndpointSrcForClearing()
peer.SendHandshakeInitiation(true)
}
@@ -123,11 +119,7 @@ func expiredSendKeepalive(peer *Peer) {
func expiredNewHandshake(peer *Peer) {
peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
/* We clear the endpoint address src address, in case this is the cause of trouble. */
- peer.Lock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
- peer.Unlock()
+ peer.markEndpointSrcForClearing()
peer.SendHandshakeInitiation(false)
}
diff --git a/device/tun.go b/device/tun.go
index d94bde1..2a2ace9 100644
--- a/device/tun.go
+++ b/device/tun.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/uapi.go b/device/uapi.go
index 550a032..d81dae3 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -99,33 +99,31 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
for _, peer := range device.peers.keyMap {
// Serialize peer state.
- // Do the work in an anonymous function so that we can use defer.
- func() {
- peer.RLock()
- defer peer.RUnlock()
-
- keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
- keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
- sendf("protocol_version=1")
- if peer.endpoint != nil {
- sendf("endpoint=%s", peer.endpoint.DstToString())
- }
-
- nano := peer.lastHandshakeNano.Load()
- secs := nano / time.Second.Nanoseconds()
- nano %= time.Second.Nanoseconds()
-
- sendf("last_handshake_time_sec=%d", secs)
- sendf("last_handshake_time_nsec=%d", nano)
- sendf("tx_bytes=%d", peer.txBytes.Load())
- sendf("rx_bytes=%d", peer.rxBytes.Load())
- sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
-
- device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
- sendf("allowed_ip=%s", prefix.String())
- return true
- })
- }()
+ peer.handshake.mutex.RLock()
+ keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
+ keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
+ peer.handshake.mutex.RUnlock()
+ sendf("protocol_version=1")
+ peer.endpoint.Lock()
+ if peer.endpoint.val != nil {
+ sendf("endpoint=%s", peer.endpoint.val.DstToString())
+ }
+ peer.endpoint.Unlock()
+
+ nano := peer.lastHandshakeNano.Load()
+ secs := nano / time.Second.Nanoseconds()
+ nano %= time.Second.Nanoseconds()
+
+ sendf("last_handshake_time_sec=%d", secs)
+ sendf("last_handshake_time_nsec=%d", nano)
+ sendf("tx_bytes=%d", peer.txBytes.Load())
+ sendf("rx_bytes=%d", peer.rxBytes.Load())
+ sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
+
+ device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
+ sendf("allowed_ip=%s", prefix.String())
+ return true
+ })
}
}()
@@ -262,7 +260,7 @@ func (peer *ipcSetPeer) handlePostConfig() {
return
}
if peer.created {
- peer.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint != nil
+ peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
}
if peer.device.isUp() {
peer.Start()
@@ -345,9 +343,9 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
}
- peer.Lock()
- defer peer.Unlock()
- peer.endpoint = endpoint
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ peer.endpoint.val = endpoint
case "persistent_keepalive_interval":
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
diff --git a/format_test.go b/format_test.go
index 198cd49..6f6cab7 100644
--- a/format_test.go
+++ b/format_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main
diff --git a/go.mod b/go.mod
index c180d1b..919dc49 100644
--- a/go.mod
+++ b/go.mod
@@ -1,16 +1,16 @@
module golang.zx2c4.com/wireguard
-go 1.19
+go 1.20
require (
- golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd
- golang.org/x/net v0.0.0-20220225172249-27dd8689420f
- golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86
- golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224
- gvisor.dev/gvisor v0.0.0-20220817001344-846276b3dbc5
+ golang.org/x/crypto v0.13.0
+ golang.org/x/net v0.15.0
+ golang.org/x/sys v0.12.0
+ golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
+ gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259
)
require (
github.com/google/btree v1.0.1 // indirect
- golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect
+ golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
)
diff --git a/go.sum b/go.sum
index 5f417a7..6bcecea 100644
--- a/go.sum
+++ b/go.sum
@@ -1,14 +1,14 @@
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
-golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd h1:XcWmESyNjXJMLahc3mqVQJcgSTDxFxhETVlfk9uGc38=
-golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
-golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc=
-golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
-golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86 h1:A9i04dxx7Cribqbs8jf3FQLogkL/CV2YN7hj9KWJCkc=
-golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
-golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
-golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
-golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
-gvisor.dev/gvisor v0.0.0-20220817001344-846276b3dbc5 h1:cv/zaNV0nr1mJzaeo4S5mHIm5va1W0/9J3/5prlsuRM=
-gvisor.dev/gvisor v0.0.0-20220817001344-846276b3dbc5/go.mod h1:TIvkJD0sxe8pIob3p6T8IzxXunlp6yfgktvTNp+DGNM=
+golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck=
+golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
+golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8=
+golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
+golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
+golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44=
+golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
+golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
+gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
+gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
diff --git a/ipc/namedpipe/file.go b/ipc/namedpipe/file.go
index ec9b8d4..ab1e13d 100644
--- a/ipc/namedpipe/file.go
+++ b/ipc/namedpipe/file.go
@@ -4,7 +4,6 @@
// license that can be found in the LICENSE file.
//go:build windows
-// +build windows
package namedpipe
diff --git a/ipc/namedpipe/namedpipe.go b/ipc/namedpipe/namedpipe.go
index 92cc1ee..ef3dea1 100644
--- a/ipc/namedpipe/namedpipe.go
+++ b/ipc/namedpipe/namedpipe.go
@@ -4,7 +4,6 @@
// license that can be found in the LICENSE file.
//go:build windows
-// +build windows
// Package namedpipe implements a net.Conn and net.Listener around Windows named pipes.
package namedpipe
diff --git a/ipc/namedpipe/namedpipe_test.go b/ipc/namedpipe/namedpipe_test.go
index 84d2987..998453b 100644
--- a/ipc/namedpipe/namedpipe_test.go
+++ b/ipc/namedpipe/namedpipe_test.go
@@ -4,7 +4,6 @@
// license that can be found in the LICENSE file.
//go:build windows
-// +build windows
package namedpipe_test
diff --git a/ipc/uapi_bsd.go b/ipc/uapi_bsd.go
index 303adeb..ddcaf27 100644
--- a/ipc/uapi_bsd.go
+++ b/ipc/uapi_bsd.go
@@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc
diff --git a/ipc/uapi_linux.go b/ipc/uapi_linux.go
index 30a8eb4..1562a18 100644
--- a/ipc/uapi_linux.go
+++ b/ipc/uapi_linux.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc
@@ -96,7 +96,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
}
go func(l *UAPIListener) {
- var buff [0]byte
+ var buf [0]byte
for {
defer uapi.inotifyRWCancel.Close()
// start with lstat to avoid race condition
@@ -104,7 +104,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
l.connErr <- err
return
}
- _, err := uapi.inotifyRWCancel.Read(buff[:])
+ _, err := uapi.inotifyRWCancel.Read(buf[:])
if err != nil {
l.connErr <- err
return
diff --git a/ipc/uapi_unix.go b/ipc/uapi_unix.go
index 6f1ee47..e67be26 100644
--- a/ipc/uapi_unix.go
+++ b/ipc/uapi_unix.go
@@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc
diff --git a/ipc/uapi_js.go b/ipc/uapi_wasm.go
index be36b5d..fa84684 100644
--- a/ipc/uapi_js.go
+++ b/ipc/uapi_wasm.go
@@ -1,11 +1,11 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc
-// Made up sentinel error codes for the js/wasm platform.
+// Made up sentinel error codes for {js,wasip1}/wasm.
const (
IpcErrorIO = 1
IpcErrorInvalid = 2
diff --git a/ipc/uapi_windows.go b/ipc/uapi_windows.go
index a1bfbd1..aa023c9 100644
--- a/ipc/uapi_windows.go
+++ b/ipc/uapi_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc
diff --git a/main.go b/main.go
index b35ac29..e016116 100644
--- a/main.go
+++ b/main.go
@@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main
@@ -13,8 +13,8 @@ import (
"os/signal"
"runtime"
"strconv"
- "syscall"
+ "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
@@ -111,7 +111,7 @@ func main() {
// open TUN device (or use supplied fd)
- tun, err := func() (tun.Device, error) {
+ tdev, err := func() (tun.Device, error) {
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
if tunFdStr == "" {
return tun.CreateTUN(interfaceName, device.DefaultMTU)
@@ -124,7 +124,7 @@ func main() {
return nil, err
}
- err = syscall.SetNonblock(int(fd), true)
+ err = unix.SetNonblock(int(fd), true)
if err != nil {
return nil, err
}
@@ -134,7 +134,7 @@ func main() {
}()
if err == nil {
- realInterfaceName, err2 := tun.Name()
+ realInterfaceName, err2 := tdev.Name()
if err2 == nil {
interfaceName = realInterfaceName
}
@@ -196,7 +196,7 @@ func main() {
files[0], // stdin
files[1], // stdout
files[2], // stderr
- tun.File(),
+ tdev.File(),
fileUAPI,
},
Dir: ".",
@@ -222,7 +222,7 @@ func main() {
return
}
- device := device.NewDevice(tun, conn.NewDefaultBind(), logger)
+ device := device.NewDevice(tdev, conn.NewDefaultBind(), logger)
logger.Verbosef("Device started")
@@ -250,7 +250,7 @@ func main() {
// wait for program to terminate
- signal.Notify(term, syscall.SIGTERM)
+ signal.Notify(term, unix.SIGTERM)
signal.Notify(term, os.Interrupt)
select {
diff --git a/main_windows.go b/main_windows.go
index 5a7b136..a4dc46f 100644
--- a/main_windows.go
+++ b/main_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main
@@ -9,7 +9,8 @@ import (
"fmt"
"os"
"os/signal"
- "syscall"
+
+ "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
@@ -81,7 +82,7 @@ func main() {
signal.Notify(term, os.Interrupt)
signal.Notify(term, os.Kill)
- signal.Notify(term, syscall.SIGTERM)
+ signal.Notify(term, windows.SIGTERM)
select {
case <-term:
diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go
index 1e3c252..f7d05ef 100644
--- a/ratelimiter/ratelimiter.go
+++ b/ratelimiter/ratelimiter.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ratelimiter
diff --git a/ratelimiter/ratelimiter_test.go b/ratelimiter/ratelimiter_test.go
index ca7db72..0bfa3af 100644
--- a/ratelimiter/ratelimiter_test.go
+++ b/ratelimiter/ratelimiter_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ratelimiter
diff --git a/replay/replay.go b/replay/replay.go
index f140272..8b99e23 100644
--- a/replay/replay.go
+++ b/replay/replay.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
diff --git a/replay/replay_test.go b/replay/replay_test.go
index 28c3a0e..9a9e4a8 100644
--- a/replay/replay_test.go
+++ b/replay/replay_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package replay
diff --git a/rwcancel/rwcancel.go b/rwcancel/rwcancel.go
index c3b2299..e397c0e 100644
--- a/rwcancel/rwcancel.go
+++ b/rwcancel/rwcancel.go
@@ -1,8 +1,8 @@
-//go:build !windows && !js
+//go:build !windows && !wasm
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
// Package rwcancel implements cancelable read/write operations on
diff --git a/rwcancel/rwcancel_stub.go b/rwcancel/rwcancel_stub.go
index 182940b..2a98b2b 100644
--- a/rwcancel/rwcancel_stub.go
+++ b/rwcancel/rwcancel_stub.go
@@ -1,4 +1,4 @@
-//go:build windows || js
+//go:build windows || wasm
// SPDX-License-Identifier: MIT
diff --git a/tai64n/tai64n.go b/tai64n/tai64n.go
index de0f5bf..8f10b39 100644
--- a/tai64n/tai64n.go
+++ b/tai64n/tai64n.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tai64n
diff --git a/tai64n/tai64n_test.go b/tai64n/tai64n_test.go
index 97eb26f..c70fc1a 100644
--- a/tai64n/tai64n_test.go
+++ b/tai64n/tai64n_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tai64n
diff --git a/tun/alignment_windows_test.go b/tun/alignment_windows_test.go
index b765cca..67a785e 100644
--- a/tun/alignment_windows_test.go
+++ b/tun/alignment_windows_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
diff --git a/tun/checksum.go b/tun/checksum.go
new file mode 100644
index 0000000..29a8fc8
--- /dev/null
+++ b/tun/checksum.go
@@ -0,0 +1,118 @@
+package tun
+
+import "encoding/binary"
+
+// TODO: Explore SIMD and/or other assembly optimizations.
+// TODO: Test native endian loads. See RFC 1071 section 2 part B.
+func checksumNoFold(b []byte, initial uint64) uint64 {
+ ac := initial
+
+ for len(b) >= 128 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+ ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+ ac += uint64(binary.BigEndian.Uint32(b[16:20]))
+ ac += uint64(binary.BigEndian.Uint32(b[20:24]))
+ ac += uint64(binary.BigEndian.Uint32(b[24:28]))
+ ac += uint64(binary.BigEndian.Uint32(b[28:32]))
+ ac += uint64(binary.BigEndian.Uint32(b[32:36]))
+ ac += uint64(binary.BigEndian.Uint32(b[36:40]))
+ ac += uint64(binary.BigEndian.Uint32(b[40:44]))
+ ac += uint64(binary.BigEndian.Uint32(b[44:48]))
+ ac += uint64(binary.BigEndian.Uint32(b[48:52]))
+ ac += uint64(binary.BigEndian.Uint32(b[52:56]))
+ ac += uint64(binary.BigEndian.Uint32(b[56:60]))
+ ac += uint64(binary.BigEndian.Uint32(b[60:64]))
+ ac += uint64(binary.BigEndian.Uint32(b[64:68]))
+ ac += uint64(binary.BigEndian.Uint32(b[68:72]))
+ ac += uint64(binary.BigEndian.Uint32(b[72:76]))
+ ac += uint64(binary.BigEndian.Uint32(b[76:80]))
+ ac += uint64(binary.BigEndian.Uint32(b[80:84]))
+ ac += uint64(binary.BigEndian.Uint32(b[84:88]))
+ ac += uint64(binary.BigEndian.Uint32(b[88:92]))
+ ac += uint64(binary.BigEndian.Uint32(b[92:96]))
+ ac += uint64(binary.BigEndian.Uint32(b[96:100]))
+ ac += uint64(binary.BigEndian.Uint32(b[100:104]))
+ ac += uint64(binary.BigEndian.Uint32(b[104:108]))
+ ac += uint64(binary.BigEndian.Uint32(b[108:112]))
+ ac += uint64(binary.BigEndian.Uint32(b[112:116]))
+ ac += uint64(binary.BigEndian.Uint32(b[116:120]))
+ ac += uint64(binary.BigEndian.Uint32(b[120:124]))
+ ac += uint64(binary.BigEndian.Uint32(b[124:128]))
+ b = b[128:]
+ }
+ if len(b) >= 64 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+ ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+ ac += uint64(binary.BigEndian.Uint32(b[16:20]))
+ ac += uint64(binary.BigEndian.Uint32(b[20:24]))
+ ac += uint64(binary.BigEndian.Uint32(b[24:28]))
+ ac += uint64(binary.BigEndian.Uint32(b[28:32]))
+ ac += uint64(binary.BigEndian.Uint32(b[32:36]))
+ ac += uint64(binary.BigEndian.Uint32(b[36:40]))
+ ac += uint64(binary.BigEndian.Uint32(b[40:44]))
+ ac += uint64(binary.BigEndian.Uint32(b[44:48]))
+ ac += uint64(binary.BigEndian.Uint32(b[48:52]))
+ ac += uint64(binary.BigEndian.Uint32(b[52:56]))
+ ac += uint64(binary.BigEndian.Uint32(b[56:60]))
+ ac += uint64(binary.BigEndian.Uint32(b[60:64]))
+ b = b[64:]
+ }
+ if len(b) >= 32 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+ ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+ ac += uint64(binary.BigEndian.Uint32(b[16:20]))
+ ac += uint64(binary.BigEndian.Uint32(b[20:24]))
+ ac += uint64(binary.BigEndian.Uint32(b[24:28]))
+ ac += uint64(binary.BigEndian.Uint32(b[28:32]))
+ b = b[32:]
+ }
+ if len(b) >= 16 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+ ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+ b = b[16:]
+ }
+ if len(b) >= 8 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ b = b[8:]
+ }
+ if len(b) >= 4 {
+ ac += uint64(binary.BigEndian.Uint32(b))
+ b = b[4:]
+ }
+ if len(b) >= 2 {
+ ac += uint64(binary.BigEndian.Uint16(b))
+ b = b[2:]
+ }
+ if len(b) == 1 {
+ ac += uint64(b[0]) << 8
+ }
+
+ return ac
+}
+
+func checksum(b []byte, initial uint64) uint16 {
+ ac := checksumNoFold(b, initial)
+ ac = (ac >> 16) + (ac & 0xffff)
+ ac = (ac >> 16) + (ac & 0xffff)
+ ac = (ac >> 16) + (ac & 0xffff)
+ ac = (ac >> 16) + (ac & 0xffff)
+ return uint16(ac)
+}
+
+func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
+ sum := checksumNoFold(srcAddr, 0)
+ sum = checksumNoFold(dstAddr, sum)
+ sum = checksumNoFold([]byte{0, protocol}, sum)
+ tmp := make([]byte, 2)
+ binary.BigEndian.PutUint16(tmp, totalLen)
+ return checksumNoFold(tmp, sum)
+}
diff --git a/tun/checksum_test.go b/tun/checksum_test.go
new file mode 100644
index 0000000..c1ccff5
--- /dev/null
+++ b/tun/checksum_test.go
@@ -0,0 +1,35 @@
+package tun
+
+import (
+ "fmt"
+ "math/rand"
+ "testing"
+)
+
+func BenchmarkChecksum(b *testing.B) {
+ lengths := []int{
+ 64,
+ 128,
+ 256,
+ 512,
+ 1024,
+ 1500,
+ 2048,
+ 4096,
+ 8192,
+ 9000,
+ 9001,
+ }
+
+ for _, length := range lengths {
+ b.Run(fmt.Sprintf("%d", length), func(b *testing.B) {
+ buf := make([]byte, length)
+ rng := rand.New(rand.NewSource(1))
+ rng.Read(buf)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ checksum(buf, 0)
+ }
+ })
+ }
+}
diff --git a/tun/errors.go b/tun/errors.go
new file mode 100644
index 0000000..75ae3a4
--- /dev/null
+++ b/tun/errors.go
@@ -0,0 +1,12 @@
+package tun
+
+import (
+ "errors"
+)
+
+var (
+ // ErrTooManySegments is returned by Device.Read() when segmentation
+ // overflows the length of supplied buffers. This error should not cause
+ // reads to cease.
+ ErrTooManySegments = errors.New("too many segments")
+)
diff --git a/tun/netstack/examples/http_client.go b/tun/netstack/examples/http_client.go
index 352c1e4..ccd32ed 100644
--- a/tun/netstack/examples/http_client.go
+++ b/tun/netstack/examples/http_client.go
@@ -1,9 +1,8 @@
//go:build ignore
-// +build ignore
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main
@@ -21,17 +20,17 @@ import (
func main() {
tun, tnet, err := netstack.CreateNetTUN(
- []netip.Addr{netip.MustParseAddr("192.168.4.29")},
+ []netip.Addr{netip.MustParseAddr("192.168.4.28")},
[]netip.Addr{netip.MustParseAddr("8.8.8.8")},
1420)
if err != nil {
log.Panic(err)
}
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
- dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f
-public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b
-endpoint=163.172.161.0:12912
+ err = dev.IpcSet(`private_key=087ec6e14bbed210e7215cdc73468dfa23f080a1bfb8665b2fd809bd99d28379
+public_key=c4c8e984c5322c8184c72265b92b250fdb63688705f504ba003c88f03393cf28
allowed_ip=0.0.0.0/0
+endpoint=127.0.0.1:58120
`)
err = dev.Up()
if err != nil {
@@ -43,7 +42,7 @@ allowed_ip=0.0.0.0/0
DialContext: tnet.DialContext,
},
}
- resp, err := client.Get("https://www.zx2c4.com/ip")
+ resp, err := client.Get("http://192.168.4.29/")
if err != nil {
log.Panic(err)
}
diff --git a/tun/netstack/examples/http_server.go b/tun/netstack/examples/http_server.go
index 0fdf4cd..f5b7a8f 100644
--- a/tun/netstack/examples/http_server.go
+++ b/tun/netstack/examples/http_server.go
@@ -1,9 +1,8 @@
//go:build ignore
-// +build ignore
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main
@@ -30,10 +29,10 @@ func main() {
log.Panic(err)
}
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
- dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f
-public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b
-endpoint=163.172.161.0:12912
-allowed_ip=0.0.0.0/0
+ dev.IpcSet(`private_key=003ed5d73b55806c30de3f8a7bdab38af13539220533055e635690b8b87ad641
+listen_port=58120
+public_key=f928d4f6c1b86c12f2562c10b07c555c5c57fd00f59e90c8d8d88767271cbf7c
+allowed_ip=192.168.4.28/32
persistent_keepalive_interval=25
`)
dev.Up()
diff --git a/tun/netstack/examples/ping_client.go b/tun/netstack/examples/ping_client.go
index a1bc7f8..2eef0fb 100644
--- a/tun/netstack/examples/ping_client.go
+++ b/tun/netstack/examples/ping_client.go
@@ -1,9 +1,8 @@
//go:build ignore
-// +build ignore
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main
diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go
index f232ca3..2b73054 100644
--- a/tun/netstack/tun.go
+++ b/tun/netstack/tun.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package netstack
@@ -19,12 +19,13 @@ import (
"regexp"
"strconv"
"strings"
+ "syscall"
"time"
"golang.zx2c4.com/wireguard/tun"
"golang.org/x/net/dns/dnsmessage"
- "gvisor.dev/gvisor/pkg/bufferv2"
+ "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -42,7 +43,7 @@ type netTun struct {
ep *channel.Endpoint
stack *stack.Stack
events chan tun.Event
- incomingPacket chan *bufferv2.View
+ incomingPacket chan *buffer.View
mtu int
dnsServers []netip.Addr
hasV4, hasV6 bool
@@ -60,12 +61,17 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device,
ep: channel.New(1024, uint32(mtu), ""),
stack: stack.New(opts),
events: make(chan tun.Event, 10),
- incomingPacket: make(chan *bufferv2.View),
+ incomingPacket: make(chan *buffer.View),
dnsServers: dnsServers,
mtu: mtu,
}
+ sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
+ tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
+ if tcpipErr != nil {
+ return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
+ }
dev.ep.AddNotify(dev)
- tcpipErr := dev.stack.CreateNIC(1, dev.ep)
+ tcpipErr = dev.stack.CreateNIC(1, dev.ep)
if tcpipErr != nil {
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
}
@@ -78,7 +84,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device,
}
protoAddr := tcpip.ProtocolAddress{
Protocol: protoNumber,
- AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
+ AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
}
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
if tcpipErr != nil {
@@ -109,39 +115,47 @@ func (tun *netTun) File() *os.File {
return nil
}
-func (tun *netTun) Events() chan tun.Event {
+func (tun *netTun) Events() <-chan tun.Event {
return tun.events
}
-func (tun *netTun) Read(buf []byte, offset int) (int, error) {
+func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
view, ok := <-tun.incomingPacket
if !ok {
return 0, os.ErrClosed
}
- return view.Read(buf[offset:])
+ n, err := view.Read(buf[0][offset:])
+ if err != nil {
+ return 0, err
+ }
+ sizes[0] = n
+ return 1, nil
}
-func (tun *netTun) Write(buf []byte, offset int) (int, error) {
- packet := buf[offset:]
- if len(packet) == 0 {
- return 0, nil
- }
+func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
+ for _, buf := range buf {
+ packet := buf[offset:]
+ if len(packet) == 0 {
+ continue
+ }
- pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)})
- switch packet[0] >> 4 {
- case 4:
- tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
- case 6:
- tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
+ pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
+ switch packet[0] >> 4 {
+ case 4:
+ tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
+ case 6:
+ tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
+ default:
+ return 0, syscall.EAFNOSUPPORT
+ }
}
-
return len(buf), nil
}
func (tun *netTun) WriteNotify() {
pkt := tun.ep.Read()
- if pkt == nil {
+ if pkt.IsNil() {
return
}
@@ -151,10 +165,6 @@ func (tun *netTun) WriteNotify() {
tun.incomingPacket <- view
}
-func (tun *netTun) Flush() error {
- return nil
-}
-
func (tun *netTun) Close() error {
tun.stack.RemoveNIC(1)
@@ -175,6 +185,10 @@ func (tun *netTun) MTU() (int, error) {
return tun.mtu, nil
}
+func (tun *netTun) BatchSize() int {
+ return 1
+}
+
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
var protoNumber tcpip.NetworkProtocolNumber
if endpoint.Addr().Is4() {
@@ -184,7 +198,7 @@ func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.Networ
}
return tcpip.FullAddress{
NIC: 1,
- Addr: tcpip.Address(endpoint.Addr().AsSlice()),
+ Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
Port: endpoint.Port(),
}, protoNumber
}
@@ -439,7 +453,7 @@ func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
}
- remoteAddr, _ := netip.AddrFromSlice([]byte(res.RemoteAddr.Addr))
+ remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice())
return res.Count, &PingAddr{remoteAddr}, nil
}
@@ -898,7 +912,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
}
}
- // We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
+ // We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled
var addrs []netip.Addr
if tnet.hasV6 {
addrs = append(addrsV6, addrsV4...)
diff --git a/tun/offload_linux.go b/tun/offload_linux.go
new file mode 100644
index 0000000..9ff7fea
--- /dev/null
+++ b/tun/offload_linux.go
@@ -0,0 +1,993 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "io"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/conn"
+)
+
+const tcpFlagsOffset = 13
+
+const (
+ tcpFlagFIN uint8 = 0x01
+ tcpFlagPSH uint8 = 0x08
+ tcpFlagACK uint8 = 0x10
+)
+
+// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
+// kernel symbol is virtio_net_hdr.
+type virtioNetHdr struct {
+ flags uint8
+ gsoType uint8
+ hdrLen uint16
+ gsoSize uint16
+ csumStart uint16
+ csumOffset uint16
+}
+
+func (v *virtioNetHdr) decode(b []byte) error {
+ if len(b) < virtioNetHdrLen {
+ return io.ErrShortBuffer
+ }
+ copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
+ return nil
+}
+
+func (v *virtioNetHdr) encode(b []byte) error {
+ if len(b) < virtioNetHdrLen {
+ return io.ErrShortBuffer
+ }
+ copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
+ return nil
+}
+
+const (
+ // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
+ // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
+ virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
+)
+
+// tcpFlowKey represents the key for a TCP flow.
+type tcpFlowKey struct {
+ srcAddr, dstAddr [16]byte
+ srcPort, dstPort uint16
+ rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
+ isV6 bool
+}
+
+// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO.
+type tcpGROTable struct {
+ itemsByFlow map[tcpFlowKey][]tcpGROItem
+ itemsPool [][]tcpGROItem
+}
+
+func newTCPGROTable() *tcpGROTable {
+ t := &tcpGROTable{
+ itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, conn.IdealBatchSize),
+ itemsPool: make([][]tcpGROItem, conn.IdealBatchSize),
+ }
+ for i := range t.itemsPool {
+ t.itemsPool[i] = make([]tcpGROItem, 0, conn.IdealBatchSize)
+ }
+ return t
+}
+
+func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey {
+ key := tcpFlowKey{}
+ addrSize := dstAddrOffset - srcAddrOffset
+ copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
+ copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
+ key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
+ key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
+ key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
+ key.isV6 = addrSize == 16
+ return key
+}
+
+// lookupOrInsert looks up a flow for the provided packet and metadata,
+// returning the packets found for the flow, or inserting a new one if none
+// is found.
+func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
+ key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
+ items, ok := t.itemsByFlow[key]
+ if ok {
+ return items, ok
+ }
+ // TODO: insert() performs another map lookup. This could be rearranged to avoid.
+ t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
+ return nil, false
+}
+
+// insert an item in the table for the provided packet and packet metadata.
+func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
+ key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
+ item := tcpGROItem{
+ key: key,
+ bufsIndex: uint16(bufsIndex),
+ gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
+ iphLen: uint8(tcphOffset),
+ tcphLen: uint8(tcphLen),
+ sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
+ pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
+ }
+ items, ok := t.itemsByFlow[key]
+ if !ok {
+ items = t.newItems()
+ }
+ items = append(items, item)
+ t.itemsByFlow[key] = items
+}
+
+func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
+ items, _ := t.itemsByFlow[item.key]
+ items[i] = item
+}
+
+func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) {
+ items, _ := t.itemsByFlow[key]
+ items = append(items[:i], items[i+1:]...)
+ t.itemsByFlow[key] = items
+}
+
+// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
+// of a GRO evaluation across a vector of packets.
+type tcpGROItem struct {
+ key tcpFlowKey
+ sentSeq uint32 // the sequence number
+ bufsIndex uint16 // the index into the original bufs slice
+ numMerged uint16 // the number of packets merged into this item
+ gsoSize uint16 // payload size
+ iphLen uint8 // ip header len
+ tcphLen uint8 // tcp header len
+ pshSet bool // psh flag is set
+}
+
+func (t *tcpGROTable) newItems() []tcpGROItem {
+ var items []tcpGROItem
+ items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
+ return items
+}
+
+func (t *tcpGROTable) reset() {
+ for k, items := range t.itemsByFlow {
+ items = items[:0]
+ t.itemsPool = append(t.itemsPool, items)
+ delete(t.itemsByFlow, k)
+ }
+}
+
+// udpFlowKey represents the key for a UDP flow.
+type udpFlowKey struct {
+ srcAddr, dstAddr [16]byte
+ srcPort, dstPort uint16
+ isV6 bool
+}
+
+// udpGROTable holds flow and coalescing information for the purposes of UDP GRO.
+type udpGROTable struct {
+ itemsByFlow map[udpFlowKey][]udpGROItem
+ itemsPool [][]udpGROItem
+}
+
+func newUDPGROTable() *udpGROTable {
+ u := &udpGROTable{
+ itemsByFlow: make(map[udpFlowKey][]udpGROItem, conn.IdealBatchSize),
+ itemsPool: make([][]udpGROItem, conn.IdealBatchSize),
+ }
+ for i := range u.itemsPool {
+ u.itemsPool[i] = make([]udpGROItem, 0, conn.IdealBatchSize)
+ }
+ return u
+}
+
+func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey {
+ key := udpFlowKey{}
+ addrSize := dstAddrOffset - srcAddrOffset
+ copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
+ copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
+ key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:])
+ key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:])
+ key.isV6 = addrSize == 16
+ return key
+}
+
+// lookupOrInsert looks up a flow for the provided packet and metadata,
+// returning the packets found for the flow, or inserting a new one if none
+// is found.
+func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) {
+ key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
+ items, ok := u.itemsByFlow[key]
+ if ok {
+ return items, ok
+ }
+ // TODO: insert() performs another map lookup. This could be rearranged to avoid.
+ u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false)
+ return nil, false
+}
+
+// insert an item in the table for the provided packet and packet metadata.
+func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) {
+ key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
+ item := udpGROItem{
+ key: key,
+ bufsIndex: uint16(bufsIndex),
+ gsoSize: uint16(len(pkt[udphOffset+udphLen:])),
+ iphLen: uint8(udphOffset),
+ cSumKnownInvalid: cSumKnownInvalid,
+ }
+ items, ok := u.itemsByFlow[key]
+ if !ok {
+ items = u.newItems()
+ }
+ items = append(items, item)
+ u.itemsByFlow[key] = items
+}
+
+func (u *udpGROTable) updateAt(item udpGROItem, i int) {
+ items, _ := u.itemsByFlow[item.key]
+ items[i] = item
+}
+
+// udpGROItem represents bookkeeping data for a UDP packet during the lifetime
+// of a GRO evaluation across a vector of packets.
+type udpGROItem struct {
+ key udpFlowKey
+ bufsIndex uint16 // the index into the original bufs slice
+ numMerged uint16 // the number of packets merged into this item
+ gsoSize uint16 // payload size
+ iphLen uint8 // ip header len
+ cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown.
+}
+
+func (u *udpGROTable) newItems() []udpGROItem {
+ var items []udpGROItem
+ items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1]
+ return items
+}
+
+func (u *udpGROTable) reset() {
+ for k, items := range u.itemsByFlow {
+ items = items[:0]
+ u.itemsPool = append(u.itemsPool, items)
+ delete(u.itemsByFlow, k)
+ }
+}
+
+// canCoalesce represents the outcome of checking if two TCP packets are
+// candidates for coalescing.
+type canCoalesce int
+
+const (
+ coalescePrepend canCoalesce = -1
+ coalesceUnavailable canCoalesce = 0
+ coalesceAppend canCoalesce = 1
+)
+
+// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB
+// meet all requirements to be merged as part of a GRO operation, otherwise it
+// returns false.
+func ipHeadersCanCoalesce(pktA, pktB []byte) bool {
+ if len(pktA) < 9 || len(pktB) < 9 {
+ return false
+ }
+ if pktA[0]>>4 == 6 {
+ if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 {
+ // cannot coalesce with unequal Traffic class values
+ return false
+ }
+ if pktA[7] != pktB[7] {
+ // cannot coalesce with unequal Hop limit values
+ return false
+ }
+ } else {
+ if pktA[1] != pktB[1] {
+ // cannot coalesce with unequal ToS values
+ return false
+ }
+ if pktA[6]>>5 != pktB[6]>>5 {
+ // cannot coalesce with unequal DF or reserved bits. MF is checked
+ // further up the stack.
+ return false
+ }
+ if pktA[8] != pktB[8] {
+ // cannot coalesce with unequal TTL values
+ return false
+ }
+ }
+ return true
+}
+
+// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
+// described by item. iphLen and gsoSize describe pkt. bufs is the vector of
+// packets involved in the current GRO evaluation. bufsOffset is the offset at
+// which packet data begins within bufs.
+func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
+ pktTarget := bufs[item.bufsIndex][bufsOffset:]
+ if !ipHeadersCanCoalesce(pkt, pktTarget) {
+ return coalesceUnavailable
+ }
+ if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 {
+ // A smaller than gsoSize packet has been appended previously.
+ // Nothing can come after a smaller packet on the end.
+ return coalesceUnavailable
+ }
+ if gsoSize > item.gsoSize {
+ // We cannot have a larger packet following a smaller one.
+ return coalesceUnavailable
+ }
+ return coalesceAppend
+}
+
+// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
+// described by item. This function makes considerations that match the kernel's
+// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
+func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
+ pktTarget := bufs[item.bufsIndex][bufsOffset:]
+ if tcphLen != item.tcphLen {
+ // cannot coalesce with unequal tcp options len
+ return coalesceUnavailable
+ }
+ if tcphLen > 20 {
+ if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
+ // cannot coalesce with unequal tcp options
+ return coalesceUnavailable
+ }
+ }
+ if !ipHeadersCanCoalesce(pkt, pktTarget) {
+ return coalesceUnavailable
+ }
+ // seq adjacency
+ lhsLen := item.gsoSize
+ lhsLen += item.numMerged * item.gsoSize
+ if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
+ if item.pshSet {
+ // We cannot append to a segment that has the PSH flag set, PSH
+ // can only be set on the final segment in a reassembled group.
+ return coalesceUnavailable
+ }
+ if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
+ // A smaller than gsoSize packet has been appended previously.
+ // Nothing can come after a smaller packet on the end.
+ return coalesceUnavailable
+ }
+ if gsoSize > item.gsoSize {
+ // We cannot have a larger packet following a smaller one.
+ return coalesceUnavailable
+ }
+ return coalesceAppend
+ } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
+ if pshSet {
+ // We cannot prepend with a segment that has the PSH flag set, PSH
+ // can only be set on the final segment in a reassembled group.
+ return coalesceUnavailable
+ }
+ if gsoSize < item.gsoSize {
+ // We cannot have a larger packet following a smaller one.
+ return coalesceUnavailable
+ }
+ if gsoSize > item.gsoSize && item.numMerged > 0 {
+ // There's at least one previous merge, and we're larger than all
+ // previous. This would put multiple smaller packets on the end.
+ return coalesceUnavailable
+ }
+ return coalescePrepend
+ }
+ return coalesceUnavailable
+}
+
+func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool {
+ srcAddrAt := ipv4SrcAddrOffset
+ addrSize := 4
+ if isV6 {
+ srcAddrAt = ipv6SrcAddrOffset
+ addrSize = 16
+ }
+ lenForPseudo := uint16(len(pkt) - int(iphLen))
+ cSum := pseudoHeaderChecksumNoFold(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo)
+ return ^checksum(pkt[iphLen:], cSum) == 0
+}
+
+// coalesceResult represents the result of attempting to coalesce two TCP
+// packets.
+type coalesceResult int
+
+const (
+ coalesceInsufficientCap coalesceResult = iota
+ coalescePSHEnding
+ coalesceItemInvalidCSum
+ coalescePktInvalidCSum
+ coalesceSuccess
+)
+
+// coalesceUDPPackets attempts to coalesce pkt with the packet described by
+// item, and returns the outcome.
+func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
+ pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front
+ headersLen := item.iphLen + udphLen
+ coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
+
+ if cap(pktHead)-bufsOffset < coalescedLen {
+ // We don't want to allocate a new underlying array if capacity is
+ // too small.
+ return coalesceInsufficientCap
+ }
+ if item.numMerged == 0 {
+ if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) {
+ return coalesceItemInvalidCSum
+ }
+ }
+ if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) {
+ return coalescePktInvalidCSum
+ }
+ extendBy := len(pkt) - int(headersLen)
+ bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
+ copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
+
+ item.numMerged++
+ return coalesceSuccess
+}
+
+// coalesceTCPPackets attempts to coalesce pkt with the packet described by
+// item, and returns the outcome. This function may swap bufs elements in the
+// event of a prepend as item's bufs index is already being tracked for writing
+// to a Device.
+func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
+ var pktHead []byte // the packet that will end up at the front
+ headersLen := item.iphLen + item.tcphLen
+ coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
+
+ // Copy data
+ if mode == coalescePrepend {
+ pktHead = pkt
+ if cap(pkt)-bufsOffset < coalescedLen {
+ // We don't want to allocate a new underlying array if capacity is
+ // too small.
+ return coalesceInsufficientCap
+ }
+ if pshSet {
+ return coalescePSHEnding
+ }
+ if item.numMerged == 0 {
+ if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
+ return coalesceItemInvalidCSum
+ }
+ }
+ if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
+ return coalescePktInvalidCSum
+ }
+ item.sentSeq = seq
+ extendBy := coalescedLen - len(pktHead)
+ bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
+ copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
+ // Flip the slice headers in bufs as part of prepend. The index of item
+ // is already being tracked for writing.
+ bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
+ } else {
+ pktHead = bufs[item.bufsIndex][bufsOffset:]
+ if cap(pktHead)-bufsOffset < coalescedLen {
+ // We don't want to allocate a new underlying array if capacity is
+ // too small.
+ return coalesceInsufficientCap
+ }
+ if item.numMerged == 0 {
+ if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
+ return coalesceItemInvalidCSum
+ }
+ }
+ if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
+ return coalescePktInvalidCSum
+ }
+ if pshSet {
+ // We are appending a segment with PSH set.
+ item.pshSet = pshSet
+ pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
+ }
+ extendBy := len(pkt) - int(headersLen)
+ bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
+ copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
+ }
+
+ if gsoSize > item.gsoSize {
+ item.gsoSize = gsoSize
+ }
+
+ item.numMerged++
+ return coalesceSuccess
+}
+
+const (
+ ipv4FlagMoreFragments uint8 = 0x20
+)
+
+const (
+ ipv4SrcAddrOffset = 12
+ ipv6SrcAddrOffset = 8
+ maxUint16 = 1<<16 - 1
+)
+
+type groResult int
+
+const (
+ groResultNoop groResult = iota
+ groResultTableInsert
+ groResultCoalesced
+)
+
+// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
+// existing packets tracked in table. It returns a groResultNoop when no
+// action was taken, groResultTableInsert when the evaluated packet was
+// inserted into table, and groResultCoalesced when the evaluated packet was
+// coalesced with another packet in table.
+func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult {
+ pkt := bufs[pktI][offset:]
+ if len(pkt) > maxUint16 {
+ // A valid IPv4 or IPv6 packet will never exceed this.
+ return groResultNoop
+ }
+ iphLen := int((pkt[0] & 0x0F) * 4)
+ if isV6 {
+ iphLen = 40
+ ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
+ if ipv6HPayloadLen != len(pkt)-iphLen {
+ return groResultNoop
+ }
+ } else {
+ totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
+ if totalLen != len(pkt) {
+ return groResultNoop
+ }
+ }
+ if len(pkt) < iphLen {
+ return groResultNoop
+ }
+ tcphLen := int((pkt[iphLen+12] >> 4) * 4)
+ if tcphLen < 20 || tcphLen > 60 {
+ return groResultNoop
+ }
+ if len(pkt) < iphLen+tcphLen {
+ return groResultNoop
+ }
+ if !isV6 {
+ if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
+ // no GRO support for fragmented segments for now
+ return groResultNoop
+ }
+ }
+ tcpFlags := pkt[iphLen+tcpFlagsOffset]
+ var pshSet bool
+ // not a candidate if any non-ACK flags (except PSH+ACK) are set
+ if tcpFlags != tcpFlagACK {
+ if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
+ return groResultNoop
+ }
+ pshSet = true
+ }
+ gsoSize := uint16(len(pkt) - tcphLen - iphLen)
+ // not a candidate if payload len is 0
+ if gsoSize < 1 {
+ return groResultNoop
+ }
+ seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
+ srcAddrOffset := ipv4SrcAddrOffset
+ addrLen := 4
+ if isV6 {
+ srcAddrOffset = ipv6SrcAddrOffset
+ addrLen = 16
+ }
+ items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
+ if !existing {
+ return groResultTableInsert
+ }
+ for i := len(items) - 1; i >= 0; i-- {
+ // In the best case of packets arriving in order iterating in reverse is
+ // more efficient if there are multiple items for a given flow. This
+ // also enables a natural table.deleteAt() in the
+ // coalesceItemInvalidCSum case without the need for index tracking.
+ // This algorithm makes a best effort to coalesce in the event of
+ // unordered packets, where pkt may land anywhere in items from a
+ // sequence number perspective, however once an item is inserted into
+ // the table it is never compared across other items later.
+ item := items[i]
+ can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
+ if can != coalesceUnavailable {
+ result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
+ switch result {
+ case coalesceSuccess:
+ table.updateAt(item, i)
+ return groResultCoalesced
+ case coalesceItemInvalidCSum:
+ // delete the item with an invalid csum
+ table.deleteAt(item.key, i)
+ case coalescePktInvalidCSum:
+ // no point in inserting an item that we can't coalesce
+ return groResultNoop
+ default:
+ }
+ }
+ }
+ // failed to coalesce with any other packets; store the item in the flow
+ table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
+ return groResultTableInsert
+}
+
+// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the
+// metadata found in table.
+func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error {
+ for _, items := range table.itemsByFlow {
+ for _, item := range items {
+ if item.numMerged > 0 {
+ hdr := virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
+ hdrLen: uint16(item.iphLen + item.tcphLen),
+ gsoSize: item.gsoSize,
+ csumStart: uint16(item.iphLen),
+ csumOffset: 16,
+ }
+ pkt := bufs[item.bufsIndex][offset:]
+
+ // Recalculate the total len (IPv4) or payload len (IPv6).
+ // Recalculate the (IPv4) header checksum.
+ if item.key.isV6 {
+ hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
+ binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
+ } else {
+ hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
+ pkt[10], pkt[11] = 0, 0
+ binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
+ iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
+ binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
+ }
+ err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+
+ // Calculate the pseudo header checksum and place it at the TCP
+ // checksum offset. Downstream checksum offloading will combine
+ // this with computation of the tcp header and payload checksum.
+ addrLen := 4
+ addrOffset := ipv4SrcAddrOffset
+ if item.key.isV6 {
+ addrLen = 16
+ addrOffset = ipv6SrcAddrOffset
+ }
+ srcAddrAt := offset + addrOffset
+ srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
+ dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
+ psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
+ binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
+ } else {
+ hdr := virtioNetHdr{}
+ err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+ }
+ }
+ }
+ return nil
+}
+
+// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the
+// metadata found in table.
+func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error {
+ for _, items := range table.itemsByFlow {
+ for _, item := range items {
+ if item.numMerged > 0 {
+ hdr := virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
+ hdrLen: uint16(item.iphLen + udphLen),
+ gsoSize: item.gsoSize,
+ csumStart: uint16(item.iphLen),
+ csumOffset: 6,
+ }
+ pkt := bufs[item.bufsIndex][offset:]
+
+ // Recalculate the total len (IPv4) or payload len (IPv6).
+ // Recalculate the (IPv4) header checksum.
+ hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4
+ if item.key.isV6 {
+ binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
+ } else {
+ pkt[10], pkt[11] = 0, 0
+ binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
+ iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
+ binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
+ }
+ err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+
+ // Recalculate the UDP len field value
+ binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:])))
+
+ // Calculate the pseudo header checksum and place it at the UDP
+ // checksum offset. Downstream checksum offloading will combine
+ // this with computation of the udp header and payload checksum.
+ addrLen := 4
+ addrOffset := ipv4SrcAddrOffset
+ if item.key.isV6 {
+ addrLen = 16
+ addrOffset = ipv6SrcAddrOffset
+ }
+ srcAddrAt := offset + addrOffset
+ srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
+ dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
+ psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
+ binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
+ } else {
+ hdr := virtioNetHdr{}
+ err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+ }
+ }
+ }
+ return nil
+}
+
+type groCandidateType uint8
+
+const (
+ notGROCandidate groCandidateType = iota
+ tcp4GROCandidate
+ tcp6GROCandidate
+ udp4GROCandidate
+ udp6GROCandidate
+)
+
+func packetIsGROCandidate(b []byte, canUDPGRO bool) groCandidateType {
+ if len(b) < 28 {
+ return notGROCandidate
+ }
+ if b[0]>>4 == 4 {
+ if b[0]&0x0F != 5 {
+ // IPv4 packets w/IP options do not coalesce
+ return notGROCandidate
+ }
+ if b[9] == unix.IPPROTO_TCP && len(b) >= 40 {
+ return tcp4GROCandidate
+ }
+ if b[9] == unix.IPPROTO_UDP && canUDPGRO {
+ return udp4GROCandidate
+ }
+ } else if b[0]>>4 == 6 {
+ if b[6] == unix.IPPROTO_TCP && len(b) >= 60 {
+ return tcp6GROCandidate
+ }
+ if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && canUDPGRO {
+ return udp6GROCandidate
+ }
+ }
+ return notGROCandidate
+}
+
+const (
+ udphLen = 8
+)
+
+// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with
+// existing packets tracked in table. It returns a groResultNoop when no
+// action was taken, groResultTableInsert when the evaluated packet was
+// inserted into table, and groResultCoalesced when the evaluated packet was
+// coalesced with another packet in table.
+func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult {
+ pkt := bufs[pktI][offset:]
+ if len(pkt) > maxUint16 {
+ // A valid IPv4 or IPv6 packet will never exceed this.
+ return groResultNoop
+ }
+ iphLen := int((pkt[0] & 0x0F) * 4)
+ if isV6 {
+ iphLen = 40
+ ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
+ if ipv6HPayloadLen != len(pkt)-iphLen {
+ return groResultNoop
+ }
+ } else {
+ totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
+ if totalLen != len(pkt) {
+ return groResultNoop
+ }
+ }
+ if len(pkt) < iphLen {
+ return groResultNoop
+ }
+ if len(pkt) < iphLen+udphLen {
+ return groResultNoop
+ }
+ if !isV6 {
+ if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
+ // no GRO support for fragmented segments for now
+ return groResultNoop
+ }
+ }
+ gsoSize := uint16(len(pkt) - udphLen - iphLen)
+ // not a candidate if payload len is 0
+ if gsoSize < 1 {
+ return groResultNoop
+ }
+ srcAddrOffset := ipv4SrcAddrOffset
+ addrLen := 4
+ if isV6 {
+ srcAddrOffset = ipv6SrcAddrOffset
+ addrLen = 16
+ }
+ items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI)
+ if !existing {
+ return groResultTableInsert
+ }
+ // With UDP we only check the last item, otherwise we could reorder packets
+ // for a given flow. We must also always insert a new item, or successfully
+ // coalesce with an existing item, for the same reason.
+ item := items[len(items)-1]
+ can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset)
+ var pktCSumKnownInvalid bool
+ if can == coalesceAppend {
+ result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6)
+ switch result {
+ case coalesceSuccess:
+ table.updateAt(item, len(items)-1)
+ return groResultCoalesced
+ case coalesceItemInvalidCSum:
+ // If the existing item has an invalid csum we take no action. A new
+ // item will be stored after it, and the existing item will never be
+ // revisited as part of future coalescing candidacy checks.
+ case coalescePktInvalidCSum:
+ // We must insert a new item, but we also mark it as invalid csum
+ // to prevent a repeat checksum validation.
+ pktCSumKnownInvalid = true
+ default:
+ }
+ }
+ // failed to coalesce with any other packets; store the item in the flow
+ table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid)
+ return groResultTableInsert
+}
+
+// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
+// packets into toWrite. toWrite, tcpTable, and udpTable should initially be
+// empty (but non-nil), and are passed in to save allocs as the caller may reset
+// and recycle them across vectors of packets. canUDPGRO indicates if UDP GRO is
+// supported.
+func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, canUDPGRO bool, toWrite *[]int) error {
+ for i := range bufs {
+ if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
+ return errors.New("invalid offset")
+ }
+ var result groResult
+ switch packetIsGROCandidate(bufs[i][offset:], canUDPGRO) {
+ case tcp4GROCandidate:
+ result = tcpGRO(bufs, offset, i, tcpTable, false)
+ case tcp6GROCandidate:
+ result = tcpGRO(bufs, offset, i, tcpTable, true)
+ case udp4GROCandidate:
+ result = udpGRO(bufs, offset, i, udpTable, false)
+ case udp6GROCandidate:
+ result = udpGRO(bufs, offset, i, udpTable, true)
+ }
+ switch result {
+ case groResultNoop:
+ hdr := virtioNetHdr{}
+ err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+ fallthrough
+ case groResultTableInsert:
+ *toWrite = append(*toWrite, i)
+ }
+ }
+ errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable)
+ errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable)
+ return errors.Join(errTCP, errUDP)
+}
+
+// gsoSplit splits packets from in into outBuffs, writing the size of each
+// element into sizes. It returns the number of buffers populated, and/or an
+// error.
+func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int, isV6 bool) (int, error) {
+ iphLen := int(hdr.csumStart)
+ srcAddrOffset := ipv6SrcAddrOffset
+ addrLen := 16
+ if !isV6 {
+ in[10], in[11] = 0, 0 // clear ipv4 header checksum
+ srcAddrOffset = ipv4SrcAddrOffset
+ addrLen = 4
+ }
+ transportCsumAt := int(hdr.csumStart + hdr.csumOffset)
+ in[transportCsumAt], in[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum
+ var firstTCPSeqNum uint32
+ var protocol uint8
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV6 {
+ protocol = unix.IPPROTO_TCP
+ firstTCPSeqNum = binary.BigEndian.Uint32(in[hdr.csumStart+4:])
+ } else {
+ protocol = unix.IPPROTO_UDP
+ }
+ nextSegmentDataAt := int(hdr.hdrLen)
+ i := 0
+ for ; nextSegmentDataAt < len(in); i++ {
+ if i == len(outBuffs) {
+ return i - 1, ErrTooManySegments
+ }
+ nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
+ if nextSegmentEnd > len(in) {
+ nextSegmentEnd = len(in)
+ }
+ segmentDataLen := nextSegmentEnd - nextSegmentDataAt
+ totalLen := int(hdr.hdrLen) + segmentDataLen
+ sizes[i] = totalLen
+ out := outBuffs[i][outOffset:]
+
+ copy(out, in[:iphLen])
+ if !isV6 {
+ // For IPv4 we are responsible for incrementing the ID field,
+ // updating the total len field, and recalculating the header
+ // checksum.
+ if i > 0 {
+ id := binary.BigEndian.Uint16(out[4:])
+ id += uint16(i)
+ binary.BigEndian.PutUint16(out[4:], id)
+ }
+ binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
+ ipv4CSum := ^checksum(out[:iphLen], 0)
+ binary.BigEndian.PutUint16(out[10:], ipv4CSum)
+ } else {
+ // For IPv6 we are responsible for updating the payload length field.
+ binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
+ }
+
+ // copy transport header
+ copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
+
+ if protocol == unix.IPPROTO_TCP {
+ // set TCP seq and adjust TCP flags
+ tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
+ binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
+ if nextSegmentEnd != len(in) {
+ // FIN and PSH should only be set on last segment
+ clearFlags := tcpFlagFIN | tcpFlagPSH
+ out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
+ }
+ } else {
+ // set UDP header len
+ binary.BigEndian.PutUint16(out[hdr.csumStart+4:], uint16(segmentDataLen)+(hdr.hdrLen-hdr.csumStart))
+ }
+
+ // payload
+ copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
+
+ // transport checksum
+ transportHeaderLen := int(hdr.hdrLen - hdr.csumStart)
+ lenForPseudo := uint16(transportHeaderLen + segmentDataLen)
+ transportCSumNoFold := pseudoHeaderChecksumNoFold(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo)
+ transportCSum := ^checksum(out[hdr.csumStart:totalLen], transportCSumNoFold)
+ binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], transportCSum)
+
+ nextSegmentDataAt += int(hdr.gsoSize)
+ }
+ return i, nil
+}
+
+func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
+ cSumAt := cSumStart + cSumOffset
+ // The initial value at the checksum offset should be summed with the
+ // checksum we compute. This is typically the pseudo-header checksum.
+ initial := binary.BigEndian.Uint16(in[cSumAt:])
+ in[cSumAt], in[cSumAt+1] = 0, 0
+ binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial)))
+ return nil
+}
diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go
new file mode 100644
index 0000000..ae55c8c
--- /dev/null
+++ b/tun/offload_linux_test.go
@@ -0,0 +1,752 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+import (
+ "net/netip"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/conn"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+const (
+ offset = virtioNetHdrLen
+)
+
+var (
+ ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
+ ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
+ ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
+ ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
+ ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
+ ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
+)
+
+func udp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv4Fields)) []byte {
+ totalLen := 28 + payloadLen
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv4H := header.IPv4(b[offset:])
+ srcAs4 := srcIPPort.Addr().As4()
+ dstAs4 := dstIPPort.Addr().As4()
+ ipFields := &header.IPv4Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
+ Protocol: unix.IPPROTO_UDP,
+ TTL: 64,
+ TotalLength: uint16(totalLen),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv4H.Encode(ipFields)
+ udpH := header.UDP(b[offset+20:])
+ udpH.Encode(&header.UDPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ Length: uint16(payloadLen + udphLen),
+ })
+ ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udphLen+payloadLen))
+ udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func udp6Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
+ return udp6PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
+}
+
+func udp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv6Fields)) []byte {
+ totalLen := 48 + payloadLen
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv6H := header.IPv6(b[offset:])
+ srcAs16 := srcIPPort.Addr().As16()
+ dstAs16 := dstIPPort.Addr().As16()
+ ipFields := &header.IPv6Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
+ TransportProtocol: unix.IPPROTO_UDP,
+ HopLimit: 64,
+ PayloadLength: uint16(payloadLen + udphLen),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv6H.Encode(ipFields)
+ udpH := header.UDP(b[offset+40:])
+ udpH.Encode(&header.UDPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ Length: uint16(payloadLen + udphLen),
+ })
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udphLen+payloadLen))
+ udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func udp4Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
+ return udp4PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
+}
+
+func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte {
+ totalLen := 40 + segmentSize
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv4H := header.IPv4(b[offset:])
+ srcAs4 := srcIPPort.Addr().As4()
+ dstAs4 := dstIPPort.Addr().As4()
+ ipFields := &header.IPv4Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
+ Protocol: unix.IPPROTO_TCP,
+ TTL: 64,
+ TotalLength: uint16(totalLen),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv4H.Encode(ipFields)
+ tcpH := header.TCP(b[offset+20:])
+ tcpH.Encode(&header.TCPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ SeqNum: seq,
+ AckNum: 1,
+ DataOffset: 20,
+ Flags: flags,
+ WindowSize: 3000,
+ })
+ ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
+ tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+ return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
+}
+
+func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte {
+ totalLen := 60 + segmentSize
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv6H := header.IPv6(b[offset:])
+ srcAs16 := srcIPPort.Addr().As16()
+ dstAs16 := dstIPPort.Addr().As16()
+ ipFields := &header.IPv6Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
+ TransportProtocol: unix.IPPROTO_TCP,
+ HopLimit: 64,
+ PayloadLength: uint16(segmentSize + 20),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv6H.Encode(ipFields)
+ tcpH := header.TCP(b[offset+40:])
+ tcpH.Encode(&header.TCPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ SeqNum: seq,
+ AckNum: 1,
+ DataOffset: 20,
+ Flags: flags,
+ WindowSize: 3000,
+ })
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
+ tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+ return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
+}
+
+func Test_handleVirtioRead(t *testing.T) {
+ tests := []struct {
+ name string
+ hdr virtioNetHdr
+ pktIn []byte
+ wantLens []int
+ wantErr bool
+ }{
+ {
+ "tcp4",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
+ gsoSize: 100,
+ hdrLen: 40,
+ csumStart: 20,
+ csumOffset: 16,
+ },
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
+ []int{140, 140},
+ false,
+ },
+ {
+ "tcp6",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
+ gsoSize: 100,
+ hdrLen: 60,
+ csumStart: 40,
+ csumOffset: 16,
+ },
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
+ []int{160, 160},
+ false,
+ },
+ {
+ "udp4",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
+ gsoSize: 100,
+ hdrLen: 28,
+ csumStart: 20,
+ csumOffset: 6,
+ },
+ udp4Packet(ip4PortA, ip4PortB, 200),
+ []int{128, 128},
+ false,
+ },
+ {
+ "udp6",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
+ gsoSize: 100,
+ hdrLen: 48,
+ csumStart: 40,
+ csumOffset: 6,
+ },
+ udp6Packet(ip6PortA, ip6PortB, 200),
+ []int{148, 148},
+ false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ out := make([][]byte, conn.IdealBatchSize)
+ sizes := make([]int, conn.IdealBatchSize)
+ for i := range out {
+ out[i] = make([]byte, 65535)
+ }
+ tt.hdr.encode(tt.pktIn)
+ n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
+ if err != nil {
+ if tt.wantErr {
+ return
+ }
+ t.Fatalf("got err: %v", err)
+ }
+ if n != len(tt.wantLens) {
+ t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
+ }
+ for i := range tt.wantLens {
+ if tt.wantLens[i] != sizes[i] {
+ t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
+ }
+ }
+ })
+ }
+}
+
+func flipTCP4Checksum(b []byte) []byte {
+ at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
+ b[at] ^= 0xFF
+ b[at+1] ^= 0xFF
+ return b
+}
+
+func flipUDP4Checksum(b []byte) []byte {
+ at := virtioNetHdrLen + 20 + 6 // 20 byte ipv4 header; udp csum offset is 6
+ b[at] ^= 0xFF
+ b[at+1] ^= 0xFF
+ return b
+}
+
+func Fuzz_handleGRO(f *testing.F) {
+ pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
+ pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
+ pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
+ pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
+ pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
+ pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
+ pkt6 := udp4Packet(ip4PortA, ip4PortB, 100)
+ pkt7 := udp4Packet(ip4PortA, ip4PortB, 100)
+ pkt8 := udp4Packet(ip4PortA, ip4PortC, 100)
+ pkt9 := udp6Packet(ip6PortA, ip6PortB, 100)
+ pkt10 := udp6Packet(ip6PortA, ip6PortB, 100)
+ pkt11 := udp6Packet(ip6PortA, ip6PortC, 100)
+ f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, true, offset)
+ f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, canUDPGRO bool, offset int) {
+ pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11}
+ toWrite := make([]int, 0, len(pkts))
+ handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, &toWrite)
+ if len(toWrite) > len(pkts) {
+ t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
+ }
+ seenWriteI := make(map[int]bool)
+ for _, writeI := range toWrite {
+ if writeI < 0 || writeI > len(pkts)-1 {
+ t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
+ }
+ if seenWriteI[writeI] {
+ t.Errorf("duplicate toWrite value: %d", writeI)
+ }
+ seenWriteI[writeI] = true
+ }
+ })
+}
+
+func Test_handleGRO(t *testing.T) {
+ tests := []struct {
+ name string
+ pktsIn [][]byte
+ canUDPGRO bool
+ wantToWrite []int
+ wantLens []int
+ wantErr bool
+ }{
+ {
+ "multiple protocols and flows",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
+ tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ },
+ true,
+ []int{0, 1, 2, 4, 5, 7, 9},
+ []int{240, 228, 128, 140, 260, 160, 248},
+ false,
+ },
+ {
+ "multiple protocols and flows no UDP GRO",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
+ tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ },
+ false,
+ []int{0, 1, 2, 4, 5, 7, 8, 9, 10},
+ []int{240, 128, 128, 140, 260, 160, 128, 148, 148},
+ false,
+ },
+ {
+ "PSH interleaved",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1
+ },
+ true,
+ []int{0, 2, 4, 6},
+ []int{240, 240, 260, 260},
+ false,
+ },
+ {
+ "coalesceItemInvalidCSum",
+ [][]byte{
+ flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
+ flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ },
+ true,
+ []int{0, 1, 3, 4},
+ []int{140, 240, 128, 228},
+ false,
+ },
+ {
+ "out of order",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
+ },
+ true,
+ []int{0},
+ []int{340},
+ false,
+ },
+ {
+ "unequal TTL",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.TTL++
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.TTL++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "unequal ToS",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.TOS++
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.TOS++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "unequal flags more fragments set",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.Flags = 1
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.Flags = 1
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "unequal flags DF set",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.Flags = 2
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.Flags = 2
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "ipv6 unequal hop limit",
+ [][]byte{
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
+ tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
+ fields.HopLimit++
+ }),
+ udp6Packet(ip6PortA, ip6PortB, 100),
+ udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
+ fields.HopLimit++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{160, 160, 148, 148},
+ false,
+ },
+ {
+ "ipv6 unequal traffic class",
+ [][]byte{
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
+ tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
+ fields.TrafficClass++
+ }),
+ udp6Packet(ip6PortA, ip6PortB, 100),
+ udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
+ fields.TrafficClass++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{160, 160, 148, 148},
+ false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ toWrite := make([]int, 0, len(tt.pktsIn))
+ err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, &toWrite)
+ if err != nil {
+ if tt.wantErr {
+ return
+ }
+ t.Fatalf("got err: %v", err)
+ }
+ if len(toWrite) != len(tt.wantToWrite) {
+ t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
+ }
+ for i, pktI := range tt.wantToWrite {
+ if tt.wantToWrite[i] != toWrite[i] {
+ t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
+ }
+ if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
+ t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
+ }
+ }
+ })
+ }
+}
+
+func Test_packetIsGROCandidate(t *testing.T) {
+ tcp4 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
+ tcp4TooShort := tcp4[:39]
+ ip4InvalidHeaderLen := make([]byte, len(tcp4))
+ copy(ip4InvalidHeaderLen, tcp4)
+ ip4InvalidHeaderLen[0] = 0x46
+ ip4InvalidProtocol := make([]byte, len(tcp4))
+ copy(ip4InvalidProtocol, tcp4)
+ ip4InvalidProtocol[9] = unix.IPPROTO_GRE
+
+ tcp6 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
+ tcp6TooShort := tcp6[:59]
+ ip6InvalidProtocol := make([]byte, len(tcp6))
+ copy(ip6InvalidProtocol, tcp6)
+ ip6InvalidProtocol[6] = unix.IPPROTO_GRE
+
+ udp4 := udp4Packet(ip4PortA, ip4PortB, 100)[virtioNetHdrLen:]
+ udp4TooShort := udp4[:27]
+
+ udp6 := udp6Packet(ip6PortA, ip6PortB, 100)[virtioNetHdrLen:]
+ udp6TooShort := udp6[:47]
+
+ tests := []struct {
+ name string
+ b []byte
+ canUDPGRO bool
+ want groCandidateType
+ }{
+ {
+ "tcp4",
+ tcp4,
+ true,
+ tcp4GROCandidate,
+ },
+ {
+ "tcp6",
+ tcp6,
+ true,
+ tcp6GROCandidate,
+ },
+ {
+ "udp4",
+ udp4,
+ true,
+ udp4GROCandidate,
+ },
+ {
+ "udp4 no support",
+ udp4,
+ false,
+ notGROCandidate,
+ },
+ {
+ "udp6",
+ udp6,
+ true,
+ udp6GROCandidate,
+ },
+ {
+ "udp6 no support",
+ udp6,
+ false,
+ notGROCandidate,
+ },
+ {
+ "udp4 too short",
+ udp4TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "udp6 too short",
+ udp6TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "tcp4 too short",
+ tcp4TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "tcp6 too short",
+ tcp6TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "invalid IP version",
+ []byte{0x00},
+ true,
+ notGROCandidate,
+ },
+ {
+ "invalid IP header len",
+ ip4InvalidHeaderLen,
+ true,
+ notGROCandidate,
+ },
+ {
+ "ip4 invalid protocol",
+ ip4InvalidProtocol,
+ true,
+ notGROCandidate,
+ },
+ {
+ "ip6 invalid protocol",
+ ip6InvalidProtocol,
+ true,
+ notGROCandidate,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := packetIsGROCandidate(tt.b, tt.canUDPGRO); got != tt.want {
+ t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func Test_udpPacketsCanCoalesce(t *testing.T) {
+ udp4a := udp4Packet(ip4PortA, ip4PortB, 100)
+ udp4b := udp4Packet(ip4PortA, ip4PortB, 100)
+ udp4c := udp4Packet(ip4PortA, ip4PortB, 110)
+
+ type args struct {
+ pkt []byte
+ iphLen uint8
+ gsoSize uint16
+ item udpGROItem
+ bufs [][]byte
+ bufsOffset int
+ }
+ tests := []struct {
+ name string
+ args args
+ want canCoalesce
+ }{
+ {
+ "coalesceAppend equal gso",
+ args{
+ pkt: udp4a[offset:],
+ iphLen: 20,
+ gsoSize: 100,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4a,
+ udp4b,
+ },
+ bufsOffset: offset,
+ },
+ coalesceAppend,
+ },
+ {
+ "coalesceAppend smaller gso",
+ args{
+ pkt: udp4a[offset : len(udp4a)-90],
+ iphLen: 20,
+ gsoSize: 10,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4a,
+ udp4b,
+ },
+ bufsOffset: offset,
+ },
+ coalesceAppend,
+ },
+ {
+ "coalesceUnavailable smaller gso previously appended",
+ args{
+ pkt: udp4a[offset:],
+ iphLen: 20,
+ gsoSize: 100,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4c,
+ udp4b,
+ },
+ bufsOffset: offset,
+ },
+ coalesceUnavailable,
+ },
+ {
+ "coalesceUnavailable larger following smaller",
+ args{
+ pkt: udp4c[offset:],
+ iphLen: 20,
+ gsoSize: 110,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4a,
+ udp4c,
+ },
+ bufsOffset: offset,
+ },
+ coalesceUnavailable,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want {
+ t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/tun/operateonfd.go b/tun/operateonfd.go
index 70d8a07..f1beb6d 100644
--- a/tun/operateonfd.go
+++ b/tun/operateonfd.go
@@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
diff --git a/tun/tun.go b/tun/tun.go
index 5521cb7..0ae53d0 100644
--- a/tun/tun.go
+++ b/tun/tun.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -18,12 +18,36 @@ const (
)
type Device interface {
- File() *os.File // returns the file descriptor of the device
- Read([]byte, int) (int, error) // read a packet from the device (without any additional headers)
- Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers)
- Flush() error // flush all previous writes to the device
- MTU() (int, error) // returns the MTU of the device
- Name() (string, error) // fetches and returns the current name
- Events() chan Event // returns a constant channel of events related to the device
- Close() error // stops the device and closes the event channel
+ // File returns the file descriptor of the device.
+ File() *os.File
+
+ // Read one or more packets from the Device (without any additional headers).
+ // On a successful read it returns the number of packets read, and sets
+ // packet lengths within the sizes slice. len(sizes) must be >= len(bufs).
+ // A nonzero offset can be used to instruct the Device on where to begin
+ // reading into each element of the bufs slice.
+ Read(bufs [][]byte, sizes []int, offset int) (n int, err error)
+
+ // Write one or more packets to the device (without any additional headers).
+ // On a successful write it returns the number of packets written. A nonzero
+ // offset can be used to instruct the Device on where to begin writing from
+ // each packet contained within the bufs slice.
+ Write(bufs [][]byte, offset int) (int, error)
+
+ // MTU returns the MTU of the Device.
+ MTU() (int, error)
+
+ // Name returns the current name of the Device.
+ Name() (string, error)
+
+ // Events returns a channel of type Event, which is fed Device events.
+ Events() <-chan Event
+
+ // Close stops the Device and closes the Event channel.
+ Close() error
+
+ // BatchSize returns the preferred/max number of packets that can be read or
+ // written in a single read/write call. BatchSize must not change over the
+ // lifetime of a Device.
+ BatchSize() int
}
diff --git a/tun/tun_darwin.go b/tun/tun_darwin.go
index 1ce8a46..c9a6c0b 100644
--- a/tun/tun_darwin.go
+++ b/tun/tun_darwin.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -8,6 +8,7 @@ package tun
import (
"errors"
"fmt"
+ "io"
"net"
"os"
"sync"
@@ -15,7 +16,6 @@ import (
"time"
"unsafe"
- "golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
)
@@ -33,7 +33,7 @@ type NativeTun struct {
func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
for i := 0; i < 20; i++ {
iface, err = net.InterfaceByIndex(index)
- if err != nil && errors.Is(err, syscall.ENOMEM) {
+ if err != nil && errors.Is(err, unix.ENOMEM) {
time.Sleep(time.Duration(i) * time.Second / 3)
continue
}
@@ -55,7 +55,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
retry:
n, err := unix.Read(tun.routeSocket, data)
if err != nil {
- if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
+ if errno, ok := err.(unix.Errno); ok && errno == unix.EINTR {
goto retry
}
tun.errors <- err
@@ -213,49 +213,50 @@ func (tun *NativeTun) File() *os.File {
return tun.tunFile
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
+ // TODO: the BSDs look very similar in Read() and Write(). They should be
+ // collapsed, with platform-specific files containing the varying parts of
+ // their implementations.
select {
case err := <-tun.errors:
return 0, err
default:
- buff := buff[offset-4:]
- n, err := tun.tunFile.Read(buff[:])
+ buf := bufs[0][offset-4:]
+ n, err := tun.tunFile.Read(buf[:])
if n < 4 {
return 0, err
}
- return n - 4, err
+ sizes[0] = n - 4
+ return 1, err
}
}
-func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
- // reserve space for header
-
- buff = buff[offset-4:]
-
- // add packet information header
-
- buff[0] = 0x00
- buff[1] = 0x00
- buff[2] = 0x00
-
- if buff[4]>>4 == ipv6.Version {
- buff[3] = unix.AF_INET6
- } else {
- buff[3] = unix.AF_INET
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
+ if offset < 4 {
+ return 0, io.ErrShortBuffer
}
-
- // write
-
- return tun.tunFile.Write(buff)
-}
-
-func (tun *NativeTun) Flush() error {
- // TODO: can flushing be implemented by buffering and using sendmmsg?
- return nil
+ for i, buf := range bufs {
+ buf = buf[offset-4:]
+ buf[0] = 0x00
+ buf[1] = 0x00
+ buf[2] = 0x00
+ switch buf[4] >> 4 {
+ case 4:
+ buf[3] = unix.AF_INET
+ case 6:
+ buf[3] = unix.AF_INET6
+ default:
+ return i, unix.EAFNOSUPPORT
+ }
+ if _, err := tun.tunFile.Write(buf); err != nil {
+ return i, err
+ }
+ }
+ return len(bufs), nil
}
func (tun *NativeTun) Close() error {
@@ -318,6 +319,10 @@ func (tun *NativeTun) MTU() (int, error) {
return int(ifr.MTU), nil
}
+func (tun *NativeTun) BatchSize() int {
+ return 1
+}
+
func socketCloexec(family, sotype, proto int) (fd int, err error) {
// See go/src/net/sys_cloexec.go for background.
syscall.ForkLock.RLock()
diff --git a/tun/tun_freebsd.go b/tun/tun_freebsd.go
index e1e8986..7c65fd9 100644
--- a/tun/tun_freebsd.go
+++ b/tun/tun_freebsd.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -329,49 +329,50 @@ func (tun *NativeTun) File() *os.File {
return tun.tunFile
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
select {
case err := <-tun.errors:
return 0, err
default:
- buff := buff[offset-4:]
- n, err := tun.tunFile.Read(buff[:])
+ buf := bufs[0][offset-4:]
+ n, err := tun.tunFile.Read(buf[:])
if n < 4 {
return 0, err
}
- return n - 4, err
+ sizes[0] = n - 4
+ return 1, err
}
}
-func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
if offset < 4 {
return 0, io.ErrShortBuffer
}
- buf = buf[offset-4:]
- if len(buf) < 5 {
- return 0, io.ErrShortBuffer
- }
- buf[0] = 0x00
- buf[1] = 0x00
- buf[2] = 0x00
- switch buf[4] >> 4 {
- case 4:
- buf[3] = unix.AF_INET
- case 6:
- buf[3] = unix.AF_INET6
- default:
- return 0, unix.EAFNOSUPPORT
+ for i, buf := range bufs {
+ buf = buf[offset-4:]
+ if len(buf) < 5 {
+ return i, io.ErrShortBuffer
+ }
+ buf[0] = 0x00
+ buf[1] = 0x00
+ buf[2] = 0x00
+ switch buf[4] >> 4 {
+ case 4:
+ buf[3] = unix.AF_INET
+ case 6:
+ buf[3] = unix.AF_INET6
+ default:
+ return i, unix.EAFNOSUPPORT
+ }
+ if _, err := tun.tunFile.Write(buf); err != nil {
+ return i, err
+ }
}
- return tun.tunFile.Write(buf)
-}
-
-func (tun *NativeTun) Flush() error {
- // TODO: can flushing be implemented by buffering and using sendmmsg?
- return nil
+ return len(bufs), nil
}
func (tun *NativeTun) Close() error {
@@ -428,3 +429,7 @@ func (tun *NativeTun) MTU() (int, error) {
}
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
}
+
+func (tun *NativeTun) BatchSize() int {
+ return 1
+}
diff --git a/tun/tun_linux.go b/tun/tun_linux.go
index 90cb2df..bd69cb5 100644
--- a/tun/tun_linux.go
+++ b/tun/tun_linux.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -17,9 +17,8 @@ import (
"time"
"unsafe"
- "golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
-
+ "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel"
)
@@ -33,17 +32,27 @@ type NativeTun struct {
index int32 // if index
errors chan error // async error handling
events chan Event // device related events
- nopi bool // the device was passed IFF_NO_PI
netlinkSock int
netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex
statusListenersShutdown chan struct{}
+ batchSize int
+ vnetHdr bool
+ udpGSO bool
closeOnce sync.Once
nameOnce sync.Once // guards calling initNameCache, which sets following fields
nameCache string // name of interface
nameErr error
+
+ readOpMu sync.Mutex // readOpMu guards readBuff
+ readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
+
+ writeOpMu sync.Mutex // writeOpMu guards toWrite, tcpGROTable
+ toWrite []int
+ tcpGROTable *tcpGROTable
+ udpGROTable *udpGROTable
}
func (tun *NativeTun) File() *os.File {
@@ -323,60 +332,150 @@ func (tun *NativeTun) nameSlow() (string, error) {
return unix.ByteSliceToString(ifr[:]), nil
}
-func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
- if tun.nopi {
- buf = buf[offset:]
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
+ tun.writeOpMu.Lock()
+ defer func() {
+ tun.tcpGROTable.reset()
+ tun.udpGROTable.reset()
+ tun.writeOpMu.Unlock()
+ }()
+ var (
+ errs error
+ total int
+ )
+ tun.toWrite = tun.toWrite[:0]
+ if tun.vnetHdr {
+ err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, &tun.toWrite)
+ if err != nil {
+ return 0, err
+ }
+ offset -= virtioNetHdrLen
} else {
- // reserve space for header
- buf = buf[offset-4:]
-
- // add packet information header
- buf[0] = 0x00
- buf[1] = 0x00
- if buf[4]>>4 == ipv6.Version {
- buf[2] = 0x86
- buf[3] = 0xdd
+ for i := range bufs {
+ tun.toWrite = append(tun.toWrite, i)
+ }
+ }
+ for _, bufsI := range tun.toWrite {
+ n, err := tun.tunFile.Write(bufs[bufsI][offset:])
+ if errors.Is(err, syscall.EBADFD) {
+ return total, os.ErrClosed
+ }
+ if err != nil {
+ errs = errors.Join(errs, err)
} else {
- buf[2] = 0x08
- buf[3] = 0x00
+ total += n
}
}
+ return total, errs
+}
- n, err := tun.tunFile.Write(buf)
- if errors.Is(err, syscall.EBADFD) {
- err = os.ErrClosed
+// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
+// each buffer. It mutates sizes to reflect the size of each element of bufs,
+// and returns the number of packets read.
+func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
+ var hdr virtioNetHdr
+ err := hdr.decode(in)
+ if err != nil {
+ return 0, err
+ }
+ in = in[virtioNetHdrLen:]
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
+ if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
+ // This means CHECKSUM_PARTIAL in skb context. We are responsible
+ // for computing the checksum starting at hdr.csumStart and placing
+ // at hdr.csumOffset.
+ err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset)
+ if err != nil {
+ return 0, err
+ }
+ }
+ if len(in) > len(bufs[0][offset:]) {
+ return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:]))
+ }
+ n := copy(bufs[0][offset:], in)
+ sizes[0] = n
+ return 1, nil
+ }
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+ return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
}
- return n, err
-}
-func (tun *NativeTun) Flush() error {
- // TODO: can flushing be implemented by buffering and using sendmmsg?
- return nil
+ ipVersion := in[0] >> 4
+ switch ipVersion {
+ case 4:
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+ return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
+ }
+ case 6:
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+ return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
+ }
+ default:
+ return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
+ }
+
+ // Don't trust hdr.hdrLen from the kernel as it can be equal to the length
+ // of the entire first packet when the kernel is handling it as part of a
+ // FORWARD path. Instead, parse the transport header length and add it onto
+ // csumStart, which is synonymous for IP header length.
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+ hdr.hdrLen = hdr.csumStart + 8
+ } else {
+ if len(in) <= int(hdr.csumStart+12) {
+ return 0, errors.New("packet is too short")
+ }
+
+ tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
+ if tcpHLen < 20 || tcpHLen > 60 {
+ // A TCP header must be between 20 and 60 bytes in length.
+ return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
+ }
+ hdr.hdrLen = hdr.csumStart + tcpHLen
+ }
+
+ if len(in) < int(hdr.hdrLen) {
+ return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
+ }
+
+ if hdr.hdrLen < hdr.csumStart {
+ return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
+ }
+ cSumAt := int(hdr.csumStart + hdr.csumOffset)
+ if cSumAt+1 >= len(in) {
+ return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
+ }
+
+ return gsoSplit(in, hdr, bufs, sizes, offset, ipVersion == 6)
}
-func (tun *NativeTun) Read(buf []byte, offset int) (n int, err error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
+ tun.readOpMu.Lock()
+ defer tun.readOpMu.Unlock()
select {
- case err = <-tun.errors:
+ case err := <-tun.errors:
+ return 0, err
default:
- if tun.nopi {
- n, err = tun.tunFile.Read(buf[offset:])
+ readInto := bufs[0][offset:]
+ if tun.vnetHdr {
+ readInto = tun.readBuff[:]
+ }
+ n, err := tun.tunFile.Read(readInto)
+ if errors.Is(err, syscall.EBADFD) {
+ err = os.ErrClosed
+ }
+ if err != nil {
+ return 0, err
+ }
+ if tun.vnetHdr {
+ return handleVirtioRead(readInto[:n], bufs, sizes, offset)
} else {
- buff := buf[offset-4:]
- n, err = tun.tunFile.Read(buff[:])
- if errors.Is(err, syscall.EBADFD) {
- err = os.ErrClosed
- }
- if n < 4 {
- n = 0
- } else {
- n -= 4
- }
+ sizes[0] = n
+ return 1, nil
}
}
- return
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
@@ -399,6 +498,56 @@ func (tun *NativeTun) Close() error {
return err2
}
+func (tun *NativeTun) BatchSize() int {
+ return tun.batchSize
+}
+
+const (
+ // TODO: support TSO with ECN bits
+ tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
+ tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
+)
+
+func (tun *NativeTun) initFromFlags(name string) error {
+ sc, err := tun.tunFile.SyscallConn()
+ if err != nil {
+ return err
+ }
+ if e := sc.Control(func(fd uintptr) {
+ var (
+ ifr *unix.Ifreq
+ )
+ ifr, err = unix.NewIfreq(name)
+ if err != nil {
+ return
+ }
+ err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
+ if err != nil {
+ return
+ }
+ got := ifr.Uint16()
+ if got&unix.IFF_VNET_HDR != 0 {
+ // tunTCPOffloads were added in Linux v2.6. We require their support
+ // if IFF_VNET_HDR is set.
+ err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads)
+ if err != nil {
+ return
+ }
+ tun.vnetHdr = true
+ tun.batchSize = conn.IdealBatchSize
+ // tunUDPOffloads were added in Linux v6.2. We do not return an
+ // error if they are unsupported at runtime.
+ tun.udpGSO = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) == nil
+ } else {
+ tun.batchSize = 1
+ }
+ }); e != nil {
+ return e
+ }
+ return err
+}
+
+// CreateTUN creates a Device with the provided name and MTU.
func CreateTUN(name string, mtu int) (Device, error) {
nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
if err != nil {
@@ -408,25 +557,16 @@ func CreateTUN(name string, mtu int) (Device, error) {
return nil, err
}
- var ifr [ifReqSize]byte
- var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack)
- nameBytes := []byte(name)
- if len(nameBytes) >= unix.IFNAMSIZ {
- unix.Close(nfd)
- return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG)
+ ifr, err := unix.NewIfreq(name)
+ if err != nil {
+ return nil, err
}
- copy(ifr[:], nameBytes)
- *(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags
-
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(nfd),
- uintptr(unix.TUNSETIFF),
- uintptr(unsafe.Pointer(&ifr[0])),
- )
- if errno != 0 {
- unix.Close(nfd)
- return nil, errno
+ // IFF_VNET_HDR enables the "tun status hack" via routineHackListener()
+ // where a null write will return EINVAL indicating the TUN is up.
+ ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
+ err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr)
+ if err != nil {
+ return nil, err
}
err = unix.SetNonblock(nfd, true)
@@ -441,13 +581,16 @@ func CreateTUN(name string, mtu int) (Device, error) {
return CreateTUNFromFile(fd, mtu)
}
+// CreateTUNFromFile creates a Device from an os.File with the provided MTU.
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{
tunFile: file,
events: make(chan Event, 5),
errors: make(chan error, 5),
statusListenersShutdown: make(chan struct{}),
- nopi: false,
+ tcpGROTable: newTCPGROTable(),
+ udpGROTable: newUDPGROTable(),
+ toWrite: make([]int, 0, conn.IdealBatchSize),
}
name, err := tun.Name()
@@ -455,8 +598,12 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return nil, err
}
- // start event listener
+ err = tun.initFromFlags(name)
+ if err != nil {
+ return nil, err
+ }
+ // start event listener
tun.index, err = getIFIndex(name)
if err != nil {
return nil, err
@@ -485,6 +632,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return tun, nil
}
+// CreateUnmonitoredTUNFromFD creates a Device from the provided file
+// descriptor.
func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
err := unix.SetNonblock(fd, true)
if err != nil {
@@ -492,14 +641,20 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
}
file := os.NewFile(uintptr(fd), "/dev/tun")
tun := &NativeTun{
- tunFile: file,
- events: make(chan Event, 5),
- errors: make(chan error, 5),
- nopi: true,
+ tunFile: file,
+ events: make(chan Event, 5),
+ errors: make(chan error, 5),
+ tcpGROTable: newTCPGROTable(),
+ udpGROTable: newUDPGROTable(),
+ toWrite: make([]int, 0, conn.IdealBatchSize),
}
name, err := tun.Name()
if err != nil {
return nil, "", err
}
- return tun, name, nil
+ err = tun.initFromFlags(name)
+ if err != nil {
+ return nil, "", err
+ }
+ return tun, name, err
}
diff --git a/tun/tun_openbsd.go b/tun/tun_openbsd.go
index b7a33b5..ae571b9 100644
--- a/tun/tun_openbsd.go
+++ b/tun/tun_openbsd.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -8,13 +8,13 @@ package tun
import (
"errors"
"fmt"
+ "io"
"net"
"os"
"sync"
"syscall"
"unsafe"
- "golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
)
@@ -200,49 +200,47 @@ func (tun *NativeTun) File() *os.File {
return tun.tunFile
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
select {
case err := <-tun.errors:
return 0, err
default:
- buff := buff[offset-4:]
- n, err := tun.tunFile.Read(buff[:])
+ buf := bufs[0][offset-4:]
+ n, err := tun.tunFile.Read(buf[:])
if n < 4 {
return 0, err
}
- return n - 4, err
+ sizes[0] = n - 4
+ return 1, err
}
}
-func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
- // reserve space for header
-
- buff = buff[offset-4:]
-
- // add packet information header
-
- buff[0] = 0x00
- buff[1] = 0x00
- buff[2] = 0x00
-
- if buff[4]>>4 == ipv6.Version {
- buff[3] = unix.AF_INET6
- } else {
- buff[3] = unix.AF_INET
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
+ if offset < 4 {
+ return 0, io.ErrShortBuffer
}
-
- // write
-
- return tun.tunFile.Write(buff)
-}
-
-func (tun *NativeTun) Flush() error {
- // TODO: can flushing be implemented by buffering and using sendmmsg?
- return nil
+ for i, buf := range bufs {
+ buf = buf[offset-4:]
+ buf[0] = 0x00
+ buf[1] = 0x00
+ buf[2] = 0x00
+ switch buf[4] >> 4 {
+ case 4:
+ buf[3] = unix.AF_INET
+ case 6:
+ buf[3] = unix.AF_INET6
+ default:
+ return i, unix.EAFNOSUPPORT
+ }
+ if _, err := tun.tunFile.Write(buf); err != nil {
+ return i, err
+ }
+ }
+ return len(bufs), nil
}
func (tun *NativeTun) Close() error {
@@ -329,3 +327,7 @@ func (tun *NativeTun) MTU() (int, error) {
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
}
+
+func (tun *NativeTun) BatchSize() int {
+ return 1
+}
diff --git a/tun/tun_windows.go b/tun/tun_windows.go
index 6782fd4..2af8e3e 100644
--- a/tun/tun_windows.go
+++ b/tun/tun_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2018-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -15,7 +15,6 @@ import (
_ "unsafe"
"golang.org/x/sys/windows"
-
"golang.zx2c4.com/wintun"
)
@@ -44,6 +43,7 @@ type NativeTun struct {
closeOnce sync.Once
close atomic.Bool
forcedMTU int
+ outSizes []int
}
var (
@@ -102,7 +102,7 @@ func (tun *NativeTun) File() *os.File {
return nil
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
@@ -127,6 +127,9 @@ func (tun *NativeTun) MTU() (int, error) {
// TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
func (tun *NativeTun) ForceMTU(mtu int) {
+ if tun.close.Load() {
+ return
+ }
update := tun.forcedMTU != mtu
tun.forcedMTU = mtu
if update {
@@ -134,9 +137,14 @@ func (tun *NativeTun) ForceMTU(mtu int) {
}
}
+func (tun *NativeTun) BatchSize() int {
+ // TODO: implement batching with wintun
+ return 1
+}
+
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
tun.running.Add(1)
defer tun.running.Done()
retry:
@@ -152,11 +160,11 @@ retry:
packet, err := tun.session.ReceivePacket()
switch err {
case nil:
- packetSize := len(packet)
- copy(buff[offset:], packet)
+ n := copy(bufs[0][offset:], packet)
+ sizes[0] = n
tun.session.ReleaseReceivePacket(packet)
- tun.rate.update(uint64(packetSize))
- return packetSize, nil
+ tun.rate.update(uint64(n))
+ return 1, nil
case windows.ERROR_NO_MORE_ITEMS:
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
@@ -173,33 +181,33 @@ retry:
}
}
-func (tun *NativeTun) Flush() error {
- return nil
-}
-
-func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
tun.running.Add(1)
defer tun.running.Done()
if tun.close.Load() {
return 0, os.ErrClosed
}
- packetSize := len(buff) - offset
- tun.rate.update(uint64(packetSize))
+ for i, buf := range bufs {
+ packetSize := len(buf) - offset
+ tun.rate.update(uint64(packetSize))
- packet, err := tun.session.AllocateSendPacket(packetSize)
- if err == nil {
- copy(packet, buff[offset:])
- tun.session.SendPacket(packet)
- return packetSize, nil
- }
- switch err {
- case windows.ERROR_HANDLE_EOF:
- return 0, os.ErrClosed
- case windows.ERROR_BUFFER_OVERFLOW:
- return 0, nil // Dropping when ring is full.
+ packet, err := tun.session.AllocateSendPacket(packetSize)
+ switch err {
+ case nil:
+ // TODO: Explore options to eliminate this copy.
+ copy(packet, buf[offset:])
+ tun.session.SendPacket(packet)
+ continue
+ case windows.ERROR_HANDLE_EOF:
+ return i, os.ErrClosed
+ case windows.ERROR_BUFFER_OVERFLOW:
+ continue // Dropping when ring is full.
+ default:
+ return i, fmt.Errorf("Write failed: %w", err)
+ }
}
- return 0, fmt.Errorf("Write failed: %w", err)
+ return len(bufs), nil
}
// LUID returns Windows interface instance ID.
diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go
index 8196c34..d07e860 100644
--- a/tun/tuntest/tuntest.go
+++ b/tun/tuntest/tuntest.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tuntest
@@ -110,38 +110,45 @@ type chTun struct {
func (t *chTun) File() *os.File { return nil }
-func (t *chTun) Read(data []byte, offset int) (int, error) {
+func (t *chTun) Read(packets [][]byte, sizes []int, offset int) (int, error) {
select {
case <-t.c.closed:
return 0, os.ErrClosed
case msg := <-t.c.Outbound:
- return copy(data[offset:], msg), nil
+ n := copy(packets[0][offset:], msg)
+ sizes[0] = n
+ return 1, nil
}
}
// Write is called by the wireguard device to deliver a packet for routing.
-func (t *chTun) Write(data []byte, offset int) (int, error) {
+func (t *chTun) Write(packets [][]byte, offset int) (int, error) {
if offset == -1 {
close(t.c.closed)
close(t.c.events)
return 0, io.EOF
}
- msg := make([]byte, len(data)-offset)
- copy(msg, data[offset:])
- select {
- case <-t.c.closed:
- return 0, os.ErrClosed
- case t.c.Inbound <- msg:
- return len(data) - offset, nil
+ for i, data := range packets {
+ msg := make([]byte, len(data)-offset)
+ copy(msg, data[offset:])
+ select {
+ case <-t.c.closed:
+ return i, os.ErrClosed
+ case t.c.Inbound <- msg:
+ }
}
+ return len(packets), nil
+}
+
+func (t *chTun) BatchSize() int {
+ return 1
}
const DefaultMTU = 1420
-func (t *chTun) Flush() error { return nil }
-func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
-func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
-func (t *chTun) Events() chan tun.Event { return t.c.events }
+func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
+func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
+func (t *chTun) Events() <-chan tun.Event { return t.c.events }
func (t *chTun) Close() error {
t.Write(nil, -1)
return nil
diff --git a/version.go b/version.go
index 79f0884..db75bb9 100644
--- a/version.go
+++ b/version.go
@@ -1,3 +1,3 @@
package main
-const Version = "0.0.20220316"
+const Version = "0.0.20230223"