aboutsummaryrefslogtreecommitdiffstats
path: root/tun/tun_linux.go
diff options
context:
space:
mode:
Diffstat (limited to 'tun/tun_linux.go')
-rw-r--r--tun/tun_linux.go308
1 files changed, 228 insertions, 80 deletions
diff --git a/tun/tun_linux.go b/tun/tun_linux.go
index 1cc84cb..bd69cb5 100644
--- a/tun/tun_linux.go
+++ b/tun/tun_linux.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -9,7 +9,6 @@ package tun
*/
import (
- "bytes"
"errors"
"fmt"
"os"
@@ -18,9 +17,8 @@ import (
"time"
"unsafe"
- "golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
-
+ "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel"
)
@@ -34,17 +32,27 @@ type NativeTun struct {
index int32 // if index
errors chan error // async error handling
events chan Event // device related events
- nopi bool // the device was passed IFF_NO_PI
netlinkSock int
netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex
statusListenersShutdown chan struct{}
+ batchSize int
+ vnetHdr bool
+ udpGSO bool
closeOnce sync.Once
nameOnce sync.Once // guards calling initNameCache, which sets following fields
nameCache string // name of interface
nameErr error
+
+ readOpMu sync.Mutex // readOpMu guards readBuff
+ readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
+
+ writeOpMu sync.Mutex // writeOpMu guards toWrite, tcpGROTable
+ toWrite []int
+ tcpGROTable *tcpGROTable
+ udpGROTable *udpGROTable
}
func (tun *NativeTun) File() *os.File {
@@ -100,7 +108,7 @@ func (tun *NativeTun) routineHackListener() {
}
func createNetlinkSocket() (int, error) {
- sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
+ sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
@@ -195,7 +203,7 @@ func (tun *NativeTun) routineNetlinkListener() {
func getIFIndex(name string) (int32, error) {
fd, err := unix.Socket(
unix.AF_INET,
- unix.SOCK_DGRAM,
+ unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
if err != nil {
@@ -229,10 +237,9 @@ func (tun *NativeTun) setMTU(n int) error {
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
- unix.SOCK_DGRAM,
+ unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
-
if err != nil {
return err
}
@@ -266,10 +273,9 @@ func (tun *NativeTun) MTU() (int, error) {
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
- unix.SOCK_DGRAM,
+ unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
-
if err != nil {
return 0, err
}
@@ -323,67 +329,153 @@ func (tun *NativeTun) nameSlow() (string, error) {
if errno != 0 {
return "", fmt.Errorf("failed to get name of TUN device: %w", errno)
}
- name := ifr[:]
- if i := bytes.IndexByte(name, 0); i != -1 {
- name = name[:i]
- }
- return string(name), nil
+ return unix.ByteSliceToString(ifr[:]), nil
}
-func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
- if tun.nopi {
- buf = buf[offset:]
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
+ tun.writeOpMu.Lock()
+ defer func() {
+ tun.tcpGROTable.reset()
+ tun.udpGROTable.reset()
+ tun.writeOpMu.Unlock()
+ }()
+ var (
+ errs error
+ total int
+ )
+ tun.toWrite = tun.toWrite[:0]
+ if tun.vnetHdr {
+ err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, &tun.toWrite)
+ if err != nil {
+ return 0, err
+ }
+ offset -= virtioNetHdrLen
} else {
- // reserve space for header
- buf = buf[offset-4:]
-
- // add packet information header
- buf[0] = 0x00
- buf[1] = 0x00
- if buf[4]>>4 == ipv6.Version {
- buf[2] = 0x86
- buf[3] = 0xdd
+ for i := range bufs {
+ tun.toWrite = append(tun.toWrite, i)
+ }
+ }
+ for _, bufsI := range tun.toWrite {
+ n, err := tun.tunFile.Write(bufs[bufsI][offset:])
+ if errors.Is(err, syscall.EBADFD) {
+ return total, os.ErrClosed
+ }
+ if err != nil {
+ errs = errors.Join(errs, err)
} else {
- buf[2] = 0x08
- buf[3] = 0x00
+ total += n
}
}
+ return total, errs
+}
- n, err := tun.tunFile.Write(buf)
- if errors.Is(err, syscall.EBADFD) {
- err = os.ErrClosed
+// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
+// each buffer. It mutates sizes to reflect the size of each element of bufs,
+// and returns the number of packets read.
+func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
+ var hdr virtioNetHdr
+ err := hdr.decode(in)
+ if err != nil {
+ return 0, err
+ }
+ in = in[virtioNetHdrLen:]
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
+ if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
+ // This means CHECKSUM_PARTIAL in skb context. We are responsible
+ // for computing the checksum starting at hdr.csumStart and placing
+ // at hdr.csumOffset.
+ err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset)
+ if err != nil {
+ return 0, err
+ }
+ }
+ if len(in) > len(bufs[0][offset:]) {
+ return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:]))
+ }
+ n := copy(bufs[0][offset:], in)
+ sizes[0] = n
+ return 1, nil
+ }
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+ return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
}
- return n, err
-}
-func (tun *NativeTun) Flush() error {
- // TODO: can flushing be implemented by buffering and using sendmmsg?
- return nil
+ ipVersion := in[0] >> 4
+ switch ipVersion {
+ case 4:
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+ return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
+ }
+ case 6:
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+ return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
+ }
+ default:
+ return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
+ }
+
+ // Don't trust hdr.hdrLen from the kernel as it can be equal to the length
+ // of the entire first packet when the kernel is handling it as part of a
+ // FORWARD path. Instead, parse the transport header length and add it onto
+ // csumStart, which is synonymous for IP header length.
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+ hdr.hdrLen = hdr.csumStart + 8
+ } else {
+ if len(in) <= int(hdr.csumStart+12) {
+ return 0, errors.New("packet is too short")
+ }
+
+ tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
+ if tcpHLen < 20 || tcpHLen > 60 {
+ // A TCP header must be between 20 and 60 bytes in length.
+ return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
+ }
+ hdr.hdrLen = hdr.csumStart + tcpHLen
+ }
+
+ if len(in) < int(hdr.hdrLen) {
+ return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
+ }
+
+ if hdr.hdrLen < hdr.csumStart {
+ return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
+ }
+ cSumAt := int(hdr.csumStart + hdr.csumOffset)
+ if cSumAt+1 >= len(in) {
+ return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
+ }
+
+ return gsoSplit(in, hdr, bufs, sizes, offset, ipVersion == 6)
}
-func (tun *NativeTun) Read(buf []byte, offset int) (n int, err error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
+ tun.readOpMu.Lock()
+ defer tun.readOpMu.Unlock()
select {
- case err = <-tun.errors:
+ case err := <-tun.errors:
+ return 0, err
default:
- if tun.nopi {
- n, err = tun.tunFile.Read(buf[offset:])
+ readInto := bufs[0][offset:]
+ if tun.vnetHdr {
+ readInto = tun.readBuff[:]
+ }
+ n, err := tun.tunFile.Read(readInto)
+ if errors.Is(err, syscall.EBADFD) {
+ err = os.ErrClosed
+ }
+ if err != nil {
+ return 0, err
+ }
+ if tun.vnetHdr {
+ return handleVirtioRead(readInto[:n], bufs, sizes, offset)
} else {
- buff := buf[offset-4:]
- n, err = tun.tunFile.Read(buff[:])
- if errors.Is(err, syscall.EBADFD) {
- err = os.ErrClosed
- }
- if n < 4 {
- n = 0
- } else {
- n -= 4
- }
+ sizes[0] = n
+ return 1, nil
}
}
- return
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
@@ -406,8 +498,58 @@ func (tun *NativeTun) Close() error {
return err2
}
+func (tun *NativeTun) BatchSize() int {
+ return tun.batchSize
+}
+
+const (
+ // TODO: support TSO with ECN bits
+ tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
+ tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
+)
+
+func (tun *NativeTun) initFromFlags(name string) error {
+ sc, err := tun.tunFile.SyscallConn()
+ if err != nil {
+ return err
+ }
+ if e := sc.Control(func(fd uintptr) {
+ var (
+ ifr *unix.Ifreq
+ )
+ ifr, err = unix.NewIfreq(name)
+ if err != nil {
+ return
+ }
+ err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
+ if err != nil {
+ return
+ }
+ got := ifr.Uint16()
+ if got&unix.IFF_VNET_HDR != 0 {
+ // tunTCPOffloads were added in Linux v2.6. We require their support
+ // if IFF_VNET_HDR is set.
+ err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads)
+ if err != nil {
+ return
+ }
+ tun.vnetHdr = true
+ tun.batchSize = conn.IdealBatchSize
+ // tunUDPOffloads were added in Linux v6.2. We do not return an
+ // error if they are unsupported at runtime.
+ tun.udpGSO = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) == nil
+ } else {
+ tun.batchSize = 1
+ }
+ }); e != nil {
+ return e
+ }
+ return err
+}
+
+// CreateTUN creates a Device with the provided name and MTU.
func CreateTUN(name string, mtu int) (Device, error) {
- nfd, err := unix.Open(cloneDevicePath, os.O_RDWR, 0)
+ nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
if err != nil {
if os.IsNotExist(err) {
return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath)
@@ -415,25 +557,16 @@ func CreateTUN(name string, mtu int) (Device, error) {
return nil, err
}
- var ifr [ifReqSize]byte
- var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack)
- nameBytes := []byte(name)
- if len(nameBytes) >= unix.IFNAMSIZ {
- unix.Close(nfd)
- return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG)
+ ifr, err := unix.NewIfreq(name)
+ if err != nil {
+ return nil, err
}
- copy(ifr[:], nameBytes)
- *(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags
-
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(nfd),
- uintptr(unix.TUNSETIFF),
- uintptr(unsafe.Pointer(&ifr[0])),
- )
- if errno != 0 {
- unix.Close(nfd)
- return nil, errno
+ // IFF_VNET_HDR enables the "tun status hack" via routineHackListener()
+ // where a null write will return EINVAL indicating the TUN is up.
+ ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
+ err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr)
+ if err != nil {
+ return nil, err
}
err = unix.SetNonblock(nfd, true)
@@ -448,13 +581,16 @@ func CreateTUN(name string, mtu int) (Device, error) {
return CreateTUNFromFile(fd, mtu)
}
+// CreateTUNFromFile creates a Device from an os.File with the provided MTU.
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{
tunFile: file,
events: make(chan Event, 5),
errors: make(chan error, 5),
statusListenersShutdown: make(chan struct{}),
- nopi: false,
+ tcpGROTable: newTCPGROTable(),
+ udpGROTable: newUDPGROTable(),
+ toWrite: make([]int, 0, conn.IdealBatchSize),
}
name, err := tun.Name()
@@ -462,8 +598,12 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return nil, err
}
- // start event listener
+ err = tun.initFromFlags(name)
+ if err != nil {
+ return nil, err
+ }
+ // start event listener
tun.index, err = getIFIndex(name)
if err != nil {
return nil, err
@@ -492,6 +632,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return tun, nil
}
+// CreateUnmonitoredTUNFromFD creates a Device from the provided file
+// descriptor.
func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
err := unix.SetNonblock(fd, true)
if err != nil {
@@ -499,14 +641,20 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
}
file := os.NewFile(uintptr(fd), "/dev/tun")
tun := &NativeTun{
- tunFile: file,
- events: make(chan Event, 5),
- errors: make(chan error, 5),
- nopi: true,
+ tunFile: file,
+ events: make(chan Event, 5),
+ errors: make(chan error, 5),
+ tcpGROTable: newTCPGROTable(),
+ udpGROTable: newUDPGROTable(),
+ toWrite: make([]int, 0, conn.IdealBatchSize),
}
name, err := tun.Name()
if err != nil {
return nil, "", err
}
- return tun, name, nil
+ err = tun.initFromFlags(name)
+ if err != nil {
+ return nil, "", err
+ }
+ return tun, name, err
}