diff options
Diffstat (limited to 'tun/tun_linux.go')
-rw-r--r-- | tun/tun_linux.go | 308 |
1 files changed, 228 insertions, 80 deletions
diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 1cc84cb..bd69cb5 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun @@ -9,7 +9,6 @@ package tun */ import ( - "bytes" "errors" "fmt" "os" @@ -18,9 +17,8 @@ import ( "time" "unsafe" - "golang.org/x/net/ipv6" "golang.org/x/sys/unix" - + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/rwcancel" ) @@ -34,17 +32,27 @@ type NativeTun struct { index int32 // if index errors chan error // async error handling events chan Event // device related events - nopi bool // the device was passed IFF_NO_PI netlinkSock int netlinkCancel *rwcancel.RWCancel hackListenerClosed sync.Mutex statusListenersShutdown chan struct{} + batchSize int + vnetHdr bool + udpGSO bool closeOnce sync.Once nameOnce sync.Once // guards calling initNameCache, which sets following fields nameCache string // name of interface nameErr error + + readOpMu sync.Mutex // readOpMu guards readBuff + readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr + + writeOpMu sync.Mutex // writeOpMu guards toWrite, tcpGROTable + toWrite []int + tcpGROTable *tcpGROTable + udpGROTable *udpGROTable } func (tun *NativeTun) File() *os.File { @@ -100,7 +108,7 @@ func (tun *NativeTun) routineHackListener() { } func createNetlinkSocket() (int, error) { - sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) + sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE) if err != nil { return -1, err } @@ -195,7 +203,7 @@ func (tun *NativeTun) routineNetlinkListener() { func getIFIndex(name string) (int32, error) { fd, err := unix.Socket( unix.AF_INET, - unix.SOCK_DGRAM, + unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0, ) if err != nil { @@ -229,10 +237,9 @@ func (tun *NativeTun) setMTU(n int) error { // open datagram socket fd, err := unix.Socket( unix.AF_INET, - unix.SOCK_DGRAM, + unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0, ) - if err != nil { return err } @@ -266,10 +273,9 @@ func (tun *NativeTun) MTU() (int, error) { // open datagram socket fd, err := unix.Socket( unix.AF_INET, - unix.SOCK_DGRAM, + unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0, ) - if err != nil { return 0, err } @@ -323,67 +329,153 @@ func (tun *NativeTun) nameSlow() (string, error) { if errno != 0 { return "", fmt.Errorf("failed to get name of TUN device: %w", errno) } - name := ifr[:] - if i := bytes.IndexByte(name, 0); i != -1 { - name = name[:i] - } - return string(name), nil + return unix.ByteSliceToString(ifr[:]), nil } -func (tun *NativeTun) Write(buf []byte, offset int) (int, error) { - if tun.nopi { - buf = buf[offset:] +func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { + tun.writeOpMu.Lock() + defer func() { + tun.tcpGROTable.reset() + tun.udpGROTable.reset() + tun.writeOpMu.Unlock() + }() + var ( + errs error + total int + ) + tun.toWrite = tun.toWrite[:0] + if tun.vnetHdr { + err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, &tun.toWrite) + if err != nil { + return 0, err + } + offset -= virtioNetHdrLen } else { - // reserve space for header - buf = buf[offset-4:] - - // add packet information header - buf[0] = 0x00 - buf[1] = 0x00 - if buf[4]>>4 == ipv6.Version { - buf[2] = 0x86 - buf[3] = 0xdd + for i := range bufs { + tun.toWrite = append(tun.toWrite, i) + } + } + for _, bufsI := range tun.toWrite { + n, err := tun.tunFile.Write(bufs[bufsI][offset:]) + if errors.Is(err, syscall.EBADFD) { + return total, os.ErrClosed + } + if err != nil { + errs = errors.Join(errs, err) } else { - buf[2] = 0x08 - buf[3] = 0x00 + total += n } } + return total, errs +} - n, err := tun.tunFile.Write(buf) - if errors.Is(err, syscall.EBADFD) { - err = os.ErrClosed +// handleVirtioRead splits in into bufs, leaving offset bytes at the front of +// each buffer. It mutates sizes to reflect the size of each element of bufs, +// and returns the number of packets read. +func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) { + var hdr virtioNetHdr + err := hdr.decode(in) + if err != nil { + return 0, err + } + in = in[virtioNetHdrLen:] + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE { + if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { + // This means CHECKSUM_PARTIAL in skb context. We are responsible + // for computing the checksum starting at hdr.csumStart and placing + // at hdr.csumOffset. + err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset) + if err != nil { + return 0, err + } + } + if len(in) > len(bufs[0][offset:]) { + return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:])) + } + n := copy(bufs[0][offset:], in) + sizes[0] = n + return 1, nil + } + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { + return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType) } - return n, err -} -func (tun *NativeTun) Flush() error { - // TODO: can flushing be implemented by buffering and using sendmmsg? - return nil + ipVersion := in[0] >> 4 + switch ipVersion { + case 4: + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) + } + case 6: + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) + } + default: + return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) + } + + // Don't trust hdr.hdrLen from the kernel as it can be equal to the length + // of the entire first packet when the kernel is handling it as part of a + // FORWARD path. Instead, parse the transport header length and add it onto + // csumStart, which is synonymous for IP header length. + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 { + hdr.hdrLen = hdr.csumStart + 8 + } else { + if len(in) <= int(hdr.csumStart+12) { + return 0, errors.New("packet is too short") + } + + tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) + if tcpHLen < 20 || tcpHLen > 60 { + // A TCP header must be between 20 and 60 bytes in length. + return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) + } + hdr.hdrLen = hdr.csumStart + tcpHLen + } + + if len(in) < int(hdr.hdrLen) { + return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen) + } + + if hdr.hdrLen < hdr.csumStart { + return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart) + } + cSumAt := int(hdr.csumStart + hdr.csumOffset) + if cSumAt+1 >= len(in) { + return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) + } + + return gsoSplit(in, hdr, bufs, sizes, offset, ipVersion == 6) } -func (tun *NativeTun) Read(buf []byte, offset int) (n int, err error) { +func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { + tun.readOpMu.Lock() + defer tun.readOpMu.Unlock() select { - case err = <-tun.errors: + case err := <-tun.errors: + return 0, err default: - if tun.nopi { - n, err = tun.tunFile.Read(buf[offset:]) + readInto := bufs[0][offset:] + if tun.vnetHdr { + readInto = tun.readBuff[:] + } + n, err := tun.tunFile.Read(readInto) + if errors.Is(err, syscall.EBADFD) { + err = os.ErrClosed + } + if err != nil { + return 0, err + } + if tun.vnetHdr { + return handleVirtioRead(readInto[:n], bufs, sizes, offset) } else { - buff := buf[offset-4:] - n, err = tun.tunFile.Read(buff[:]) - if errors.Is(err, syscall.EBADFD) { - err = os.ErrClosed - } - if n < 4 { - n = 0 - } else { - n -= 4 - } + sizes[0] = n + return 1, nil } } - return } -func (tun *NativeTun) Events() chan Event { +func (tun *NativeTun) Events() <-chan Event { return tun.events } @@ -406,8 +498,58 @@ func (tun *NativeTun) Close() error { return err2 } +func (tun *NativeTun) BatchSize() int { + return tun.batchSize +} + +const ( + // TODO: support TSO with ECN bits + tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6 +) + +func (tun *NativeTun) initFromFlags(name string) error { + sc, err := tun.tunFile.SyscallConn() + if err != nil { + return err + } + if e := sc.Control(func(fd uintptr) { + var ( + ifr *unix.Ifreq + ) + ifr, err = unix.NewIfreq(name) + if err != nil { + return + } + err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr) + if err != nil { + return + } + got := ifr.Uint16() + if got&unix.IFF_VNET_HDR != 0 { + // tunTCPOffloads were added in Linux v2.6. We require their support + // if IFF_VNET_HDR is set. + err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads) + if err != nil { + return + } + tun.vnetHdr = true + tun.batchSize = conn.IdealBatchSize + // tunUDPOffloads were added in Linux v6.2. We do not return an + // error if they are unsupported at runtime. + tun.udpGSO = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) == nil + } else { + tun.batchSize = 1 + } + }); e != nil { + return e + } + return err +} + +// CreateTUN creates a Device with the provided name and MTU. func CreateTUN(name string, mtu int) (Device, error) { - nfd, err := unix.Open(cloneDevicePath, os.O_RDWR, 0) + nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0) if err != nil { if os.IsNotExist(err) { return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath) @@ -415,25 +557,16 @@ func CreateTUN(name string, mtu int) (Device, error) { return nil, err } - var ifr [ifReqSize]byte - var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack) - nameBytes := []byte(name) - if len(nameBytes) >= unix.IFNAMSIZ { - unix.Close(nfd) - return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG) + ifr, err := unix.NewIfreq(name) + if err != nil { + return nil, err } - copy(ifr[:], nameBytes) - *(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags - - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(nfd), - uintptr(unix.TUNSETIFF), - uintptr(unsafe.Pointer(&ifr[0])), - ) - if errno != 0 { - unix.Close(nfd) - return nil, errno + // IFF_VNET_HDR enables the "tun status hack" via routineHackListener() + // where a null write will return EINVAL indicating the TUN is up. + ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR) + err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr) + if err != nil { + return nil, err } err = unix.SetNonblock(nfd, true) @@ -448,13 +581,16 @@ func CreateTUN(name string, mtu int) (Device, error) { return CreateTUNFromFile(fd, mtu) } +// CreateTUNFromFile creates a Device from an os.File with the provided MTU. func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { tun := &NativeTun{ tunFile: file, events: make(chan Event, 5), errors: make(chan error, 5), statusListenersShutdown: make(chan struct{}), - nopi: false, + tcpGROTable: newTCPGROTable(), + udpGROTable: newUDPGROTable(), + toWrite: make([]int, 0, conn.IdealBatchSize), } name, err := tun.Name() @@ -462,8 +598,12 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { return nil, err } - // start event listener + err = tun.initFromFlags(name) + if err != nil { + return nil, err + } + // start event listener tun.index, err = getIFIndex(name) if err != nil { return nil, err @@ -492,6 +632,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { return tun, nil } +// CreateUnmonitoredTUNFromFD creates a Device from the provided file +// descriptor. func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { err := unix.SetNonblock(fd, true) if err != nil { @@ -499,14 +641,20 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { } file := os.NewFile(uintptr(fd), "/dev/tun") tun := &NativeTun{ - tunFile: file, - events: make(chan Event, 5), - errors: make(chan error, 5), - nopi: true, + tunFile: file, + events: make(chan Event, 5), + errors: make(chan error, 5), + tcpGROTable: newTCPGROTable(), + udpGROTable: newUDPGROTable(), + toWrite: make([]int, 0, conn.IdealBatchSize), } name, err := tun.Name() if err != nil { return nil, "", err } - return tun, name, nil + err = tun.initFromFlags(name) + if err != nil { + return nil, "", err + } + return tun, name, err } |