aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJordan Whited <jordan@tailscale.com>2023-10-31 19:53:35 -0700
committerJason A. Donenfeld <Jason@zx2c4.com>2023-12-11 16:27:22 +0100
commitd0bc03c707974a84a672716c718f99fab49e7eb8 (patch)
tree2d11dff4d00bdee8b22a570b1db665f523f38a73
parenttun: fix Device.Read() buf length assumption on Windows (diff)
downloadwireguard-go-d0bc03c707974a84a672716c718f99fab49e7eb8.tar.xz
wireguard-go-d0bc03c707974a84a672716c718f99fab49e7eb8.zip
tun: implement UDP GSO/GRO for Linux
Implement UDP GSO and GRO for the Linux tun.Device, which is made possible by virtio extensions in the kernel's TUN driver starting in v6.2. secnetperf, a QUIC benchmark utility from microsoft/msquic@8e1eb1a, is used to demonstrate the effect of this commit between two Linux computers with i5-12400 CPUs. There is roughly ~13us of round trip latency between them. secnetperf was invoked with the following command line options: -stats:1 -exec:maxtput -test:tput -download:10000 -timed:1 -encrypt:0 The first result is from commit 2e0774f without UDP GSO/GRO on the TUN. [conn][0x55739a144980] STATS: EcnCapable=0 RTT=3973 us SendTotalPackets=55859 SendSuspectedLostPackets=61 SendSpuriousLostPackets=59 SendCongestionCount=27 SendEcnCongestionCount=0 RecvTotalPackets=2779122 RecvReorderedPackets=0 RecvDroppedPackets=0 RecvDuplicatePackets=0 RecvDecryptionFailures=0 Result: 3654977571 bytes @ 2922821 kbps (10003.972 ms). The second result is with UDP GSO/GRO on the TUN. [conn][0x56493dfd09a0] STATS: EcnCapable=0 RTT=1216 us SendTotalPackets=165033 SendSuspectedLostPackets=64 SendSpuriousLostPackets=61 SendCongestionCount=53 SendEcnCongestionCount=0 RecvTotalPackets=11845268 RecvReorderedPackets=25267 RecvDroppedPackets=0 RecvDuplicatePackets=0 RecvDecryptionFailures=0 Result: 15574671184 bytes @ 12458214 kbps (10001.222 ms). Signed-off-by: Jordan Whited <jordan@tailscale.com> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r--tun/offload_linux.go (renamed from tun/tcp_offload_linux.go)598
-rw-r--r--tun/offload_linux_test.go752
-rw-r--r--tun/tcp_offload_linux_test.go411
-rw-r--r--tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b778
-rw-r--r--tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d8
-rw-r--r--tun/tun_linux.go71
6 files changed, 1258 insertions, 590 deletions
diff --git a/tun/tcp_offload_linux.go b/tun/offload_linux.go
index 1afd27e..9ff7fea 100644
--- a/tun/tcp_offload_linux.go
+++ b/tun/offload_linux.go
@@ -57,22 +57,23 @@ const (
virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
)
-// flowKey represents the key for a flow.
-type flowKey struct {
+// tcpFlowKey represents the key for a TCP flow.
+type tcpFlowKey struct {
srcAddr, dstAddr [16]byte
srcPort, dstPort uint16
rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
+ isV6 bool
}
-// tcpGROTable holds flow and coalescing information for the purposes of GRO.
+// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO.
type tcpGROTable struct {
- itemsByFlow map[flowKey][]tcpGROItem
+ itemsByFlow map[tcpFlowKey][]tcpGROItem
itemsPool [][]tcpGROItem
}
func newTCPGROTable() *tcpGROTable {
t := &tcpGROTable{
- itemsByFlow: make(map[flowKey][]tcpGROItem, conn.IdealBatchSize),
+ itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, conn.IdealBatchSize),
itemsPool: make([][]tcpGROItem, conn.IdealBatchSize),
}
for i := range t.itemsPool {
@@ -81,14 +82,15 @@ func newTCPGROTable() *tcpGROTable {
return t
}
-func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
- key := flowKey{}
- addrSize := dstAddr - srcAddr
- copy(key.srcAddr[:], pkt[srcAddr:dstAddr])
- copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize])
+func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey {
+ key := tcpFlowKey{}
+ addrSize := dstAddrOffset - srcAddrOffset
+ copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
+ copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
+ key.isV6 = addrSize == 16
return key
}
@@ -96,7 +98,7 @@ func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
// returning the packets found for the flow, or inserting a new one if none
// is found.
func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
- key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
+ key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
items, ok := t.itemsByFlow[key]
if ok {
return items, ok
@@ -108,7 +110,7 @@ func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, t
// insert an item in the table for the provided packet and packet metadata.
func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
- key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
+ key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
item := tcpGROItem{
key: key,
bufsIndex: uint16(bufsIndex),
@@ -131,7 +133,7 @@ func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
items[i] = item
}
-func (t *tcpGROTable) deleteAt(key flowKey, i int) {
+func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) {
items, _ := t.itemsByFlow[key]
items = append(items[:i], items[i+1:]...)
t.itemsByFlow[key] = items
@@ -140,7 +142,7 @@ func (t *tcpGROTable) deleteAt(key flowKey, i int) {
// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
// of a GRO evaluation across a vector of packets.
type tcpGROItem struct {
- key flowKey
+ key tcpFlowKey
sentSeq uint32 // the sequence number
bufsIndex uint16 // the index into the original bufs slice
numMerged uint16 // the number of packets merged into this item
@@ -164,6 +166,103 @@ func (t *tcpGROTable) reset() {
}
}
+// udpFlowKey represents the key for a UDP flow.
+type udpFlowKey struct {
+ srcAddr, dstAddr [16]byte
+ srcPort, dstPort uint16
+ isV6 bool
+}
+
+// udpGROTable holds flow and coalescing information for the purposes of UDP GRO.
+type udpGROTable struct {
+ itemsByFlow map[udpFlowKey][]udpGROItem
+ itemsPool [][]udpGROItem
+}
+
+func newUDPGROTable() *udpGROTable {
+ u := &udpGROTable{
+ itemsByFlow: make(map[udpFlowKey][]udpGROItem, conn.IdealBatchSize),
+ itemsPool: make([][]udpGROItem, conn.IdealBatchSize),
+ }
+ for i := range u.itemsPool {
+ u.itemsPool[i] = make([]udpGROItem, 0, conn.IdealBatchSize)
+ }
+ return u
+}
+
+func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey {
+ key := udpFlowKey{}
+ addrSize := dstAddrOffset - srcAddrOffset
+ copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
+ copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
+ key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:])
+ key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:])
+ key.isV6 = addrSize == 16
+ return key
+}
+
+// lookupOrInsert looks up a flow for the provided packet and metadata,
+// returning the packets found for the flow, or inserting a new one if none
+// is found.
+func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) {
+ key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
+ items, ok := u.itemsByFlow[key]
+ if ok {
+ return items, ok
+ }
+ // TODO: insert() performs another map lookup. This could be rearranged to avoid.
+ u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false)
+ return nil, false
+}
+
+// insert an item in the table for the provided packet and packet metadata.
+func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) {
+ key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
+ item := udpGROItem{
+ key: key,
+ bufsIndex: uint16(bufsIndex),
+ gsoSize: uint16(len(pkt[udphOffset+udphLen:])),
+ iphLen: uint8(udphOffset),
+ cSumKnownInvalid: cSumKnownInvalid,
+ }
+ items, ok := u.itemsByFlow[key]
+ if !ok {
+ items = u.newItems()
+ }
+ items = append(items, item)
+ u.itemsByFlow[key] = items
+}
+
+func (u *udpGROTable) updateAt(item udpGROItem, i int) {
+ items, _ := u.itemsByFlow[item.key]
+ items[i] = item
+}
+
+// udpGROItem represents bookkeeping data for a UDP packet during the lifetime
+// of a GRO evaluation across a vector of packets.
+type udpGROItem struct {
+ key udpFlowKey
+ bufsIndex uint16 // the index into the original bufs slice
+ numMerged uint16 // the number of packets merged into this item
+ gsoSize uint16 // payload size
+ iphLen uint8 // ip header len
+ cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown.
+}
+
+func (u *udpGROTable) newItems() []udpGROItem {
+ var items []udpGROItem
+ items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1]
+ return items
+}
+
+func (u *udpGROTable) reset() {
+ for k, items := range u.itemsByFlow {
+ items = items[:0]
+ u.itemsPool = append(u.itemsPool, items)
+ delete(u.itemsByFlow, k)
+ }
+}
+
// canCoalesce represents the outcome of checking if two TCP packets are
// candidates for coalescing.
type canCoalesce int
@@ -174,6 +273,61 @@ const (
coalesceAppend canCoalesce = 1
)
+// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB
+// meet all requirements to be merged as part of a GRO operation, otherwise it
+// returns false.
+func ipHeadersCanCoalesce(pktA, pktB []byte) bool {
+ if len(pktA) < 9 || len(pktB) < 9 {
+ return false
+ }
+ if pktA[0]>>4 == 6 {
+ if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 {
+ // cannot coalesce with unequal Traffic class values
+ return false
+ }
+ if pktA[7] != pktB[7] {
+ // cannot coalesce with unequal Hop limit values
+ return false
+ }
+ } else {
+ if pktA[1] != pktB[1] {
+ // cannot coalesce with unequal ToS values
+ return false
+ }
+ if pktA[6]>>5 != pktB[6]>>5 {
+ // cannot coalesce with unequal DF or reserved bits. MF is checked
+ // further up the stack.
+ return false
+ }
+ if pktA[8] != pktB[8] {
+ // cannot coalesce with unequal TTL values
+ return false
+ }
+ }
+ return true
+}
+
+// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
+// described by item. iphLen and gsoSize describe pkt. bufs is the vector of
+// packets involved in the current GRO evaluation. bufsOffset is the offset at
+// which packet data begins within bufs.
+func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
+ pktTarget := bufs[item.bufsIndex][bufsOffset:]
+ if !ipHeadersCanCoalesce(pkt, pktTarget) {
+ return coalesceUnavailable
+ }
+ if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 {
+ // A smaller than gsoSize packet has been appended previously.
+ // Nothing can come after a smaller packet on the end.
+ return coalesceUnavailable
+ }
+ if gsoSize > item.gsoSize {
+ // We cannot have a larger packet following a smaller one.
+ return coalesceUnavailable
+ }
+ return coalesceAppend
+}
+
// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
// described by item. This function makes considerations that match the kernel's
// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
@@ -189,29 +343,8 @@ func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet
return coalesceUnavailable
}
}
- if pkt[0]>>4 == 6 {
- if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 {
- // cannot coalesce with unequal Traffic class values
- return coalesceUnavailable
- }
- if pkt[7] != pktTarget[7] {
- // cannot coalesce with unequal Hop limit values
- return coalesceUnavailable
- }
- } else {
- if pkt[1] != pktTarget[1] {
- // cannot coalesce with unequal ToS values
- return coalesceUnavailable
- }
- if pkt[6]>>5 != pktTarget[6]>>5 {
- // cannot coalesce with unequal DF or reserved bits. MF is checked
- // further up the stack.
- return coalesceUnavailable
- }
- if pkt[8] != pktTarget[8] {
- // cannot coalesce with unequal TTL values
- return coalesceUnavailable
- }
+ if !ipHeadersCanCoalesce(pkt, pktTarget) {
+ return coalesceUnavailable
}
// seq adjacency
lhsLen := item.gsoSize
@@ -252,16 +385,16 @@ func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet
return coalesceUnavailable
}
-func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
+func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool {
srcAddrAt := ipv4SrcAddrOffset
addrSize := 4
if isV6 {
srcAddrAt = ipv6SrcAddrOffset
addrSize = 16
}
- tcpTotalLen := uint16(len(pkt) - int(iphLen))
- tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen)
- return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0
+ lenForPseudo := uint16(len(pkt) - int(iphLen))
+ cSum := pseudoHeaderChecksumNoFold(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo)
+ return ^checksum(pkt[iphLen:], cSum) == 0
}
// coalesceResult represents the result of attempting to coalesce two TCP
@@ -276,8 +409,36 @@ const (
coalesceSuccess
)
+// coalesceUDPPackets attempts to coalesce pkt with the packet described by
+// item, and returns the outcome.
+func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
+ pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front
+ headersLen := item.iphLen + udphLen
+ coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
+
+ if cap(pktHead)-bufsOffset < coalescedLen {
+ // We don't want to allocate a new underlying array if capacity is
+ // too small.
+ return coalesceInsufficientCap
+ }
+ if item.numMerged == 0 {
+ if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) {
+ return coalesceItemInvalidCSum
+ }
+ }
+ if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) {
+ return coalescePktInvalidCSum
+ }
+ extendBy := len(pkt) - int(headersLen)
+ bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
+ copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
+
+ item.numMerged++
+ return coalesceSuccess
+}
+
// coalesceTCPPackets attempts to coalesce pkt with the packet described by
-// item, returning the outcome. This function may swap bufs elements in the
+// item, and returns the outcome. This function may swap bufs elements in the
// event of a prepend as item's bufs index is already being tracked for writing
// to a Device.
func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
@@ -297,11 +458,11 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize
return coalescePSHEnding
}
if item.numMerged == 0 {
- if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
+ if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
return coalesceItemInvalidCSum
}
}
- if !tcpChecksumValid(pkt, item.iphLen, isV6) {
+ if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
return coalescePktInvalidCSum
}
item.sentSeq = seq
@@ -319,11 +480,11 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize
return coalesceInsufficientCap
}
if item.numMerged == 0 {
- if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
+ if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
return coalesceItemInvalidCSum
}
}
- if !tcpChecksumValid(pkt, item.iphLen, isV6) {
+ if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
return coalescePktInvalidCSum
}
if pshSet {
@@ -354,52 +515,52 @@ const (
maxUint16 = 1<<16 - 1
)
-type tcpGROResult int
+type groResult int
const (
- tcpGROResultNoop tcpGROResult = iota
- tcpGROResultTableInsert
- tcpGROResultCoalesced
+ groResultNoop groResult = iota
+ groResultTableInsert
+ groResultCoalesced
)
// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
-// existing packets tracked in table. It returns a tcpGROResultNoop when no
-// action was taken, tcpGROResultTableInsert when the evaluated packet was
-// inserted into table, and tcpGROResultCoalesced when the evaluated packet was
+// existing packets tracked in table. It returns a groResultNoop when no
+// action was taken, groResultTableInsert when the evaluated packet was
+// inserted into table, and groResultCoalesced when the evaluated packet was
// coalesced with another packet in table.
-func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) tcpGROResult {
+func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult {
pkt := bufs[pktI][offset:]
if len(pkt) > maxUint16 {
// A valid IPv4 or IPv6 packet will never exceed this.
- return tcpGROResultNoop
+ return groResultNoop
}
iphLen := int((pkt[0] & 0x0F) * 4)
if isV6 {
iphLen = 40
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
if ipv6HPayloadLen != len(pkt)-iphLen {
- return tcpGROResultNoop
+ return groResultNoop
}
} else {
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
if totalLen != len(pkt) {
- return tcpGROResultNoop
+ return groResultNoop
}
}
if len(pkt) < iphLen {
- return tcpGROResultNoop
+ return groResultNoop
}
tcphLen := int((pkt[iphLen+12] >> 4) * 4)
if tcphLen < 20 || tcphLen > 60 {
- return tcpGROResultNoop
+ return groResultNoop
}
if len(pkt) < iphLen+tcphLen {
- return tcpGROResultNoop
+ return groResultNoop
}
if !isV6 {
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
// no GRO support for fragmented segments for now
- return tcpGROResultNoop
+ return groResultNoop
}
}
tcpFlags := pkt[iphLen+tcpFlagsOffset]
@@ -407,14 +568,14 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool)
// not a candidate if any non-ACK flags (except PSH+ACK) are set
if tcpFlags != tcpFlagACK {
if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
- return tcpGROResultNoop
+ return groResultNoop
}
pshSet = true
}
gsoSize := uint16(len(pkt) - tcphLen - iphLen)
// not a candidate if payload len is 0
if gsoSize < 1 {
- return tcpGROResultNoop
+ return groResultNoop
}
seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
srcAddrOffset := ipv4SrcAddrOffset
@@ -425,7 +586,7 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool)
}
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
if !existing {
- return tcpGROResultNoop
+ return groResultTableInsert
}
for i := len(items) - 1; i >= 0; i-- {
// In the best case of packets arriving in order iterating in reverse is
@@ -443,54 +604,25 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool)
switch result {
case coalesceSuccess:
table.updateAt(item, i)
- return tcpGROResultCoalesced
+ return groResultCoalesced
case coalesceItemInvalidCSum:
// delete the item with an invalid csum
table.deleteAt(item.key, i)
case coalescePktInvalidCSum:
// no point in inserting an item that we can't coalesce
- return tcpGROResultNoop
+ return groResultNoop
default:
}
}
}
// failed to coalesce with any other packets; store the item in the flow
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
- return tcpGROResultTableInsert
-}
-
-func isTCP4NoIPOptions(b []byte) bool {
- if len(b) < 40 {
- return false
- }
- if b[0]>>4 != 4 {
- return false
- }
- if b[0]&0x0F != 5 {
- return false
- }
- if b[9] != unix.IPPROTO_TCP {
- return false
- }
- return true
+ return groResultTableInsert
}
-func isTCP6NoEH(b []byte) bool {
- if len(b) < 60 {
- return false
- }
- if b[0]>>4 != 6 {
- return false
- }
- if b[6] != unix.IPPROTO_TCP {
- return false
- }
- return true
-}
-
-// applyCoalesceAccounting updates bufs to account for coalescing based on the
+// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the
// metadata found in table.
-func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 bool) error {
+func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error {
for _, items := range table.itemsByFlow {
for _, item := range items {
if item.numMerged > 0 {
@@ -505,7 +637,7 @@ func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6
// Recalculate the total len (IPv4) or payload len (IPv6).
// Recalculate the (IPv4) header checksum.
- if isV6 {
+ if item.key.isV6 {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
} else {
@@ -525,7 +657,7 @@ func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6
// this with computation of the tcp header and payload checksum.
addrLen := 4
addrOffset := ipv4SrcAddrOffset
- if isV6 {
+ if item.key.isV6 {
addrLen = 16
addrOffset = ipv6SrcAddrOffset
}
@@ -546,54 +678,245 @@ func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6
return nil
}
+// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the
+// metadata found in table.
+func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error {
+ for _, items := range table.itemsByFlow {
+ for _, item := range items {
+ if item.numMerged > 0 {
+ hdr := virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
+ hdrLen: uint16(item.iphLen + udphLen),
+ gsoSize: item.gsoSize,
+ csumStart: uint16(item.iphLen),
+ csumOffset: 6,
+ }
+ pkt := bufs[item.bufsIndex][offset:]
+
+ // Recalculate the total len (IPv4) or payload len (IPv6).
+ // Recalculate the (IPv4) header checksum.
+ hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4
+ if item.key.isV6 {
+ binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
+ } else {
+ pkt[10], pkt[11] = 0, 0
+ binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
+ iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
+ binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
+ }
+ err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+
+ // Recalculate the UDP len field value
+ binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:])))
+
+ // Calculate the pseudo header checksum and place it at the UDP
+ // checksum offset. Downstream checksum offloading will combine
+ // this with computation of the udp header and payload checksum.
+ addrLen := 4
+ addrOffset := ipv4SrcAddrOffset
+ if item.key.isV6 {
+ addrLen = 16
+ addrOffset = ipv6SrcAddrOffset
+ }
+ srcAddrAt := offset + addrOffset
+ srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
+ dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
+ psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
+ binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
+ } else {
+ hdr := virtioNetHdr{}
+ err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+ }
+ }
+ }
+ return nil
+}
+
+type groCandidateType uint8
+
+const (
+ notGROCandidate groCandidateType = iota
+ tcp4GROCandidate
+ tcp6GROCandidate
+ udp4GROCandidate
+ udp6GROCandidate
+)
+
+func packetIsGROCandidate(b []byte, canUDPGRO bool) groCandidateType {
+ if len(b) < 28 {
+ return notGROCandidate
+ }
+ if b[0]>>4 == 4 {
+ if b[0]&0x0F != 5 {
+ // IPv4 packets w/IP options do not coalesce
+ return notGROCandidate
+ }
+ if b[9] == unix.IPPROTO_TCP && len(b) >= 40 {
+ return tcp4GROCandidate
+ }
+ if b[9] == unix.IPPROTO_UDP && canUDPGRO {
+ return udp4GROCandidate
+ }
+ } else if b[0]>>4 == 6 {
+ if b[6] == unix.IPPROTO_TCP && len(b) >= 60 {
+ return tcp6GROCandidate
+ }
+ if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && canUDPGRO {
+ return udp6GROCandidate
+ }
+ }
+ return notGROCandidate
+}
+
+const (
+ udphLen = 8
+)
+
+// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with
+// existing packets tracked in table. It returns a groResultNoop when no
+// action was taken, groResultTableInsert when the evaluated packet was
+// inserted into table, and groResultCoalesced when the evaluated packet was
+// coalesced with another packet in table.
+func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult {
+ pkt := bufs[pktI][offset:]
+ if len(pkt) > maxUint16 {
+ // A valid IPv4 or IPv6 packet will never exceed this.
+ return groResultNoop
+ }
+ iphLen := int((pkt[0] & 0x0F) * 4)
+ if isV6 {
+ iphLen = 40
+ ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
+ if ipv6HPayloadLen != len(pkt)-iphLen {
+ return groResultNoop
+ }
+ } else {
+ totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
+ if totalLen != len(pkt) {
+ return groResultNoop
+ }
+ }
+ if len(pkt) < iphLen {
+ return groResultNoop
+ }
+ if len(pkt) < iphLen+udphLen {
+ return groResultNoop
+ }
+ if !isV6 {
+ if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
+ // no GRO support for fragmented segments for now
+ return groResultNoop
+ }
+ }
+ gsoSize := uint16(len(pkt) - udphLen - iphLen)
+ // not a candidate if payload len is 0
+ if gsoSize < 1 {
+ return groResultNoop
+ }
+ srcAddrOffset := ipv4SrcAddrOffset
+ addrLen := 4
+ if isV6 {
+ srcAddrOffset = ipv6SrcAddrOffset
+ addrLen = 16
+ }
+ items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI)
+ if !existing {
+ return groResultTableInsert
+ }
+ // With UDP we only check the last item, otherwise we could reorder packets
+ // for a given flow. We must also always insert a new item, or successfully
+ // coalesce with an existing item, for the same reason.
+ item := items[len(items)-1]
+ can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset)
+ var pktCSumKnownInvalid bool
+ if can == coalesceAppend {
+ result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6)
+ switch result {
+ case coalesceSuccess:
+ table.updateAt(item, len(items)-1)
+ return groResultCoalesced
+ case coalesceItemInvalidCSum:
+ // If the existing item has an invalid csum we take no action. A new
+ // item will be stored after it, and the existing item will never be
+ // revisited as part of future coalescing candidacy checks.
+ case coalescePktInvalidCSum:
+ // We must insert a new item, but we also mark it as invalid csum
+ // to prevent a repeat checksum validation.
+ pktCSumKnownInvalid = true
+ default:
+ }
+ }
+ // failed to coalesce with any other packets; store the item in the flow
+ table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid)
+ return groResultTableInsert
+}
+
// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
-// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
+// packets into toWrite. toWrite, tcpTable, and udpTable should initially be
// empty (but non-nil), and are passed in to save allocs as the caller may reset
-// and recycle them across vectors of packets.
-func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error {
+// and recycle them across vectors of packets. canUDPGRO indicates if UDP GRO is
+// supported.
+func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, canUDPGRO bool, toWrite *[]int) error {
for i := range bufs {
if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
return errors.New("invalid offset")
}
- var result tcpGROResult
- switch {
- case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce
- result = tcpGRO(bufs, offset, i, tcp4Table, false)
- case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
- result = tcpGRO(bufs, offset, i, tcp6Table, true)
+ var result groResult
+ switch packetIsGROCandidate(bufs[i][offset:], canUDPGRO) {
+ case tcp4GROCandidate:
+ result = tcpGRO(bufs, offset, i, tcpTable, false)
+ case tcp6GROCandidate:
+ result = tcpGRO(bufs, offset, i, tcpTable, true)
+ case udp4GROCandidate:
+ result = udpGRO(bufs, offset, i, udpTable, false)
+ case udp6GROCandidate:
+ result = udpGRO(bufs, offset, i, udpTable, true)
}
switch result {
- case tcpGROResultNoop:
+ case groResultNoop:
hdr := virtioNetHdr{}
err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
if err != nil {
return err
}
fallthrough
- case tcpGROResultTableInsert:
+ case groResultTableInsert:
*toWrite = append(*toWrite, i)
}
}
- err4 := applyCoalesceAccounting(bufs, offset, tcp4Table, false)
- err6 := applyCoalesceAccounting(bufs, offset, tcp6Table, true)
- return errors.Join(err4, err6)
+ errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable)
+ errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable)
+ return errors.Join(errTCP, errUDP)
}
-// tcpTSO splits packets from in into outBuffs, writing the size of each
+// gsoSplit splits packets from in into outBuffs, writing the size of each
// element into sizes. It returns the number of buffers populated, and/or an
// error.
-func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) {
+func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int, isV6 bool) (int, error) {
iphLen := int(hdr.csumStart)
srcAddrOffset := ipv6SrcAddrOffset
addrLen := 16
- if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
+ if !isV6 {
in[10], in[11] = 0, 0 // clear ipv4 header checksum
srcAddrOffset = ipv4SrcAddrOffset
addrLen = 4
}
- tcpCSumAt := int(hdr.csumStart + hdr.csumOffset)
- in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum
- firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:])
+ transportCsumAt := int(hdr.csumStart + hdr.csumOffset)
+ in[transportCsumAt], in[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum
+ var firstTCPSeqNum uint32
+ var protocol uint8
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV6 {
+ protocol = unix.IPPROTO_TCP
+ firstTCPSeqNum = binary.BigEndian.Uint32(in[hdr.csumStart+4:])
+ } else {
+ protocol = unix.IPPROTO_UDP
+ }
nextSegmentDataAt := int(hdr.hdrLen)
i := 0
for ; nextSegmentDataAt < len(in); i++ {
@@ -610,7 +933,7 @@ func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffs
out := outBuffs[i][outOffset:]
copy(out, in[:iphLen])
- if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
+ if !isV6 {
// For IPv4 we are responsible for incrementing the ID field,
// updating the total len field, and recalculating the header
// checksum.
@@ -627,25 +950,32 @@ func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffs
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
}
- // TCP header
+ // copy transport header
copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
- tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
- binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
- if nextSegmentEnd != len(in) {
- // FIN and PSH should only be set on last segment
- clearFlags := tcpFlagFIN | tcpFlagPSH
- out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
+
+ if protocol == unix.IPPROTO_TCP {
+ // set TCP seq and adjust TCP flags
+ tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
+ binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
+ if nextSegmentEnd != len(in) {
+ // FIN and PSH should only be set on last segment
+ clearFlags := tcpFlagFIN | tcpFlagPSH
+ out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
+ }
+ } else {
+ // set UDP header len
+ binary.BigEndian.PutUint16(out[hdr.csumStart+4:], uint16(segmentDataLen)+(hdr.hdrLen-hdr.csumStart))
}
// payload
copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
- // TCP checksum
- tcpHLen := int(hdr.hdrLen - hdr.csumStart)
- tcpLenForPseudo := uint16(tcpHLen + segmentDataLen)
- tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo)
- tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold)
- binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum)
+ // transport checksum
+ transportHeaderLen := int(hdr.hdrLen - hdr.csumStart)
+ lenForPseudo := uint16(transportHeaderLen + segmentDataLen)
+ transportCSumNoFold := pseudoHeaderChecksumNoFold(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo)
+ transportCSum := ^checksum(out[hdr.csumStart:totalLen], transportCSumNoFold)
+ binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], transportCSum)
nextSegmentDataAt += int(hdr.gsoSize)
}
diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go
new file mode 100644
index 0000000..ae55c8c
--- /dev/null
+++ b/tun/offload_linux_test.go
@@ -0,0 +1,752 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+import (
+ "net/netip"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/conn"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+const (
+ offset = virtioNetHdrLen
+)
+
+var (
+ ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
+ ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
+ ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
+ ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
+ ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
+ ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
+)
+
+func udp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv4Fields)) []byte {
+ totalLen := 28 + payloadLen
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv4H := header.IPv4(b[offset:])
+ srcAs4 := srcIPPort.Addr().As4()
+ dstAs4 := dstIPPort.Addr().As4()
+ ipFields := &header.IPv4Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
+ Protocol: unix.IPPROTO_UDP,
+ TTL: 64,
+ TotalLength: uint16(totalLen),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv4H.Encode(ipFields)
+ udpH := header.UDP(b[offset+20:])
+ udpH.Encode(&header.UDPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ Length: uint16(payloadLen + udphLen),
+ })
+ ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udphLen+payloadLen))
+ udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func udp6Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
+ return udp6PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
+}
+
+func udp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv6Fields)) []byte {
+ totalLen := 48 + payloadLen
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv6H := header.IPv6(b[offset:])
+ srcAs16 := srcIPPort.Addr().As16()
+ dstAs16 := dstIPPort.Addr().As16()
+ ipFields := &header.IPv6Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
+ TransportProtocol: unix.IPPROTO_UDP,
+ HopLimit: 64,
+ PayloadLength: uint16(payloadLen + udphLen),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv6H.Encode(ipFields)
+ udpH := header.UDP(b[offset+40:])
+ udpH.Encode(&header.UDPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ Length: uint16(payloadLen + udphLen),
+ })
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udphLen+payloadLen))
+ udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func udp4Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
+ return udp4PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
+}
+
+func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte {
+ totalLen := 40 + segmentSize
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv4H := header.IPv4(b[offset:])
+ srcAs4 := srcIPPort.Addr().As4()
+ dstAs4 := dstIPPort.Addr().As4()
+ ipFields := &header.IPv4Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
+ Protocol: unix.IPPROTO_TCP,
+ TTL: 64,
+ TotalLength: uint16(totalLen),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv4H.Encode(ipFields)
+ tcpH := header.TCP(b[offset+20:])
+ tcpH.Encode(&header.TCPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ SeqNum: seq,
+ AckNum: 1,
+ DataOffset: 20,
+ Flags: flags,
+ WindowSize: 3000,
+ })
+ ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
+ tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+ return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
+}
+
+func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte {
+ totalLen := 60 + segmentSize
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv6H := header.IPv6(b[offset:])
+ srcAs16 := srcIPPort.Addr().As16()
+ dstAs16 := dstIPPort.Addr().As16()
+ ipFields := &header.IPv6Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
+ TransportProtocol: unix.IPPROTO_TCP,
+ HopLimit: 64,
+ PayloadLength: uint16(segmentSize + 20),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv6H.Encode(ipFields)
+ tcpH := header.TCP(b[offset+40:])
+ tcpH.Encode(&header.TCPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ SeqNum: seq,
+ AckNum: 1,
+ DataOffset: 20,
+ Flags: flags,
+ WindowSize: 3000,
+ })
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
+ tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+ return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
+}
+
+func Test_handleVirtioRead(t *testing.T) {
+ tests := []struct {
+ name string
+ hdr virtioNetHdr
+ pktIn []byte
+ wantLens []int
+ wantErr bool
+ }{
+ {
+ "tcp4",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
+ gsoSize: 100,
+ hdrLen: 40,
+ csumStart: 20,
+ csumOffset: 16,
+ },
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
+ []int{140, 140},
+ false,
+ },
+ {
+ "tcp6",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
+ gsoSize: 100,
+ hdrLen: 60,
+ csumStart: 40,
+ csumOffset: 16,
+ },
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
+ []int{160, 160},
+ false,
+ },
+ {
+ "udp4",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
+ gsoSize: 100,
+ hdrLen: 28,
+ csumStart: 20,
+ csumOffset: 6,
+ },
+ udp4Packet(ip4PortA, ip4PortB, 200),
+ []int{128, 128},
+ false,
+ },
+ {
+ "udp6",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
+ gsoSize: 100,
+ hdrLen: 48,
+ csumStart: 40,
+ csumOffset: 6,
+ },
+ udp6Packet(ip6PortA, ip6PortB, 200),
+ []int{148, 148},
+ false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ out := make([][]byte, conn.IdealBatchSize)
+ sizes := make([]int, conn.IdealBatchSize)
+ for i := range out {
+ out[i] = make([]byte, 65535)
+ }
+ tt.hdr.encode(tt.pktIn)
+ n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
+ if err != nil {
+ if tt.wantErr {
+ return
+ }
+ t.Fatalf("got err: %v", err)
+ }
+ if n != len(tt.wantLens) {
+ t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
+ }
+ for i := range tt.wantLens {
+ if tt.wantLens[i] != sizes[i] {
+ t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
+ }
+ }
+ })
+ }
+}
+
+func flipTCP4Checksum(b []byte) []byte {
+ at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
+ b[at] ^= 0xFF
+ b[at+1] ^= 0xFF
+ return b
+}
+
+func flipUDP4Checksum(b []byte) []byte {
+ at := virtioNetHdrLen + 20 + 6 // 20 byte ipv4 header; udp csum offset is 6
+ b[at] ^= 0xFF
+ b[at+1] ^= 0xFF
+ return b
+}
+
+func Fuzz_handleGRO(f *testing.F) {
+ pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
+ pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
+ pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
+ pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
+ pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
+ pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
+ pkt6 := udp4Packet(ip4PortA, ip4PortB, 100)
+ pkt7 := udp4Packet(ip4PortA, ip4PortB, 100)
+ pkt8 := udp4Packet(ip4PortA, ip4PortC, 100)
+ pkt9 := udp6Packet(ip6PortA, ip6PortB, 100)
+ pkt10 := udp6Packet(ip6PortA, ip6PortB, 100)
+ pkt11 := udp6Packet(ip6PortA, ip6PortC, 100)
+ f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, true, offset)
+ f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, canUDPGRO bool, offset int) {
+ pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11}
+ toWrite := make([]int, 0, len(pkts))
+ handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, &toWrite)
+ if len(toWrite) > len(pkts) {
+ t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
+ }
+ seenWriteI := make(map[int]bool)
+ for _, writeI := range toWrite {
+ if writeI < 0 || writeI > len(pkts)-1 {
+ t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
+ }
+ if seenWriteI[writeI] {
+ t.Errorf("duplicate toWrite value: %d", writeI)
+ }
+ seenWriteI[writeI] = true
+ }
+ })
+}
+
+func Test_handleGRO(t *testing.T) {
+ tests := []struct {
+ name string
+ pktsIn [][]byte
+ canUDPGRO bool
+ wantToWrite []int
+ wantLens []int
+ wantErr bool
+ }{
+ {
+ "multiple protocols and flows",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
+ tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ },
+ true,
+ []int{0, 1, 2, 4, 5, 7, 9},
+ []int{240, 228, 128, 140, 260, 160, 248},
+ false,
+ },
+ {
+ "multiple protocols and flows no UDP GRO",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
+ tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ },
+ false,
+ []int{0, 1, 2, 4, 5, 7, 8, 9, 10},
+ []int{240, 128, 128, 140, 260, 160, 128, 148, 148},
+ false,
+ },
+ {
+ "PSH interleaved",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1
+ },
+ true,
+ []int{0, 2, 4, 6},
+ []int{240, 240, 260, 260},
+ false,
+ },
+ {
+ "coalesceItemInvalidCSum",
+ [][]byte{
+ flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
+ flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ },
+ true,
+ []int{0, 1, 3, 4},
+ []int{140, 240, 128, 228},
+ false,
+ },
+ {
+ "out of order",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
+ },
+ true,
+ []int{0},
+ []int{340},
+ false,
+ },
+ {
+ "unequal TTL",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.TTL++
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.TTL++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "unequal ToS",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.TOS++
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.TOS++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "unequal flags more fragments set",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.Flags = 1
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.Flags = 1
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "unequal flags DF set",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.Flags = 2
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.Flags = 2
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "ipv6 unequal hop limit",
+ [][]byte{
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
+ tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
+ fields.HopLimit++
+ }),
+ udp6Packet(ip6PortA, ip6PortB, 100),
+ udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
+ fields.HopLimit++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{160, 160, 148, 148},
+ false,
+ },
+ {
+ "ipv6 unequal traffic class",
+ [][]byte{
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
+ tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
+ fields.TrafficClass++
+ }),
+ udp6Packet(ip6PortA, ip6PortB, 100),
+ udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
+ fields.TrafficClass++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{160, 160, 148, 148},
+ false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ toWrite := make([]int, 0, len(tt.pktsIn))
+ err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, &toWrite)
+ if err != nil {
+ if tt.wantErr {
+ return
+ }
+ t.Fatalf("got err: %v", err)
+ }
+ if len(toWrite) != len(tt.wantToWrite) {
+ t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
+ }
+ for i, pktI := range tt.wantToWrite {
+ if tt.wantToWrite[i] != toWrite[i] {
+ t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
+ }
+ if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
+ t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
+ }
+ }
+ })
+ }
+}
+
+func Test_packetIsGROCandidate(t *testing.T) {
+ tcp4 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
+ tcp4TooShort := tcp4[:39]
+ ip4InvalidHeaderLen := make([]byte, len(tcp4))
+ copy(ip4InvalidHeaderLen, tcp4)
+ ip4InvalidHeaderLen[0] = 0x46
+ ip4InvalidProtocol := make([]byte, len(tcp4))
+ copy(ip4InvalidProtocol, tcp4)
+ ip4InvalidProtocol[9] = unix.IPPROTO_GRE
+
+ tcp6 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
+ tcp6TooShort := tcp6[:59]
+ ip6InvalidProtocol := make([]byte, len(tcp6))
+ copy(ip6InvalidProtocol, tcp6)
+ ip6InvalidProtocol[6] = unix.IPPROTO_GRE
+
+ udp4 := udp4Packet(ip4PortA, ip4PortB, 100)[virtioNetHdrLen:]
+ udp4TooShort := udp4[:27]
+
+ udp6 := udp6Packet(ip6PortA, ip6PortB, 100)[virtioNetHdrLen:]
+ udp6TooShort := udp6[:47]
+
+ tests := []struct {
+ name string
+ b []byte
+ canUDPGRO bool
+ want groCandidateType
+ }{
+ {
+ "tcp4",
+ tcp4,
+ true,
+ tcp4GROCandidate,
+ },
+ {
+ "tcp6",
+ tcp6,
+ true,
+ tcp6GROCandidate,
+ },
+ {
+ "udp4",
+ udp4,
+ true,
+ udp4GROCandidate,
+ },
+ {
+ "udp4 no support",
+ udp4,
+ false,
+ notGROCandidate,
+ },
+ {
+ "udp6",
+ udp6,
+ true,
+ udp6GROCandidate,
+ },
+ {
+ "udp6 no support",
+ udp6,
+ false,
+ notGROCandidate,
+ },
+ {
+ "udp4 too short",
+ udp4TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "udp6 too short",
+ udp6TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "tcp4 too short",
+ tcp4TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "tcp6 too short",
+ tcp6TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "invalid IP version",
+ []byte{0x00},
+ true,
+ notGROCandidate,
+ },
+ {
+ "invalid IP header len",
+ ip4InvalidHeaderLen,
+ true,
+ notGROCandidate,
+ },
+ {
+ "ip4 invalid protocol",
+ ip4InvalidProtocol,
+ true,
+ notGROCandidate,
+ },
+ {
+ "ip6 invalid protocol",
+ ip6InvalidProtocol,
+ true,
+ notGROCandidate,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := packetIsGROCandidate(tt.b, tt.canUDPGRO); got != tt.want {
+ t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func Test_udpPacketsCanCoalesce(t *testing.T) {
+ udp4a := udp4Packet(ip4PortA, ip4PortB, 100)
+ udp4b := udp4Packet(ip4PortA, ip4PortB, 100)
+ udp4c := udp4Packet(ip4PortA, ip4PortB, 110)
+
+ type args struct {
+ pkt []byte
+ iphLen uint8
+ gsoSize uint16
+ item udpGROItem
+ bufs [][]byte
+ bufsOffset int
+ }
+ tests := []struct {
+ name string
+ args args
+ want canCoalesce
+ }{
+ {
+ "coalesceAppend equal gso",
+ args{
+ pkt: udp4a[offset:],
+ iphLen: 20,
+ gsoSize: 100,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4a,
+ udp4b,
+ },
+ bufsOffset: offset,
+ },
+ coalesceAppend,
+ },
+ {
+ "coalesceAppend smaller gso",
+ args{
+ pkt: udp4a[offset : len(udp4a)-90],
+ iphLen: 20,
+ gsoSize: 10,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4a,
+ udp4b,
+ },
+ bufsOffset: offset,
+ },
+ coalesceAppend,
+ },
+ {
+ "coalesceUnavailable smaller gso previously appended",
+ args{
+ pkt: udp4a[offset:],
+ iphLen: 20,
+ gsoSize: 100,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4c,
+ udp4b,
+ },
+ bufsOffset: offset,
+ },
+ coalesceUnavailable,
+ },
+ {
+ "coalesceUnavailable larger following smaller",
+ args{
+ pkt: udp4c[offset:],
+ iphLen: 20,
+ gsoSize: 110,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4a,
+ udp4c,
+ },
+ bufsOffset: offset,
+ },
+ coalesceUnavailable,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want {
+ t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/tun/tcp_offload_linux_test.go b/tun/tcp_offload_linux_test.go
deleted file mode 100644
index ddddc48..0000000
--- a/tun/tcp_offload_linux_test.go
+++ /dev/null
@@ -1,411 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
- */
-
-package tun
-
-import (
- "net/netip"
- "testing"
-
- "golang.org/x/sys/unix"
- "golang.zx2c4.com/wireguard/conn"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/header"
-)
-
-const (
- offset = virtioNetHdrLen
-)
-
-var (
- ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
- ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
- ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
- ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
- ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
- ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
-)
-
-func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte {
- totalLen := 40 + segmentSize
- b := make([]byte, offset+int(totalLen), 65535)
- ipv4H := header.IPv4(b[offset:])
- srcAs4 := srcIPPort.Addr().As4()
- dstAs4 := dstIPPort.Addr().As4()
- ipFields := &header.IPv4Fields{
- SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
- DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
- Protocol: unix.IPPROTO_TCP,
- TTL: 64,
- TotalLength: uint16(totalLen),
- }
- if ipFn != nil {
- ipFn(ipFields)
- }
- ipv4H.Encode(ipFields)
- tcpH := header.TCP(b[offset+20:])
- tcpH.Encode(&header.TCPFields{
- SrcPort: srcIPPort.Port(),
- DstPort: dstIPPort.Port(),
- SeqNum: seq,
- AckNum: 1,
- DataOffset: 20,
- Flags: flags,
- WindowSize: 3000,
- })
- ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
- pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
- tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
- return b
-}
-
-func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
- return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
-}
-
-func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte {
- totalLen := 60 + segmentSize
- b := make([]byte, offset+int(totalLen), 65535)
- ipv6H := header.IPv6(b[offset:])
- srcAs16 := srcIPPort.Addr().As16()
- dstAs16 := dstIPPort.Addr().As16()
- ipFields := &header.IPv6Fields{
- SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
- DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
- TransportProtocol: unix.IPPROTO_TCP,
- HopLimit: 64,
- PayloadLength: uint16(segmentSize + 20),
- }
- if ipFn != nil {
- ipFn(ipFields)
- }
- ipv6H.Encode(ipFields)
- tcpH := header.TCP(b[offset+40:])
- tcpH.Encode(&header.TCPFields{
- SrcPort: srcIPPort.Port(),
- DstPort: dstIPPort.Port(),
- SeqNum: seq,
- AckNum: 1,
- DataOffset: 20,
- Flags: flags,
- WindowSize: 3000,
- })
- pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
- tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
- return b
-}
-
-func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
- return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
-}
-
-func Test_handleVirtioRead(t *testing.T) {
- tests := []struct {
- name string
- hdr virtioNetHdr
- pktIn []byte
- wantLens []int
- wantErr bool
- }{
- {
- "tcp4",
- virtioNetHdr{
- flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
- gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
- gsoSize: 100,
- hdrLen: 40,
- csumStart: 20,
- csumOffset: 16,
- },
- tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
- []int{140, 140},
- false,
- },
- {
- "tcp6",
- virtioNetHdr{
- flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
- gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
- gsoSize: 100,
- hdrLen: 60,
- csumStart: 40,
- csumOffset: 16,
- },
- tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
- []int{160, 160},
- false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- out := make([][]byte, conn.IdealBatchSize)
- sizes := make([]int, conn.IdealBatchSize)
- for i := range out {
- out[i] = make([]byte, 65535)
- }
- tt.hdr.encode(tt.pktIn)
- n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
- if err != nil {
- if tt.wantErr {
- return
- }
- t.Fatalf("got err: %v", err)
- }
- if n != len(tt.wantLens) {
- t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
- }
- for i := range tt.wantLens {
- if tt.wantLens[i] != sizes[i] {
- t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
- }
- }
- })
- }
-}
-
-func flipTCP4Checksum(b []byte) []byte {
- at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
- b[at] ^= 0xFF
- b[at+1] ^= 0xFF
- return b
-}
-
-func Fuzz_handleGRO(f *testing.F) {
- pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
- pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
- pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
- pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
- pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
- pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
- f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, offset)
- f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5 []byte, offset int) {
- pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5}
- toWrite := make([]int, 0, len(pkts))
- handleGRO(pkts, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
- if len(toWrite) > len(pkts) {
- t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
- }
- seenWriteI := make(map[int]bool)
- for _, writeI := range toWrite {
- if writeI < 0 || writeI > len(pkts)-1 {
- t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
- }
- if seenWriteI[writeI] {
- t.Errorf("duplicate toWrite value: %d", writeI)
- }
- seenWriteI[writeI] = true
- }
- })
-}
-
-func Test_handleGRO(t *testing.T) {
- tests := []struct {
- name string
- pktsIn [][]byte
- wantToWrite []int
- wantLens []int
- wantErr bool
- }{
- {
- "multiple flows",
- [][]byte{
- tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
- tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1
- tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // v4 flow 2
- tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
- tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // v6 flow 1
- tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // v6 flow 2
- },
- []int{0, 2, 3, 5},
- []int{240, 140, 260, 160},
- false,
- },
- {
- "PSH interleaved",
- [][]byte{
- tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
- tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
- tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1
- tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1
- tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
- tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
- tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1
- tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1
- },
- []int{0, 2, 4, 6},
- []int{240, 240, 260, 260},
- false,
- },
- {
- "coalesceItemInvalidCSum",
- [][]byte{
- flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
- tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
- tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
- },
- []int{0, 1},
- []int{140, 240},
- false,
- },
- {
- "out of order",
- [][]byte{
- tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
- tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100
- tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
- },
- []int{0},
- []int{340},
- false,
- },
- {
- "tcp4 unequal TTL",
- [][]byte{
- tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
- tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
- fields.TTL++
- }),
- },
- []int{0, 1},
- []int{140, 140},
- false,
- },
- {
- "tcp4 unequal ToS",
- [][]byte{
- tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
- tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
- fields.TOS++
- }),
- },
- []int{0, 1},
- []int{140, 140},
- false,
- },
- {
- "tcp4 unequal flags more fragments set",
- [][]byte{
- tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
- tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
- fields.Flags = 1
- }),
- },
- []int{0, 1},
- []int{140, 140},
- false,
- },
- {
- "tcp4 unequal flags DF set",
- [][]byte{
- tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
- tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
- fields.Flags = 2
- }),
- },
- []int{0, 1},
- []int{140, 140},
- false,
- },
- {
- "tcp6 unequal hop limit",
- [][]byte{
- tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
- tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
- fields.HopLimit++
- }),
- },
- []int{0, 1},
- []int{160, 160},
- false,
- },
- {
- "tcp6 unequal traffic class",
- [][]byte{
- tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
- tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
- fields.TrafficClass++
- }),
- },
- []int{0, 1},
- []int{160, 160},
- false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- toWrite := make([]int, 0, len(tt.pktsIn))
- err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
- if err != nil {
- if tt.wantErr {
- return
- }
- t.Fatalf("got err: %v", err)
- }
- if len(toWrite) != len(tt.wantToWrite) {
- t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
- }
- for i, pktI := range tt.wantToWrite {
- if tt.wantToWrite[i] != toWrite[i] {
- t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
- }
- if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
- t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
- }
- }
- })
- }
-}
-
-func Test_isTCP4NoIPOptions(t *testing.T) {
- valid := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
- invalidLen := valid[:39]
- invalidHeaderLen := make([]byte, len(valid))
- copy(invalidHeaderLen, valid)
- invalidHeaderLen[0] = 0x46
- invalidProtocol := make([]byte, len(valid))
- copy(invalidProtocol, valid)
- invalidProtocol[9] = unix.IPPROTO_TCP + 1
-
- tests := []struct {
- name string
- b []byte
- want bool
- }{
- {
- "valid",
- valid,
- true,
- },
- {
- "invalid length",
- invalidLen,
- false,
- },
- {
- "invalid version",
- []byte{0x00},
- false,
- },
- {
- "invalid header len",
- invalidHeaderLen,
- false,
- },
- {
- "invalid protocol",
- invalidProtocol,
- false,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := isTCP4NoIPOptions(tt.b); got != tt.want {
- t.Errorf("isTCP4NoIPOptions() = %v, want %v", got, tt.want)
- }
- })
- }
-}
diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 b/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77
deleted file mode 100644
index 5461e79..0000000
--- a/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77
+++ /dev/null
@@ -1,8 +0,0 @@
-go test fuzz v1
-[]byte("0")
-[]byte("0")
-[]byte("0")
-[]byte("0")
-[]byte("0")
-[]byte("0")
-int(34)
diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d b/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d
deleted file mode 100644
index b441819..0000000
--- a/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d
+++ /dev/null
@@ -1,8 +0,0 @@
-go test fuzz v1
-[]byte("0")
-[]byte("0")
-[]byte("0")
-[]byte("0")
-[]byte("0")
-[]byte("0")
-int(-48)
diff --git a/tun/tun_linux.go b/tun/tun_linux.go
index 12cd49f..bd69cb5 100644
--- a/tun/tun_linux.go
+++ b/tun/tun_linux.go
@@ -38,6 +38,7 @@ type NativeTun struct {
statusListenersShutdown chan struct{}
batchSize int
vnetHdr bool
+ udpGSO bool
closeOnce sync.Once
@@ -48,9 +49,10 @@ type NativeTun struct {
readOpMu sync.Mutex // readOpMu guards readBuff
readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
- writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable
- toWrite []int
- tcp4GROTable, tcp6GROTable *tcpGROTable
+ writeOpMu sync.Mutex // writeOpMu guards toWrite, tcpGROTable
+ toWrite []int
+ tcpGROTable *tcpGROTable
+ udpGROTable *udpGROTable
}
func (tun *NativeTun) File() *os.File {
@@ -333,8 +335,8 @@ func (tun *NativeTun) nameSlow() (string, error) {
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
tun.writeOpMu.Lock()
defer func() {
- tun.tcp4GROTable.reset()
- tun.tcp6GROTable.reset()
+ tun.tcpGROTable.reset()
+ tun.udpGROTable.reset()
tun.writeOpMu.Unlock()
}()
var (
@@ -343,7 +345,7 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
)
tun.toWrite = tun.toWrite[:0]
if tun.vnetHdr {
- err := handleGRO(bufs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite)
+ err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, &tun.toWrite)
if err != nil {
return 0, err
}
@@ -394,37 +396,42 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e
sizes[0] = n
return 1, nil
}
- if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
}
ipVersion := in[0] >> 4
switch ipVersion {
case 4:
- if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
}
case 6:
- if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
}
default:
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
}
- if len(in) <= int(hdr.csumStart+12) {
- return 0, errors.New("packet is too short")
- }
// Don't trust hdr.hdrLen from the kernel as it can be equal to the length
// of the entire first packet when the kernel is handling it as part of a
- // FORWARD path. Instead, parse the TCP header length and add it onto
+ // FORWARD path. Instead, parse the transport header length and add it onto
// csumStart, which is synonymous for IP header length.
- tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
- if tcpHLen < 20 || tcpHLen > 60 {
- // A TCP header must be between 20 and 60 bytes in length.
- return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+ hdr.hdrLen = hdr.csumStart + 8
+ } else {
+ if len(in) <= int(hdr.csumStart+12) {
+ return 0, errors.New("packet is too short")
+ }
+
+ tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
+ if tcpHLen < 20 || tcpHLen > 60 {
+ // A TCP header must be between 20 and 60 bytes in length.
+ return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
+ }
+ hdr.hdrLen = hdr.csumStart + tcpHLen
}
- hdr.hdrLen = hdr.csumStart + tcpHLen
if len(in) < int(hdr.hdrLen) {
return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
@@ -438,7 +445,7 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
}
- return tcpTSO(in, hdr, bufs, sizes, offset)
+ return gsoSplit(in, hdr, bufs, sizes, offset, ipVersion == 6)
}
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
@@ -497,7 +504,8 @@ func (tun *NativeTun) BatchSize() int {
const (
// TODO: support TSO with ECN bits
- tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
+ tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
+ tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
)
func (tun *NativeTun) initFromFlags(name string) error {
@@ -519,12 +527,17 @@ func (tun *NativeTun) initFromFlags(name string) error {
}
got := ifr.Uint16()
if got&unix.IFF_VNET_HDR != 0 {
- err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads)
+ // tunTCPOffloads were added in Linux v2.6. We require their support
+ // if IFF_VNET_HDR is set.
+ err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads)
if err != nil {
return
}
tun.vnetHdr = true
tun.batchSize = conn.IdealBatchSize
+ // tunUDPOffloads were added in Linux v6.2. We do not return an
+ // error if they are unsupported at runtime.
+ tun.udpGSO = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) == nil
} else {
tun.batchSize = 1
}
@@ -575,8 +588,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
events: make(chan Event, 5),
errors: make(chan error, 5),
statusListenersShutdown: make(chan struct{}),
- tcp4GROTable: newTCPGROTable(),
- tcp6GROTable: newTCPGROTable(),
+ tcpGROTable: newTCPGROTable(),
+ udpGROTable: newUDPGROTable(),
toWrite: make([]int, 0, conn.IdealBatchSize),
}
@@ -628,12 +641,12 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
}
file := os.NewFile(uintptr(fd), "/dev/tun")
tun := &NativeTun{
- tunFile: file,
- events: make(chan Event, 5),
- errors: make(chan error, 5),
- tcp4GROTable: newTCPGROTable(),
- tcp6GROTable: newTCPGROTable(),
- toWrite: make([]int, 0, conn.IdealBatchSize),
+ tunFile: file,
+ events: make(chan Event, 5),
+ errors: make(chan error, 5),
+ tcpGROTable: newTCPGROTable(),
+ udpGROTable: newUDPGROTable(),
+ toWrite: make([]int, 0, conn.IdealBatchSize),
}
name, err := tun.Name()
if err != nil {