aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJordan Whited <jordan@tailscale.com>2023-03-02 15:08:28 -0800
committerJason A. Donenfeld <Jason@zx2c4.com>2023-03-10 14:52:17 +0100
commit9e2f3860220280a5630971478b53c8ad9a991ca8 (patch)
tree218f1bd9a8dd649a8fdb50571a921d1ccff4cae5
parentconn, device, tun: implement vectorized I/O plumbing (diff)
downloadwireguard-go-9e2f3860220280a5630971478b53c8ad9a991ca8.tar.xz
wireguard-go-9e2f3860220280a5630971478b53c8ad9a991ca8.zip
conn, device, tun: implement vectorized I/O on Linux
Implement TCP offloading via TSO and GRO for the Linux tun.Device, which is made possible by virtio extensions in the kernel's TUN driver. Delete conn.LinuxSocketEndpoint in favor of a collapsed conn.StdNetBind. conn.StdNetBind makes use of recvmmsg() and sendmmsg() on Linux. All platforms now fall under conn.StdNetBind, except for Windows, which remains in conn.WinRingBind, which still needs to be adjusted to handle multiple packets. Also refactor sticky sockets support to eventually be applicable on platforms other than just Linux. However Linux remains the sole platform that fully implements it for now. Co-authored-by: James Tucker <james@tailscale.com> Signed-off-by: James Tucker <james@tailscale.com> Signed-off-by: Jordan Whited <jordan@tailscale.com> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r--conn/bind_linux.go587
-rw-r--r--conn/bind_std.go339
-rw-r--r--conn/boundif_android.go8
-rw-r--r--conn/conn.go2
-rw-r--r--conn/controlfns.go36
-rw-r--r--conn/controlfns_linux.go41
-rw-r--r--conn/controlfns_unix.go28
-rw-r--r--conn/default.go2
-rw-r--r--conn/mark_default.go2
-rw-r--r--conn/mark_unix.go10
-rw-r--r--conn/sticky_default.go26
-rw-r--r--conn/sticky_linux.go111
-rw-r--r--conn/sticky_linux_test.go207
-rw-r--r--device/queueconstants_android.go4
-rw-r--r--device/queueconstants_default.go4
-rw-r--r--device/sticky_linux.go14
-rw-r--r--go.mod6
-rw-r--r--go.sum12
-rw-r--r--tun/checksum.go42
-rw-r--r--tun/tcp_offload_linux.go612
-rw-r--r--tun/tcp_offload_linux_test.go273
-rw-r--r--tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b778
-rw-r--r--tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d8
-rw-r--r--tun/tun_linux.go275
24 files changed, 1870 insertions, 787 deletions
diff --git a/conn/bind_linux.go b/conn/bind_linux.go
deleted file mode 100644
index b6bc0dc..0000000
--- a/conn/bind_linux.go
+++ /dev/null
@@ -1,587 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
- */
-
-package conn
-
-import (
- "errors"
- "net"
- "net/netip"
- "strconv"
- "sync"
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/unix"
-)
-
-type ipv4Source struct {
- Src [4]byte
- Ifindex int32
-}
-
-type ipv6Source struct {
- src [16]byte
- // ifindex belongs in dst.ZoneId
-}
-
-type LinuxSocketEndpoint struct {
- mu sync.Mutex
- dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
- src [unsafe.Sizeof(ipv6Source{})]byte
- isV6 bool
-}
-
-func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source { return endpoint.src4() }
-func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
-func (endpoint *LinuxSocketEndpoint) IsV6() bool { return endpoint.isV6 }
-
-func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source {
- return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0]))
-}
-
-func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source {
- return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0]))
-}
-
-func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 {
- return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
-}
-
-func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
- return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
-}
-
-// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
-type LinuxSocketBind struct {
- // mu guards sock4 and sock6 and the associated fds.
- // As long as someone holds mu (read or write), the associated fds are valid.
- mu sync.RWMutex
- sock4 int
- sock6 int
-}
-
-func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
-func NewDefaultBind() Bind { return NewLinuxSocketBind() }
-
-var (
- _ Endpoint = (*LinuxSocketEndpoint)(nil)
- _ Bind = (*LinuxSocketBind)(nil)
-)
-
-func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
- var end LinuxSocketEndpoint
- e, err := netip.ParseAddrPort(s)
- if err != nil {
- return nil, err
- }
-
- if e.Addr().Is4() {
- dst := end.dst4()
- end.isV6 = false
- dst.Port = int(e.Port())
- dst.Addr = e.Addr().As4()
- end.ClearSrc()
- return &end, nil
- }
-
- if e.Addr().Is6() {
- zone, err := zoneToUint32(e.Addr().Zone())
- if err != nil {
- return nil, err
- }
- dst := end.dst6()
- end.isV6 = true
- dst.Port = int(e.Port())
- dst.ZoneId = zone
- dst.Addr = e.Addr().As16()
- end.ClearSrc()
- return &end, nil
- }
-
- return nil, errors.New("invalid IP address")
-}
-
-func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
- bind.mu.Lock()
- defer bind.mu.Unlock()
-
- var err error
- var newPort uint16
- var tries int
-
- if bind.sock4 != -1 || bind.sock6 != -1 {
- return nil, 0, ErrBindAlreadyOpen
- }
-
- originalPort := port
-
-again:
- port = originalPort
- var sock4, sock6 int
- // Attempt ipv6 bind, update port if successful.
- sock6, newPort, err = create6(port)
- if err != nil {
- if !errors.Is(err, syscall.EAFNOSUPPORT) {
- return nil, 0, err
- }
- } else {
- port = newPort
- }
-
- // Attempt ipv4 bind, update port if successful.
- sock4, newPort, err = create4(port)
- if err != nil {
- if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
- unix.Close(sock6)
- tries++
- goto again
- }
- if !errors.Is(err, syscall.EAFNOSUPPORT) {
- unix.Close(sock6)
- return nil, 0, err
- }
- } else {
- port = newPort
- }
-
- var fns []ReceiveFunc
- if sock4 != -1 {
- bind.sock4 = sock4
- fns = append(fns, bind.receiveIPv4)
- }
- if sock6 != -1 {
- bind.sock6 = sock6
- fns = append(fns, bind.receiveIPv6)
- }
- if len(fns) == 0 {
- return nil, 0, syscall.EAFNOSUPPORT
- }
- return fns, port, nil
-}
-
-func (bind *LinuxSocketBind) SetMark(value uint32) error {
- bind.mu.RLock()
- defer bind.mu.RUnlock()
-
- if bind.sock6 != -1 {
- err := unix.SetsockoptInt(
- bind.sock6,
- unix.SOL_SOCKET,
- unix.SO_MARK,
- int(value),
- )
- if err != nil {
- return err
- }
- }
-
- if bind.sock4 != -1 {
- err := unix.SetsockoptInt(
- bind.sock4,
- unix.SOL_SOCKET,
- unix.SO_MARK,
- int(value),
- )
- if err != nil {
- return err
- }
- }
-
- return nil
-}
-
-func (bind *LinuxSocketBind) BatchSize() int {
- return 1
-}
-
-func (bind *LinuxSocketBind) Close() error {
- // Take a readlock to shut down the sockets...
- bind.mu.RLock()
- if bind.sock6 != -1 {
- unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
- }
- if bind.sock4 != -1 {
- unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
- }
- bind.mu.RUnlock()
- // ...and a write lock to close the fd.
- // This ensures that no one else is using the fd.
- bind.mu.Lock()
- defer bind.mu.Unlock()
- var err1, err2 error
- if bind.sock6 != -1 {
- err1 = unix.Close(bind.sock6)
- bind.sock6 = -1
- }
- if bind.sock4 != -1 {
- err2 = unix.Close(bind.sock4)
- bind.sock4 = -1
- }
-
- if err1 != nil {
- return err1
- }
- return err2
-}
-
-func (bind *LinuxSocketBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
- bind.mu.RLock()
- defer bind.mu.RUnlock()
- if bind.sock4 == -1 {
- return 0, net.ErrClosed
- }
- var end LinuxSocketEndpoint
- n, err := receive4(bind.sock4, buffs[0], &end)
- if err != nil {
- return 0, err
- }
- eps[0] = &end
- sizes[0] = n
- return 1, nil
-}
-
-func (bind *LinuxSocketBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
- bind.mu.RLock()
- defer bind.mu.RUnlock()
- if bind.sock6 == -1 {
- return 0, net.ErrClosed
- }
- var end LinuxSocketEndpoint
- n, err := receive6(bind.sock6, buffs[0], &end)
- if err != nil {
- return 0, err
- }
- eps[0] = &end
- sizes[0] = n
- return 1, nil
-}
-
-func (bind *LinuxSocketBind) Send(buffs [][]byte, end Endpoint) error {
- nend, ok := end.(*LinuxSocketEndpoint)
- if !ok {
- return ErrWrongEndpointType
- }
- bind.mu.RLock()
- defer bind.mu.RUnlock()
- if !nend.isV6 {
- if bind.sock4 == -1 {
- return net.ErrClosed
- }
- for _, buff := range buffs {
- err := send4(bind.sock4, nend, buff)
- if err != nil {
- return err
- }
- }
- } else {
- if bind.sock6 == -1 {
- return net.ErrClosed
- }
- for _, buff := range buffs {
- err := send6(bind.sock6, nend, buff)
- if err != nil {
- return err
- }
- }
- }
- return nil
-}
-
-func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
- if !end.isV6 {
- return netip.AddrFrom4(end.src4().Src)
- } else {
- return netip.AddrFrom16(end.src6().src)
- }
-}
-
-func (end *LinuxSocketEndpoint) DstIP() netip.Addr {
- if !end.isV6 {
- return netip.AddrFrom4(end.dst4().Addr)
- } else {
- return netip.AddrFrom16(end.dst6().Addr)
- }
-}
-
-func (end *LinuxSocketEndpoint) DstToBytes() []byte {
- if !end.isV6 {
- return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
- } else {
- return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
- }
-}
-
-func (end *LinuxSocketEndpoint) SrcToString() string {
- return end.SrcIP().String()
-}
-
-func (end *LinuxSocketEndpoint) DstToString() string {
- var port int
- if !end.isV6 {
- port = end.dst4().Port
- } else {
- port = end.dst6().Port
- }
- return netip.AddrPortFrom(end.DstIP(), uint16(port)).String()
-}
-
-func (end *LinuxSocketEndpoint) ClearDst() {
- for i := range end.dst {
- end.dst[i] = 0
- }
-}
-
-func (end *LinuxSocketEndpoint) ClearSrc() {
- for i := range end.src {
- end.src[i] = 0
- }
-}
-
-func zoneToUint32(zone string) (uint32, error) {
- if zone == "" {
- return 0, nil
- }
- if intr, err := net.InterfaceByName(zone); err == nil {
- return uint32(intr.Index), nil
- }
- n, err := strconv.ParseUint(zone, 10, 32)
- return uint32(n), err
-}
-
-func create4(port uint16) (int, uint16, error) {
- // create socket
-
- fd, err := unix.Socket(
- unix.AF_INET,
- unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
- 0,
- )
- if err != nil {
- return -1, 0, err
- }
-
- addr := unix.SockaddrInet4{
- Port: int(port),
- }
-
- // set sockopts and bind
-
- if err := func() error {
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IP,
- unix.IP_PKTINFO,
- 1,
- ); err != nil {
- return err
- }
-
- return unix.Bind(fd, &addr)
- }(); err != nil {
- unix.Close(fd)
- return -1, 0, err
- }
-
- sa, err := unix.Getsockname(fd)
- if err == nil {
- addr.Port = sa.(*unix.SockaddrInet4).Port
- }
-
- return fd, uint16(addr.Port), err
-}
-
-func create6(port uint16) (int, uint16, error) {
- // create socket
-
- fd, err := unix.Socket(
- unix.AF_INET6,
- unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
- 0,
- )
- if err != nil {
- return -1, 0, err
- }
-
- // set sockopts and bind
-
- addr := unix.SockaddrInet6{
- Port: int(port),
- }
-
- if err := func() error {
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IPV6,
- unix.IPV6_RECVPKTINFO,
- 1,
- ); err != nil {
- return err
- }
-
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IPV6,
- unix.IPV6_V6ONLY,
- 1,
- ); err != nil {
- return err
- }
-
- return unix.Bind(fd, &addr)
- }(); err != nil {
- unix.Close(fd)
- return -1, 0, err
- }
-
- sa, err := unix.Getsockname(fd)
- if err == nil {
- addr.Port = sa.(*unix.SockaddrInet6).Port
- }
-
- return fd, uint16(addr.Port), err
-}
-
-func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error {
- // construct message header
-
- cmsg := struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet4Pktinfo
- }{
- unix.Cmsghdr{
- Level: unix.IPPROTO_IP,
- Type: unix.IP_PKTINFO,
- Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
- },
- unix.Inet4Pktinfo{
- Spec_dst: end.src4().Src,
- Ifindex: end.src4().Ifindex,
- },
- }
-
- end.mu.Lock()
- _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
- end.mu.Unlock()
-
- if err == nil {
- return nil
- }
-
- // clear src and retry
-
- if err == unix.EINVAL {
- end.ClearSrc()
- cmsg.pktinfo = unix.Inet4Pktinfo{}
- end.mu.Lock()
- _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
- end.mu.Unlock()
- }
-
- return err
-}
-
-func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error {
- // construct message header
-
- cmsg := struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet6Pktinfo
- }{
- unix.Cmsghdr{
- Level: unix.IPPROTO_IPV6,
- Type: unix.IPV6_PKTINFO,
- Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
- },
- unix.Inet6Pktinfo{
- Addr: end.src6().src,
- Ifindex: end.dst6().ZoneId,
- },
- }
-
- if cmsg.pktinfo.Addr == [16]byte{} {
- cmsg.pktinfo.Ifindex = 0
- }
-
- end.mu.Lock()
- _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
- end.mu.Unlock()
-
- if err == nil {
- return nil
- }
-
- // clear src and retry
-
- if err == unix.EINVAL {
- end.ClearSrc()
- cmsg.pktinfo = unix.Inet6Pktinfo{}
- end.mu.Lock()
- _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
- end.mu.Unlock()
- }
-
- return err
-}
-
-func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
- // construct message header
-
- var cmsg struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet4Pktinfo
- }
-
- size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
- if err != nil {
- return 0, err
- }
- end.isV6 = false
-
- if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
- *end.dst4() = *newDst4
- }
-
- // update source cache
-
- if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
- cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
- cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
- end.src4().Src = cmsg.pktinfo.Spec_dst
- end.src4().Ifindex = cmsg.pktinfo.Ifindex
- }
-
- return size, nil
-}
-
-func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
- // construct message header
-
- var cmsg struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet6Pktinfo
- }
-
- size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
- if err != nil {
- return 0, err
- }
- end.isV6 = true
-
- if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
- *end.dst6() = *newDst6
- }
-
- // update source cache
-
- if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
- cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
- cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
- end.src6().src = cmsg.pktinfo.Addr
- end.dst6().ZoneId = cmsg.pktinfo.Ifindex
- }
-
- return size, nil
-}
diff --git a/conn/bind_std.go b/conn/bind_std.go
index 98fe23c..a164f56 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -6,32 +6,91 @@
package conn
import (
+ "context"
"errors"
"net"
"net/netip"
+ "strconv"
"sync"
"syscall"
+
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+)
+
+var (
+ _ Bind = (*StdNetBind)(nil)
)
-// StdNetBind is meant to be a temporary solution on platforms for which
-// the sticky socket / source caching behavior has not yet been implemented.
-// It uses the Go's net package to implement networking.
-// See LinuxSocketBind for a proper implementation on the Linux platform.
+// StdNetBind implements Bind for all platforms except Windows.
type StdNetBind struct {
- mu sync.Mutex // protects following fields
- ipv4 *net.UDPConn
- ipv6 *net.UDPConn
- blackhole4 bool
- blackhole6 bool
+ mu sync.Mutex // protects following fields
+ ipv4 *net.UDPConn
+ ipv6 *net.UDPConn
+ blackhole4 bool
+ blackhole6 bool
+ ipv4PC *ipv4.PacketConn
+ ipv6PC *ipv6.PacketConn
+ batchSize int
+ udpAddrPool sync.Pool
+ ipv4MsgsPool sync.Pool
+ ipv6MsgsPool sync.Pool
}
-func NewStdNetBind() Bind { return &StdNetBind{} }
+func NewStdNetBind() Bind { return NewStdNetBindBatch(DefaultBatchSize) }
+
+func NewStdNetBindBatch(maxBatchSize int) Bind {
+ if maxBatchSize == 0 {
+ maxBatchSize = DefaultBatchSize
+ }
+ return &StdNetBind{
+ batchSize: maxBatchSize,
+
+ udpAddrPool: sync.Pool{
+ New: func() any {
+ return &net.UDPAddr{
+ IP: make([]byte, 16),
+ }
+ },
+ },
-type StdNetEndpoint netip.AddrPort
+ ipv4MsgsPool: sync.Pool{
+ New: func() any {
+ msgs := make([]ipv4.Message, maxBatchSize)
+ for i := range msgs {
+ msgs[i].Buffers = make(net.Buffers, 1)
+ msgs[i].OOB = make([]byte, srcControlSize)
+ }
+ return &msgs
+ },
+ },
+
+ ipv6MsgsPool: sync.Pool{
+ New: func() any {
+ msgs := make([]ipv6.Message, maxBatchSize)
+ for i := range msgs {
+ msgs[i].Buffers = make(net.Buffers, 1)
+ msgs[i].OOB = make([]byte, srcControlSize)
+ }
+ return &msgs
+ },
+ },
+ }
+}
+
+type StdNetEndpoint struct {
+ // AddrPort is the endpoint destination.
+ netip.AddrPort
+ // src is the current sticky source address and interface index, if supported.
+ src struct {
+ netip.Addr
+ ifidx int32
+ }
+}
var (
_ Bind = (*StdNetBind)(nil)
- _ Endpoint = StdNetEndpoint{}
+ _ Endpoint = &StdNetEndpoint{}
)
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
@@ -39,31 +98,38 @@ func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
return asEndpoint(e), err
}
-func (StdNetEndpoint) ClearSrc() {}
+func (e *StdNetEndpoint) ClearSrc() {
+ e.src.ifidx = 0
+ e.src.Addr = netip.Addr{}
+}
+
+func (e *StdNetEndpoint) DstIP() netip.Addr {
+ return e.AddrPort.Addr()
+}
-func (e StdNetEndpoint) DstIP() netip.Addr {
- return (netip.AddrPort)(e).Addr()
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+ return e.src.Addr
}
-func (e StdNetEndpoint) SrcIP() netip.Addr {
- return netip.Addr{} // not supported
+func (e *StdNetEndpoint) SrcIfidx() int32 {
+ return e.src.ifidx
}
-func (e StdNetEndpoint) DstToBytes() []byte {
- b, _ := (netip.AddrPort)(e).MarshalBinary()
+func (e *StdNetEndpoint) DstToBytes() []byte {
+ b, _ := e.AddrPort.MarshalBinary()
return b
}
-func (e StdNetEndpoint) DstToString() string {
- return (netip.AddrPort)(e).String()
+func (e *StdNetEndpoint) DstToString() string {
+ return e.AddrPort.String()
}
-func (e StdNetEndpoint) SrcToString() string {
- return ""
+func (e *StdNetEndpoint) SrcToString() string {
+ return e.src.Addr.String()
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
- conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
+ conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
if err != nil {
return nil, 0, err
}
@@ -77,17 +143,17 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
if err != nil {
return nil, 0, err
}
- return conn, uaddr.Port, nil
+ return conn.(*net.UDPConn), uaddr.Port, nil
}
-func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
- bind.mu.Lock()
- defer bind.mu.Unlock()
+func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
var err error
var tries int
- if bind.ipv4 != nil || bind.ipv6 != nil {
+ if s.ipv4 != nil || s.ipv6 != nil {
return nil, 0, ErrBindAlreadyOpen
}
@@ -95,104 +161,121 @@ func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
// If uport is 0, we can retry on failure.
again:
port := int(uport)
- var ipv4, ipv6 *net.UDPConn
+ var v4conn, v6conn *net.UDPConn
- ipv4, port, err = listenNet("udp4", port)
+ v4conn, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
// Listen on the same port as we're using for ipv4.
- ipv6, port, err = listenNet("udp6", port)
+ v6conn, port, err = listenNet("udp6", port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
- ipv4.Close()
+ v4conn.Close()
tries++
goto again
}
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
- ipv4.Close()
+ v4conn.Close()
return nil, 0, err
}
var fns []ReceiveFunc
- if ipv4 != nil {
- fns = append(fns, bind.makeReceiveIPv4(ipv4))
- bind.ipv4 = ipv4
+ if v4conn != nil {
+ fns = append(fns, s.receiveIPv4)
+ s.ipv4 = v4conn
}
- if ipv6 != nil {
- fns = append(fns, bind.makeReceiveIPv6(ipv6))
- bind.ipv6 = ipv6
+ if v6conn != nil {
+ fns = append(fns, s.receiveIPv6)
+ s.ipv6 = v6conn
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
- return fns, uint16(port), nil
-}
-func (bind *StdNetBind) BatchSize() int {
- return 1
-}
+ s.ipv4PC = ipv4.NewPacketConn(s.ipv4)
+ s.ipv6PC = ipv6.NewPacketConn(s.ipv6)
-func (bind *StdNetBind) Close() error {
- bind.mu.Lock()
- defer bind.mu.Unlock()
+ return fns, uint16(port), nil
+}
- var err1, err2 error
- if bind.ipv4 != nil {
- err1 = bind.ipv4.Close()
- bind.ipv4 = nil
+func (s *StdNetBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
+ defer s.ipv4MsgsPool.Put(msgs)
+ for i := range buffs {
+ (*msgs)[i].Buffers[0] = buffs[i]
}
- if bind.ipv6 != nil {
- err2 = bind.ipv6.Close()
- bind.ipv6 = nil
+ numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0)
+ if err != nil {
+ return 0, err
}
- bind.blackhole4 = false
- bind.blackhole6 = false
- if err1 != nil {
- return err1
+ for i := 0; i < numMsgs; i++ {
+ msg := &(*msgs)[i]
+ sizes[i] = msg.N
+ addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+ ep := asEndpoint(addrPort)
+ getSrcFromControl(msg.OOB, ep)
+ eps[i] = ep
}
- return err2
+ return numMsgs, nil
}
-func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
- return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
- size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0])
- if err == nil {
- sizes[0] = size
- eps[0] = asEndpoint(endpoint)
- return 1, nil
- }
+func (s *StdNetBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
+ defer s.ipv6MsgsPool.Put(msgs)
+ for i := range buffs {
+ (*msgs)[i].Buffers[0] = buffs[i]
+ }
+ numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0)
+ if err != nil {
return 0, err
}
+ for i := 0; i < numMsgs; i++ {
+ msg := &(*msgs)[i]
+ sizes[i] = msg.N
+ addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+ ep := asEndpoint(addrPort)
+ getSrcFromControl(msg.OOB, ep)
+ eps[i] = ep
+ }
+ return numMsgs, nil
}
-func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
- return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
- size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0])
- if err == nil {
- sizes[0] = size
- eps[0] = asEndpoint(endpoint)
- return 1, nil
- }
- return 0, err
- }
+func (s *StdNetBind) BatchSize() int {
+ return s.batchSize
}
-func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
- var err error
- nend, ok := endpoint.(StdNetEndpoint)
- if !ok {
- return ErrWrongEndpointType
+func (s *StdNetBind) Close() error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ var err1, err2 error
+ if s.ipv4 != nil {
+ err1 = s.ipv4.Close()
+ s.ipv4 = nil
+ }
+ if s.ipv6 != nil {
+ err2 = s.ipv6.Close()
+ s.ipv6 = nil
+ }
+ s.blackhole4 = false
+ s.blackhole6 = false
+ if err1 != nil {
+ return err1
}
- addrPort := netip.AddrPort(nend)
+ return err2
+}
- bind.mu.Lock()
- blackhole := bind.blackhole4
- conn := bind.ipv4
- if addrPort.Addr().Is6() {
- blackhole = bind.blackhole6
- conn = bind.ipv6
+func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
+ s.mu.Lock()
+ blackhole := s.blackhole4
+ conn := s.ipv4
+ is6 := false
+ if endpoint.DstIP().Is6() {
+ blackhole = s.blackhole6
+ conn = s.ipv6
+ is6 = true
}
- bind.mu.Unlock()
+ s.mu.Unlock()
if blackhole {
return nil
@@ -200,13 +283,69 @@ func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
if conn == nil {
return syscall.EAFNOSUPPORT
}
- for _, buff := range buffs {
- _, err = conn.WriteToUDPAddrPort(buff, addrPort)
- if err != nil {
- return err
+ if is6 {
+ return s.send6(s.ipv6PC, endpoint, buffs)
+ } else {
+ return s.send4(s.ipv4PC, endpoint, buffs)
+ }
+}
+
+func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error {
+ ua := s.udpAddrPool.Get().(*net.UDPAddr)
+ as4 := ep.DstIP().As4()
+ copy(ua.IP, as4[:])
+ ua.IP = ua.IP[:4]
+ ua.Port = int(ep.(*StdNetEndpoint).Port())
+ msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
+ for i, buff := range buffs {
+ (*msgs)[i].Buffers[0] = buff
+ (*msgs)[i].Addr = ua
+ setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
+ }
+ var (
+ n int
+ err error
+ start int
+ )
+ for {
+ n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
+ if err != nil || n == len((*msgs)[start:len(buffs)]) {
+ break
+ }
+ start += n
+ }
+ s.udpAddrPool.Put(ua)
+ s.ipv4MsgsPool.Put(msgs)
+ return err
+}
+
+func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error {
+ ua := s.udpAddrPool.Get().(*net.UDPAddr)
+ as16 := ep.DstIP().As16()
+ copy(ua.IP, as16[:])
+ ua.IP = ua.IP[:16]
+ ua.Port = int(ep.(*StdNetEndpoint).Port())
+ msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
+ for i, buff := range buffs {
+ (*msgs)[i].Buffers[0] = buff
+ (*msgs)[i].Addr = ua
+ setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
+ }
+ var (
+ n int
+ err error
+ start int
+ )
+ for {
+ n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
+ if err != nil || n == len((*msgs)[start:len(buffs)]) {
+ break
}
+ start += n
}
- return nil
+ s.udpAddrPool.Put(ua)
+ s.ipv6MsgsPool.Put(msgs)
+ return err
}
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
@@ -214,17 +353,17 @@ func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
// but Endpoints are immutable, so we can re-use them.
var endpointPool = sync.Pool{
New: func() any {
- return make(map[netip.AddrPort]Endpoint)
+ return make(map[netip.AddrPort]*StdNetEndpoint)
},
}
// asEndpoint returns an Endpoint containing ap.
-func asEndpoint(ap netip.AddrPort) Endpoint {
- m := endpointPool.Get().(map[netip.AddrPort]Endpoint)
+func asEndpoint(ap netip.AddrPort) *StdNetEndpoint {
+ m := endpointPool.Get().(map[netip.AddrPort]*StdNetEndpoint)
defer endpointPool.Put(m)
e, ok := m[ap]
if !ok {
- e = Endpoint(StdNetEndpoint(ap))
+ e = &StdNetEndpoint{AddrPort: ap}
m[ap] = e
}
return e
diff --git a/conn/boundif_android.go b/conn/boundif_android.go
index 818e4e6..dd3ca5b 100644
--- a/conn/boundif_android.go
+++ b/conn/boundif_android.go
@@ -5,8 +5,8 @@
package conn
-func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
- sysconn, err := bind.ipv4.SyscallConn()
+func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
+ sysconn, err := s.ipv4.SyscallConn()
if err != nil {
return -1, err
}
@@ -19,8 +19,8 @@ func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
return
}
-func (bind *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
- sysconn, err := bind.ipv6.SyscallConn()
+func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
+ sysconn, err := s.ipv6.SyscallConn()
if err != nil {
return -1, err
}
diff --git a/conn/conn.go b/conn/conn.go
index 8c0a827..9cbd0af 100644
--- a/conn/conn.go
+++ b/conn/conn.go
@@ -16,7 +16,7 @@ import (
)
const (
- DefaultBatchSize = 1 // maximum number of packets handled per read and write
+ DefaultBatchSize = 128 // maximum number of packets handled per read and write
)
// A ReceiveFunc receives at least one packet from the network and writes them
diff --git a/conn/controlfns.go b/conn/controlfns.go
new file mode 100644
index 0000000..fe32871
--- /dev/null
+++ b/conn/controlfns.go
@@ -0,0 +1,36 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "net"
+ "syscall"
+)
+
+// controlFn is the callback function signature from net.ListenConfig.Control.
+// It is used to apply platform specific configuration to the socket prior to
+// bind.
+type controlFn func(network, address string, c syscall.RawConn) error
+
+// controlFns is a list of functions that are called from the listen config
+// that can apply socket options.
+var controlFns = []controlFn{}
+
+// listenConfig returns a net.ListenConfig that applies the controlFns to the
+// socket prior to bind. This is used to apply socket buffer sizing and packet
+// information OOB configuration for sticky sockets.
+func listenConfig() *net.ListenConfig {
+ return &net.ListenConfig{
+ Control: func(network, address string, c syscall.RawConn) error {
+ for _, fn := range controlFns {
+ if err := fn(network, address, c); err != nil {
+ return err
+ }
+ }
+ return nil
+ },
+ }
+}
diff --git a/conn/controlfns_linux.go b/conn/controlfns_linux.go
new file mode 100644
index 0000000..9e26d95
--- /dev/null
+++ b/conn/controlfns_linux.go
@@ -0,0 +1,41 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "fmt"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+func init() {
+ controlFns = append(controlFns,
+
+ // Enable receiving of the packet information (IP_PKTINFO for IPv4,
+ // IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
+ func(network, address string, c syscall.RawConn) error {
+ var err error
+ switch network {
+ case "udp4":
+ c.Control(func(fd uintptr) {
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
+ })
+ case "udp6":
+ c.Control(func(fd uintptr) {
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
+ if err != nil {
+ return
+ }
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
+ })
+ default:
+ err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
+ }
+ return err
+ },
+ )
+}
diff --git a/conn/controlfns_unix.go b/conn/controlfns_unix.go
new file mode 100644
index 0000000..9738c73
--- /dev/null
+++ b/conn/controlfns_unix.go
@@ -0,0 +1,28 @@
+//go:build !windows && !linux && !js
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+func init() {
+ controlFns = append(controlFns,
+ func(network, address string, c syscall.RawConn) error {
+ var err error
+ if network == "udp6" {
+ c.Control(func(fd uintptr) {
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
+ })
+ }
+ return err
+ },
+ )
+}
diff --git a/conn/default.go b/conn/default.go
index c7b4a84..b6f761b 100644
--- a/conn/default.go
+++ b/conn/default.go
@@ -1,4 +1,4 @@
-//go:build !linux && !windows
+//go:build !windows
/* SPDX-License-Identifier: MIT
*
diff --git a/conn/mark_default.go b/conn/mark_default.go
index 9944c38..3102384 100644
--- a/conn/mark_default.go
+++ b/conn/mark_default.go
@@ -7,6 +7,6 @@
package conn
-func (bind *StdNetBind) SetMark(mark uint32) error {
+func (s *StdNetBind) SetMark(mark uint32) error {
return nil
}
diff --git a/conn/mark_unix.go b/conn/mark_unix.go
index 5566b28..d9e46ee 100644
--- a/conn/mark_unix.go
+++ b/conn/mark_unix.go
@@ -26,13 +26,13 @@ func init() {
}
}
-func (bind *StdNetBind) SetMark(mark uint32) error {
+func (s *StdNetBind) SetMark(mark uint32) error {
var operr error
if fwmarkIoctl == 0 {
return nil
}
- if bind.ipv4 != nil {
- fd, err := bind.ipv4.SyscallConn()
+ if s.ipv4 != nil {
+ fd, err := s.ipv4.SyscallConn()
if err != nil {
return err
}
@@ -46,8 +46,8 @@ func (bind *StdNetBind) SetMark(mark uint32) error {
return err
}
}
- if bind.ipv6 != nil {
- fd, err := bind.ipv6.SyscallConn()
+ if s.ipv6 != nil {
+ fd, err := s.ipv6.SyscallConn()
if err != nil {
return err
}
diff --git a/conn/sticky_default.go b/conn/sticky_default.go
new file mode 100644
index 0000000..3ce9a56
--- /dev/null
+++ b/conn/sticky_default.go
@@ -0,0 +1,26 @@
+//go:build !linux
+// +build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but
+// use alternatively named flags and need ports and require testing.
+
+// getSrcFromControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
+}
+
+// setSrcControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
+}
+
+// srcControlSize returns the recommended buffer size for pooling sticky control
+// data.
+const srcControlSize = 0
diff --git a/conn/sticky_linux.go b/conn/sticky_linux.go
new file mode 100644
index 0000000..bf17839
--- /dev/null
+++ b/conn/sticky_linux.go
@@ -0,0 +1,111 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "net/netip"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+// getSrcFromControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
+ ep.ClearSrc()
+
+ var (
+ hdr unix.Cmsghdr
+ data []byte
+ rem []byte = control
+ err error
+ )
+
+ for len(rem) > unix.SizeofCmsghdr {
+ hdr, data, rem, err = unix.ParseOneSocketControlMessage(control)
+ if err != nil {
+ return
+ }
+
+ if hdr.Level == unix.IPPROTO_IP &&
+ hdr.Type == unix.IP_PKTINFO {
+
+ info := pktInfoFromBuf[unix.Inet4Pktinfo](data)
+ ep.src.Addr = netip.AddrFrom4(info.Spec_dst)
+ ep.src.ifidx = info.Ifindex
+
+ return
+ }
+
+ if hdr.Level == unix.IPPROTO_IPV6 &&
+ hdr.Type == unix.IPV6_PKTINFO {
+
+ info := pktInfoFromBuf[unix.Inet6Pktinfo](data)
+ ep.src.Addr = netip.AddrFrom16(info.Addr)
+ ep.src.ifidx = int32(info.Ifindex)
+
+ return
+ }
+ }
+}
+
+// pktInfoFromBuf returns type T populated from the provided buf via copy(). It
+// panics if buf is of insufficient size.
+func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) {
+ size := int(unsafe.Sizeof(t))
+ if len(buf) < size {
+ panic("pktInfoFromBuf: buffer too small")
+ }
+ copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf)
+ return t
+}
+
+// setSrcControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
+ *control = (*control)[:cap(*control)]
+ if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
+ *control = (*control)[:0]
+ return
+ }
+
+ if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() {
+ *control = (*control)[:0]
+ return
+ }
+
+ if len(*control) < srcControlSize {
+ *control = (*control)[:0]
+ return
+ }
+
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
+ if ep.SrcIP().Is4() {
+ hdr.Level = unix.IPPROTO_IP
+ hdr.Type = unix.IP_PKTINFO
+ hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
+
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
+ info.Ifindex = ep.src.ifidx
+ if ep.SrcIP().IsValid() {
+ info.Spec_dst = ep.SrcIP().As4()
+ }
+ } else {
+ hdr.Level = unix.IPPROTO_IPV6
+ hdr.Type = unix.IPV6_PKTINFO
+ hdr.Len = unix.SizeofCmsghdr + unix.SizeofInet6Pktinfo
+
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
+ info.Ifindex = uint32(ep.src.ifidx)
+ if ep.SrcIP().IsValid() {
+ info.Addr = ep.SrcIP().As16()
+ }
+ }
+
+ *control = (*control)[:hdr.Len]
+}
+
+var srcControlSize = unix.CmsgLen(unix.SizeofInet6Pktinfo)
diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go
new file mode 100644
index 0000000..a42c89e
--- /dev/null
+++ b/conn/sticky_linux_test.go
@@ -0,0 +1,207 @@
+//go:build linux
+// +build linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "context"
+ "net"
+ "net/netip"
+ "runtime"
+ "testing"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+func Test_setSrcControl(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ ep := &StdNetEndpoint{
+ AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
+ }
+ ep.src.Addr = netip.MustParseAddr("127.0.0.1")
+ ep.src.ifidx = 5
+
+ control := make([]byte, srcControlSize)
+
+ setSrcControl(&control, ep)
+
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ if hdr.Level != unix.IPPROTO_IP {
+ t.Errorf("unexpected level: %d", hdr.Level)
+ }
+ if hdr.Type != unix.IP_PKTINFO {
+ t.Errorf("unexpected type: %d", hdr.Type)
+ }
+ if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
+ t.Errorf("unexpected length: %d", hdr.Len)
+ }
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
+ t.Errorf("unexpected address: %v", info.Spec_dst)
+ }
+ if info.Ifindex != 5 {
+ t.Errorf("unexpected ifindex: %d", info.Ifindex)
+ }
+ })
+
+ t.Run("IPv6", func(t *testing.T) {
+ ep := &StdNetEndpoint{
+ AddrPort: netip.MustParseAddrPort("[::1]:1234"),
+ }
+ ep.src.Addr = netip.MustParseAddr("::1")
+ ep.src.ifidx = 5
+
+ control := make([]byte, srcControlSize)
+
+ setSrcControl(&control, ep)
+
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ if hdr.Level != unix.IPPROTO_IPV6 {
+ t.Errorf("unexpected level: %d", hdr.Level)
+ }
+ if hdr.Type != unix.IPV6_PKTINFO {
+ t.Errorf("unexpected type: %d", hdr.Type)
+ }
+ if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
+ t.Errorf("unexpected length: %d", hdr.Len)
+ }
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ if info.Addr != ep.SrcIP().As16() {
+ t.Errorf("unexpected address: %v", info.Addr)
+ }
+ if info.Ifindex != 5 {
+ t.Errorf("unexpected ifindex: %d", info.Ifindex)
+ }
+ })
+
+ t.Run("ClearOnNoSrc", func(t *testing.T) {
+ control := make([]byte, srcControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = 1
+ hdr.Type = 2
+ hdr.Len = 3
+
+ setSrcControl(&control, &StdNetEndpoint{})
+
+ if len(control) != 0 {
+ t.Errorf("unexpected control: %v", control)
+ }
+ })
+}
+
+func Test_getSrcFromControl(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ control := make([]byte, srcControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = unix.IPPROTO_IP
+ hdr.Type = unix.IP_PKTINFO
+ hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ info.Spec_dst = [4]byte{127, 0, 0, 1}
+ info.Ifindex = 5
+
+ ep := &StdNetEndpoint{}
+ getSrcFromControl(control, ep)
+
+ if ep.src.Addr != netip.MustParseAddr("127.0.0.1") {
+ t.Errorf("unexpected address: %v", ep.src.Addr)
+ }
+ if ep.src.ifidx != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+ }
+ })
+ t.Run("IPv6", func(t *testing.T) {
+ control := make([]byte, srcControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = unix.IPPROTO_IPV6
+ hdr.Type = unix.IPV6_PKTINFO
+ hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
+ info.Ifindex = 5
+
+ ep := &StdNetEndpoint{}
+ getSrcFromControl(control, ep)
+
+ if ep.SrcIP() != netip.MustParseAddr("::1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.src.ifidx != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+ }
+ })
+ t.Run("ClearOnEmpty", func(t *testing.T) {
+ control := make([]byte, srcControlSize)
+ ep := &StdNetEndpoint{}
+ ep.src.Addr = netip.MustParseAddr("::1")
+ ep.src.ifidx = 5
+
+ getSrcFromControl(control, ep)
+ if ep.SrcIP().IsValid() {
+ t.Errorf("unexpected address: %v", ep.src.Addr)
+ }
+ if ep.src.ifidx != 0 {
+ t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+ }
+ })
+}
+
+func Test_listenConfig(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ sc, err := conn.(*net.UDPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if runtime.GOOS == "linux" {
+ var i int
+ sc.Control(func(fd uintptr) {
+ i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if i != 1 {
+ t.Error("IP_PKTINFO not set!")
+ }
+ } else {
+ t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
+ }
+ })
+ t.Run("IPv6", func(t *testing.T) {
+ conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ sc, err := conn.(*net.UDPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if runtime.GOOS == "linux" {
+ var i int
+ sc.Control(func(fd uintptr) {
+ i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if i != 1 {
+ t.Error("IPV6_PKTINFO not set!")
+ }
+ } else {
+ t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
+ }
+ })
+}
diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go
index fc937b3..1158387 100644
--- a/device/queueconstants_android.go
+++ b/device/queueconstants_android.go
@@ -5,10 +5,12 @@
package device
+import "golang.zx2c4.com/wireguard/conn"
+
/* Reduce memory consumption for Android */
const (
- QueueStagedSize = 128
+ QueueStagedSize = conn.DefaultBatchSize
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
diff --git a/device/queueconstants_default.go b/device/queueconstants_default.go
index 6b69150..7ed70a1 100644
--- a/device/queueconstants_default.go
+++ b/device/queueconstants_default.go
@@ -7,8 +7,10 @@
package device
+import "golang.zx2c4.com/wireguard/conn"
+
const (
- QueueStagedSize = 128
+ QueueStagedSize = conn.DefaultBatchSize
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
diff --git a/device/sticky_linux.go b/device/sticky_linux.go
index 7afdf28..3ce0769 100644
--- a/device/sticky_linux.go
+++ b/device/sticky_linux.go
@@ -25,7 +25,7 @@ import (
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
- if _, ok := bind.(*conn.LinuxSocketBind); !ok {
+ if _, ok := bind.(*conn.StdNetBind); !ok {
return nil, nil
}
@@ -112,11 +112,11 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
pePtr.peer.Unlock()
break
}
- if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx {
+ if uint32(pePtr.peer.endpoint.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
pePtr.peer.Unlock()
break
}
- pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc()
+ pePtr.peer.endpoint.(*conn.StdNetEndpoint).ClearSrc()
pePtr.peer.Unlock()
}
attr = attr[attrhdr.Len:]
@@ -136,12 +136,12 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
peer.RUnlock()
continue
}
- nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint)
+ nativeEP, _ := peer.endpoint.(*conn.StdNetEndpoint)
if nativeEP == nil {
peer.RUnlock()
continue
}
- if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
+ if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
peer.RUnlock()
break
}
@@ -169,12 +169,12 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
Len: 8,
Type: unix.RTA_DST,
},
- nativeEP.Dst4().Addr,
+ nativeEP.DstIP().As4(),
unix.RtAttr{
Len: 8,
Type: unix.RTA_SRC,
},
- nativeEP.Src4().Src,
+ nativeEP.SrcIP().As4(),
unix.RtAttr{
Len: 8,
Type: unix.RTA_MARK,
diff --git a/go.mod b/go.mod
index 46427df..a6f7254 100644
--- a/go.mod
+++ b/go.mod
@@ -3,9 +3,9 @@ module golang.zx2c4.com/wireguard
go 1.19
require (
- golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd
- golang.org/x/net v0.0.0-20220225172249-27dd8689420f
- golang.org/x/sys v0.2.0
+ golang.org/x/crypto v0.3.0
+ golang.org/x/net v0.2.0
+ golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224
gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0
)
diff --git a/go.sum b/go.sum
index 97ae95d..a967a10 100644
--- a/go.sum
+++ b/go.sum
@@ -1,11 +1,11 @@
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
-golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd h1:XcWmESyNjXJMLahc3mqVQJcgSTDxFxhETVlfk9uGc38=
-golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
-golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc=
-golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
-golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A=
-golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/crypto v0.3.0 h1:a06MkbcxBrEFc0w0QIZWXrH/9cCX6KJyWbBOIwAn+7A=
+golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4=
+golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU=
+golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
+golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 h1:260HNjMTPDya+jq5AM1zZLgG9pv9GASPAGiEEJUbRg4=
+golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
diff --git a/tun/checksum.go b/tun/checksum.go
new file mode 100644
index 0000000..f4f8471
--- /dev/null
+++ b/tun/checksum.go
@@ -0,0 +1,42 @@
+package tun
+
+import "encoding/binary"
+
+// TODO: Explore SIMD and/or other assembly optimizations.
+func checksumNoFold(b []byte, initial uint64) uint64 {
+ ac := initial
+ i := 0
+ n := len(b)
+ for n >= 4 {
+ ac += uint64(binary.BigEndian.Uint32(b[i : i+4]))
+ n -= 4
+ i += 4
+ }
+ for n >= 2 {
+ ac += uint64(binary.BigEndian.Uint16(b[i : i+2]))
+ n -= 2
+ i += 2
+ }
+ if n == 1 {
+ ac += uint64(b[i]) << 8
+ }
+ return ac
+}
+
+func checksum(b []byte, initial uint64) uint16 {
+ ac := checksumNoFold(b, initial)
+ ac = (ac >> 16) + (ac & 0xffff)
+ ac = (ac >> 16) + (ac & 0xffff)
+ ac = (ac >> 16) + (ac & 0xffff)
+ ac = (ac >> 16) + (ac & 0xffff)
+ return uint16(ac)
+}
+
+func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
+ sum := checksumNoFold(srcAddr, 0)
+ sum = checksumNoFold(dstAddr, sum)
+ sum = checksumNoFold([]byte{0, protocol}, sum)
+ tmp := make([]byte, 2)
+ binary.BigEndian.PutUint16(tmp, totalLen)
+ return checksumNoFold(tmp, sum)
+}
diff --git a/tun/tcp_offload_linux.go b/tun/tcp_offload_linux.go
new file mode 100644
index 0000000..f3ffa75
--- /dev/null
+++ b/tun/tcp_offload_linux.go
@@ -0,0 +1,612 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "io"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/conn"
+)
+
+const tcpFlagsOffset = 13
+
+const (
+ tcpFlagFIN uint8 = 0x01
+ tcpFlagPSH uint8 = 0x08
+ tcpFlagACK uint8 = 0x10
+)
+
+// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
+// kernel symbol is virtio_net_hdr.
+type virtioNetHdr struct {
+ flags uint8
+ gsoType uint8
+ hdrLen uint16
+ gsoSize uint16
+ csumStart uint16
+ csumOffset uint16
+}
+
+func (v *virtioNetHdr) decode(b []byte) error {
+ if len(b) < virtioNetHdrLen {
+ return io.ErrShortBuffer
+ }
+ copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
+ return nil
+}
+
+func (v *virtioNetHdr) encode(b []byte) error {
+ if len(b) < virtioNetHdrLen {
+ return io.ErrShortBuffer
+ }
+ copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
+ return nil
+}
+
+const (
+ // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
+ // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
+ virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
+)
+
+// flowKey represents the key for a flow.
+type flowKey struct {
+ srcAddr, dstAddr [16]byte
+ srcPort, dstPort uint16
+ rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
+}
+
+// tcpGROTable holds flow and coalescing information for the purposes of GRO.
+type tcpGROTable struct {
+ itemsByFlow map[flowKey][]tcpGROItem
+ itemsPool [][]tcpGROItem
+}
+
+func newTCPGROTable() *tcpGROTable {
+ t := &tcpGROTable{
+ itemsByFlow: make(map[flowKey][]tcpGROItem, conn.DefaultBatchSize),
+ itemsPool: make([][]tcpGROItem, conn.DefaultBatchSize),
+ }
+ for i := range t.itemsPool {
+ t.itemsPool[i] = make([]tcpGROItem, 0, conn.DefaultBatchSize)
+ }
+ return t
+}
+
+func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
+ key := flowKey{}
+ addrSize := dstAddr - srcAddr
+ copy(key.srcAddr[:], pkt[srcAddr:dstAddr])
+ copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize])
+ key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
+ key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
+ key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
+ return key
+}
+
+// lookupOrInsert looks up a flow for the provided packet and metadata,
+// returning the packets found for the flow, or inserting a new one if none
+// is found.
+func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, buffsIndex int) ([]tcpGROItem, bool) {
+ key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
+ items, ok := t.itemsByFlow[key]
+ if ok {
+ return items, ok
+ }
+ // TODO: insert() performs another map lookup. This could be rearranged to avoid.
+ t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, buffsIndex)
+ return nil, false
+}
+
+// insert an item in the table for the provided packet and packet metadata.
+func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, buffsIndex int) {
+ key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
+ item := tcpGROItem{
+ key: key,
+ buffsIndex: uint16(buffsIndex),
+ gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
+ iphLen: uint8(tcphOffset),
+ tcphLen: uint8(tcphLen),
+ sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
+ pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
+ }
+ items, ok := t.itemsByFlow[key]
+ if !ok {
+ items = t.newItems()
+ }
+ items = append(items, item)
+ t.itemsByFlow[key] = items
+}
+
+func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
+ items, _ := t.itemsByFlow[item.key]
+ items[i] = item
+}
+
+func (t *tcpGROTable) deleteAt(key flowKey, i int) {
+ items, _ := t.itemsByFlow[key]
+ items = append(items[:i], items[i+1:]...)
+ t.itemsByFlow[key] = items
+}
+
+// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
+// of a GRO evaluation across a vector of packets.
+type tcpGROItem struct {
+ key flowKey
+ sentSeq uint32 // the sequence number
+ buffsIndex uint16 // the index into the original buffs slice
+ numMerged uint16 // the number of packets merged into this item
+ gsoSize uint16 // payload size
+ iphLen uint8 // ip header len
+ tcphLen uint8 // tcp header len
+ pshSet bool // psh flag is set
+}
+
+func (t *tcpGROTable) newItems() []tcpGROItem {
+ var items []tcpGROItem
+ items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
+ return items
+}
+
+func (t *tcpGROTable) reset() {
+ for k, items := range t.itemsByFlow {
+ items = items[:0]
+ t.itemsPool = append(t.itemsPool, items)
+ delete(t.itemsByFlow, k)
+ }
+}
+
+// canCoalesce represents the outcome of checking if two TCP packets are
+// candidates for coalescing.
+type canCoalesce int
+
+const (
+ coalescePrepend canCoalesce = -1
+ coalesceUnavailable canCoalesce = 0
+ coalesceAppend canCoalesce = 1
+)
+
+// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
+// described by item. This function makes considerations that match the kernel's
+// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
+func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, buffs [][]byte, buffsOffset int) canCoalesce {
+ pktTarget := buffs[item.buffsIndex][buffsOffset:]
+ if tcphLen != item.tcphLen {
+ // cannot coalesce with unequal tcp options len
+ return coalesceUnavailable
+ }
+ if tcphLen > 20 {
+ if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
+ // cannot coalesce with unequal tcp options
+ return coalesceUnavailable
+ }
+ }
+ if pkt[1] != pktTarget[1] {
+ // cannot coalesce with unequal ToS values
+ return coalesceUnavailable
+ }
+ if pkt[6]>>5 != pktTarget[6]>>5 {
+ // cannot coalesce with unequal DF or reserved bits. MF is checked
+ // further up the stack.
+ return coalesceUnavailable
+ }
+ // seq adjacency
+ lhsLen := item.gsoSize
+ lhsLen += item.numMerged * item.gsoSize
+ if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
+ if item.pshSet {
+ // We cannot append to a segment that has the PSH flag set, PSH
+ // can only be set on the final segment in a reassembled group.
+ return coalesceUnavailable
+ }
+ if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
+ // A smaller than gsoSize packet has been appended previously.
+ // Nothing can come after a smaller packet on the end.
+ return coalesceUnavailable
+ }
+ if gsoSize > item.gsoSize {
+ // We cannot have a larger packet following a smaller one.
+ return coalesceUnavailable
+ }
+ return coalesceAppend
+ } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
+ if pshSet {
+ // We cannot prepend with a segment that has the PSH flag set, PSH
+ // can only be set on the final segment in a reassembled group.
+ return coalesceUnavailable
+ }
+ if gsoSize < item.gsoSize {
+ // We cannot have a larger packet following a smaller one.
+ return coalesceUnavailable
+ }
+ if gsoSize > item.gsoSize && item.numMerged > 0 {
+ // There's at least one previous merge, and we're larger than all
+ // previous. This would put multiple smaller packets on the end.
+ return coalesceUnavailable
+ }
+ return coalescePrepend
+ }
+ return coalesceUnavailable
+}
+
+func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
+ srcAddrAt := ipv4SrcAddrOffset
+ addrSize := 4
+ if isV6 {
+ srcAddrAt = ipv6SrcAddrOffset
+ addrSize = 16
+ }
+ tcpTotalLen := uint16(len(pkt) - int(iphLen))
+ tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen)
+ return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0
+}
+
+// coalesceResult represents the result of attempting to coalesce two TCP
+// packets.
+type coalesceResult int
+
+const (
+ coalesceInsufficientCap coalesceResult = 0
+ coalescePSHEnding coalesceResult = 1
+ coalesceItemInvalidCSum coalesceResult = 2
+ coalescePktInvalidCSum coalesceResult = 3
+ coalesceSuccess coalesceResult = 4
+)
+
+// coalesceTCPPackets attempts to coalesce pkt with the packet described by
+// item, returning the outcome. This function may swap buffs elements in the
+// event of a prepend as item's buffs index is already being tracked for writing
+// to a Device.
+func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, buffs [][]byte, buffsOffset int, isV6 bool) coalesceResult {
+ var pktHead []byte // the packet that will end up at the front
+ headersLen := item.iphLen + item.tcphLen
+ coalescedLen := len(buffs[item.buffsIndex][buffsOffset:]) + len(pkt) - int(headersLen)
+
+ // Copy data
+ if mode == coalescePrepend {
+ pktHead = pkt
+ if cap(pkt)-buffsOffset < coalescedLen {
+ // We don't want to allocate a new underlying array if capacity is
+ // too small.
+ return coalesceInsufficientCap
+ }
+ if pshSet {
+ return coalescePSHEnding
+ }
+ if item.numMerged == 0 {
+ if !tcpChecksumValid(buffs[item.buffsIndex][buffsOffset:], item.iphLen, isV6) {
+ return coalesceItemInvalidCSum
+ }
+ }
+ if !tcpChecksumValid(pkt, item.iphLen, isV6) {
+ return coalescePktInvalidCSum
+ }
+ item.sentSeq = seq
+ extendBy := coalescedLen - len(pktHead)
+ buffs[pktBuffsIndex] = append(buffs[pktBuffsIndex], make([]byte, extendBy)...)
+ copy(buffs[pktBuffsIndex][buffsOffset+len(pkt):], buffs[item.buffsIndex][buffsOffset+int(headersLen):])
+ // Flip the slice headers in buffs as part of prepend. The index of item
+ // is already being tracked for writing.
+ buffs[item.buffsIndex], buffs[pktBuffsIndex] = buffs[pktBuffsIndex], buffs[item.buffsIndex]
+ } else {
+ pktHead = buffs[item.buffsIndex][buffsOffset:]
+ if cap(pktHead)-buffsOffset < coalescedLen {
+ // We don't want to allocate a new underlying array if capacity is
+ // too small.
+ return coalesceInsufficientCap
+ }
+ if item.numMerged == 0 {
+ if !tcpChecksumValid(buffs[item.buffsIndex][buffsOffset:], item.iphLen, isV6) {
+ return coalesceItemInvalidCSum
+ }
+ }
+ if !tcpChecksumValid(pkt, item.iphLen, isV6) {
+ return coalescePktInvalidCSum
+ }
+ if pshSet {
+ // We are appending a segment with PSH set.
+ item.pshSet = pshSet
+ pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
+ }
+ extendBy := len(pkt) - int(headersLen)
+ buffs[item.buffsIndex] = append(buffs[item.buffsIndex], make([]byte, extendBy)...)
+ copy(buffs[item.buffsIndex][buffsOffset+len(pktHead):], pkt[headersLen:])
+ }
+
+ if gsoSize > item.gsoSize {
+ item.gsoSize = gsoSize
+ }
+ hdr := virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
+ hdrLen: uint16(headersLen),
+ gsoSize: uint16(item.gsoSize),
+ csumStart: uint16(item.iphLen),
+ csumOffset: 16,
+ }
+
+ // Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the
+ // (IPv4) header checksum.
+ if isV6 {
+ hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
+ binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len
+ } else {
+ hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
+ pktHead[10], pktHead[11] = 0, 0 // clear checksum field
+ binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length
+ iphCSum := ^checksum(pktHead[:item.iphLen], 0) // compute checksum
+ binary.BigEndian.PutUint16(pktHead[10:], iphCSum) // set checksum field
+ }
+ hdr.encode(buffs[item.buffsIndex][buffsOffset-virtioNetHdrLen:])
+
+ // Calculate the pseudo header checksum and place it at the TCP checksum
+ // offset. Downstream checksum offloading will combine this with computation
+ // of the tcp header and payload checksum.
+ addrLen := 4
+ addrOffset := ipv4SrcAddrOffset
+ if isV6 {
+ addrLen = 16
+ addrOffset = ipv6SrcAddrOffset
+ }
+ srcAddrAt := buffsOffset + addrOffset
+ srcAddr := buffs[item.buffsIndex][srcAddrAt : srcAddrAt+addrLen]
+ dstAddr := buffs[item.buffsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
+ psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen)))
+ binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
+
+ item.numMerged++
+ return coalesceSuccess
+}
+
+const (
+ ipv4FlagMoreFragments = 0x80
+)
+
+const (
+ ipv4SrcAddrOffset = 12
+ ipv6SrcAddrOffset = 8
+ maxUint16 = 1<<16 - 1
+)
+
+// tcpGRO evaluates the TCP packet at pktI in buffs for coalescing with
+// existing packets tracked in table. It will return false when pktI is not
+// coalesced, otherwise true. This indicates to the caller if buffs[pktI]
+// should be written to the Device.
+func tcpGRO(buffs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) {
+ pkt := buffs[pktI][offset:]
+ if len(pkt) > maxUint16 {
+ // A valid IPv4 or IPv6 packet will never exceed this.
+ return false
+ }
+ iphLen := int((pkt[0] & 0x0F) * 4)
+ if isV6 {
+ iphLen = 40
+ ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
+ if ipv6HPayloadLen != len(pkt)-iphLen {
+ return false
+ }
+ } else {
+ totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
+ if totalLen != len(pkt) {
+ return false
+ }
+ if iphLen < 20 || iphLen > 60 {
+ return false
+ }
+ }
+ if len(pkt) < iphLen {
+ return false
+ }
+ tcphLen := int((pkt[iphLen+12] >> 4) * 4)
+ if tcphLen < 20 || tcphLen > 60 {
+ return false
+ }
+ if len(pkt) < iphLen+tcphLen {
+ return false
+ }
+ if !isV6 {
+ if pkt[6]&ipv4FlagMoreFragments != 0 || (pkt[6]<<3 != 0 || pkt[7] != 0) {
+ // no GRO support for fragmented segments for now
+ return false
+ }
+ }
+ tcpFlags := pkt[iphLen+tcpFlagsOffset]
+ var pshSet bool
+ // not a candidate if any non-ACK flags (except PSH+ACK) are set
+ if tcpFlags != tcpFlagACK {
+ if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
+ return false
+ }
+ pshSet = true
+ }
+ gsoSize := uint16(len(pkt) - tcphLen - iphLen)
+ // not a candidate if payload len is 0
+ if gsoSize < 1 {
+ return false
+ }
+ seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
+ srcAddrOffset := ipv4SrcAddrOffset
+ addrLen := 4
+ if isV6 {
+ srcAddrOffset = ipv6SrcAddrOffset
+ addrLen = 16
+ }
+ items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
+ if !existing {
+ return false
+ }
+ for i := len(items) - 1; i >= 0; i-- {
+ // In the best case of packets arriving in order iterating in reverse is
+ // more efficient if there are multiple items for a given flow. This
+ // also enables a natural table.deleteAt() in the
+ // coalesceItemInvalidCSum case without the need for index tracking.
+ // This algorithm makes a best effort to coalesce in the event of
+ // unordered packets, where pkt may land anywhere in items from a
+ // sequence number perspective, however once an item is inserted into
+ // the table it is never compared across other items later.
+ item := items[i]
+ can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, buffs, offset)
+ if can != coalesceUnavailable {
+ result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, buffs, offset, isV6)
+ switch result {
+ case coalesceSuccess:
+ table.updateAt(item, i)
+ return true
+ case coalesceItemInvalidCSum:
+ // delete the item with an invalid csum
+ table.deleteAt(item.key, i)
+ case coalescePktInvalidCSum:
+ // no point in inserting an item that we can't coalesce
+ return false
+ default:
+ }
+ }
+ }
+ // failed to coalesce with any other packets; store the item in the flow
+ table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
+ return false
+}
+
+func isTCP4(b []byte) bool {
+ if len(b) < 40 {
+ return false
+ }
+ if b[0]>>4 != 4 {
+ return false
+ }
+ if b[9] != unix.IPPROTO_TCP {
+ return false
+ }
+ return true
+}
+
+func isTCP6NoEH(b []byte) bool {
+ if len(b) < 60 {
+ return false
+ }
+ if b[0]>>4 != 6 {
+ return false
+ }
+ if b[6] != unix.IPPROTO_TCP {
+ return false
+ }
+ return true
+}
+
+// handleGRO evaluates buffs for GRO, and writes the indices of the resulting
+// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
+// empty (but non-nil), and are passed in to save allocs as the caller may reset
+// and recycle them across vectors of packets.
+func handleGRO(buffs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error {
+ for i := range buffs {
+ if offset < virtioNetHdrLen || offset > len(buffs[i])-1 {
+ return errors.New("invalid offset")
+ }
+ var coalesced bool
+ switch {
+ case isTCP4(buffs[i][offset:]):
+ coalesced = tcpGRO(buffs, offset, i, tcp4Table, false)
+ case isTCP6NoEH(buffs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
+ coalesced = tcpGRO(buffs, offset, i, tcp6Table, true)
+ }
+ if !coalesced {
+ hdr := virtioNetHdr{}
+ err := hdr.encode(buffs[i][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+ *toWrite = append(*toWrite, i)
+ }
+ }
+ return nil
+}
+
+// tcpTSO splits packets from in into outBuffs, writing the size of each
+// element into sizes. It returns the number of buffers populated, and/or an
+// error.
+func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) {
+ iphLen := int(hdr.csumStart)
+ srcAddrOffset := ipv6SrcAddrOffset
+ addrLen := 16
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
+ in[10], in[11] = 0, 0 // clear ipv4 header checksum
+ srcAddrOffset = ipv4SrcAddrOffset
+ addrLen = 4
+ }
+ tcpCSumAt := int(hdr.csumStart + hdr.csumOffset)
+ in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum
+ firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:])
+ nextSegmentDataAt := int(hdr.hdrLen)
+ i := 0
+ for ; nextSegmentDataAt < len(in); i++ {
+ if i == len(outBuffs) {
+ return i - 1, ErrTooManySegments
+ }
+ nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
+ if nextSegmentEnd > len(in) {
+ nextSegmentEnd = len(in)
+ }
+ segmentDataLen := nextSegmentEnd - nextSegmentDataAt
+ totalLen := int(hdr.hdrLen) + segmentDataLen
+ sizes[i] = totalLen
+ out := outBuffs[i][outOffset:]
+
+ copy(out, in[:iphLen])
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
+ // For IPv4 we are responsible for incrementing the ID field,
+ // updating the total len field, and recalculating the header
+ // checksum.
+ if i > 0 {
+ id := binary.BigEndian.Uint16(out[4:])
+ id += uint16(i)
+ binary.BigEndian.PutUint16(out[4:], id)
+ }
+ binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
+ ipv4CSum := ^checksum(out[:iphLen], 0)
+ binary.BigEndian.PutUint16(out[10:], ipv4CSum)
+ } else {
+ // For IPv6 we are responsible for updating the payload length field.
+ binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
+ }
+
+ // TCP header
+ copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
+ tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
+ binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
+ if nextSegmentEnd != len(in) {
+ // FIN and PSH should only be set on last segment
+ clearFlags := tcpFlagFIN | tcpFlagPSH
+ out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
+ }
+
+ // payload
+ copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
+
+ // TCP checksum
+ tcpHLen := int(hdr.hdrLen - hdr.csumStart)
+ tcpLenForPseudo := uint16(tcpHLen + segmentDataLen)
+ tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo)
+ tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold)
+ binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum)
+
+ nextSegmentDataAt += int(hdr.gsoSize)
+ }
+ return i, nil
+}
+
+func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
+ cSumAt := cSumStart + cSumOffset
+ // The initial value at the checksum offset should be summed with the
+ // checksum we compute. This is typically the pseudo-header checksum.
+ initial := binary.BigEndian.Uint16(in[cSumAt:])
+ in[cSumAt], in[cSumAt+1] = 0, 0
+ binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial)))
+ return nil
+}
diff --git a/tun/tcp_offload_linux_test.go b/tun/tcp_offload_linux_test.go
new file mode 100644
index 0000000..7fa0777
--- /dev/null
+++ b/tun/tcp_offload_linux_test.go
@@ -0,0 +1,273 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+import (
+ "net/netip"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/conn"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+const (
+ offset = virtioNetHdrLen
+)
+
+var (
+ ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
+ ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
+ ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
+ ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
+ ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
+ ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
+)
+
+func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+ totalLen := 40 + segmentSize
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv4H := header.IPv4(b[offset:])
+ srcAs4 := srcIPPort.Addr().As4()
+ dstAs4 := dstIPPort.Addr().As4()
+ ipv4H.Encode(&header.IPv4Fields{
+ SrcAddr: tcpip.Address(srcAs4[:]),
+ DstAddr: tcpip.Address(dstAs4[:]),
+ Protocol: unix.IPPROTO_TCP,
+ TTL: 64,
+ TotalLength: uint16(totalLen),
+ })
+ tcpH := header.TCP(b[offset+20:])
+ tcpH.Encode(&header.TCPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ SeqNum: seq,
+ AckNum: 1,
+ DataOffset: 20,
+ Flags: flags,
+ WindowSize: 3000,
+ })
+ ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
+ tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+ totalLen := 60 + segmentSize
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv6H := header.IPv6(b[offset:])
+ srcAs16 := srcIPPort.Addr().As16()
+ dstAs16 := dstIPPort.Addr().As16()
+ ipv6H.Encode(&header.IPv6Fields{
+ SrcAddr: tcpip.Address(srcAs16[:]),
+ DstAddr: tcpip.Address(dstAs16[:]),
+ TransportProtocol: unix.IPPROTO_TCP,
+ HopLimit: 64,
+ PayloadLength: uint16(segmentSize + 20),
+ })
+ tcpH := header.TCP(b[offset+40:])
+ tcpH.Encode(&header.TCPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ SeqNum: seq,
+ AckNum: 1,
+ DataOffset: 20,
+ Flags: flags,
+ WindowSize: 3000,
+ })
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
+ tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func Test_handleVirtioRead(t *testing.T) {
+ tests := []struct {
+ name string
+ hdr virtioNetHdr
+ pktIn []byte
+ wantLens []int
+ wantErr bool
+ }{
+ {
+ "tcp4",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
+ gsoSize: 100,
+ hdrLen: 40,
+ csumStart: 20,
+ csumOffset: 16,
+ },
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
+ []int{140, 140},
+ false,
+ },
+ {
+ "tcp6",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
+ gsoSize: 100,
+ hdrLen: 60,
+ csumStart: 40,
+ csumOffset: 16,
+ },
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
+ []int{160, 160},
+ false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ out := make([][]byte, conn.DefaultBatchSize)
+ sizes := make([]int, conn.DefaultBatchSize)
+ for i := range out {
+ out[i] = make([]byte, 65535)
+ }
+ tt.hdr.encode(tt.pktIn)
+ n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
+ if err != nil {
+ if tt.wantErr {
+ return
+ }
+ t.Fatalf("got err: %v", err)
+ }
+ if n != len(tt.wantLens) {
+ t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
+ }
+ for i := range tt.wantLens {
+ if tt.wantLens[i] != sizes[i] {
+ t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
+ }
+ }
+ })
+ }
+}
+
+func flipTCP4Checksum(b []byte) []byte {
+ at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
+ b[at] ^= 0xFF
+ b[at+1] ^= 0xFF
+ return b
+}
+
+func Fuzz_handleGRO(f *testing.F) {
+ pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
+ pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
+ pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
+ pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
+ pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
+ pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
+ f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, offset)
+ f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5 []byte, offset int) {
+ pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5}
+ toWrite := make([]int, 0, len(pkts))
+ handleGRO(pkts, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
+ if len(toWrite) > len(pkts) {
+ t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
+ }
+ seenWriteI := make(map[int]bool)
+ for _, writeI := range toWrite {
+ if writeI < 0 || writeI > len(pkts)-1 {
+ t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
+ }
+ if seenWriteI[writeI] {
+ t.Errorf("duplicate toWrite value: %d", writeI)
+ }
+ seenWriteI[writeI] = true
+ }
+ })
+}
+
+func Test_handleGRO(t *testing.T) {
+ tests := []struct {
+ name string
+ pktsIn [][]byte
+ wantToWrite []int
+ wantLens []int
+ wantErr bool
+ }{
+ {
+ "multiple flows",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // v4 flow 2
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // v6 flow 2
+ },
+ []int{0, 2, 3, 5},
+ []int{240, 140, 260, 160},
+ false,
+ },
+ {
+ "PSH interleaved",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1
+ },
+ []int{0, 2, 4, 6},
+ []int{240, 240, 260, 260},
+ false,
+ },
+ {
+ "coalesceItemInvalidCSum",
+ [][]byte{
+ flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
+ },
+ []int{0, 1},
+ []int{140, 240},
+ false,
+ },
+ {
+ "out of order",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
+ },
+ []int{0},
+ []int{340},
+ false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ toWrite := make([]int, 0, len(tt.pktsIn))
+ err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
+ if err != nil {
+ if tt.wantErr {
+ return
+ }
+ t.Fatalf("got err: %v", err)
+ }
+ if len(toWrite) != len(tt.wantToWrite) {
+ t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
+ }
+ for i, pktI := range tt.wantToWrite {
+ if tt.wantToWrite[i] != toWrite[i] {
+ t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
+ }
+ if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
+ t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
+ }
+ }
+ })
+ }
+}
diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 b/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77
new file mode 100644
index 0000000..5461e79
--- /dev/null
+++ b/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77
@@ -0,0 +1,8 @@
+go test fuzz v1
+[]byte("0")
+[]byte("0")
+[]byte("0")
+[]byte("0")
+[]byte("0")
+[]byte("0")
+int(34)
diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d b/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d
new file mode 100644
index 0000000..b441819
--- /dev/null
+++ b/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d
@@ -0,0 +1,8 @@
+go test fuzz v1
+[]byte("0")
+[]byte("0")
+[]byte("0")
+[]byte("0")
+[]byte("0")
+[]byte("0")
+int(-48)
diff --git a/tun/tun_linux.go b/tun/tun_linux.go
index 21984ca..d56e3c1 100644
--- a/tun/tun_linux.go
+++ b/tun/tun_linux.go
@@ -17,9 +17,8 @@ import (
"time"
"unsafe"
- "golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
-
+ "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel"
)
@@ -33,17 +32,25 @@ type NativeTun struct {
index int32 // if index
errors chan error // async error handling
events chan Event // device related events
- nopi bool // the device was passed IFF_NO_PI
netlinkSock int
netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex
statusListenersShutdown chan struct{}
+ batchSize int
+ vnetHdr bool
closeOnce sync.Once
nameOnce sync.Once // guards calling initNameCache, which sets following fields
nameCache string // name of interface
nameErr error
+
+ readOpMu sync.Mutex // readOpMu guards readBuff
+ readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
+
+ writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable
+ toWrite []int
+ tcp4GROTable, tcp6GROTable *tcpGROTable
}
func (tun *NativeTun) File() *os.File {
@@ -323,60 +330,142 @@ func (tun *NativeTun) nameSlow() (string, error) {
return unix.ByteSliceToString(ifr[:]), nil
}
-func (tun *NativeTun) Write(buffs [][]byte, offset int) (n int, err error) {
- var buf []byte
- if tun.nopi {
- buf = buffs[0][offset:]
+func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) {
+ tun.writeOpMu.Lock()
+ defer func() {
+ tun.tcp4GROTable.reset()
+ tun.tcp6GROTable.reset()
+ tun.writeOpMu.Unlock()
+ }()
+ var (
+ errs []error
+ total int
+ )
+ tun.toWrite = tun.toWrite[:0]
+ if tun.vnetHdr {
+ err := handleGRO(buffs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite)
+ if err != nil {
+ return 0, err
+ }
+ offset -= virtioNetHdrLen
} else {
- // reserve space for header
- buf = buffs[0][offset-4:]
-
- // add packet information header
- buf[0] = 0x00
- buf[1] = 0x00
- if buf[4]>>4 == ipv6.Version {
- buf[2] = 0x86
- buf[3] = 0xdd
+ for i := range buffs {
+ tun.toWrite = append(tun.toWrite, i)
+ }
+ }
+ for _, buffsI := range tun.toWrite {
+ n, err := tun.tunFile.Write(buffs[buffsI][offset:])
+ if errors.Is(err, syscall.EBADFD) {
+ return total, os.ErrClosed
+ }
+ if err != nil {
+ errs = append(errs, err)
} else {
- buf[2] = 0x08
- buf[3] = 0x00
+ total += n
+ }
+ }
+ return total, ErrorBatch(errs)
+}
+
+// handleVirtioRead splits in into buffs, leaving offset bytes at the front of
+// each buffer. It mutates sizes to reflect the size of each element of buffs,
+// and returns the number of packets read.
+func handleVirtioRead(in []byte, buffs [][]byte, sizes []int, offset int) (int, error) {
+ var hdr virtioNetHdr
+ err := hdr.decode(in)
+ if err != nil {
+ return 0, err
+ }
+ in = in[virtioNetHdrLen:]
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
+ if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
+ // This means CHECKSUM_PARTIAL in skb context. We are responsible
+ // for computing the checksum starting at hdr.csumStart and placing
+ // at hdr.csumOffset.
+ err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset)
+ if err != nil {
+ return 0, err
+ }
+ }
+ if len(in) > len(buffs[0][offset:]) {
+ return 0, fmt.Errorf("read len %d overflows buffs element len %d", len(in), len(buffs[0][offset:]))
}
+ n := copy(buffs[0][offset:], in)
+ sizes[0] = n
+ return 1, nil
+ }
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
+ return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
+ }
+
+ ipVersion := in[0] >> 4
+ switch ipVersion {
+ case 4:
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
+ return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
+ }
+ case 6:
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
+ return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
+ }
+ default:
+ return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
+ }
+
+ if len(in) <= int(hdr.csumStart+12) {
+ return 0, errors.New("packet is too short")
+ }
+ // Don't trust hdr.hdrLen from the kernel as it can be equal to the length
+ // of the entire first packet when the kernel is handling it as part of a
+ // FORWARD path. Instead, parse the TCP header length and add it onto
+ // csumStart, which is synonymous for IP header length.
+ tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
+ if tcpHLen < 20 || tcpHLen > 60 {
+ // A TCP header must be between 20 and 60 bytes in length.
+ return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
+ }
+ hdr.hdrLen = hdr.csumStart + tcpHLen
+
+ if len(in) < int(hdr.hdrLen) {
+ return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
}
- _, err = tun.tunFile.Write(buf)
- if errors.Is(err, syscall.EBADFD) {
- err = os.ErrClosed
- } else if err == nil {
- n = 1
+ if hdr.hdrLen < hdr.csumStart {
+ return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
}
- return n, err
+ cSumAt := int(hdr.csumStart + hdr.csumOffset)
+ if cSumAt+1 >= len(in) {
+ return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
+ }
+
+ return tcpTSO(in, hdr, buffs, sizes, offset)
}
-func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (n int, err error) {
+func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
+ tun.readOpMu.Lock()
+ defer tun.readOpMu.Unlock()
select {
- case err = <-tun.errors:
+ case err := <-tun.errors:
+ return 0, err
default:
- if tun.nopi {
- sizes[0], err = tun.tunFile.Read(buffs[0][offset:])
- if err == nil {
- n = 1
- }
+ readInto := buffs[0][offset:]
+ if tun.vnetHdr {
+ readInto = tun.readBuff[:]
+ }
+ n, err := tun.tunFile.Read(readInto)
+ if errors.Is(err, syscall.EBADFD) {
+ err = os.ErrClosed
+ }
+ if err != nil {
+ return 0, err
+ }
+ if tun.vnetHdr {
+ return handleVirtioRead(readInto[:n], buffs, sizes, offset)
} else {
- buff := buffs[0][offset-4:]
- sizes[0], err = tun.tunFile.Read(buff[:])
- if errors.Is(err, syscall.EBADFD) {
- err = os.ErrClosed
- } else if err == nil {
- n = 1
- }
- if sizes[0] < 4 {
- sizes[0] = 0
- } else {
- sizes[0] -= 4
- }
+ sizes[0] = n
+ return 1, nil
}
}
- return
}
func (tun *NativeTun) Events() <-chan Event {
@@ -403,9 +492,49 @@ func (tun *NativeTun) Close() error {
}
func (tun *NativeTun) BatchSize() int {
- return 1
+ return tun.batchSize
}
+const (
+ // TODO: support TSO with ECN bits
+ tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
+)
+
+func (tun *NativeTun) initFromFlags(name string) error {
+ sc, err := tun.tunFile.SyscallConn()
+ if err != nil {
+ return err
+ }
+ if e := sc.Control(func(fd uintptr) {
+ var (
+ ifr *unix.Ifreq
+ )
+ ifr, err = unix.NewIfreq(name)
+ if err != nil {
+ return
+ }
+ err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
+ if err != nil {
+ return
+ }
+ got := ifr.Uint16()
+ if got&unix.IFF_VNET_HDR != 0 {
+ err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads)
+ if err != nil {
+ return
+ }
+ tun.vnetHdr = true
+ tun.batchSize = conn.DefaultBatchSize
+ } else {
+ tun.batchSize = 1
+ }
+ }); e != nil {
+ return e
+ }
+ return err
+}
+
+// CreateTUN creates a Device with the provided name and MTU.
func CreateTUN(name string, mtu int) (Device, error) {
nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
if err != nil {
@@ -415,25 +544,16 @@ func CreateTUN(name string, mtu int) (Device, error) {
return nil, err
}
- var ifr [ifReqSize]byte
- var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack)
- nameBytes := []byte(name)
- if len(nameBytes) >= unix.IFNAMSIZ {
- unix.Close(nfd)
- return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG)
+ ifr, err := unix.NewIfreq(name)
+ if err != nil {
+ return nil, err
}
- copy(ifr[:], nameBytes)
- *(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags
-
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(nfd),
- uintptr(unix.TUNSETIFF),
- uintptr(unsafe.Pointer(&ifr[0])),
- )
- if errno != 0 {
- unix.Close(nfd)
- return nil, errno
+ // IFF_VNET_HDR enables the "tun status hack" via routineHackListener()
+ // where a null write will return EINVAL indicating the TUN is up.
+ ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
+ err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr)
+ if err != nil {
+ return nil, err
}
err = unix.SetNonblock(nfd, true)
@@ -448,13 +568,16 @@ func CreateTUN(name string, mtu int) (Device, error) {
return CreateTUNFromFile(fd, mtu)
}
+// CreateTUNFromFile creates a Device from an os.File with the provided MTU.
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{
tunFile: file,
events: make(chan Event, 5),
errors: make(chan error, 5),
statusListenersShutdown: make(chan struct{}),
- nopi: false,
+ tcp4GROTable: newTCPGROTable(),
+ tcp6GROTable: newTCPGROTable(),
+ toWrite: make([]int, 0, conn.DefaultBatchSize),
}
name, err := tun.Name()
@@ -462,8 +585,12 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return nil, err
}
- // start event listener
+ err = tun.initFromFlags(name)
+ if err != nil {
+ return nil, err
+ }
+ // start event listener
tun.index, err = getIFIndex(name)
if err != nil {
return nil, err
@@ -492,6 +619,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return tun, nil
}
+// CreateUnmonitoredTUNFromFD creates a Device from the provided file
+// descriptor.
func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
err := unix.SetNonblock(fd, true)
if err != nil {
@@ -499,14 +628,20 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
}
file := os.NewFile(uintptr(fd), "/dev/tun")
tun := &NativeTun{
- tunFile: file,
- events: make(chan Event, 5),
- errors: make(chan error, 5),
- nopi: true,
+ tunFile: file,
+ events: make(chan Event, 5),
+ errors: make(chan error, 5),
+ tcp4GROTable: newTCPGROTable(),
+ tcp6GROTable: newTCPGROTable(),
+ toWrite: make([]int, 0, conn.DefaultBatchSize),
}
name, err := tun.Name()
if err != nil {
return nil, "", err
}
- return tun, name, nil
+ err = tun.initFromFlags(name)
+ if err != nil {
+ return nil, "", err
+ }
+ return tun, name, err
}