diff options
102 files changed, 3482 insertions, 1703 deletions
@@ -56,7 +56,7 @@ $ make ## License - Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in diff --git a/conn/bind_std.go b/conn/bind_std.go index 69789b3..f5c8816 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn @@ -8,6 +8,7 @@ package conn import ( "context" "errors" + "fmt" "net" "net/netip" "runtime" @@ -29,16 +30,19 @@ var ( // methods for sending and receiving multiple datagrams per-syscall. See the // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. type StdNetBind struct { - mu sync.Mutex // protects all fields except as specified - ipv4 *net.UDPConn - ipv6 *net.UDPConn - ipv4PC *ipv4.PacketConn // will be nil on non-Linux - ipv6PC *ipv6.PacketConn // will be nil on non-Linux - - // these three fields are not guarded by mu - udpAddrPool sync.Pool - ipv4MsgsPool sync.Pool - ipv6MsgsPool sync.Pool + mu sync.Mutex // protects all fields except as specified + ipv4 *net.UDPConn + ipv6 *net.UDPConn + ipv4PC *ipv4.PacketConn // will be nil on non-Linux + ipv6PC *ipv6.PacketConn // will be nil on non-Linux + ipv4TxOffload bool + ipv4RxOffload bool + ipv6TxOffload bool + ipv6RxOffload bool + + // these two fields are not guarded by mu + udpAddrPool sync.Pool + msgsPool sync.Pool blackhole4 bool blackhole6 bool @@ -54,23 +58,14 @@ func NewStdNetBind() Bind { }, }, - ipv4MsgsPool: sync.Pool{ - New: func() any { - msgs := make([]ipv4.Message, IdealBatchSize) - for i := range msgs { - msgs[i].Buffers = make(net.Buffers, 1) - msgs[i].OOB = make([]byte, srcControlSize) - } - return &msgs - }, - }, - - ipv6MsgsPool: sync.Pool{ + msgsPool: sync.Pool{ New: func() any { + // ipv6.Message and ipv4.Message are interchangeable as they are + // both aliases for x/net/internal/socket.Message. msgs := make([]ipv6.Message, IdealBatchSize) for i := range msgs { msgs[i].Buffers = make(net.Buffers, 1) - msgs[i].OOB = make([]byte, srcControlSize) + msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize) } return &msgs }, @@ -81,11 +76,10 @@ func NewStdNetBind() Bind { type StdNetEndpoint struct { // AddrPort is the endpoint destination. netip.AddrPort - // src is the current sticky source address and interface index, if supported. - src struct { - netip.Addr - ifidx int32 - } + // src is the current sticky source address and interface index, if + // supported. Typically this is a PKTINFO structure from/for control + // messages, see unix.PKTINFO for an example. + src []byte } var ( @@ -104,21 +98,17 @@ func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { } func (e *StdNetEndpoint) ClearSrc() { - e.src.ifidx = 0 - e.src.Addr = netip.Addr{} + if e.src != nil { + // Truncate src, no need to reallocate. + e.src = e.src[:0] + } } func (e *StdNetEndpoint) DstIP() netip.Addr { return e.AddrPort.Addr() } -func (e *StdNetEndpoint) SrcIP() netip.Addr { - return e.src.Addr -} - -func (e *StdNetEndpoint) SrcIfidx() int32 { - return e.src.ifidx -} +// See control_default,linux, etc for implementations of SrcIP and SrcIfidx. func (e *StdNetEndpoint) DstToBytes() []byte { b, _ := e.AddrPort.MarshalBinary() @@ -129,10 +119,6 @@ func (e *StdNetEndpoint) DstToString() string { return e.AddrPort.String() } -func (e *StdNetEndpoint) SrcToString() string { - return e.src.Addr.String() -} - func listenNet(network string, port int) (*net.UDPConn, int, error) { conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) if err != nil { @@ -188,19 +174,21 @@ again: } var fns []ReceiveFunc if v4conn != nil { - if runtime.GOOS == "linux" { + s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn) + if runtime.GOOS == "linux" || runtime.GOOS == "android" { v4pc = ipv4.NewPacketConn(v4conn) s.ipv4PC = v4pc } - fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn)) + fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload)) s.ipv4 = v4conn } if v6conn != nil { - if runtime.GOOS == "linux" { + s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn) + if runtime.GOOS == "linux" || runtime.GOOS == "android" { v6pc = ipv6.NewPacketConn(v6conn) s.ipv6PC = v6pc } - fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn)) + fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload)) s.ipv6 = v6conn } if len(fns) == 0 { @@ -210,76 +198,101 @@ again: return fns, uint16(port), nil } -func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc { - return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) - defer s.ipv4MsgsPool.Put(msgs) - for i := range bufs { - (*msgs)[i].Buffers[0] = bufs[i] - } - var numMsgs int - if runtime.GOOS == "linux" { - numMsgs, err = pc.ReadBatch(*msgs, 0) +func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) { + for i := range *msgs { + (*msgs)[i].OOB = (*msgs)[i].OOB[:0] + (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB} + } + s.msgsPool.Put(msgs) +} + +func (s *StdNetBind) getMessages() *[]ipv6.Message { + return s.msgsPool.Get().(*[]ipv6.Message) +} + +var ( + // If compilation fails here these are no longer the same underlying type. + _ ipv6.Message = ipv4.Message{} +) + +type batchReader interface { + ReadBatch([]ipv6.Message, int) (int, error) +} + +type batchWriter interface { + WriteBatch([]ipv6.Message, int) (int, error) +} + +func (s *StdNetBind) receiveIP( + br batchReader, + conn *net.UDPConn, + rxOffload bool, + bufs [][]byte, + sizes []int, + eps []Endpoint, +) (n int, err error) { + msgs := s.getMessages() + for i := range bufs { + (*msgs)[i].Buffers[0] = bufs[i] + (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] + } + defer s.putMessages(msgs) + var numMsgs int + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + if rxOffload { + readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams) + numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) + if err != nil { + return 0, err + } + numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize) if err != nil { return 0, err } } else { - msg := &(*msgs)[0] - msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + numMsgs, err = br.ReadBatch(*msgs, 0) if err != nil { return 0, err } - numMsgs = 1 } - for i := 0; i < numMsgs; i++ { - msg := &(*msgs)[i] - sizes[i] = msg.N - addrPort := msg.Addr.(*net.UDPAddr).AddrPort() - ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation - getSrcFromControl(msg.OOB[:msg.NN], ep) - eps[i] = ep + } else { + msg := &(*msgs)[0] + msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + if err != nil { + return 0, err + } + numMsgs = 1 + } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + if sizes[i] == 0 { + continue } - return numMsgs, nil + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation + getSrcFromControl(msg.OOB[:msg.NN], ep) + eps[i] = ep } + return numMsgs, nil } -func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc { +func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) - defer s.ipv6MsgsPool.Put(msgs) - for i := range bufs { - (*msgs)[i].Buffers[0] = bufs[i] - } - var numMsgs int - if runtime.GOOS == "linux" { - numMsgs, err = pc.ReadBatch(*msgs, 0) - if err != nil { - return 0, err - } - } else { - msg := &(*msgs)[0] - msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) - if err != nil { - return 0, err - } - numMsgs = 1 - } - for i := 0; i < numMsgs; i++ { - msg := &(*msgs)[i] - sizes[i] = msg.N - addrPort := msg.Addr.(*net.UDPAddr).AddrPort() - ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation - getSrcFromControl(msg.OOB[:msg.NN], ep) - eps[i] = ep - } - return numMsgs, nil + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + } +} + +func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) } } // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and // rename the IdealBatchSize constant to BatchSize. func (s *StdNetBind) BatchSize() int { - if runtime.GOOS == "linux" { + if runtime.GOOS == "linux" || runtime.GOOS == "android" { return IdealBatchSize } return 1 @@ -302,28 +315,42 @@ func (s *StdNetBind) Close() error { } s.blackhole4 = false s.blackhole6 = false + s.ipv4TxOffload = false + s.ipv4RxOffload = false + s.ipv6TxOffload = false + s.ipv6RxOffload = false if err1 != nil { return err1 } return err2 } +type ErrUDPGSODisabled struct { + onLaddr string + RetryErr error +} + +func (e ErrUDPGSODisabled) Error() string { + return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr) +} + +func (e ErrUDPGSODisabled) Unwrap() error { + return e.RetryErr +} + func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { s.mu.Lock() blackhole := s.blackhole4 conn := s.ipv4 - var ( - pc4 *ipv4.PacketConn - pc6 *ipv6.PacketConn - ) + offload := s.ipv4TxOffload + br := batchWriter(s.ipv4PC) is6 := false if endpoint.DstIP().Is6() { blackhole = s.blackhole6 conn = s.ipv6 - pc6 = s.ipv6PC + br = s.ipv6PC is6 = true - } else { - pc4 = s.ipv4PC + offload = s.ipv6TxOffload } s.mu.Unlock() @@ -333,85 +360,185 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { if conn == nil { return syscall.EAFNOSUPPORT } + + msgs := s.getMessages() + defer s.putMessages(msgs) + ua := s.udpAddrPool.Get().(*net.UDPAddr) + defer s.udpAddrPool.Put(ua) if is6 { - return s.send6(conn, pc6, endpoint, bufs) + as16 := endpoint.DstIP().As16() + copy(ua.IP, as16[:]) + ua.IP = ua.IP[:16] + } else { + as4 := endpoint.DstIP().As4() + copy(ua.IP, as4[:]) + ua.IP = ua.IP[:4] + } + ua.Port = int(endpoint.(*StdNetEndpoint).Port()) + var ( + retried bool + err error + ) +retry: + if offload { + n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize) + err = s.send(conn, br, (*msgs)[:n]) + if err != nil && offload && errShouldDisableUDPGSO(err) { + offload = false + s.mu.Lock() + if is6 { + s.ipv6TxOffload = false + } else { + s.ipv4TxOffload = false + } + s.mu.Unlock() + retried = true + goto retry + } } else { - return s.send4(conn, pc4, endpoint, bufs) + for i := range bufs { + (*msgs)[i].Addr = ua + (*msgs)[i].Buffers[0] = bufs[i] + setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint)) + } + err = s.send(conn, br, (*msgs)[:len(bufs)]) + } + if retried { + return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err} } + return err } -func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error { - ua := s.udpAddrPool.Get().(*net.UDPAddr) - as4 := ep.DstIP().As4() - copy(ua.IP, as4[:]) - ua.IP = ua.IP[:4] - ua.Port = int(ep.(*StdNetEndpoint).Port()) - msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) - for i, buf := range bufs { - (*msgs)[i].Buffers[0] = buf - (*msgs)[i].Addr = ua - setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) - } +func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error { var ( n int err error start int ) - if runtime.GOOS == "linux" { + if runtime.GOOS == "linux" || runtime.GOOS == "android" { for { - n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) - if err != nil || n == len((*msgs)[start:len(bufs)]) { + n, err = pc.WriteBatch(msgs[start:], 0) + if err != nil || n == len(msgs[start:]) { break } start += n } } else { - for i, buf := range bufs { - _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) + for _, msg := range msgs { + _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr)) if err != nil { break } } } - s.udpAddrPool.Put(ua) - s.ipv4MsgsPool.Put(msgs) return err } -func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error { - ua := s.udpAddrPool.Get().(*net.UDPAddr) - as16 := ep.DstIP().As16() - copy(ua.IP, as16[:]) - ua.IP = ua.IP[:16] - ua.Port = int(ep.(*StdNetEndpoint).Port()) - msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) - for i, buf := range bufs { - (*msgs)[i].Buffers[0] = buf - (*msgs)[i].Addr = ua - setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) - } +const ( + // Exceeding these values results in EMSGSIZE. They account for layer3 and + // layer4 headers. IPv6 does not need to account for itself as the payload + // length field is self excluding. + maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 + maxIPv6PayloadLen = 1<<16 - 1 - 8 + + // This is a hard limit imposed by the kernel. + udpSegmentMaxDatagrams = 64 +) + +type setGSOFunc func(control *[]byte, gsoSize uint16) + +func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int { var ( - n int - err error - start int + base = -1 // index of msg we are currently coalescing into + gsoSize int // segmentation size of msgs[base] + dgramCnt int // number of dgrams coalesced into msgs[base] + endBatch bool // tracking flag to start a new batch on next iteration of bufs ) - if runtime.GOOS == "linux" { - for { - n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) - if err != nil || n == len((*msgs)[start:len(bufs)]) { - break + maxPayloadLen := maxIPv4PayloadLen + if ep.DstIP().Is6() { + maxPayloadLen = maxIPv6PayloadLen + } + for i, buf := range bufs { + if i > 0 { + msgLen := len(buf) + baseLenBefore := len(msgs[base].Buffers[0]) + freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore + if msgLen+baseLenBefore <= maxPayloadLen && + msgLen <= gsoSize && + msgLen <= freeBaseCap && + dgramCnt < udpSegmentMaxDatagrams && + !endBatch { + msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...) + if i == len(bufs)-1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) + } + dgramCnt++ + if msgLen < gsoSize { + // A smaller than gsoSize packet on the tail is legal, but + // it must end the batch. + endBatch = true + } + continue } - start += n } - } else { - for i, buf := range bufs { - _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) - if err != nil { - break + if dgramCnt > 1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) + } + // Reset prior to incrementing base since we are preparing to start a + // new potential batch. + endBatch = false + base++ + gsoSize = len(buf) + setSrcControl(&msgs[base].OOB, ep) + msgs[base].Buffers[0] = buf + msgs[base].Addr = addr + dgramCnt = 1 + } + return base + 1 +} + +type getGSOFunc func(control []byte) (int, error) + +func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) { + for i := firstMsgAt; i < len(msgs); i++ { + msg := &msgs[i] + if msg.N == 0 { + return n, err + } + var ( + gsoSize int + start int + end = msg.N + numToSplit = 1 + ) + gsoSize, err = getGSO(msg.OOB[:msg.NN]) + if err != nil { + return n, err + } + if gsoSize > 0 { + numToSplit = (msg.N + gsoSize - 1) / gsoSize + end = gsoSize + } + for j := 0; j < numToSplit; j++ { + if n > i { + return n, errors.New("splitting coalesced packet resulted in overflow") } + copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) + msgs[n].N = copied + msgs[n].Addr = msg.Addr + start = end + end += gsoSize + if end > msg.N { + end = msg.N + } + n++ + } + if i != n-1 { + // It is legal for bytes to move within msg.Buffers[0] as a result + // of splitting, so we only zero the source msg len when it is not + // the destination of the last split operation above. + msg.N = 0 } } - s.udpAddrPool.Put(ua) - s.ipv6MsgsPool.Put(msgs) - return err + return n, nil } diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go index 1e46776..34a3c9a 100644 --- a/conn/bind_std_test.go +++ b/conn/bind_std_test.go @@ -1,6 +1,12 @@ package conn -import "testing" +import ( + "encoding/binary" + "net" + "testing" + + "golang.org/x/net/ipv6" +) func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { bind := NewStdNetBind().(*StdNetBind) @@ -20,3 +26,225 @@ func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { fn(bufs, sizes, eps) } } + +func mockSetGSOSize(control *[]byte, gsoSize uint16) { + *control = (*control)[:cap(*control)] + binary.LittleEndian.PutUint16(*control, gsoSize) +} + +func Test_coalesceMessages(t *testing.T) { + cases := []struct { + name string + buffs [][]byte + wantLens []int + wantGSO []int + }{ + { + name: "one message no coalesce", + buffs: [][]byte{ + make([]byte, 1, 1), + }, + wantLens: []int{1}, + wantGSO: []int{0}, + }, + { + name: "two messages equal len coalesce", + buffs: [][]byte{ + make([]byte, 1, 2), + make([]byte, 1, 1), + }, + wantLens: []int{2}, + wantGSO: []int{1}, + }, + { + name: "two messages unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + }, + wantLens: []int{3}, + wantGSO: []int{2}, + }, + { + name: "three messages second unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + make([]byte, 2, 2), + }, + wantLens: []int{3, 2}, + wantGSO: []int{2, 0}, + }, + { + name: "three messages limited cap coalesce", + buffs: [][]byte{ + make([]byte, 2, 4), + make([]byte, 2, 2), + make([]byte, 2, 2), + }, + wantLens: []int{4, 2}, + wantGSO: []int{2, 0}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + addr := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1").To4(), + Port: 1, + } + msgs := make([]ipv6.Message, len(tt.buffs)) + for i := range msgs { + msgs[i].Buffers = make([][]byte, 1) + msgs[i].OOB = make([]byte, 0, 2) + } + got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize) + if got != len(tt.wantLens) { + t.Fatalf("got len %d want: %d", got, len(tt.wantLens)) + } + for i := 0; i < got; i++ { + if msgs[i].Addr != addr { + t.Errorf("msgs[%d].Addr != passed addr", i) + } + gotLen := len(msgs[i].Buffers[0]) + if gotLen != tt.wantLens[i] { + t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i]) + } + gotGSO, err := mockGetGSOSize(msgs[i].OOB) + if err != nil { + t.Fatalf("msgs[%d] getGSOSize err: %v", i, err) + } + if gotGSO != tt.wantGSO[i] { + t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i]) + } + } + }) + } +} + +func mockGetGSOSize(control []byte) (int, error) { + if len(control) < 2 { + return 0, nil + } + return int(binary.LittleEndian.Uint16(control)), nil +} + +func Test_splitCoalescedMessages(t *testing.T) { + newMsg := func(n, gso int) ipv6.Message { + msg := ipv6.Message{ + Buffers: [][]byte{make([]byte, 1<<16-1)}, + N: n, + OOB: make([]byte, 2), + } + binary.LittleEndian.PutUint16(msg.OOB, uint16(gso)) + if gso > 0 { + msg.NN = 2 + } + return msg + } + + cases := []struct { + name string + msgs []ipv6.Message + firstMsgAt int + wantNumEval int + wantMsgLens []int + wantErr bool + }{ + { + name: "second last split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(3, 1), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 3, + wantMsgLens: []int{1, 1, 1, 0}, + wantErr: false, + }, + { + name: "second last no split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 1, + wantMsgLens: []int{1, 0, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last no split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(1, 0), + }, + firstMsgAt: 2, + wantNumEval: 2, + wantMsgLens: []int{1, 1, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(3, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(2, 1), + newMsg(2, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last no split last split overflow", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(4, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: true, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize) + if err != nil && !tt.wantErr { + t.Fatalf("err: %v", err) + } + if got != tt.wantNumEval { + t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval) + } + for i, msg := range tt.msgs { + if msg.N != tt.wantMsgLens[i] { + t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i]) + } + } + }) + } +} diff --git a/conn/bind_windows.go b/conn/bind_windows.go index 228167e..a3b8460 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn @@ -164,7 +164,7 @@ func (e *WinRingEndpoint) DstToBytes() []byte { func (e *WinRingEndpoint) DstToString() string { switch e.family { case windows.AF_INET: - netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String() + return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String() case windows.AF_INET6: var zone string if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 { diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go index 74e7add..46e20e6 100644 --- a/conn/bindtest/bindtest.go +++ b/conn/bindtest/bindtest.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package bindtest diff --git a/conn/boundif_android.go b/conn/boundif_android.go index dd3ca5b..be69b2a 100644 --- a/conn/boundif_android.go +++ b/conn/boundif_android.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/conn.go b/conn/conn.go index a1f57d2..1304657 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ // Package conn implements WireGuard's network connections. diff --git a/conn/conn_test.go b/conn/conn_test.go index c6194ee..618d02b 100644 --- a/conn/conn_test.go +++ b/conn/conn_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/controlfns.go b/conn/controlfns.go index 4f7d90f..27421bd 100644 --- a/conn/controlfns.go +++ b/conn/controlfns.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/controlfns_linux.go b/conn/controlfns_linux.go index a2396fe..f0deefa 100644 --- a/conn/controlfns_linux.go +++ b/conn/controlfns_linux.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn @@ -13,6 +13,35 @@ import ( "golang.org/x/sys/unix" ) +// Taken from go/src/internal/syscall/unix/kernel_version_linux.go +func kernelVersion() (major, minor int) { + var uname unix.Utsname + if err := unix.Uname(&uname); err != nil { + return + } + + var ( + values [2]int + value, vi int + ) + for _, c := range uname.Release { + if '0' <= c && c <= '9' { + value = (value * 10) + int(c-'0') + } else { + // Note that we're assuming N.N.N here. + // If we see anything else, we are likely to mis-parse it. + values[vi] = value + vi++ + if vi >= len(values) { + break + } + value = 0 + } + } + + return values[0], values[1] +} + func init() { controlFns = append(controlFns, @@ -57,5 +86,24 @@ func init() { } return err }, + + // Attempt to enable UDP_GRO + func(network, address string, c syscall.RawConn) error { + // Kernels below 5.12 are missing 98184612aca0 ("net: + // udp: Add support for getsockopt(..., ..., UDP_GRO, + // ..., ...);"), which means we can't read this back + // later. We could pipe the return value through to + // the rest of the code, but UDP_GRO is kind of buggy + // anyway, so just gate this here. + major, minor := kernelVersion() + if major < 5 || (major == 5 && minor < 12) { + return nil + } + + c.Control(func(fd uintptr) { + _ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1) + }) + return nil + }, ) } diff --git a/conn/controlfns_unix.go b/conn/controlfns_unix.go index c4536d4..b2e7570 100644 --- a/conn/controlfns_unix.go +++ b/conn/controlfns_unix.go @@ -1,8 +1,8 @@ -//go:build !windows && !linux && !js +//go:build !windows && !linux && !wasm /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/controlfns_windows.go b/conn/controlfns_windows.go index c3bdf7d..5e38305 100644 --- a/conn/controlfns_windows.go +++ b/conn/controlfns_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/default.go b/conn/default.go index b6f761b..2ce1579 100644 --- a/conn/default.go +++ b/conn/default.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/errors_default.go b/conn/errors_default.go new file mode 100644 index 0000000..3c9b223 --- /dev/null +++ b/conn/errors_default.go @@ -0,0 +1,12 @@ +//go:build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +func errShouldDisableUDPGSO(_ error) bool { + return false +} diff --git a/conn/errors_linux.go b/conn/errors_linux.go new file mode 100644 index 0000000..037d820 --- /dev/null +++ b/conn/errors_linux.go @@ -0,0 +1,26 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "errors" + "os" + + "golang.org/x/sys/unix" +) + +func errShouldDisableUDPGSO(err error) bool { + var serr *os.SyscallError + if errors.As(err, &serr) { + // EIO is returned by udp_send_skb() if the device driver does not have + // tx checksumming enabled, which is a hard requirement of UDP_SEGMENT. + // See: + // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 + // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 + return serr.Err == unix.EIO + } + return false +} diff --git a/conn/features_default.go b/conn/features_default.go new file mode 100644 index 0000000..9fc5088 --- /dev/null +++ b/conn/features_default.go @@ -0,0 +1,15 @@ +//go:build !linux +// +build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import "net" + +func supportsUDPOffload(_ *net.UDPConn) (txOffload, rxOffload bool) { + return +} diff --git a/conn/features_linux.go b/conn/features_linux.go new file mode 100644 index 0000000..6386023 --- /dev/null +++ b/conn/features_linux.go @@ -0,0 +1,29 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "net" + + "golang.org/x/sys/unix" +) + +func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { + rc, err := conn.SyscallConn() + if err != nil { + return + } + err = rc.Control(func(fd uintptr) { + _, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT) + txOffload = errSyscall == nil + opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO) + rxOffload = errSyscall == nil && opt == 1 + }) + if err != nil { + return false, false + } + return txOffload, rxOffload +} diff --git a/conn/gso_default.go b/conn/gso_default.go new file mode 100644 index 0000000..a9a3e80 --- /dev/null +++ b/conn/gso_default.go @@ -0,0 +1,21 @@ +//go:build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. +func setGSOSize(control *[]byte, gsoSize uint16) { +} + +// gsoControlSize returns the recommended buffer size for pooling sticky and UDP +// offloading control data. +const gsoControlSize = 0 diff --git a/conn/gso_linux.go b/conn/gso_linux.go new file mode 100644 index 0000000..4ee31fa --- /dev/null +++ b/conn/gso_linux.go @@ -0,0 +1,65 @@ +//go:build linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "fmt" + "unsafe" + + "golang.org/x/sys/unix" +) + +const ( + sizeOfGSOData = 2 +) + +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + var ( + hdr unix.Cmsghdr + data []byte + rem = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) + if err != nil { + return 0, fmt.Errorf("error parsing socket control message: %w", err) + } + if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData { + var gso uint16 + copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData]) + return int(gso), nil + } + } + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing +// data in control untouched. +func setGSOSize(control *[]byte, gsoSize uint16) { + existingLen := len(*control) + avail := cap(*control) - existingLen + space := unix.CmsgSpace(sizeOfGSOData) + if avail < space { + return + } + *control = (*control)[:cap(*control)] + gsoControl := (*control)[existingLen:] + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0])) + hdr.Level = unix.SOL_UDP + hdr.Type = unix.UDP_SEGMENT + hdr.SetLen(unix.CmsgLen(sizeOfGSOData)) + copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData)) + *control = (*control)[:existingLen+space] +} + +// gsoControlSize returns the recommended buffer size for pooling UDP +// offloading control data. +var gsoControlSize = unix.CmsgSpace(sizeOfGSOData) diff --git a/conn/mark_default.go b/conn/mark_default.go index 3102384..72b266e 100644 --- a/conn/mark_default.go +++ b/conn/mark_default.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/mark_unix.go b/conn/mark_unix.go index d9e46ee..d0580d5 100644 --- a/conn/mark_unix.go +++ b/conn/mark_unix.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn diff --git a/conn/sticky_default.go b/conn/sticky_default.go index 05f00ea..15b65af 100644 --- a/conn/sticky_default.go +++ b/conn/sticky_default.go @@ -2,13 +2,28 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn -// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but -// use alternatively named flags and need ports and require testing. +import "net/netip" + +func (e *StdNetEndpoint) SrcIP() netip.Addr { + return netip.Addr{} +} + +func (e *StdNetEndpoint) SrcIfidx() int32 { + return 0 +} + +func (e *StdNetEndpoint) SrcToString() string { + return "" +} + +// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets +// {get,set}srcControl feature set, but use alternatively named flags and need +// ports and require testing. // getSrcFromControl parses the control for PKTINFO and if found updates ep with // the source information found. @@ -20,8 +35,8 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) { func setSrcControl(control *[]byte, ep *StdNetEndpoint) { } -// srcControlSize returns the recommended buffer size for pooling sticky control -// data. -const srcControlSize = 0 +// stickyControlSize returns the recommended buffer size for pooling sticky +// offloading control data. +const stickyControlSize = 0 const StdNetSupportsStickySockets = false diff --git a/conn/sticky_linux.go b/conn/sticky_linux.go index 274fa38..adfedc1 100644 --- a/conn/sticky_linux.go +++ b/conn/sticky_linux.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn @@ -14,6 +14,37 @@ import ( "golang.org/x/sys/unix" ) +func (e *StdNetEndpoint) SrcIP() netip.Addr { + switch len(e.src) { + case unix.CmsgSpace(unix.SizeofInet4Pktinfo): + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return netip.AddrFrom4(info.Spec_dst) + case unix.CmsgSpace(unix.SizeofInet6Pktinfo): + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + // TODO: set zone. in order to do so we need to check if the address is + // link local, and if it is perform a syscall to turn the ifindex into a + // zone string because netip uses string zones. + return netip.AddrFrom16(info.Addr) + } + return netip.Addr{} +} + +func (e *StdNetEndpoint) SrcIfidx() int32 { + switch len(e.src) { + case unix.CmsgSpace(unix.SizeofInet4Pktinfo): + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return info.Ifindex + case unix.CmsgSpace(unix.SizeofInet6Pktinfo): + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return int32(info.Ifindex) + } + return 0 +} + +func (e *StdNetEndpoint) SrcToString() string { + return e.SrcIP().String() +} + // getSrcFromControl parses the control for PKTINFO and if found updates ep with // the source information found. func getSrcFromControl(control []byte, ep *StdNetEndpoint) { @@ -35,83 +66,47 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) { if hdr.Level == unix.IPPROTO_IP && hdr.Type == unix.IP_PKTINFO { - info := pktInfoFromBuf[unix.Inet4Pktinfo](data) - ep.src.Addr = netip.AddrFrom4(info.Spec_dst) - ep.src.ifidx = info.Ifindex + if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) { + ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + } + ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)] + hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) + copy(ep.src, hdrBuf) + copy(ep.src[unix.CmsgLen(0):], data) return } if hdr.Level == unix.IPPROTO_IPV6 && hdr.Type == unix.IPV6_PKTINFO { - info := pktInfoFromBuf[unix.Inet6Pktinfo](data) - ep.src.Addr = netip.AddrFrom16(info.Addr) - ep.src.ifidx = int32(info.Ifindex) + if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) { + ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) + } + + ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)] + hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) + copy(ep.src, hdrBuf) + copy(ep.src[unix.CmsgLen(0):], data) return } } } -// pktInfoFromBuf returns type T populated from the provided buf via copy(). It -// panics if buf is of insufficient size. -func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) { - size := int(unsafe.Sizeof(t)) - if len(buf) < size { - panic("pktInfoFromBuf: buffer too small") - } - copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf) - return t -} - // setSrcControl sets an IP{V6}_PKTINFO in control based on the source address // and source ifindex found in ep. control's len will be set to 0 in the event // that ep is a default value. func setSrcControl(control *[]byte, ep *StdNetEndpoint) { - *control = (*control)[:cap(*control)] - if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) { - *control = (*control)[:0] + if cap(*control) < len(ep.src) { return } - - if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() { - *control = (*control)[:0] - return - } - - if len(*control) < srcControlSize { - *control = (*control)[:0] - return - } - - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0])) - if ep.SrcIP().Is4() { - hdr.Level = unix.IPPROTO_IP - hdr.Type = unix.IP_PKTINFO - hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo)) - - info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr])) - info.Ifindex = ep.src.ifidx - if ep.SrcIP().IsValid() { - info.Spec_dst = ep.SrcIP().As4() - } - *control = (*control)[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)] - } else { - hdr.Level = unix.IPPROTO_IPV6 - hdr.Type = unix.IPV6_PKTINFO - hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo)) - - info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr])) - info.Ifindex = uint32(ep.src.ifidx) - if ep.SrcIP().IsValid() { - info.Addr = ep.SrcIP().As16() - } - *control = (*control)[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)] - } - + *control = (*control)[:0] + *control = append(*control, ep.src...) } -var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) +// stickyControlSize returns the recommended buffer size for pooling sticky +// offloading control data. +var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) const StdNetSupportsStickySockets = true diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go index 0219ac3..1b1ee68 100644 --- a/conn/sticky_linux_test.go +++ b/conn/sticky_linux_test.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn @@ -18,15 +18,49 @@ import ( "golang.org/x/sys/unix" ) +func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) { + var buf []byte + if addr.Is4() { + buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + hdr := unix.Cmsghdr{ + Level: unix.IPPROTO_IP, + Type: unix.IP_PKTINFO, + } + hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo)) + copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) + + info := unix.Inet4Pktinfo{ + Ifindex: ifidx, + Spec_dst: addr.As4(), + } + copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo)) + } else { + buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) + hdr := unix.Cmsghdr{ + Level: unix.IPPROTO_IPV6, + Type: unix.IPV6_PKTINFO, + } + hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo)) + copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) + + info := unix.Inet6Pktinfo{ + Ifindex: uint32(ifidx), + Addr: addr.As16(), + } + copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo)) + } + + ep.src = buf +} + func Test_setSrcControl(t *testing.T) { t.Run("IPv4", func(t *testing.T) { ep := &StdNetEndpoint{ AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"), } - ep.src.Addr = netip.MustParseAddr("127.0.0.1") - ep.src.ifidx = 5 + setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5) - control := make([]byte, srcControlSize) + control := make([]byte, stickyControlSize) setSrcControl(&control, ep) @@ -53,10 +87,9 @@ func Test_setSrcControl(t *testing.T) { ep := &StdNetEndpoint{ AddrPort: netip.MustParseAddrPort("[::1]:1234"), } - ep.src.Addr = netip.MustParseAddr("::1") - ep.src.ifidx = 5 + setSrc(ep, netip.MustParseAddr("::1"), 5) - control := make([]byte, srcControlSize) + control := make([]byte, stickyControlSize) setSrcControl(&control, ep) @@ -80,7 +113,7 @@ func Test_setSrcControl(t *testing.T) { }) t.Run("ClearOnNoSrc", func(t *testing.T) { - control := make([]byte, srcControlSize) + control := make([]byte, stickyControlSize) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = 1 hdr.Type = 2 @@ -96,7 +129,7 @@ func Test_setSrcControl(t *testing.T) { func Test_getSrcFromControl(t *testing.T) { t.Run("IPv4", func(t *testing.T) { - control := make([]byte, srcControlSize) + control := make([]byte, stickyControlSize) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = unix.IPPROTO_IP hdr.Type = unix.IP_PKTINFO @@ -108,15 +141,15 @@ func Test_getSrcFromControl(t *testing.T) { ep := &StdNetEndpoint{} getSrcFromControl(control, ep) - if ep.src.Addr != netip.MustParseAddr("127.0.0.1") { - t.Errorf("unexpected address: %v", ep.src.Addr) + if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) } - if ep.src.ifidx != 5 { - t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) } }) t.Run("IPv6", func(t *testing.T) { - control := make([]byte, srcControlSize) + control := make([]byte, stickyControlSize) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = unix.IPPROTO_IPV6 hdr.Type = unix.IPV6_PKTINFO @@ -131,22 +164,21 @@ func Test_getSrcFromControl(t *testing.T) { if ep.SrcIP() != netip.MustParseAddr("::1") { t.Errorf("unexpected address: %v", ep.SrcIP()) } - if ep.src.ifidx != 5 { - t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) } }) t.Run("ClearOnEmpty", func(t *testing.T) { - control := make([]byte, srcControlSize) + var control []byte ep := &StdNetEndpoint{} - ep.src.Addr = netip.MustParseAddr("::1") - ep.src.ifidx = 5 + setSrc(ep, netip.MustParseAddr("::1"), 5) getSrcFromControl(control, ep) if ep.SrcIP().IsValid() { - t.Errorf("unexpected address: %v", ep.src.Addr) + t.Errorf("unexpected address: %v", ep.SrcIP()) } - if ep.src.ifidx != 0 { - t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + if ep.SrcIfidx() != 0 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) } }) t.Run("Multiple", func(t *testing.T) { @@ -154,7 +186,7 @@ func Test_getSrcFromControl(t *testing.T) { zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0])) zeroHdr.SetLen(unix.CmsgLen(0)) - control := make([]byte, srcControlSize) + control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = unix.IPPROTO_IP hdr.Type = unix.IP_PKTINFO @@ -170,11 +202,11 @@ func Test_getSrcFromControl(t *testing.T) { ep := &StdNetEndpoint{} getSrcFromControl(combined, ep) - if ep.src.Addr != netip.MustParseAddr("127.0.0.1") { - t.Errorf("unexpected address: %v", ep.src.Addr) + if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) } - if ep.src.ifidx != 5 { - t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) } }) } diff --git a/conn/winrio/rio_windows.go b/conn/winrio/rio_windows.go index d1037bb..c396658 100644 --- a/conn/winrio/rio_windows.go +++ b/conn/winrio/rio_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package winrio diff --git a/device/allowedips.go b/device/allowedips.go index fa46f97..d15373c 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -223,6 +223,60 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) } } +func (node *trieEntry) remove() { + node.removeFromPeerEntries() + node.peer = nil + if node.child[0] != nil && node.child[1] != nil { + return + } + bit := 0 + if node.child[0] == nil { + bit = 1 + } + child := node.child[bit] + if child != nil { + child.parent = node.parent + } + *node.parent.parentBit = child + if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { + node.zeroizePointers() + return + } + parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) + if parent.peer != nil { + node.zeroizePointers() + return + } + child = parent.child[node.parent.parentBitType^1] + if child != nil { + child.parent = parent.parent + } + *parent.parent.parentBit = child + node.zeroizePointers() + parent.zeroizePointers() +} + +func (table *AllowedIPs) Remove(prefix netip.Prefix, peer *Peer) { + table.mutex.Lock() + defer table.mutex.Unlock() + var node *trieEntry + var exact bool + + if prefix.Addr().Is6() { + ip := prefix.Addr().As16() + node, exact = table.IPv6.nodePlacement(ip[:], uint8(prefix.Bits())) + } else if prefix.Addr().Is4() { + ip := prefix.Addr().As4() + node, exact = table.IPv4.nodePlacement(ip[:], uint8(prefix.Bits())) + } else { + panic(errors.New("removing unknown address type")) + } + if !exact || node == nil || peer != node.peer { + return + } + node.remove() +} + func (table *AllowedIPs) RemoveByPeer(peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() @@ -230,38 +284,7 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) { var next *list.Element for elem := peer.trieEntries.Front(); elem != nil; elem = next { next = elem.Next() - node := elem.Value.(*trieEntry) - - node.removeFromPeerEntries() - node.peer = nil - if node.child[0] != nil && node.child[1] != nil { - continue - } - bit := 0 - if node.child[0] == nil { - bit = 1 - } - child := node.child[bit] - if child != nil { - child.parent = node.parent - } - *node.parent.parentBit = child - if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { - node.zeroizePointers() - continue - } - parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) - if parent.peer != nil { - node.zeroizePointers() - continue - } - child = parent.child[node.parent.parentBitType^1] - if child != nil { - child.parent = parent.parent - } - *parent.parent.parentBit = child - node.zeroizePointers() - parent.zeroizePointers() + elem.Value.(*trieEntry).remove() } } diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go index 07065c3..b863696 100644 --- a/device/allowedips_rand_test.go +++ b/device/allowedips_rand_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -83,7 +83,7 @@ func TestTrieRandom(t *testing.T) { var peers []*Peer var allowedIPs AllowedIPs - rand.Seed(1) + rng := rand.New(rand.NewSource(1)) for n := 0; n < NumberOfPeers; n++ { peers = append(peers, &Peer{}) @@ -91,14 +91,14 @@ func TestTrieRandom(t *testing.T) { for n := 0; n < NumberOfAddresses; n++ { var addr4 [4]byte - rand.Read(addr4[:]) + rng.Read(addr4[:]) cidr := uint8(rand.Intn(32) + 1) index := rand.Intn(NumberOfPeers) allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index]) slow4 = slow4.Insert(addr4[:], cidr, peers[index]) var addr6 [16]byte - rand.Read(addr6[:]) + rng.Read(addr6[:]) cidr = uint8(rand.Intn(128) + 1) index = rand.Intn(NumberOfPeers) allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index]) @@ -109,7 +109,7 @@ func TestTrieRandom(t *testing.T) { for p = 0; ; p++ { for n := 0; n < NumberOfTests; n++ { var addr4 [4]byte - rand.Read(addr4[:]) + rng.Read(addr4[:]) peer1 := slow4.Lookup(addr4[:]) peer2 := allowedIPs.Lookup(addr4[:]) if peer1 != peer2 { @@ -117,7 +117,7 @@ func TestTrieRandom(t *testing.T) { } var addr6 [16]byte - rand.Read(addr6[:]) + rng.Read(addr6[:]) peer1 = slow6.Lookup(addr6[:]) peer2 = allowedIPs.Lookup(addr6[:]) if peer1 != peer2 { diff --git a/device/allowedips_test.go b/device/allowedips_test.go index cde068e..a4b08a3 100644 --- a/device/allowedips_test.go +++ b/device/allowedips_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -39,12 +39,12 @@ func TestCommonBits(t *testing.T) { } } -func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) { +func benchmarkTrie(peerNumber, addressNumber, _ int, b *testing.B) { var trie *trieEntry var peers []*Peer root := parentIndirection{&trie, 2} - rand.Seed(1) + rng := rand.New(rand.NewSource(1)) const AddressLength = 4 @@ -54,15 +54,15 @@ func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) { for n := 0; n < addressNumber; n++ { var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint8(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % peerNumber + rng.Read(addr[:]) + cidr := uint8(rng.Uint32() % (AddressLength * 8)) + index := rng.Int() % peerNumber root.insert(addr[:], cidr, peers[index]) } for n := 0; n < b.N; n++ { var addr [AddressLength]byte - rand.Read(addr[:]) + rng.Read(addr[:]) trie.lookup(addr[:]) } } @@ -101,6 +101,10 @@ func TestTrieIPv4(t *testing.T) { allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) } + remove := func(peer *Peer, a, b, c, d byte, cidr uint8) { + allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) + } + assertEQ := func(peer *Peer, a, b, c, d byte) { p := allowedIPs.Lookup([]byte{a, b, c, d}) if p != peer { @@ -176,6 +180,21 @@ func TestTrieIPv4(t *testing.T) { allowedIPs.RemoveByPeer(a) assertNEQ(a, 192, 168, 0, 1) + + insert(a, 1, 0, 0, 0, 32) + insert(a, 192, 0, 0, 0, 24) + assertEQ(a, 1, 0, 0, 0) + assertEQ(a, 192, 0, 0, 1) + remove(a, 192, 0, 0, 0, 32) + assertEQ(a, 192, 0, 0, 1) + remove(nil, 192, 0, 0, 0, 24) + assertEQ(a, 192, 0, 0, 1) + remove(b, 192, 0, 0, 0, 24) + assertEQ(a, 192, 0, 0, 1) + remove(a, 192, 0, 0, 0, 24) + assertNEQ(a, 192, 0, 0, 1) + remove(a, 1, 0, 0, 0, 32) + assertNEQ(a, 1, 0, 0, 0) } /* Test ported from kernel implementation: @@ -211,6 +230,15 @@ func TestTrieIPv6(t *testing.T) { allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) } + remove := func(peer *Peer, a, b, c, d uint32, cidr uint8) { + var addr []byte + addr = append(addr, expand(a)...) + addr = append(addr, expand(b)...) + addr = append(addr, expand(c)...) + addr = append(addr, expand(d)...) + allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) + } + assertEQ := func(peer *Peer, a, b, c, d uint32) { var addr []byte addr = append(addr, expand(a)...) @@ -223,6 +251,18 @@ func TestTrieIPv6(t *testing.T) { } } + assertNEQ := func(peer *Peer, a, b, c, d uint32) { + var addr []byte + addr = append(addr, expand(a)...) + addr = append(addr, expand(b)...) + addr = append(addr, expand(c)...) + addr = append(addr, expand(d)...) + p := allowedIPs.Lookup(addr) + if p == peer { + t.Error("Assert NEQ failed") + } + } + insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) insert(c, 0x26075300, 0x60006b00, 0, 0, 64) insert(e, 0, 0, 0, 0, 0) @@ -244,4 +284,21 @@ func TestTrieIPv6(t *testing.T) { assertEQ(h, 0x24046800, 0x40040800, 0, 0) assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010) assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef) + + insert(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + insert(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) + assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010) + remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 96) + assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + remove(nil, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + remove(b, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + assertNEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + remove(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) + assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010) + remove(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) + assertNEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010) } diff --git a/device/bind_test.go b/device/bind_test.go index 302a521..d3fa565 100644 --- a/device/bind_test.go +++ b/device/bind_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/channels.go b/device/channels.go index 039d8df..be15d1c 100644 --- a/device/channels.go +++ b/device/channels.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -19,13 +19,13 @@ import ( // call wg.Done to remove the initial reference. // When the refcount hits 0, the queue's channel is closed. type outboundQueue struct { - c chan *QueueOutboundElement + c chan *QueueOutboundElementsContainer wg sync.WaitGroup } func newOutboundQueue() *outboundQueue { q := &outboundQueue{ - c: make(chan *QueueOutboundElement, QueueOutboundSize), + c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), } q.wg.Add(1) go func() { @@ -37,13 +37,13 @@ func newOutboundQueue() *outboundQueue { // A inboundQueue is similar to an outboundQueue; see those docs. type inboundQueue struct { - c chan *QueueInboundElement + c chan *QueueInboundElementsContainer wg sync.WaitGroup } func newInboundQueue() *inboundQueue { q := &inboundQueue{ - c: make(chan *QueueInboundElement, QueueInboundSize), + c: make(chan *QueueInboundElementsContainer, QueueInboundSize), } q.wg.Add(1) go func() { @@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue { } type autodrainingInboundQueue struct { - c chan *[]*QueueInboundElement + c chan *QueueInboundElementsContainer } // newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd. @@ -81,7 +81,7 @@ type autodrainingInboundQueue struct { // some other means, such as sending a sentinel nil values. func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { q := &autodrainingInboundQueue{ - c: make(chan *[]*QueueInboundElement, QueueInboundSize), + c: make(chan *QueueInboundElementsContainer, QueueInboundSize), } runtime.SetFinalizer(q, device.flushInboundQueue) return q @@ -90,13 +90,13 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { for { select { - case elems := <-q.c: - for _, elem := range *elems { - elem.Lock() + case elemsContainer := <-q.c: + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) } - device.PutInboundElementsSlice(elems) + device.PutInboundElementsContainer(elemsContainer) default: return } @@ -104,7 +104,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { } type autodrainingOutboundQueue struct { - c chan *[]*QueueOutboundElement + c chan *QueueOutboundElementsContainer } // newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd. @@ -114,7 +114,7 @@ type autodrainingOutboundQueue struct { // All sends to the channel must be best-effort, because there may be no receivers. func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { q := &autodrainingOutboundQueue{ - c: make(chan *[]*QueueOutboundElement, QueueOutboundSize), + c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), } runtime.SetFinalizer(q, device.flushOutboundQueue) return q @@ -123,13 +123,13 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { for { select { - case elems := <-q.c: - for _, elem := range *elems { - elem.Lock() + case elemsContainer := <-q.c: + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } - device.PutOutboundElementsSlice(elems) + device.PutOutboundElementsContainer(elemsContainer) default: return } diff --git a/device/constants.go b/device/constants.go index 59854a1..41da618 100644 --- a/device/constants.go +++ b/device/constants.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/cookie.go b/device/cookie.go index 876f05d..a093c8b 100644 --- a/device/cookie.go +++ b/device/cookie.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/cookie_test.go b/device/cookie_test.go index 4f1e50a..c937290 100644 --- a/device/cookie_test.go +++ b/device/cookie_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/device.go b/device/device.go index 091c8d4..6854ed8 100644 --- a/device/device.go +++ b/device/device.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -68,11 +68,11 @@ type Device struct { cookieChecker CookieChecker pool struct { - outboundElementsSlice *WaitPool - inboundElementsSlice *WaitPool - messageBuffers *WaitPool - inboundElements *WaitPool - outboundElements *WaitPool + inboundElementsContainer *WaitPool + outboundElementsContainer *WaitPool + messageBuffers *WaitPool + inboundElements *WaitPool + outboundElements *WaitPool } queue struct { @@ -370,6 +370,8 @@ func (device *Device) RemoveAllPeers() { func (device *Device) Close() { device.state.Lock() defer device.state.Unlock() + device.ipcMutex.Lock() + defer device.ipcMutex.Unlock() if device.isClosed() { return } @@ -459,11 +461,7 @@ func (device *Device) BindSetMark(mark uint32) error { // clear cached source addresses device.peers.RLock() for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } + peer.markEndpointSrcForClearing() } device.peers.RUnlock() @@ -513,11 +511,7 @@ func (device *Device) BindUpdate() error { // clear cached source addresses device.peers.RLock() for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } + peer.markEndpointSrcForClearing() } device.peers.RUnlock() diff --git a/device/device_test.go b/device/device_test.go index fff172b..0091e20 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/endpoint_test.go b/device/endpoint_test.go index 93a4998..85482d8 100644 --- a/device/endpoint_test.go +++ b/device/endpoint_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/indextable.go b/device/indextable.go index 00ade7d..2460fa6 100644 --- a/device/indextable.go +++ b/device/indextable.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/ip.go b/device/ip.go index eaf2363..f558744 100644 --- a/device/ip.go +++ b/device/ip.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/kdf_test.go b/device/kdf_test.go index f9c76d6..325db59 100644 --- a/device/kdf_test.go +++ b/device/kdf_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/keypair.go b/device/keypair.go index e3540d7..0b72e19 100644 --- a/device/keypair.go +++ b/device/keypair.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/logger.go b/device/logger.go index 22b0df0..a2adea3 100644 --- a/device/logger.go +++ b/device/logger.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/mobilequirks.go b/device/mobilequirks.go index 4e5051d..af4be31 100644 --- a/device/mobilequirks.go +++ b/device/mobilequirks.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -11,9 +11,9 @@ func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() { device.net.brokenRoaming = true device.peers.RLock() for _, peer := range device.peers.keyMap { - peer.Lock() - peer.disableRoaming = peer.endpoint != nil - peer.Unlock() + peer.endpoint.Lock() + peer.endpoint.disableRoaming = peer.endpoint.val != nil + peer.endpoint.Unlock() } device.peers.RUnlock() } diff --git a/device/noise-helpers.go b/device/noise-helpers.go index c2f356b..35dd907 100644 --- a/device/noise-helpers.go +++ b/device/noise-helpers.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/noise-protocol.go b/device/noise-protocol.go index e8f6145..5cf1702 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -1,11 +1,12 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device import ( + "encoding/binary" "errors" "fmt" "sync" @@ -115,6 +116,98 @@ type MessageCookieReply struct { Cookie [blake2s.Size128 + poly1305.TagSize]byte } +var errMessageLengthMismatch = errors.New("message length mismatch") + +func (msg *MessageInitiation) unmarshal(b []byte) error { + if len(b) != MessageInitiationSize { + return errMessageLengthMismatch + } + + msg.Type = binary.LittleEndian.Uint32(b) + msg.Sender = binary.LittleEndian.Uint32(b[4:]) + copy(msg.Ephemeral[:], b[8:]) + copy(msg.Static[:], b[8+len(msg.Ephemeral):]) + copy(msg.Timestamp[:], b[8+len(msg.Ephemeral)+len(msg.Static):]) + copy(msg.MAC1[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):]) + copy(msg.MAC2[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp)+len(msg.MAC1):]) + + return nil +} + +func (msg *MessageInitiation) marshal(b []byte) error { + if len(b) != MessageInitiationSize { + return errMessageLengthMismatch + } + + binary.LittleEndian.PutUint32(b, msg.Type) + binary.LittleEndian.PutUint32(b[4:], msg.Sender) + copy(b[8:], msg.Ephemeral[:]) + copy(b[8+len(msg.Ephemeral):], msg.Static[:]) + copy(b[8+len(msg.Ephemeral)+len(msg.Static):], msg.Timestamp[:]) + copy(b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):], msg.MAC1[:]) + copy(b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp)+len(msg.MAC1):], msg.MAC2[:]) + + return nil +} + +func (msg *MessageResponse) unmarshal(b []byte) error { + if len(b) != MessageResponseSize { + return errMessageLengthMismatch + } + + msg.Type = binary.LittleEndian.Uint32(b) + msg.Sender = binary.LittleEndian.Uint32(b[4:]) + msg.Receiver = binary.LittleEndian.Uint32(b[8:]) + copy(msg.Ephemeral[:], b[12:]) + copy(msg.Empty[:], b[12+len(msg.Ephemeral):]) + copy(msg.MAC1[:], b[12+len(msg.Ephemeral)+len(msg.Empty):]) + copy(msg.MAC2[:], b[12+len(msg.Ephemeral)+len(msg.Empty)+len(msg.MAC1):]) + + return nil +} + +func (msg *MessageResponse) marshal(b []byte) error { + if len(b) != MessageResponseSize { + return errMessageLengthMismatch + } + + binary.LittleEndian.PutUint32(b, msg.Type) + binary.LittleEndian.PutUint32(b[4:], msg.Sender) + binary.LittleEndian.PutUint32(b[8:], msg.Receiver) + copy(b[12:], msg.Ephemeral[:]) + copy(b[12+len(msg.Ephemeral):], msg.Empty[:]) + copy(b[12+len(msg.Ephemeral)+len(msg.Empty):], msg.MAC1[:]) + copy(b[12+len(msg.Ephemeral)+len(msg.Empty)+len(msg.MAC1):], msg.MAC2[:]) + + return nil +} + +func (msg *MessageCookieReply) unmarshal(b []byte) error { + if len(b) != MessageCookieReplySize { + return errMessageLengthMismatch + } + + msg.Type = binary.LittleEndian.Uint32(b) + msg.Receiver = binary.LittleEndian.Uint32(b[4:]) + copy(msg.Nonce[:], b[8:]) + copy(msg.Cookie[:], b[8+len(msg.Nonce):]) + + return nil +} + +func (msg *MessageCookieReply) marshal(b []byte) error { + if len(b) != MessageCookieReplySize { + return errMessageLengthMismatch + } + + binary.LittleEndian.PutUint32(b, msg.Type) + binary.LittleEndian.PutUint32(b[4:], msg.Receiver) + copy(b[8:], msg.Nonce[:]) + copy(b[8+len(msg.Nonce):], msg.Cookie[:]) + + return nil +} + type Handshake struct { state handshakeState mutex sync.RWMutex diff --git a/device/noise-types.go b/device/noise-types.go index e850359..41c944e 100644 --- a/device/noise-types.go +++ b/device/noise-types.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/noise_test.go b/device/noise_test.go index 2dd5324..f0928ac 100644 --- a/device/noise_test.go +++ b/device/noise_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/peer.go b/device/peer.go index 0ac4896..ebf25f9 100644 --- a/device/peer.go +++ b/device/peer.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -17,17 +17,20 @@ import ( type Peer struct { isRunning atomic.Bool - sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer keypairs Keypairs handshake Handshake device *Device - endpoint conn.Endpoint stopping sync.WaitGroup // routines pending stop txBytes atomic.Uint64 // bytes send to peer (endpoint) rxBytes atomic.Uint64 // bytes received from peer lastHandshakeNano atomic.Int64 // nano seconds since epoch - disableRoaming bool + endpoint struct { + sync.Mutex + val conn.Endpoint + clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission + disableRoaming bool + } timers struct { retransmitHandshake *Timer @@ -45,9 +48,9 @@ type Peer struct { } queue struct { - staged chan *[]*QueueOutboundElement // staged packets before a handshake is available - outbound *autodrainingOutboundQueue // sequential ordering of udp transmission - inbound *autodrainingInboundQueue // sequential ordering of tun writing + staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available + outbound *autodrainingOutboundQueue // sequential ordering of udp transmission + inbound *autodrainingInboundQueue // sequential ordering of tun writing } cookieGenerator CookieGenerator @@ -74,14 +77,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { // create peer peer := new(Peer) - peer.Lock() - defer peer.Unlock() peer.cookieGenerator.Init(pk) peer.device = device peer.queue.outbound = newAutodrainingOutboundQueue(device) peer.queue.inbound = newAutodrainingInboundQueue(device) - peer.queue.staged = make(chan *[]*QueueOutboundElement, QueueStagedSize) + peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize) // map public key _, ok := device.peers.keyMap[pk] @@ -97,7 +98,11 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { handshake.mutex.Unlock() // reset endpoint - peer.endpoint = nil + peer.endpoint.Lock() + peer.endpoint.val = nil + peer.endpoint.disableRoaming = false + peer.endpoint.clearSrcOnTx = false + peer.endpoint.Unlock() // init timers peer.timersInit() @@ -116,14 +121,19 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error { return nil } - peer.RLock() - defer peer.RUnlock() - - if peer.endpoint == nil { + peer.endpoint.Lock() + endpoint := peer.endpoint.val + if endpoint == nil { + peer.endpoint.Unlock() return errors.New("no known endpoint for peer") } + if peer.endpoint.clearSrcOnTx { + endpoint.ClearSrc() + peer.endpoint.clearSrcOnTx = false + } + peer.endpoint.Unlock() - err := peer.device.net.bind.Send(buffers, peer.endpoint) + err := peer.device.net.bind.Send(buffers, endpoint) if err == nil { var totalLen uint64 for _, b := range buffers { @@ -267,10 +277,20 @@ func (peer *Peer) Stop() { } func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { - if peer.disableRoaming { + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + if peer.endpoint.disableRoaming { + return + } + peer.endpoint.clearSrcOnTx = false + peer.endpoint.val = endpoint +} + +func (peer *Peer) markEndpointSrcForClearing() { + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + if peer.endpoint.val == nil { return } - peer.Lock() - peer.endpoint = endpoint - peer.Unlock() + peer.endpoint.clearSrcOnTx = true } diff --git a/device/pools.go b/device/pools.go index 02a5d6a..2c18f41 100644 --- a/device/pools.go +++ b/device/pools.go @@ -1,20 +1,19 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device import ( "sync" - "sync/atomic" ) type WaitPool struct { pool sync.Pool cond sync.Cond lock sync.Mutex - count atomic.Uint32 + count uint32 // Get calls not yet Put back max uint32 } @@ -27,10 +26,10 @@ func NewWaitPool(max uint32, new func() any) *WaitPool { func (p *WaitPool) Get() any { if p.max != 0 { p.lock.Lock() - for p.count.Load() >= p.max { + for p.count >= p.max { p.cond.Wait() } - p.count.Add(1) + p.count++ p.lock.Unlock() } return p.pool.Get() @@ -41,18 +40,20 @@ func (p *WaitPool) Put(x any) { if p.max == 0 { return } - p.count.Add(^uint32(0)) + p.lock.Lock() + defer p.lock.Unlock() + p.count-- p.cond.Signal() } func (device *Device) PopulatePools() { - device.pool.outboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any { - s := make([]*QueueOutboundElement, 0, device.BatchSize()) - return &s - }) - device.pool.inboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any { + device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { s := make([]*QueueInboundElement, 0, device.BatchSize()) - return &s + return &QueueInboundElementsContainer{elems: s} + }) + device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { + s := make([]*QueueOutboundElement, 0, device.BatchSize()) + return &QueueOutboundElementsContainer{elems: s} }) device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { return new([MaxMessageSize]byte) @@ -65,28 +66,32 @@ func (device *Device) PopulatePools() { }) } -func (device *Device) GetOutboundElementsSlice() *[]*QueueOutboundElement { - return device.pool.outboundElementsSlice.Get().(*[]*QueueOutboundElement) +func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer { + c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer) + c.Mutex = sync.Mutex{} + return c } -func (device *Device) PutOutboundElementsSlice(s *[]*QueueOutboundElement) { - for i := range *s { - (*s)[i] = nil +func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) { + for i := range c.elems { + c.elems[i] = nil } - *s = (*s)[:0] - device.pool.outboundElementsSlice.Put(s) + c.elems = c.elems[:0] + device.pool.inboundElementsContainer.Put(c) } -func (device *Device) GetInboundElementsSlice() *[]*QueueInboundElement { - return device.pool.inboundElementsSlice.Get().(*[]*QueueInboundElement) +func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer { + c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer) + c.Mutex = sync.Mutex{} + return c } -func (device *Device) PutInboundElementsSlice(s *[]*QueueInboundElement) { - for i := range *s { - (*s)[i] = nil +func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) { + for i := range c.elems { + c.elems[i] = nil } - *s = (*s)[:0] - device.pool.inboundElementsSlice.Put(s) + c.elems = c.elems[:0] + device.pool.outboundElementsContainer.Put(c) } func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { diff --git a/device/pools_test.go b/device/pools_test.go index 82d7493..8381d5a 100644 --- a/device/pools_test.go +++ b/device/pools_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -32,7 +32,9 @@ func TestWaitPool(t *testing.T) { wg.Add(workers) var max atomic.Uint32 updateMax := func() { - count := p.count.Load() + p.lock.Lock() + count := p.count + p.lock.Unlock() if count > p.max { t.Errorf("count (%d) > max (%d)", count, p.max) } diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go index 3d80ead..236dea1 100644 --- a/device/queueconstants_android.go +++ b/device/queueconstants_android.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -14,6 +14,6 @@ const ( QueueOutboundSize = 1024 QueueInboundSize = 1024 QueueHandshakeSize = 1024 - MaxSegmentSize = 2200 + MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram PreallocatedBuffersPerPool = 4096 ) diff --git a/device/queueconstants_default.go b/device/queueconstants_default.go index ea763d0..b061185 100644 --- a/device/queueconstants_default.go +++ b/device/queueconstants_default.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/queueconstants_ios.go b/device/queueconstants_ios.go index acd3cec..632e29d 100644 --- a/device/queueconstants_ios.go +++ b/device/queueconstants_ios.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/queueconstants_windows.go b/device/queueconstants_windows.go index 1eee32b..9a296d6 100644 --- a/device/queueconstants_windows.go +++ b/device/queueconstants_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/race_disabled_test.go b/device/race_disabled_test.go index bb5c450..14b3284 100644 --- a/device/race_disabled_test.go +++ b/device/race_disabled_test.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/race_enabled_test.go b/device/race_enabled_test.go index 4e9daea..f1ea5cf 100644 --- a/device/race_enabled_test.go +++ b/device/race_enabled_test.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/receive.go b/device/receive.go index e24d29f..1392957 100644 --- a/device/receive.go +++ b/device/receive.go @@ -1,12 +1,11 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device import ( - "bytes" "encoding/binary" "errors" "net" @@ -27,7 +26,6 @@ type QueueHandshakeElement struct { } type QueueInboundElement struct { - sync.Mutex buffer *[MaxMessageSize]byte packet []byte counter uint64 @@ -35,6 +33,11 @@ type QueueInboundElement struct { endpoint conn.Endpoint } +type QueueInboundElementsContainer struct { + sync.Mutex + elems []*QueueInboundElement +} + // clearPointers clears elem fields that contain pointers. // This makes the garbage collector's life easier and // avoids accidentally keeping other objects around unnecessarily. @@ -87,7 +90,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive count int endpoints = make([]conn.Endpoint, maxBatchSize) deathSpiral int - elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize) + elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize) ) for i := range bufsArrs { @@ -170,15 +173,14 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive elem.keypair = keypair elem.endpoint = endpoints[i] elem.counter = 0 - elem.Mutex = sync.Mutex{} - elem.Lock() elemsForPeer, ok := elemsByPeer[peer] if !ok { - elemsForPeer = device.GetInboundElementsSlice() + elemsForPeer = device.GetInboundElementsContainer() + elemsForPeer.Lock() elemsByPeer[peer] = elemsForPeer } - *elemsForPeer = append(*elemsForPeer, elem) + elemsForPeer.elems = append(elemsForPeer.elems, elem) bufsArrs[i] = device.GetMessageBuffer() bufs[i] = bufsArrs[i][:] continue @@ -217,18 +219,16 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive default: } } - for peer, elems := range elemsByPeer { + for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { - peer.queue.inbound.c <- elems - for _, elem := range *elems { - device.queue.decryption.c <- elem - } + peer.queue.inbound.c <- elemsContainer + device.queue.decryption.c <- elemsContainer } else { - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) } - device.PutInboundElementsSlice(elems) + device.PutInboundElementsContainer(elemsContainer) } delete(elemsByPeer, peer) } @@ -241,26 +241,28 @@ func (device *Device) RoutineDecryption(id int) { defer device.log.Verbosef("Routine: decryption worker %d - stopped", id) device.log.Verbosef("Routine: decryption worker %d - started", id) - for elem := range device.queue.decryption.c { - // split message into fields - counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] - content := elem.packet[MessageTransportOffsetContent:] - - // decrypt and release to consumer - var err error - elem.counter = binary.LittleEndian.Uint64(counter) - // copy counter to nonce - binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) - elem.packet, err = elem.keypair.receive.Open( - content[:0], - nonce[:], - content, - nil, - ) - if err != nil { - elem.packet = nil + for elemsContainer := range device.queue.decryption.c { + for _, elem := range elemsContainer.elems { + // split message into fields + counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] + content := elem.packet[MessageTransportOffsetContent:] + + // decrypt and release to consumer + var err error + elem.counter = binary.LittleEndian.Uint64(counter) + // copy counter to nonce + binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) + elem.packet, err = elem.keypair.receive.Open( + content[:0], + nonce[:], + content, + nil, + ) + if err != nil { + elem.packet = nil + } } - elem.Unlock() + elemsContainer.Unlock() } } @@ -284,8 +286,7 @@ func (device *Device) RoutineHandshake(id int) { // unmarshal packet var reply MessageCookieReply - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &reply) + err := reply.unmarshal(elem.packet) if err != nil { device.log.Verbosef("Failed to decode cookie reply") goto skip @@ -350,8 +351,7 @@ func (device *Device) RoutineHandshake(id int) { // unmarshal var msg MessageInitiation - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) + err := msg.unmarshal(elem.packet) if err != nil { device.log.Errorf("Failed to decode initiation message") goto skip @@ -383,8 +383,7 @@ func (device *Device) RoutineHandshake(id int) { // unmarshal var msg MessageResponse - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) + err := msg.unmarshal(elem.packet) if err != nil { device.log.Errorf("Failed to decode response message") goto skip @@ -437,12 +436,15 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { bufs := make([][]byte, 0, maxBatchSize) - for elems := range peer.queue.inbound.c { - if elems == nil { + for elemsContainer := range peer.queue.inbound.c { + if elemsContainer == nil { return } - for _, elem := range *elems { - elem.Lock() + elemsContainer.Lock() + validTailPacket := -1 + dataPacketReceived := false + rxBytesLen := uint64(0) + for i, elem := range elemsContainer.elems { if elem.packet == nil { // decryption failed continue @@ -452,21 +454,19 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { continue } - peer.SetEndpointFromPacket(elem.endpoint) + validTailPacket = i if peer.ReceivedWithKeypair(elem.keypair) { + peer.SetEndpointFromPacket(elem.endpoint) peer.timersHandshakeComplete() peer.SendStagedPackets() } - peer.keepKeyFreshReceiving() - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketReceived() - peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize)) + rxBytesLen += uint64(len(elem.packet) + MinMessageSize) if len(elem.packet) == 0 { device.log.Verbosef("%v - Receiving keepalive packet", peer) continue } - peer.timersDataReceived() + dataPacketReceived = true switch elem.packet[0] >> 4 { case 4: @@ -509,17 +509,28 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) } + + peer.rxBytes.Add(rxBytesLen) + if validTailPacket >= 0 { + peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint) + peer.keepKeyFreshReceiving() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + } + if dataPacketReceived { + peer.timersDataReceived() + } if len(bufs) > 0 { _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent) if err != nil && !device.isClosed() { device.log.Errorf("Failed to write packets to TUN device: %v", err) } } - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) } bufs = bufs[:0] - device.PutInboundElementsSlice(elems) + device.PutInboundElementsContainer(elemsContainer) } } diff --git a/device/send.go b/device/send.go index d22bf26..ff8f7da 100644 --- a/device/send.go +++ b/device/send.go @@ -1,12 +1,11 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device import ( - "bytes" "encoding/binary" "errors" "net" @@ -17,6 +16,7 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/tun" ) @@ -45,7 +45,6 @@ import ( */ type QueueOutboundElement struct { - sync.Mutex buffer *[MaxMessageSize]byte // slice holding the packet data packet []byte // slice of "buffer" (always!) nonce uint64 // nonce for encryption @@ -53,10 +52,14 @@ type QueueOutboundElement struct { peer *Peer // related peer } +type QueueOutboundElementsContainer struct { + sync.Mutex + elems []*QueueOutboundElement +} + func (device *Device) NewOutboundElement() *QueueOutboundElement { elem := device.GetOutboundElement() elem.buffer = device.GetMessageBuffer() - elem.Mutex = sync.Mutex{} elem.nonce = 0 // keypair and peer were cleared (if necessary) by clearPointers. return elem @@ -78,15 +81,15 @@ func (elem *QueueOutboundElement) clearPointers() { func (peer *Peer) SendKeepalive() { if len(peer.queue.staged) == 0 && peer.isRunning.Load() { elem := peer.device.NewOutboundElement() - elems := peer.device.GetOutboundElementsSlice() - *elems = append(*elems, elem) + elemsContainer := peer.device.GetOutboundElementsContainer() + elemsContainer.elems = append(elemsContainer.elems, elem) select { - case peer.queue.staged <- elems: + case peer.queue.staged <- elemsContainer: peer.device.log.Verbosef("%v - Sending keepalive packet", peer) default: peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) - peer.device.PutOutboundElementsSlice(elems) + peer.device.PutOutboundElementsContainer(elemsContainer) } } peer.SendStagedPackets() @@ -120,10 +123,8 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { return err } - var buf [MessageInitiationSize]byte - writer := bytes.NewBuffer(buf[:0]) - binary.Write(writer, binary.LittleEndian, msg) - packet := writer.Bytes() + packet := make([]byte, MessageInitiationSize) + _ = msg.marshal(packet) peer.cookieGenerator.AddMacs(packet) peer.timersAnyAuthenticatedPacketTraversal() @@ -151,10 +152,8 @@ func (peer *Peer) SendHandshakeResponse() error { return err } - var buf [MessageResponseSize]byte - writer := bytes.NewBuffer(buf[:0]) - binary.Write(writer, binary.LittleEndian, response) - packet := writer.Bytes() + packet := make([]byte, MessageResponseSize) + _ = response.marshal(packet) peer.cookieGenerator.AddMacs(packet) err = peer.BeginSymmetricSession() @@ -185,11 +184,11 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) return err } - var buf [MessageCookieReplySize]byte - writer := bytes.NewBuffer(buf[:0]) - binary.Write(writer, binary.LittleEndian, reply) + packet := make([]byte, MessageCookieReplySize) + _ = reply.marshal(packet) // TODO: allocation could be avoided - device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint) + device.net.bind.Send([][]byte{packet}, initiatingElem.endpoint) + return nil } @@ -218,7 +217,7 @@ func (device *Device) RoutineReadFromTUN() { readErr error elems = make([]*QueueOutboundElement, batchSize) bufs = make([][]byte, batchSize) - elemsByPeer = make(map[*Peer]*[]*QueueOutboundElement, batchSize) + elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize) count = 0 sizes = make([]int, batchSize) offset = MessageTransportHeaderSize @@ -275,10 +274,10 @@ func (device *Device) RoutineReadFromTUN() { } elemsForPeer, ok := elemsByPeer[peer] if !ok { - elemsForPeer = device.GetOutboundElementsSlice() + elemsForPeer = device.GetOutboundElementsContainer() elemsByPeer[peer] = elemsForPeer } - *elemsForPeer = append(*elemsForPeer, elem) + elemsForPeer.elems = append(elemsForPeer.elems, elem) elems[i] = device.NewOutboundElement() bufs[i] = elems[i].buffer[:] } @@ -288,11 +287,11 @@ func (device *Device) RoutineReadFromTUN() { peer.StagePackets(elemsForPeer) peer.SendStagedPackets() } else { - for _, elem := range *elemsForPeer { + for _, elem := range elemsForPeer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } - device.PutOutboundElementsSlice(elemsForPeer) + device.PutOutboundElementsContainer(elemsForPeer) } delete(elemsByPeer, peer) } @@ -316,7 +315,7 @@ func (device *Device) RoutineReadFromTUN() { } } -func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) { +func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) { for { select { case peer.queue.staged <- elems: @@ -325,11 +324,11 @@ func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) { } select { case tooOld := <-peer.queue.staged: - for _, elem := range *tooOld { + for _, elem := range tooOld.elems { peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } - peer.device.PutOutboundElementsSlice(tooOld) + peer.device.PutOutboundElementsContainer(tooOld) default: } } @@ -348,54 +347,52 @@ top: } for { - var elemsOOO *[]*QueueOutboundElement + var elemsContainerOOO *QueueOutboundElementsContainer select { - case elems := <-peer.queue.staged: + case elemsContainer := <-peer.queue.staged: i := 0 - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { elem.peer = peer elem.nonce = keypair.sendNonce.Add(1) - 1 if elem.nonce >= RejectAfterMessages { keypair.sendNonce.Store(RejectAfterMessages) - if elemsOOO == nil { - elemsOOO = peer.device.GetOutboundElementsSlice() + if elemsContainerOOO == nil { + elemsContainerOOO = peer.device.GetOutboundElementsContainer() } - *elemsOOO = append(*elemsOOO, elem) + elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem) continue } else { - (*elems)[i] = elem + elemsContainer.elems[i] = elem i++ } elem.keypair = keypair - elem.Lock() } - *elems = (*elems)[:i] + elemsContainer.Lock() + elemsContainer.elems = elemsContainer.elems[:i] - if elemsOOO != nil { - peer.StagePackets(elemsOOO) // XXX: Out of order, but we can't front-load go chans + if elemsContainerOOO != nil { + peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans } - if len(*elems) == 0 { - peer.device.PutOutboundElementsSlice(elems) + if len(elemsContainer.elems) == 0 { + peer.device.PutOutboundElementsContainer(elemsContainer) goto top } // add to parallel and sequential queue if peer.isRunning.Load() { - peer.queue.outbound.c <- elems - for _, elem := range *elems { - peer.device.queue.encryption.c <- elem - } + peer.queue.outbound.c <- elemsContainer + peer.device.queue.encryption.c <- elemsContainer } else { - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } - peer.device.PutOutboundElementsSlice(elems) + peer.device.PutOutboundElementsContainer(elemsContainer) } - if elemsOOO != nil { + if elemsContainerOOO != nil { goto top } default: @@ -407,12 +404,12 @@ top: func (peer *Peer) FlushStagedPackets() { for { select { - case elems := <-peer.queue.staged: - for _, elem := range *elems { + case elemsContainer := <-peer.queue.staged: + for _, elem := range elemsContainer.elems { peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } - peer.device.PutOutboundElementsSlice(elems) + peer.device.PutOutboundElementsContainer(elemsContainer) default: return } @@ -446,32 +443,34 @@ func (device *Device) RoutineEncryption(id int) { defer device.log.Verbosef("Routine: encryption worker %d - stopped", id) device.log.Verbosef("Routine: encryption worker %d - started", id) - for elem := range device.queue.encryption.c { - // populate header fields - header := elem.buffer[:MessageTransportHeaderSize] + for elemsContainer := range device.queue.encryption.c { + for _, elem := range elemsContainer.elems { + // populate header fields + header := elem.buffer[:MessageTransportHeaderSize] - fieldType := header[0:4] - fieldReceiver := header[4:8] - fieldNonce := header[8:16] + fieldType := header[0:4] + fieldReceiver := header[4:8] + fieldNonce := header[8:16] - binary.LittleEndian.PutUint32(fieldType, MessageTransportType) - binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) - binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) + binary.LittleEndian.PutUint32(fieldType, MessageTransportType) + binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) + binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) - // pad content to multiple of 16 - paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) - elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) + // pad content to multiple of 16 + paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) + elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) - // encrypt content and release to consumer + // encrypt content and release to consumer - binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) - elem.packet = elem.keypair.send.Seal( - header, - nonce[:], - elem.packet, - nil, - ) - elem.Unlock() + binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) + elem.packet = elem.keypair.send.Seal( + header, + nonce[:], + elem.packet, + nil, + ) + } + elemsContainer.Unlock() } } @@ -485,9 +484,9 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { bufs := make([][]byte, 0, maxBatchSize) - for elems := range peer.queue.outbound.c { + for elemsContainer := range peer.queue.outbound.c { bufs = bufs[:0] - if elems == nil { + if elemsContainer == nil { return } if !peer.isRunning.Load() { @@ -497,16 +496,17 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { // The timers and SendBuffers code are resilient to a few stragglers. // TODO: rework peer shutdown order to ensure // that we never accidentally keep timers alive longer than necessary. - for _, elem := range *elems { - elem.Lock() + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } + device.PutOutboundElementsContainer(elemsContainer) continue } dataSent := false - for _, elem := range *elems { - elem.Lock() + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { if len(elem.packet) != MessageKeepaliveSize { dataSent = true } @@ -520,11 +520,18 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { if dataSent { peer.timersDataSent() } - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } - device.PutOutboundElementsSlice(elems) + device.PutOutboundElementsContainer(elemsContainer) + if err != nil { + var errGSO conn.ErrUDPGSODisabled + if errors.As(err, &errGSO) { + device.log.Verbosef(err.Error()) + err = errGSO.RetryErr + } + } if err != nil { device.log.Errorf("%v - Failed to send data packets: %v", peer, err) continue diff --git a/device/sticky_default.go b/device/sticky_default.go index 1038256..22e1e15 100644 --- a/device/sticky_default.go +++ b/device/sticky_default.go @@ -7,6 +7,6 @@ import ( "golang.zx2c4.com/wireguard/rwcancel" ) -func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { +func (device *Device) startRouteListener(_ conn.Bind) (*rwcancel.RWCancel, error) { return nil, nil } diff --git a/device/sticky_linux.go b/device/sticky_linux.go index f9230f8..f23ff02 100644 --- a/device/sticky_linux.go +++ b/device/sticky_linux.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. * * This implements userspace semantics of "sticky sockets", modeled after * WireGuard's kernelspace implementation. This is more or less a straight port @@ -9,7 +9,7 @@ * * Currently there is no way to achieve this within the net package: * See e.g. https://github.com/golang/go/issues/17930 - * So this code is remains platform dependent. + * So this code remains platform dependent. */ package device @@ -47,7 +47,7 @@ func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, er return netlinkCancel, nil } -func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) { +func (device *Device) routineRouteListener(_ conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) { type peerEndpointPtr struct { peer *Peer endpoint *conn.Endpoint @@ -110,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl if !ok { break } - pePtr.peer.Lock() - if &pePtr.peer.endpoint != pePtr.endpoint { - pePtr.peer.Unlock() + pePtr.peer.endpoint.Lock() + if &pePtr.peer.endpoint.val != pePtr.endpoint { + pePtr.peer.endpoint.Unlock() break } - if uint32(pePtr.peer.endpoint.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx { - pePtr.peer.Unlock() + if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx { + pePtr.peer.endpoint.Unlock() break } - pePtr.peer.endpoint.(*conn.StdNetEndpoint).ClearSrc() - pePtr.peer.Unlock() + pePtr.peer.endpoint.clearSrcOnTx = true + pePtr.peer.endpoint.Unlock() } attr = attr[attrhdr.Len:] } @@ -134,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl device.peers.RLock() i := uint32(1) for _, peer := range device.peers.keyMap { - peer.RLock() - if peer.endpoint == nil { - peer.RUnlock() + peer.endpoint.Lock() + if peer.endpoint.val == nil { + peer.endpoint.Unlock() continue } - nativeEP, _ := peer.endpoint.(*conn.StdNetEndpoint) + nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint) if nativeEP == nil { - peer.RUnlock() + peer.endpoint.Unlock() continue } if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 { - peer.RUnlock() + peer.endpoint.Unlock() break } nlmsg := struct { @@ -188,10 +188,10 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl reqPeerLock.Lock() reqPeer[i] = peerEndpointPtr{ peer: peer, - endpoint: &peer.endpoint, + endpoint: &peer.endpoint.val, } reqPeerLock.Unlock() - peer.RUnlock() + peer.endpoint.Unlock() i++ _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) if err != nil { diff --git a/device/timers.go b/device/timers.go index e28732c..32519aa 100644 --- a/device/timers.go +++ b/device/timers.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. * * This is based heavily on timers.c from the kernel implementation. */ @@ -100,11 +100,7 @@ func expiredRetransmitHandshake(peer *Peer) { peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1) /* We clear the endpoint address src address, in case this is the cause of trouble. */ - peer.Lock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - peer.Unlock() + peer.markEndpointSrcForClearing() peer.SendHandshakeInitiation(true) } @@ -123,11 +119,7 @@ func expiredSendKeepalive(peer *Peer) { func expiredNewHandshake(peer *Peer) { peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) /* We clear the endpoint address src address, in case this is the cause of trouble. */ - peer.Lock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - peer.Unlock() + peer.markEndpointSrcForClearing() peer.SendHandshakeInitiation(false) } diff --git a/device/tun.go b/device/tun.go index 2a2ace9..c85dd50 100644 --- a/device/tun.go +++ b/device/tun.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/uapi.go b/device/uapi.go index 617dcd3..cc69488 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -99,33 +99,31 @@ func (device *Device) IpcGetOperation(w io.Writer) error { for _, peer := range device.peers.keyMap { // Serialize peer state. - // Do the work in an anonymous function so that we can use defer. - func() { - peer.RLock() - defer peer.RUnlock() - - keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic)) - keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey)) - sendf("protocol_version=1") - if peer.endpoint != nil { - sendf("endpoint=%s", peer.endpoint.DstToString()) - } - - nano := peer.lastHandshakeNano.Load() - secs := nano / time.Second.Nanoseconds() - nano %= time.Second.Nanoseconds() - - sendf("last_handshake_time_sec=%d", secs) - sendf("last_handshake_time_nsec=%d", nano) - sendf("tx_bytes=%d", peer.txBytes.Load()) - sendf("rx_bytes=%d", peer.rxBytes.Load()) - sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load()) - - device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { - sendf("allowed_ip=%s", prefix.String()) - return true - }) - }() + peer.handshake.mutex.RLock() + keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic)) + keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey)) + peer.handshake.mutex.RUnlock() + sendf("protocol_version=1") + peer.endpoint.Lock() + if peer.endpoint.val != nil { + sendf("endpoint=%s", peer.endpoint.val.DstToString()) + } + peer.endpoint.Unlock() + + nano := peer.lastHandshakeNano.Load() + secs := nano / time.Second.Nanoseconds() + nano %= time.Second.Nanoseconds() + + sendf("last_handshake_time_sec=%d", secs) + sendf("last_handshake_time_nsec=%d", nano) + sendf("tx_bytes=%d", peer.txBytes.Load()) + sendf("rx_bytes=%d", peer.rxBytes.Load()) + sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load()) + + device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { + sendf("allowed_ip=%s", prefix.String()) + return true + }) } }() @@ -262,7 +260,7 @@ func (peer *ipcSetPeer) handlePostConfig() { return } if peer.created { - peer.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint != nil + peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil } if peer.device.isUp() { peer.Start() @@ -345,9 +343,9 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) } - peer.Lock() - defer peer.Unlock() - peer.endpoint = endpoint + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + peer.endpoint.val = endpoint case "persistent_keepalive_interval": device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer) @@ -373,7 +371,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error device.allowedips.RemoveByPeer(peer.Peer) case "allowed_ip": - device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer) + add := true + verb := "Adding" + if len(value) > 0 && value[0] == '-' { + add = false + verb = "Removing" + value = value[1:] + } + device.log.Verbosef("%v - UAPI: %s allowedip", peer.Peer, verb) prefix, err := netip.ParsePrefix(value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) @@ -381,7 +386,11 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error if peer.dummy { return nil } - device.allowedips.Insert(prefix, peer.Peer) + if add { + device.allowedips.Insert(prefix, peer.Peer) + } else { + device.allowedips.Remove(prefix, peer.Peer) + } case "protocol_version": if value != "1" { diff --git a/format_test.go b/format_test.go index 6f6cab7..4d02c48 100644 --- a/format_test.go +++ b/format_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package main @@ -1,16 +1,16 @@ module golang.zx2c4.com/wireguard -go 1.20 +go 1.23.1 require ( - golang.org/x/crypto v0.6.0 - golang.org/x/net v0.7.0 - golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 + golang.org/x/crypto v0.37.0 + golang.org/x/net v0.39.0 + golang.org/x/sys v0.32.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 - gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 + gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c ) require ( - github.com/google/btree v1.0.1 // indirect - golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect + github.com/google/btree v1.1.2 // indirect + golang.org/x/time v0.7.0 // indirect ) @@ -1,14 +1,14 @@ -github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= -github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= -golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= -golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= -golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= -golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 h1:260HNjMTPDya+jq5AM1zZLgG9pv9GASPAGiEEJUbRg4= -golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= +github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= +golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= +golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= +golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= +golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= +golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= -gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 h1:Wobr37noukisGxpKo5jAsLREcpj61RxrWYzD8uwveOY= -gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0/go.mod h1:Dn5idtptoW1dIos9U6A2rpebLs/MtTwFacjKb8jLdQA= +gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI= +gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g= diff --git a/ipc/uapi_bsd.go b/ipc/uapi_bsd.go index ddcaf27..fd433a5 100644 --- a/ipc/uapi_bsd.go +++ b/ipc/uapi_bsd.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package ipc diff --git a/ipc/uapi_linux.go b/ipc/uapi_linux.go index 1562a18..fddded0 100644 --- a/ipc/uapi_linux.go +++ b/ipc/uapi_linux.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package ipc diff --git a/ipc/uapi_unix.go b/ipc/uapi_unix.go index e67be26..dcce167 100644 --- a/ipc/uapi_unix.go +++ b/ipc/uapi_unix.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package ipc diff --git a/ipc/uapi_js.go b/ipc/uapi_wasm.go index 2570515..50ac091 100644 --- a/ipc/uapi_js.go +++ b/ipc/uapi_wasm.go @@ -1,11 +1,11 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package ipc -// Made up sentinel error codes for the js/wasm platform. +// Made up sentinel error codes for {js,wasip1}/wasm. const ( IpcErrorIO = 1 IpcErrorInvalid = 2 diff --git a/ipc/uapi_windows.go b/ipc/uapi_windows.go index aa023c9..86e60b0 100644 --- a/ipc/uapi_windows.go +++ b/ipc/uapi_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package ipc @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package main diff --git a/main_windows.go b/main_windows.go index a4dc46f..67036cf 100644 --- a/main_windows.go +++ b/main_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package main diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go index f7d05ef..ac69e3a 100644 --- a/ratelimiter/ratelimiter.go +++ b/ratelimiter/ratelimiter.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package ratelimiter diff --git a/ratelimiter/ratelimiter_test.go b/ratelimiter/ratelimiter_test.go index 0bfa3af..71140da 100644 --- a/ratelimiter/ratelimiter_test.go +++ b/ratelimiter/ratelimiter_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package ratelimiter diff --git a/replay/replay.go b/replay/replay.go index 8b99e23..46e224d 100644 --- a/replay/replay.go +++ b/replay/replay.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ // Package replay implements an efficient anti-replay algorithm as specified in RFC 6479. diff --git a/replay/replay_test.go b/replay/replay_test.go index 9a9e4a8..8378ec3 100644 --- a/replay/replay_test.go +++ b/replay/replay_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package replay diff --git a/rwcancel/rwcancel.go b/rwcancel/rwcancel.go index 63e1510..4372453 100644 --- a/rwcancel/rwcancel.go +++ b/rwcancel/rwcancel.go @@ -1,8 +1,8 @@ -//go:build !windows && !js +//go:build !windows && !wasm /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ // Package rwcancel implements cancelable read/write operations on @@ -64,7 +64,7 @@ func (rw *RWCancel) ReadyRead() bool { func (rw *RWCancel) ReadyWrite() bool { closeFd := int32(rw.closingReader.Fd()) - pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}} + pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLIN}} var err error for { _, err = unix.Poll(pollFds, -1) diff --git a/rwcancel/rwcancel_stub.go b/rwcancel/rwcancel_stub.go index 182940b..2a98b2b 100644 --- a/rwcancel/rwcancel_stub.go +++ b/rwcancel/rwcancel_stub.go @@ -1,4 +1,4 @@ -//go:build windows || js +//go:build windows || wasm // SPDX-License-Identifier: MIT diff --git a/tai64n/tai64n.go b/tai64n/tai64n.go index 8f10b39..e1a97a5 100644 --- a/tai64n/tai64n.go +++ b/tai64n/tai64n.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tai64n diff --git a/tai64n/tai64n_test.go b/tai64n/tai64n_test.go index c70fc1a..d0b4425 100644 --- a/tai64n/tai64n_test.go +++ b/tai64n/tai64n_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tai64n diff --git a/tun/alignment_windows_test.go b/tun/alignment_windows_test.go index 67a785e..e3252b2 100644 --- a/tun/alignment_windows_test.go +++ b/tun/alignment_windows_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tun diff --git a/tun/checksum.go b/tun/checksum.go index f4f8471..b489c56 100644 --- a/tun/checksum.go +++ b/tun/checksum.go @@ -1,26 +1,86 @@ package tun -import "encoding/binary" +import ( + "encoding/binary" + "math/bits" +) // TODO: Explore SIMD and/or other assembly optimizations. func checksumNoFold(b []byte, initial uint64) uint64 { - ac := initial - i := 0 - n := len(b) - for n >= 4 { - ac += uint64(binary.BigEndian.Uint32(b[i : i+4])) - n -= 4 - i += 4 - } - for n >= 2 { - ac += uint64(binary.BigEndian.Uint16(b[i : i+2])) - n -= 2 - i += 2 - } - if n == 1 { - ac += uint64(b[i]) << 8 - } - return ac + tmp := make([]byte, 8) + binary.NativeEndian.PutUint64(tmp, initial) + ac := binary.BigEndian.Uint64(tmp) + var carry uint64 + + for len(b) >= 128 { + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[64:72]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[72:80]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[80:88]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[88:96]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[96:104]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[104:112]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[112:120]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[120:128]), carry) + ac += carry + b = b[128:] + } + if len(b) >= 64 { + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry) + ac += carry + b = b[64:] + } + if len(b) >= 32 { + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry) + ac += carry + b = b[32:] + } + if len(b) >= 16 { + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0) + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry) + ac += carry + b = b[16:] + } + if len(b) >= 8 { + ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0) + ac += carry + b = b[8:] + } + if len(b) >= 4 { + ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint32(b[:4])), 0) + ac += carry + b = b[4:] + } + if len(b) >= 2 { + ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint16(b[:2])), 0) + ac += carry + b = b[2:] + } + if len(b) == 1 { + tmp := binary.NativeEndian.Uint16([]byte{b[0], 0}) + ac, carry = bits.Add64(ac, uint64(tmp), 0) + ac += carry + } + + binary.NativeEndian.PutUint64(tmp, ac) + return binary.BigEndian.Uint64(tmp) } func checksum(b []byte, initial uint64) uint16 { diff --git a/tun/checksum_test.go b/tun/checksum_test.go new file mode 100644 index 0000000..4ea9b8b --- /dev/null +++ b/tun/checksum_test.go @@ -0,0 +1,98 @@ +package tun + +import ( + "encoding/binary" + "fmt" + "math/rand" + "testing" + + "golang.org/x/sys/unix" +) + +func checksumRef(b []byte, initial uint16) uint16 { + ac := uint64(initial) + + for len(b) >= 2 { + ac += uint64(binary.BigEndian.Uint16(b)) + b = b[2:] + } + if len(b) == 1 { + ac += uint64(b[0]) << 8 + } + + for (ac >> 16) > 0 { + ac = (ac >> 16) + (ac & 0xffff) + } + return uint16(ac) +} + +func pseudoHeaderChecksumRefNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 { + sum := checksumRef(srcAddr, 0) + sum = checksumRef(dstAddr, sum) + sum = checksumRef([]byte{0, protocol}, sum) + tmp := make([]byte, 2) + binary.BigEndian.PutUint16(tmp, totalLen) + return checksumRef(tmp, sum) +} + +func TestChecksum(t *testing.T) { + for length := 0; length <= 9001; length++ { + buf := make([]byte, length) + rng := rand.New(rand.NewSource(1)) + rng.Read(buf) + csum := checksum(buf, 0x1234) + csumRef := checksumRef(buf, 0x1234) + if csum != csumRef { + t.Error("Expected checksum", csumRef, "got", csum) + } + } +} + +func TestPseudoHeaderChecksum(t *testing.T) { + for _, addrLen := range []int{4, 16} { + for length := 0; length <= 9001; length++ { + srcAddr := make([]byte, addrLen) + dstAddr := make([]byte, addrLen) + buf := make([]byte, length) + rng := rand.New(rand.NewSource(1)) + rng.Read(srcAddr) + rng.Read(dstAddr) + rng.Read(buf) + phSum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length)) + csum := checksum(buf, phSum) + phSumRef := pseudoHeaderChecksumRefNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length)) + csumRef := checksumRef(buf, phSumRef) + if csum != csumRef { + t.Error("Expected checksumRef", csumRef, "got", csum) + } + } + } +} + +func BenchmarkChecksum(b *testing.B) { + lengths := []int{ + 64, + 128, + 256, + 512, + 1024, + 1500, + 2048, + 4096, + 8192, + 9000, + 9001, + } + + for _, length := range lengths { + b.Run(fmt.Sprintf("%d", length), func(b *testing.B) { + buf := make([]byte, length) + rng := rand.New(rand.NewSource(1)) + rng.Read(buf) + b.ResetTimer() + for i := 0; i < b.N; i++ { + checksum(buf, 0) + } + }) + } +} diff --git a/tun/netstack/examples/http_client.go b/tun/netstack/examples/http_client.go index ccd32ed..d71267d 100644 --- a/tun/netstack/examples/http_client.go +++ b/tun/netstack/examples/http_client.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package main diff --git a/tun/netstack/examples/http_server.go b/tun/netstack/examples/http_server.go index f5b7a8f..7278851 100644 --- a/tun/netstack/examples/http_server.go +++ b/tun/netstack/examples/http_server.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package main diff --git a/tun/netstack/examples/ping_client.go b/tun/netstack/examples/ping_client.go index 2eef0fb..d1b562f 100644 --- a/tun/netstack/examples/ping_client.go +++ b/tun/netstack/examples/ping_client.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package main diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index fa15f53..a7aec9e 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package netstack @@ -25,7 +25,7 @@ import ( "golang.zx2c4.com/wireguard/tun" "golang.org/x/net/dns/dnsmessage" - "gvisor.dev/gvisor/pkg/bufferv2" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -43,7 +43,8 @@ type netTun struct { ep *channel.Endpoint stack *stack.Stack events chan tun.Event - incomingPacket chan *bufferv2.View + notifyHandle *channel.NotificationHandle + incomingPacket chan *buffer.View mtu int dnsServers []netip.Addr hasV4, hasV6 bool @@ -61,7 +62,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, ep: channel.New(1024, uint32(mtu), ""), stack: stack.New(opts), events: make(chan tun.Event, 10), - incomingPacket: make(chan *bufferv2.View), + incomingPacket: make(chan *buffer.View), dnsServers: dnsServers, mtu: mtu, } @@ -70,7 +71,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, if tcpipErr != nil { return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) } - dev.ep.AddNotify(dev) + dev.notifyHandle = dev.ep.AddNotify(dev) tcpipErr = dev.stack.CreateNIC(1, dev.ep) if tcpipErr != nil { return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) @@ -84,7 +85,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, } protoAddr := tcpip.ProtocolAddress{ Protocol: protoNumber, - AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(), + AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), } tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) if tcpipErr != nil { @@ -140,7 +141,7 @@ func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { continue } - pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)}) + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) switch packet[0] >> 4 { case 4: tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) @@ -155,7 +156,7 @@ func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { func (tun *netTun) WriteNotify() { pkt := tun.ep.Read() - if pkt.IsNil() { + if pkt == nil { return } @@ -167,13 +168,14 @@ func (tun *netTun) WriteNotify() { func (tun *netTun) Close() error { tun.stack.RemoveNIC(1) + tun.stack.Close() + tun.ep.RemoveNotify(tun.notifyHandle) + tun.ep.Close() if tun.events != nil { close(tun.events) } - tun.ep.Close() - if tun.incomingPacket != nil { close(tun.incomingPacket) } @@ -198,7 +200,7 @@ func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.Networ } return tcpip.FullAddress{ NIC: 1, - Addr: tcpip.Address(endpoint.Addr().AsSlice()), + Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), Port: endpoint.Port(), }, protoNumber } @@ -453,7 +455,7 @@ func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { return 0, nil, fmt.Errorf("ping read: %s", tcpipErr) } - remoteAddr, _ := netip.AddrFromSlice([]byte(res.RemoteAddr.Addr)) + remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice()) return res.Count, &PingAddr{remoteAddr}, nil } @@ -912,7 +914,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, } } } - // We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled + // We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled var addrs []netip.Addr if tnet.hasV6 { addrs = append(addrsV6, addrsV4...) diff --git a/tun/offload_linux.go b/tun/offload_linux.go new file mode 100644 index 0000000..5f0db06 --- /dev/null +++ b/tun/offload_linux.go @@ -0,0 +1,993 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "unsafe" + + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" +) + +const tcpFlagsOffset = 13 + +const ( + tcpFlagFIN uint8 = 0x01 + tcpFlagPSH uint8 = 0x08 + tcpFlagACK uint8 = 0x10 +) + +// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The +// kernel symbol is virtio_net_hdr. +type virtioNetHdr struct { + flags uint8 + gsoType uint8 + hdrLen uint16 + gsoSize uint16 + csumStart uint16 + csumOffset uint16 +} + +func (v *virtioNetHdr) decode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen]) + return nil +} + +func (v *virtioNetHdr) encode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen)) + return nil +} + +const ( + // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the + // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr). + virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{})) +) + +// tcpFlowKey represents the key for a TCP flow. +type tcpFlowKey struct { + srcAddr, dstAddr [16]byte + srcPort, dstPort uint16 + rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows. + isV6 bool +} + +// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO. +type tcpGROTable struct { + itemsByFlow map[tcpFlowKey][]tcpGROItem + itemsPool [][]tcpGROItem +} + +func newTCPGROTable() *tcpGROTable { + t := &tcpGROTable{ + itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, conn.IdealBatchSize), + itemsPool: make([][]tcpGROItem, conn.IdealBatchSize), + } + for i := range t.itemsPool { + t.itemsPool[i] = make([]tcpGROItem, 0, conn.IdealBatchSize) + } + return t +} + +func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey { + key := tcpFlowKey{} + addrSize := dstAddrOffset - srcAddrOffset + copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset]) + copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize]) + key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:]) + key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:]) + key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:]) + key.isV6 = addrSize == 16 + return key +} + +// lookupOrInsert looks up a flow for the provided packet and metadata, +// returning the packets found for the flow, or inserting a new one if none +// is found. +func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) { + key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + items, ok := t.itemsByFlow[key] + if ok { + return items, ok + } + // TODO: insert() performs another map lookup. This could be rearranged to avoid. + t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex) + return nil, false +} + +// insert an item in the table for the provided packet and packet metadata. +func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) { + key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + item := tcpGROItem{ + key: key, + bufsIndex: uint16(bufsIndex), + gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])), + iphLen: uint8(tcphOffset), + tcphLen: uint8(tcphLen), + sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]), + pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0, + } + items, ok := t.itemsByFlow[key] + if !ok { + items = t.newItems() + } + items = append(items, item) + t.itemsByFlow[key] = items +} + +func (t *tcpGROTable) updateAt(item tcpGROItem, i int) { + items, _ := t.itemsByFlow[item.key] + items[i] = item +} + +func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) { + items, _ := t.itemsByFlow[key] + items = append(items[:i], items[i+1:]...) + t.itemsByFlow[key] = items +} + +// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime +// of a GRO evaluation across a vector of packets. +type tcpGROItem struct { + key tcpFlowKey + sentSeq uint32 // the sequence number + bufsIndex uint16 // the index into the original bufs slice + numMerged uint16 // the number of packets merged into this item + gsoSize uint16 // payload size + iphLen uint8 // ip header len + tcphLen uint8 // tcp header len + pshSet bool // psh flag is set +} + +func (t *tcpGROTable) newItems() []tcpGROItem { + var items []tcpGROItem + items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1] + return items +} + +func (t *tcpGROTable) reset() { + for k, items := range t.itemsByFlow { + items = items[:0] + t.itemsPool = append(t.itemsPool, items) + delete(t.itemsByFlow, k) + } +} + +// udpFlowKey represents the key for a UDP flow. +type udpFlowKey struct { + srcAddr, dstAddr [16]byte + srcPort, dstPort uint16 + isV6 bool +} + +// udpGROTable holds flow and coalescing information for the purposes of UDP GRO. +type udpGROTable struct { + itemsByFlow map[udpFlowKey][]udpGROItem + itemsPool [][]udpGROItem +} + +func newUDPGROTable() *udpGROTable { + u := &udpGROTable{ + itemsByFlow: make(map[udpFlowKey][]udpGROItem, conn.IdealBatchSize), + itemsPool: make([][]udpGROItem, conn.IdealBatchSize), + } + for i := range u.itemsPool { + u.itemsPool[i] = make([]udpGROItem, 0, conn.IdealBatchSize) + } + return u +} + +func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey { + key := udpFlowKey{} + addrSize := dstAddrOffset - srcAddrOffset + copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset]) + copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize]) + key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:]) + key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:]) + key.isV6 = addrSize == 16 + return key +} + +// lookupOrInsert looks up a flow for the provided packet and metadata, +// returning the packets found for the flow, or inserting a new one if none +// is found. +func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) { + key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) + items, ok := u.itemsByFlow[key] + if ok { + return items, ok + } + // TODO: insert() performs another map lookup. This could be rearranged to avoid. + u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false) + return nil, false +} + +// insert an item in the table for the provided packet and packet metadata. +func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) { + key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) + item := udpGROItem{ + key: key, + bufsIndex: uint16(bufsIndex), + gsoSize: uint16(len(pkt[udphOffset+udphLen:])), + iphLen: uint8(udphOffset), + cSumKnownInvalid: cSumKnownInvalid, + } + items, ok := u.itemsByFlow[key] + if !ok { + items = u.newItems() + } + items = append(items, item) + u.itemsByFlow[key] = items +} + +func (u *udpGROTable) updateAt(item udpGROItem, i int) { + items, _ := u.itemsByFlow[item.key] + items[i] = item +} + +// udpGROItem represents bookkeeping data for a UDP packet during the lifetime +// of a GRO evaluation across a vector of packets. +type udpGROItem struct { + key udpFlowKey + bufsIndex uint16 // the index into the original bufs slice + numMerged uint16 // the number of packets merged into this item + gsoSize uint16 // payload size + iphLen uint8 // ip header len + cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown. +} + +func (u *udpGROTable) newItems() []udpGROItem { + var items []udpGROItem + items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1] + return items +} + +func (u *udpGROTable) reset() { + for k, items := range u.itemsByFlow { + items = items[:0] + u.itemsPool = append(u.itemsPool, items) + delete(u.itemsByFlow, k) + } +} + +// canCoalesce represents the outcome of checking if two TCP packets are +// candidates for coalescing. +type canCoalesce int + +const ( + coalescePrepend canCoalesce = -1 + coalesceUnavailable canCoalesce = 0 + coalesceAppend canCoalesce = 1 +) + +// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB +// meet all requirements to be merged as part of a GRO operation, otherwise it +// returns false. +func ipHeadersCanCoalesce(pktA, pktB []byte) bool { + if len(pktA) < 9 || len(pktB) < 9 { + return false + } + if pktA[0]>>4 == 6 { + if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 { + // cannot coalesce with unequal Traffic class values + return false + } + if pktA[7] != pktB[7] { + // cannot coalesce with unequal Hop limit values + return false + } + } else { + if pktA[1] != pktB[1] { + // cannot coalesce with unequal ToS values + return false + } + if pktA[6]>>5 != pktB[6]>>5 { + // cannot coalesce with unequal DF or reserved bits. MF is checked + // further up the stack. + return false + } + if pktA[8] != pktB[8] { + // cannot coalesce with unequal TTL values + return false + } + } + return true +} + +// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet +// described by item. iphLen and gsoSize describe pkt. bufs is the vector of +// packets involved in the current GRO evaluation. bufsOffset is the offset at +// which packet data begins within bufs. +func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { + pktTarget := bufs[item.bufsIndex][bufsOffset:] + if !ipHeadersCanCoalesce(pkt, pktTarget) { + return coalesceUnavailable + } + if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 { + // A smaller than gsoSize packet has been appended previously. + // Nothing can come after a smaller packet on the end. + return coalesceUnavailable + } + if gsoSize > item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + return coalesceAppend +} + +// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet +// described by item. This function makes considerations that match the kernel's +// GRO self tests, which can be found in tools/testing/selftests/net/gro.c. +func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { + pktTarget := bufs[item.bufsIndex][bufsOffset:] + if tcphLen != item.tcphLen { + // cannot coalesce with unequal tcp options len + return coalesceUnavailable + } + if tcphLen > 20 { + if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) { + // cannot coalesce with unequal tcp options + return coalesceUnavailable + } + } + if !ipHeadersCanCoalesce(pkt, pktTarget) { + return coalesceUnavailable + } + // seq adjacency + lhsLen := item.gsoSize + lhsLen += item.numMerged * item.gsoSize + if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective + if item.pshSet { + // We cannot append to a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 { + // A smaller than gsoSize packet has been appended previously. + // Nothing can come after a smaller packet on the end. + return coalesceUnavailable + } + if gsoSize > item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + return coalesceAppend + } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective + if pshSet { + // We cannot prepend with a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if gsoSize < item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + if gsoSize > item.gsoSize && item.numMerged > 0 { + // There's at least one previous merge, and we're larger than all + // previous. This would put multiple smaller packets on the end. + return coalesceUnavailable + } + return coalescePrepend + } + return coalesceUnavailable +} + +func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool { + srcAddrAt := ipv4SrcAddrOffset + addrSize := 4 + if isV6 { + srcAddrAt = ipv6SrcAddrOffset + addrSize = 16 + } + lenForPseudo := uint16(len(pkt) - int(iphLen)) + cSum := pseudoHeaderChecksumNoFold(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo) + return ^checksum(pkt[iphLen:], cSum) == 0 +} + +// coalesceResult represents the result of attempting to coalesce two TCP +// packets. +type coalesceResult int + +const ( + coalesceInsufficientCap coalesceResult = iota + coalescePSHEnding + coalesceItemInvalidCSum + coalescePktInvalidCSum + coalesceSuccess +) + +// coalesceUDPPackets attempts to coalesce pkt with the packet described by +// item, and returns the outcome. +func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { + pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front + headersLen := item.iphLen + udphLen + coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) + + if cap(pktHead)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if item.numMerged == 0 { + if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) { + return coalescePktInvalidCSum + } + extendBy := len(pkt) - int(headersLen) + bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) + copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) + + item.numMerged++ + return coalesceSuccess +} + +// coalesceTCPPackets attempts to coalesce pkt with the packet described by +// item, and returns the outcome. This function may swap bufs elements in the +// event of a prepend as item's bufs index is already being tracked for writing +// to a Device. +func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { + var pktHead []byte // the packet that will end up at the front + headersLen := item.iphLen + item.tcphLen + coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) + + // Copy data + if mode == coalescePrepend { + pktHead = pkt + if cap(pkt)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if pshSet { + return coalescePSHEnding + } + if item.numMerged == 0 { + if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalescePktInvalidCSum + } + item.sentSeq = seq + extendBy := coalescedLen - len(pktHead) + bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...) + copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):]) + // Flip the slice headers in bufs as part of prepend. The index of item + // is already being tracked for writing. + bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex] + } else { + pktHead = bufs[item.bufsIndex][bufsOffset:] + if cap(pktHead)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if item.numMerged == 0 { + if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalescePktInvalidCSum + } + if pshSet { + // We are appending a segment with PSH set. + item.pshSet = pshSet + pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH + } + extendBy := len(pkt) - int(headersLen) + bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) + copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) + } + + if gsoSize > item.gsoSize { + item.gsoSize = gsoSize + } + + item.numMerged++ + return coalesceSuccess +} + +const ( + ipv4FlagMoreFragments uint8 = 0x20 +) + +const ( + ipv4SrcAddrOffset = 12 + ipv6SrcAddrOffset = 8 + maxUint16 = 1<<16 - 1 +) + +type groResult int + +const ( + groResultNoop groResult = iota + groResultTableInsert + groResultCoalesced +) + +// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with +// existing packets tracked in table. It returns a groResultNoop when no +// action was taken, groResultTableInsert when the evaluated packet was +// inserted into table, and groResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult { + pkt := bufs[pktI][offset:] + if len(pkt) > maxUint16 { + // A valid IPv4 or IPv6 packet will never exceed this. + return groResultNoop + } + iphLen := int((pkt[0] & 0x0F) * 4) + if isV6 { + iphLen = 40 + ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) + if ipv6HPayloadLen != len(pkt)-iphLen { + return groResultNoop + } + } else { + totalLen := int(binary.BigEndian.Uint16(pkt[2:])) + if totalLen != len(pkt) { + return groResultNoop + } + } + if len(pkt) < iphLen { + return groResultNoop + } + tcphLen := int((pkt[iphLen+12] >> 4) * 4) + if tcphLen < 20 || tcphLen > 60 { + return groResultNoop + } + if len(pkt) < iphLen+tcphLen { + return groResultNoop + } + if !isV6 { + if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { + // no GRO support for fragmented segments for now + return groResultNoop + } + } + tcpFlags := pkt[iphLen+tcpFlagsOffset] + var pshSet bool + // not a candidate if any non-ACK flags (except PSH+ACK) are set + if tcpFlags != tcpFlagACK { + if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { + return groResultNoop + } + pshSet = true + } + gsoSize := uint16(len(pkt) - tcphLen - iphLen) + // not a candidate if payload len is 0 + if gsoSize < 1 { + return groResultNoop + } + seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) + srcAddrOffset := ipv4SrcAddrOffset + addrLen := 4 + if isV6 { + srcAddrOffset = ipv6SrcAddrOffset + addrLen = 16 + } + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + if !existing { + return groResultTableInsert + } + for i := len(items) - 1; i >= 0; i-- { + // In the best case of packets arriving in order iterating in reverse is + // more efficient if there are multiple items for a given flow. This + // also enables a natural table.deleteAt() in the + // coalesceItemInvalidCSum case without the need for index tracking. + // This algorithm makes a best effort to coalesce in the event of + // unordered packets, where pkt may land anywhere in items from a + // sequence number perspective, however once an item is inserted into + // the table it is never compared across other items later. + item := items[i] + can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset) + if can != coalesceUnavailable { + result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6) + switch result { + case coalesceSuccess: + table.updateAt(item, i) + return groResultCoalesced + case coalesceItemInvalidCSum: + // delete the item with an invalid csum + table.deleteAt(item.key, i) + case coalescePktInvalidCSum: + // no point in inserting an item that we can't coalesce + return groResultNoop + default: + } + } + } + // failed to coalesce with any other packets; store the item in the flow + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + return groResultTableInsert +} + +// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + item.tcphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 16, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + if item.key.isV6 { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Calculate the pseudo header checksum and place it at the TCP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the tcp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if item.key.isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + +// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + udphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 6, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4 + if item.key.isV6 { + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Recalculate the UDP len field value + binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:]))) + + // Calculate the pseudo header checksum and place it at the UDP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the udp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if item.key.isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + +type groCandidateType uint8 + +const ( + notGROCandidate groCandidateType = iota + tcp4GROCandidate + tcp6GROCandidate + udp4GROCandidate + udp6GROCandidate +) + +func packetIsGROCandidate(b []byte, canUDPGRO bool) groCandidateType { + if len(b) < 28 { + return notGROCandidate + } + if b[0]>>4 == 4 { + if b[0]&0x0F != 5 { + // IPv4 packets w/IP options do not coalesce + return notGROCandidate + } + if b[9] == unix.IPPROTO_TCP && len(b) >= 40 { + return tcp4GROCandidate + } + if b[9] == unix.IPPROTO_UDP && canUDPGRO { + return udp4GROCandidate + } + } else if b[0]>>4 == 6 { + if b[6] == unix.IPPROTO_TCP && len(b) >= 60 { + return tcp6GROCandidate + } + if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && canUDPGRO { + return udp6GROCandidate + } + } + return notGROCandidate +} + +const ( + udphLen = 8 +) + +// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with +// existing packets tracked in table. It returns a groResultNoop when no +// action was taken, groResultTableInsert when the evaluated packet was +// inserted into table, and groResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult { + pkt := bufs[pktI][offset:] + if len(pkt) > maxUint16 { + // A valid IPv4 or IPv6 packet will never exceed this. + return groResultNoop + } + iphLen := int((pkt[0] & 0x0F) * 4) + if isV6 { + iphLen = 40 + ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) + if ipv6HPayloadLen != len(pkt)-iphLen { + return groResultNoop + } + } else { + totalLen := int(binary.BigEndian.Uint16(pkt[2:])) + if totalLen != len(pkt) { + return groResultNoop + } + } + if len(pkt) < iphLen { + return groResultNoop + } + if len(pkt) < iphLen+udphLen { + return groResultNoop + } + if !isV6 { + if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { + // no GRO support for fragmented segments for now + return groResultNoop + } + } + gsoSize := uint16(len(pkt) - udphLen - iphLen) + // not a candidate if payload len is 0 + if gsoSize < 1 { + return groResultNoop + } + srcAddrOffset := ipv4SrcAddrOffset + addrLen := 4 + if isV6 { + srcAddrOffset = ipv6SrcAddrOffset + addrLen = 16 + } + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI) + if !existing { + return groResultTableInsert + } + // With UDP we only check the last item, otherwise we could reorder packets + // for a given flow. We must also always insert a new item, or successfully + // coalesce with an existing item, for the same reason. + item := items[len(items)-1] + can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset) + var pktCSumKnownInvalid bool + if can == coalesceAppend { + result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6) + switch result { + case coalesceSuccess: + table.updateAt(item, len(items)-1) + return groResultCoalesced + case coalesceItemInvalidCSum: + // If the existing item has an invalid csum we take no action. A new + // item will be stored after it, and the existing item will never be + // revisited as part of future coalescing candidacy checks. + case coalescePktInvalidCSum: + // We must insert a new item, but we also mark it as invalid csum + // to prevent a repeat checksum validation. + pktCSumKnownInvalid = true + default: + } + } + // failed to coalesce with any other packets; store the item in the flow + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid) + return groResultTableInsert +} + +// handleGRO evaluates bufs for GRO, and writes the indices of the resulting +// packets into toWrite. toWrite, tcpTable, and udpTable should initially be +// empty (but non-nil), and are passed in to save allocs as the caller may reset +// and recycle them across vectors of packets. canUDPGRO indicates if UDP GRO is +// supported. +func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, canUDPGRO bool, toWrite *[]int) error { + for i := range bufs { + if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { + return errors.New("invalid offset") + } + var result groResult + switch packetIsGROCandidate(bufs[i][offset:], canUDPGRO) { + case tcp4GROCandidate: + result = tcpGRO(bufs, offset, i, tcpTable, false) + case tcp6GROCandidate: + result = tcpGRO(bufs, offset, i, tcpTable, true) + case udp4GROCandidate: + result = udpGRO(bufs, offset, i, udpTable, false) + case udp6GROCandidate: + result = udpGRO(bufs, offset, i, udpTable, true) + } + switch result { + case groResultNoop: + hdr := virtioNetHdr{} + err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + fallthrough + case groResultTableInsert: + *toWrite = append(*toWrite, i) + } + } + errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable) + errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable) + return errors.Join(errTCP, errUDP) +} + +// gsoSplit splits packets from in into outBuffs, writing the size of each +// element into sizes. It returns the number of buffers populated, and/or an +// error. +func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int, isV6 bool) (int, error) { + iphLen := int(hdr.csumStart) + srcAddrOffset := ipv6SrcAddrOffset + addrLen := 16 + if !isV6 { + in[10], in[11] = 0, 0 // clear ipv4 header checksum + srcAddrOffset = ipv4SrcAddrOffset + addrLen = 4 + } + transportCsumAt := int(hdr.csumStart + hdr.csumOffset) + in[transportCsumAt], in[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum + var firstTCPSeqNum uint32 + var protocol uint8 + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV6 { + protocol = unix.IPPROTO_TCP + firstTCPSeqNum = binary.BigEndian.Uint32(in[hdr.csumStart+4:]) + } else { + protocol = unix.IPPROTO_UDP + } + nextSegmentDataAt := int(hdr.hdrLen) + i := 0 + for ; nextSegmentDataAt < len(in); i++ { + if i == len(outBuffs) { + return i - 1, ErrTooManySegments + } + nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize) + if nextSegmentEnd > len(in) { + nextSegmentEnd = len(in) + } + segmentDataLen := nextSegmentEnd - nextSegmentDataAt + totalLen := int(hdr.hdrLen) + segmentDataLen + sizes[i] = totalLen + out := outBuffs[i][outOffset:] + + copy(out, in[:iphLen]) + if !isV6 { + // For IPv4 we are responsible for incrementing the ID field, + // updating the total len field, and recalculating the header + // checksum. + if i > 0 { + id := binary.BigEndian.Uint16(out[4:]) + id += uint16(i) + binary.BigEndian.PutUint16(out[4:], id) + } + binary.BigEndian.PutUint16(out[2:], uint16(totalLen)) + ipv4CSum := ^checksum(out[:iphLen], 0) + binary.BigEndian.PutUint16(out[10:], ipv4CSum) + } else { + // For IPv6 we are responsible for updating the payload length field. + binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen)) + } + + // copy transport header + copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen]) + + if protocol == unix.IPPROTO_TCP { + // set TCP seq and adjust TCP flags + tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i)) + binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq) + if nextSegmentEnd != len(in) { + // FIN and PSH should only be set on last segment + clearFlags := tcpFlagFIN | tcpFlagPSH + out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags + } + } else { + // set UDP header len + binary.BigEndian.PutUint16(out[hdr.csumStart+4:], uint16(segmentDataLen)+(hdr.hdrLen-hdr.csumStart)) + } + + // payload + copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd]) + + // transport checksum + transportHeaderLen := int(hdr.hdrLen - hdr.csumStart) + lenForPseudo := uint16(transportHeaderLen + segmentDataLen) + transportCSumNoFold := pseudoHeaderChecksumNoFold(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo) + transportCSum := ^checksum(out[hdr.csumStart:totalLen], transportCSumNoFold) + binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], transportCSum) + + nextSegmentDataAt += int(hdr.gsoSize) + } + return i, nil +} + +func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error { + cSumAt := cSumStart + cSumOffset + // The initial value at the checksum offset should be summed with the + // checksum we compute. This is typically the pseudo-header checksum. + initial := binary.BigEndian.Uint16(in[cSumAt:]) + in[cSumAt], in[cSumAt+1] = 0, 0 + binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial))) + return nil +} diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go new file mode 100644 index 0000000..d87e636 --- /dev/null +++ b/tun/offload_linux_test.go @@ -0,0 +1,752 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "net/netip" + "testing" + + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +const ( + offset = virtioNetHdrLen +) + +var ( + ip4PortA = netip.MustParseAddrPort("192.0.2.1:1") + ip4PortB = netip.MustParseAddrPort("192.0.2.2:1") + ip4PortC = netip.MustParseAddrPort("192.0.2.3:1") + ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1") + ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1") + ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1") +) + +func udp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv4Fields)) []byte { + totalLen := 28 + payloadLen + b := make([]byte, offset+int(totalLen), 65535) + ipv4H := header.IPv4(b[offset:]) + srcAs4 := srcIPPort.Addr().As4() + dstAs4 := dstIPPort.Addr().As4() + ipFields := &header.IPv4Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs4[:]), + DstAddr: tcpip.AddrFromSlice(dstAs4[:]), + Protocol: unix.IPPROTO_UDP, + TTL: 64, + TotalLength: uint16(totalLen), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv4H.Encode(ipFields) + udpH := header.UDP(b[offset+20:]) + udpH.Encode(&header.UDPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + Length: uint16(payloadLen + udphLen), + }) + ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udphLen+payloadLen)) + udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum)) + return b +} + +func udp6Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte { + return udp6PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil) +} + +func udp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv6Fields)) []byte { + totalLen := 48 + payloadLen + b := make([]byte, offset+int(totalLen), 65535) + ipv6H := header.IPv6(b[offset:]) + srcAs16 := srcIPPort.Addr().As16() + dstAs16 := dstIPPort.Addr().As16() + ipFields := &header.IPv6Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs16[:]), + DstAddr: tcpip.AddrFromSlice(dstAs16[:]), + TransportProtocol: unix.IPPROTO_UDP, + HopLimit: 64, + PayloadLength: uint16(payloadLen + udphLen), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv6H.Encode(ipFields) + udpH := header.UDP(b[offset+40:]) + udpH.Encode(&header.UDPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + Length: uint16(payloadLen + udphLen), + }) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udphLen+payloadLen)) + udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum)) + return b +} + +func udp4Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte { + return udp4PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil) +} + +func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte { + totalLen := 40 + segmentSize + b := make([]byte, offset+int(totalLen), 65535) + ipv4H := header.IPv4(b[offset:]) + srcAs4 := srcIPPort.Addr().As4() + dstAs4 := dstIPPort.Addr().As4() + ipFields := &header.IPv4Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs4[:]), + DstAddr: tcpip.AddrFromSlice(dstAs4[:]), + Protocol: unix.IPPROTO_TCP, + TTL: 64, + TotalLength: uint16(totalLen), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv4H.Encode(ipFields) + tcpH := header.TCP(b[offset+20:]) + tcpH.Encode(&header.TCPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + SeqNum: seq, + AckNum: 1, + DataOffset: 20, + Flags: flags, + WindowSize: 3000, + }) + ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize)) + tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) + return b +} + +func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { + return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) +} + +func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte { + totalLen := 60 + segmentSize + b := make([]byte, offset+int(totalLen), 65535) + ipv6H := header.IPv6(b[offset:]) + srcAs16 := srcIPPort.Addr().As16() + dstAs16 := dstIPPort.Addr().As16() + ipFields := &header.IPv6Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs16[:]), + DstAddr: tcpip.AddrFromSlice(dstAs16[:]), + TransportProtocol: unix.IPPROTO_TCP, + HopLimit: 64, + PayloadLength: uint16(segmentSize + 20), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv6H.Encode(ipFields) + tcpH := header.TCP(b[offset+40:]) + tcpH.Encode(&header.TCPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + SeqNum: seq, + AckNum: 1, + DataOffset: 20, + Flags: flags, + WindowSize: 3000, + }) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize)) + tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) + return b +} + +func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { + return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) +} + +func Test_handleVirtioRead(t *testing.T) { + tests := []struct { + name string + hdr virtioNetHdr + pktIn []byte + wantLens []int + wantErr bool + }{ + { + "tcp4", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4, + gsoSize: 100, + hdrLen: 40, + csumStart: 20, + csumOffset: 16, + }, + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), + []int{140, 140}, + false, + }, + { + "tcp6", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6, + gsoSize: 100, + hdrLen: 60, + csumStart: 40, + csumOffset: 16, + }, + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), + []int{160, 160}, + false, + }, + { + "udp4", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4, + gsoSize: 100, + hdrLen: 28, + csumStart: 20, + csumOffset: 6, + }, + udp4Packet(ip4PortA, ip4PortB, 200), + []int{128, 128}, + false, + }, + { + "udp6", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4, + gsoSize: 100, + hdrLen: 48, + csumStart: 40, + csumOffset: 6, + }, + udp6Packet(ip6PortA, ip6PortB, 200), + []int{148, 148}, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out := make([][]byte, conn.IdealBatchSize) + sizes := make([]int, conn.IdealBatchSize) + for i := range out { + out[i] = make([]byte, 65535) + } + tt.hdr.encode(tt.pktIn) + n, err := handleVirtioRead(tt.pktIn, out, sizes, offset) + if err != nil { + if tt.wantErr { + return + } + t.Fatalf("got err: %v", err) + } + if n != len(tt.wantLens) { + t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens)) + } + for i := range tt.wantLens { + if tt.wantLens[i] != sizes[i] { + t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i]) + } + } + }) + } +} + +func flipTCP4Checksum(b []byte) []byte { + at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16 + b[at] ^= 0xFF + b[at+1] ^= 0xFF + return b +} + +func flipUDP4Checksum(b []byte) []byte { + at := virtioNetHdrLen + 20 + 6 // 20 byte ipv4 header; udp csum offset is 6 + b[at] ^= 0xFF + b[at+1] ^= 0xFF + return b +} + +func Fuzz_handleGRO(f *testing.F) { + pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1) + pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101) + pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201) + pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1) + pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101) + pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201) + pkt6 := udp4Packet(ip4PortA, ip4PortB, 100) + pkt7 := udp4Packet(ip4PortA, ip4PortB, 100) + pkt8 := udp4Packet(ip4PortA, ip4PortC, 100) + pkt9 := udp6Packet(ip6PortA, ip6PortB, 100) + pkt10 := udp6Packet(ip6PortA, ip6PortB, 100) + pkt11 := udp6Packet(ip6PortA, ip6PortC, 100) + f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, true, offset) + f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, canUDPGRO bool, offset int) { + pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11} + toWrite := make([]int, 0, len(pkts)) + handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, &toWrite) + if len(toWrite) > len(pkts) { + t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) + } + seenWriteI := make(map[int]bool) + for _, writeI := range toWrite { + if writeI < 0 || writeI > len(pkts)-1 { + t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts)) + } + if seenWriteI[writeI] { + t.Errorf("duplicate toWrite value: %d", writeI) + } + seenWriteI[writeI] = true + } + }) +} + +func Test_handleGRO(t *testing.T) { + tests := []struct { + name string + pktsIn [][]byte + canUDPGRO bool + wantToWrite []int + wantLens []int + wantErr bool + }{ + { + "multiple protocols and flows", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1 + tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + }, + true, + []int{0, 1, 2, 4, 5, 7, 9}, + []int{240, 228, 128, 140, 260, 160, 248}, + false, + }, + { + "multiple protocols and flows no UDP GRO", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1 + tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + }, + false, + []int{0, 1, 2, 4, 5, 7, 8, 9, 10}, + []int{240, 128, 128, 140, 260, 160, 128, 148, 148}, + false, + }, + { + "PSH interleaved", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1 + }, + true, + []int{0, 2, 4, 6}, + []int{240, 240, 260, 260}, + false, + }, + { + "coalesceItemInvalidCSum", + [][]byte{ + flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 + flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4Packet(ip4PortA, ip4PortB, 100), + }, + true, + []int{0, 1, 3, 4}, + []int{140, 240, 128, 228}, + false, + }, + { + "out of order", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 + }, + true, + []int{0}, + []int{340}, + false, + }, + { + "unequal TTL", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.TTL++ + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.TTL++ + }), + }, + true, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, + }, + { + "unequal ToS", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.TOS++ + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.TOS++ + }), + }, + true, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, + }, + { + "unequal flags more fragments set", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.Flags = 1 + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.Flags = 1 + }), + }, + true, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, + }, + { + "unequal flags DF set", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.Flags = 2 + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.Flags = 2 + }), + }, + true, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, + }, + { + "ipv6 unequal hop limit", + [][]byte{ + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), + tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { + fields.HopLimit++ + }), + udp6Packet(ip6PortA, ip6PortB, 100), + udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) { + fields.HopLimit++ + }), + }, + true, + []int{0, 1, 2, 3}, + []int{160, 160, 148, 148}, + false, + }, + { + "ipv6 unequal traffic class", + [][]byte{ + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), + tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { + fields.TrafficClass++ + }), + udp6Packet(ip6PortA, ip6PortB, 100), + udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) { + fields.TrafficClass++ + }), + }, + true, + []int{0, 1, 2, 3}, + []int{160, 160, 148, 148}, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + toWrite := make([]int, 0, len(tt.pktsIn)) + err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, &toWrite) + if err != nil { + if tt.wantErr { + return + } + t.Fatalf("got err: %v", err) + } + if len(toWrite) != len(tt.wantToWrite) { + t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite)) + } + for i, pktI := range tt.wantToWrite { + if tt.wantToWrite[i] != toWrite[i] { + t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i]) + } + if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) { + t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:])) + } + } + }) + } +} + +func Test_packetIsGROCandidate(t *testing.T) { + tcp4 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] + tcp4TooShort := tcp4[:39] + ip4InvalidHeaderLen := make([]byte, len(tcp4)) + copy(ip4InvalidHeaderLen, tcp4) + ip4InvalidHeaderLen[0] = 0x46 + ip4InvalidProtocol := make([]byte, len(tcp4)) + copy(ip4InvalidProtocol, tcp4) + ip4InvalidProtocol[9] = unix.IPPROTO_GRE + + tcp6 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] + tcp6TooShort := tcp6[:59] + ip6InvalidProtocol := make([]byte, len(tcp6)) + copy(ip6InvalidProtocol, tcp6) + ip6InvalidProtocol[6] = unix.IPPROTO_GRE + + udp4 := udp4Packet(ip4PortA, ip4PortB, 100)[virtioNetHdrLen:] + udp4TooShort := udp4[:27] + + udp6 := udp6Packet(ip6PortA, ip6PortB, 100)[virtioNetHdrLen:] + udp6TooShort := udp6[:47] + + tests := []struct { + name string + b []byte + canUDPGRO bool + want groCandidateType + }{ + { + "tcp4", + tcp4, + true, + tcp4GROCandidate, + }, + { + "tcp6", + tcp6, + true, + tcp6GROCandidate, + }, + { + "udp4", + udp4, + true, + udp4GROCandidate, + }, + { + "udp4 no support", + udp4, + false, + notGROCandidate, + }, + { + "udp6", + udp6, + true, + udp6GROCandidate, + }, + { + "udp6 no support", + udp6, + false, + notGROCandidate, + }, + { + "udp4 too short", + udp4TooShort, + true, + notGROCandidate, + }, + { + "udp6 too short", + udp6TooShort, + true, + notGROCandidate, + }, + { + "tcp4 too short", + tcp4TooShort, + true, + notGROCandidate, + }, + { + "tcp6 too short", + tcp6TooShort, + true, + notGROCandidate, + }, + { + "invalid IP version", + []byte{0x00}, + true, + notGROCandidate, + }, + { + "invalid IP header len", + ip4InvalidHeaderLen, + true, + notGROCandidate, + }, + { + "ip4 invalid protocol", + ip4InvalidProtocol, + true, + notGROCandidate, + }, + { + "ip6 invalid protocol", + ip6InvalidProtocol, + true, + notGROCandidate, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := packetIsGROCandidate(tt.b, tt.canUDPGRO); got != tt.want { + t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_udpPacketsCanCoalesce(t *testing.T) { + udp4a := udp4Packet(ip4PortA, ip4PortB, 100) + udp4b := udp4Packet(ip4PortA, ip4PortB, 100) + udp4c := udp4Packet(ip4PortA, ip4PortB, 110) + + type args struct { + pkt []byte + iphLen uint8 + gsoSize uint16 + item udpGROItem + bufs [][]byte + bufsOffset int + } + tests := []struct { + name string + args args + want canCoalesce + }{ + { + "coalesceAppend equal gso", + args{ + pkt: udp4a[offset:], + iphLen: 20, + gsoSize: 100, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4a, + udp4b, + }, + bufsOffset: offset, + }, + coalesceAppend, + }, + { + "coalesceAppend smaller gso", + args{ + pkt: udp4a[offset : len(udp4a)-90], + iphLen: 20, + gsoSize: 10, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4a, + udp4b, + }, + bufsOffset: offset, + }, + coalesceAppend, + }, + { + "coalesceUnavailable smaller gso previously appended", + args{ + pkt: udp4a[offset:], + iphLen: 20, + gsoSize: 100, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4c, + udp4b, + }, + bufsOffset: offset, + }, + coalesceUnavailable, + }, + { + "coalesceUnavailable larger following smaller", + args{ + pkt: udp4c[offset:], + iphLen: 20, + gsoSize: 110, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4a, + udp4c, + }, + bufsOffset: offset, + }, + coalesceUnavailable, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want { + t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/tun/operateonfd.go b/tun/operateonfd.go index f1beb6d..343f754 100644 --- a/tun/operateonfd.go +++ b/tun/operateonfd.go @@ -2,7 +2,7 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tun diff --git a/tun/tcp_offload_linux.go b/tun/tcp_offload_linux.go deleted file mode 100644 index 4912efd..0000000 --- a/tun/tcp_offload_linux.go +++ /dev/null @@ -1,612 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - */ - -package tun - -import ( - "bytes" - "encoding/binary" - "errors" - "io" - "unsafe" - - "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/conn" -) - -const tcpFlagsOffset = 13 - -const ( - tcpFlagFIN uint8 = 0x01 - tcpFlagPSH uint8 = 0x08 - tcpFlagACK uint8 = 0x10 -) - -// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The -// kernel symbol is virtio_net_hdr. -type virtioNetHdr struct { - flags uint8 - gsoType uint8 - hdrLen uint16 - gsoSize uint16 - csumStart uint16 - csumOffset uint16 -} - -func (v *virtioNetHdr) decode(b []byte) error { - if len(b) < virtioNetHdrLen { - return io.ErrShortBuffer - } - copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen]) - return nil -} - -func (v *virtioNetHdr) encode(b []byte) error { - if len(b) < virtioNetHdrLen { - return io.ErrShortBuffer - } - copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen)) - return nil -} - -const ( - // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the - // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr). - virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{})) -) - -// flowKey represents the key for a flow. -type flowKey struct { - srcAddr, dstAddr [16]byte - srcPort, dstPort uint16 - rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows. -} - -// tcpGROTable holds flow and coalescing information for the purposes of GRO. -type tcpGROTable struct { - itemsByFlow map[flowKey][]tcpGROItem - itemsPool [][]tcpGROItem -} - -func newTCPGROTable() *tcpGROTable { - t := &tcpGROTable{ - itemsByFlow: make(map[flowKey][]tcpGROItem, conn.IdealBatchSize), - itemsPool: make([][]tcpGROItem, conn.IdealBatchSize), - } - for i := range t.itemsPool { - t.itemsPool[i] = make([]tcpGROItem, 0, conn.IdealBatchSize) - } - return t -} - -func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey { - key := flowKey{} - addrSize := dstAddr - srcAddr - copy(key.srcAddr[:], pkt[srcAddr:dstAddr]) - copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize]) - key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:]) - key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:]) - key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:]) - return key -} - -// lookupOrInsert looks up a flow for the provided packet and metadata, -// returning the packets found for the flow, or inserting a new one if none -// is found. -func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) { - key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) - items, ok := t.itemsByFlow[key] - if ok { - return items, ok - } - // TODO: insert() performs another map lookup. This could be rearranged to avoid. - t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex) - return nil, false -} - -// insert an item in the table for the provided packet and packet metadata. -func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) { - key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) - item := tcpGROItem{ - key: key, - bufsIndex: uint16(bufsIndex), - gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])), - iphLen: uint8(tcphOffset), - tcphLen: uint8(tcphLen), - sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]), - pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0, - } - items, ok := t.itemsByFlow[key] - if !ok { - items = t.newItems() - } - items = append(items, item) - t.itemsByFlow[key] = items -} - -func (t *tcpGROTable) updateAt(item tcpGROItem, i int) { - items, _ := t.itemsByFlow[item.key] - items[i] = item -} - -func (t *tcpGROTable) deleteAt(key flowKey, i int) { - items, _ := t.itemsByFlow[key] - items = append(items[:i], items[i+1:]...) - t.itemsByFlow[key] = items -} - -// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime -// of a GRO evaluation across a vector of packets. -type tcpGROItem struct { - key flowKey - sentSeq uint32 // the sequence number - bufsIndex uint16 // the index into the original bufs slice - numMerged uint16 // the number of packets merged into this item - gsoSize uint16 // payload size - iphLen uint8 // ip header len - tcphLen uint8 // tcp header len - pshSet bool // psh flag is set -} - -func (t *tcpGROTable) newItems() []tcpGROItem { - var items []tcpGROItem - items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1] - return items -} - -func (t *tcpGROTable) reset() { - for k, items := range t.itemsByFlow { - items = items[:0] - t.itemsPool = append(t.itemsPool, items) - delete(t.itemsByFlow, k) - } -} - -// canCoalesce represents the outcome of checking if two TCP packets are -// candidates for coalescing. -type canCoalesce int - -const ( - coalescePrepend canCoalesce = -1 - coalesceUnavailable canCoalesce = 0 - coalesceAppend canCoalesce = 1 -) - -// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet -// described by item. This function makes considerations that match the kernel's -// GRO self tests, which can be found in tools/testing/selftests/net/gro.c. -func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { - pktTarget := bufs[item.bufsIndex][bufsOffset:] - if tcphLen != item.tcphLen { - // cannot coalesce with unequal tcp options len - return coalesceUnavailable - } - if tcphLen > 20 { - if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) { - // cannot coalesce with unequal tcp options - return coalesceUnavailable - } - } - if pkt[1] != pktTarget[1] { - // cannot coalesce with unequal ToS values - return coalesceUnavailable - } - if pkt[6]>>5 != pktTarget[6]>>5 { - // cannot coalesce with unequal DF or reserved bits. MF is checked - // further up the stack. - return coalesceUnavailable - } - // seq adjacency - lhsLen := item.gsoSize - lhsLen += item.numMerged * item.gsoSize - if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective - if item.pshSet { - // We cannot append to a segment that has the PSH flag set, PSH - // can only be set on the final segment in a reassembled group. - return coalesceUnavailable - } - if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 { - // A smaller than gsoSize packet has been appended previously. - // Nothing can come after a smaller packet on the end. - return coalesceUnavailable - } - if gsoSize > item.gsoSize { - // We cannot have a larger packet following a smaller one. - return coalesceUnavailable - } - return coalesceAppend - } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective - if pshSet { - // We cannot prepend with a segment that has the PSH flag set, PSH - // can only be set on the final segment in a reassembled group. - return coalesceUnavailable - } - if gsoSize < item.gsoSize { - // We cannot have a larger packet following a smaller one. - return coalesceUnavailable - } - if gsoSize > item.gsoSize && item.numMerged > 0 { - // There's at least one previous merge, and we're larger than all - // previous. This would put multiple smaller packets on the end. - return coalesceUnavailable - } - return coalescePrepend - } - return coalesceUnavailable -} - -func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool { - srcAddrAt := ipv4SrcAddrOffset - addrSize := 4 - if isV6 { - srcAddrAt = ipv6SrcAddrOffset - addrSize = 16 - } - tcpTotalLen := uint16(len(pkt) - int(iphLen)) - tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen) - return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0 -} - -// coalesceResult represents the result of attempting to coalesce two TCP -// packets. -type coalesceResult int - -const ( - coalesceInsufficientCap coalesceResult = 0 - coalescePSHEnding coalesceResult = 1 - coalesceItemInvalidCSum coalesceResult = 2 - coalescePktInvalidCSum coalesceResult = 3 - coalesceSuccess coalesceResult = 4 -) - -// coalesceTCPPackets attempts to coalesce pkt with the packet described by -// item, returning the outcome. This function may swap bufs elements in the -// event of a prepend as item's bufs index is already being tracked for writing -// to a Device. -func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { - var pktHead []byte // the packet that will end up at the front - headersLen := item.iphLen + item.tcphLen - coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) - - // Copy data - if mode == coalescePrepend { - pktHead = pkt - if cap(pkt)-bufsOffset < coalescedLen { - // We don't want to allocate a new underlying array if capacity is - // too small. - return coalesceInsufficientCap - } - if pshSet { - return coalescePSHEnding - } - if item.numMerged == 0 { - if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) { - return coalesceItemInvalidCSum - } - } - if !tcpChecksumValid(pkt, item.iphLen, isV6) { - return coalescePktInvalidCSum - } - item.sentSeq = seq - extendBy := coalescedLen - len(pktHead) - bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...) - copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):]) - // Flip the slice headers in bufs as part of prepend. The index of item - // is already being tracked for writing. - bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex] - } else { - pktHead = bufs[item.bufsIndex][bufsOffset:] - if cap(pktHead)-bufsOffset < coalescedLen { - // We don't want to allocate a new underlying array if capacity is - // too small. - return coalesceInsufficientCap - } - if item.numMerged == 0 { - if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) { - return coalesceItemInvalidCSum - } - } - if !tcpChecksumValid(pkt, item.iphLen, isV6) { - return coalescePktInvalidCSum - } - if pshSet { - // We are appending a segment with PSH set. - item.pshSet = pshSet - pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH - } - extendBy := len(pkt) - int(headersLen) - bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) - copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) - } - - if gsoSize > item.gsoSize { - item.gsoSize = gsoSize - } - hdr := virtioNetHdr{ - flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb - hdrLen: uint16(headersLen), - gsoSize: uint16(item.gsoSize), - csumStart: uint16(item.iphLen), - csumOffset: 16, - } - - // Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the - // (IPv4) header checksum. - if isV6 { - hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 - binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len - } else { - hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 - pktHead[10], pktHead[11] = 0, 0 // clear checksum field - binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length - iphCSum := ^checksum(pktHead[:item.iphLen], 0) // compute checksum - binary.BigEndian.PutUint16(pktHead[10:], iphCSum) // set checksum field - } - hdr.encode(bufs[item.bufsIndex][bufsOffset-virtioNetHdrLen:]) - - // Calculate the pseudo header checksum and place it at the TCP checksum - // offset. Downstream checksum offloading will combine this with computation - // of the tcp header and payload checksum. - addrLen := 4 - addrOffset := ipv4SrcAddrOffset - if isV6 { - addrLen = 16 - addrOffset = ipv6SrcAddrOffset - } - srcAddrAt := bufsOffset + addrOffset - srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] - dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] - psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen))) - binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) - - item.numMerged++ - return coalesceSuccess -} - -const ( - ipv4FlagMoreFragments = 0x80 -) - -const ( - ipv4SrcAddrOffset = 12 - ipv6SrcAddrOffset = 8 - maxUint16 = 1<<16 - 1 -) - -// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with -// existing packets tracked in table. It will return false when pktI is not -// coalesced, otherwise true. This indicates to the caller if bufs[pktI] -// should be written to the Device. -func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) { - pkt := bufs[pktI][offset:] - if len(pkt) > maxUint16 { - // A valid IPv4 or IPv6 packet will never exceed this. - return false - } - iphLen := int((pkt[0] & 0x0F) * 4) - if isV6 { - iphLen = 40 - ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) - if ipv6HPayloadLen != len(pkt)-iphLen { - return false - } - } else { - totalLen := int(binary.BigEndian.Uint16(pkt[2:])) - if totalLen != len(pkt) { - return false - } - } - if len(pkt) < iphLen { - return false - } - tcphLen := int((pkt[iphLen+12] >> 4) * 4) - if tcphLen < 20 || tcphLen > 60 { - return false - } - if len(pkt) < iphLen+tcphLen { - return false - } - if !isV6 { - if pkt[6]&ipv4FlagMoreFragments != 0 || (pkt[6]<<3 != 0 || pkt[7] != 0) { - // no GRO support for fragmented segments for now - return false - } - } - tcpFlags := pkt[iphLen+tcpFlagsOffset] - var pshSet bool - // not a candidate if any non-ACK flags (except PSH+ACK) are set - if tcpFlags != tcpFlagACK { - if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { - return false - } - pshSet = true - } - gsoSize := uint16(len(pkt) - tcphLen - iphLen) - // not a candidate if payload len is 0 - if gsoSize < 1 { - return false - } - seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) - srcAddrOffset := ipv4SrcAddrOffset - addrLen := 4 - if isV6 { - srcAddrOffset = ipv6SrcAddrOffset - addrLen = 16 - } - items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) - if !existing { - return false - } - for i := len(items) - 1; i >= 0; i-- { - // In the best case of packets arriving in order iterating in reverse is - // more efficient if there are multiple items for a given flow. This - // also enables a natural table.deleteAt() in the - // coalesceItemInvalidCSum case without the need for index tracking. - // This algorithm makes a best effort to coalesce in the event of - // unordered packets, where pkt may land anywhere in items from a - // sequence number perspective, however once an item is inserted into - // the table it is never compared across other items later. - item := items[i] - can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset) - if can != coalesceUnavailable { - result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6) - switch result { - case coalesceSuccess: - table.updateAt(item, i) - return true - case coalesceItemInvalidCSum: - // delete the item with an invalid csum - table.deleteAt(item.key, i) - case coalescePktInvalidCSum: - // no point in inserting an item that we can't coalesce - return false - default: - } - } - } - // failed to coalesce with any other packets; store the item in the flow - table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) - return false -} - -func isTCP4NoIPOptions(b []byte) bool { - if len(b) < 40 { - return false - } - if b[0]>>4 != 4 { - return false - } - if b[0]&0x0F != 5 { - return false - } - if b[9] != unix.IPPROTO_TCP { - return false - } - return true -} - -func isTCP6NoEH(b []byte) bool { - if len(b) < 60 { - return false - } - if b[0]>>4 != 6 { - return false - } - if b[6] != unix.IPPROTO_TCP { - return false - } - return true -} - -// handleGRO evaluates bufs for GRO, and writes the indices of the resulting -// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be -// empty (but non-nil), and are passed in to save allocs as the caller may reset -// and recycle them across vectors of packets. -func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error { - for i := range bufs { - if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { - return errors.New("invalid offset") - } - var coalesced bool - switch { - case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce - coalesced = tcpGRO(bufs, offset, i, tcp4Table, false) - case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce - coalesced = tcpGRO(bufs, offset, i, tcp6Table, true) - } - if !coalesced { - hdr := virtioNetHdr{} - err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) - if err != nil { - return err - } - *toWrite = append(*toWrite, i) - } - } - return nil -} - -// tcpTSO splits packets from in into outBuffs, writing the size of each -// element into sizes. It returns the number of buffers populated, and/or an -// error. -func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) { - iphLen := int(hdr.csumStart) - srcAddrOffset := ipv6SrcAddrOffset - addrLen := 16 - if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 { - in[10], in[11] = 0, 0 // clear ipv4 header checksum - srcAddrOffset = ipv4SrcAddrOffset - addrLen = 4 - } - tcpCSumAt := int(hdr.csumStart + hdr.csumOffset) - in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum - firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:]) - nextSegmentDataAt := int(hdr.hdrLen) - i := 0 - for ; nextSegmentDataAt < len(in); i++ { - if i == len(outBuffs) { - return i - 1, ErrTooManySegments - } - nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize) - if nextSegmentEnd > len(in) { - nextSegmentEnd = len(in) - } - segmentDataLen := nextSegmentEnd - nextSegmentDataAt - totalLen := int(hdr.hdrLen) + segmentDataLen - sizes[i] = totalLen - out := outBuffs[i][outOffset:] - - copy(out, in[:iphLen]) - if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 { - // For IPv4 we are responsible for incrementing the ID field, - // updating the total len field, and recalculating the header - // checksum. - if i > 0 { - id := binary.BigEndian.Uint16(out[4:]) - id += uint16(i) - binary.BigEndian.PutUint16(out[4:], id) - } - binary.BigEndian.PutUint16(out[2:], uint16(totalLen)) - ipv4CSum := ^checksum(out[:iphLen], 0) - binary.BigEndian.PutUint16(out[10:], ipv4CSum) - } else { - // For IPv6 we are responsible for updating the payload length field. - binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen)) - } - - // TCP header - copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen]) - tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i)) - binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq) - if nextSegmentEnd != len(in) { - // FIN and PSH should only be set on last segment - clearFlags := tcpFlagFIN | tcpFlagPSH - out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags - } - - // payload - copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd]) - - // TCP checksum - tcpHLen := int(hdr.hdrLen - hdr.csumStart) - tcpLenForPseudo := uint16(tcpHLen + segmentDataLen) - tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo) - tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold) - binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum) - - nextSegmentDataAt += int(hdr.gsoSize) - } - return i, nil -} - -func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error { - cSumAt := cSumStart + cSumOffset - // The initial value at the checksum offset should be summed with the - // checksum we compute. This is typically the pseudo-header checksum. - initial := binary.BigEndian.Uint16(in[cSumAt:]) - in[cSumAt], in[cSumAt+1] = 0, 0 - binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial))) - return nil -} diff --git a/tun/tcp_offload_linux_test.go b/tun/tcp_offload_linux_test.go deleted file mode 100644 index 046e177..0000000 --- a/tun/tcp_offload_linux_test.go +++ /dev/null @@ -1,323 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - */ - -package tun - -import ( - "net/netip" - "testing" - - "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/conn" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -const ( - offset = virtioNetHdrLen -) - -var ( - ip4PortA = netip.MustParseAddrPort("192.0.2.1:1") - ip4PortB = netip.MustParseAddrPort("192.0.2.2:1") - ip4PortC = netip.MustParseAddrPort("192.0.2.3:1") - ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1") - ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1") - ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1") -) - -func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { - totalLen := 40 + segmentSize - b := make([]byte, offset+int(totalLen), 65535) - ipv4H := header.IPv4(b[offset:]) - srcAs4 := srcIPPort.Addr().As4() - dstAs4 := dstIPPort.Addr().As4() - ipv4H.Encode(&header.IPv4Fields{ - SrcAddr: tcpip.Address(srcAs4[:]), - DstAddr: tcpip.Address(dstAs4[:]), - Protocol: unix.IPPROTO_TCP, - TTL: 64, - TotalLength: uint16(totalLen), - }) - tcpH := header.TCP(b[offset+20:]) - tcpH.Encode(&header.TCPFields{ - SrcPort: srcIPPort.Port(), - DstPort: dstIPPort.Port(), - SeqNum: seq, - AckNum: 1, - DataOffset: 20, - Flags: flags, - WindowSize: 3000, - }) - ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) - pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize)) - tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) - return b -} - -func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { - totalLen := 60 + segmentSize - b := make([]byte, offset+int(totalLen), 65535) - ipv6H := header.IPv6(b[offset:]) - srcAs16 := srcIPPort.Addr().As16() - dstAs16 := dstIPPort.Addr().As16() - ipv6H.Encode(&header.IPv6Fields{ - SrcAddr: tcpip.Address(srcAs16[:]), - DstAddr: tcpip.Address(dstAs16[:]), - TransportProtocol: unix.IPPROTO_TCP, - HopLimit: 64, - PayloadLength: uint16(segmentSize + 20), - }) - tcpH := header.TCP(b[offset+40:]) - tcpH.Encode(&header.TCPFields{ - SrcPort: srcIPPort.Port(), - DstPort: dstIPPort.Port(), - SeqNum: seq, - AckNum: 1, - DataOffset: 20, - Flags: flags, - WindowSize: 3000, - }) - pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize)) - tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) - return b -} - -func Test_handleVirtioRead(t *testing.T) { - tests := []struct { - name string - hdr virtioNetHdr - pktIn []byte - wantLens []int - wantErr bool - }{ - { - "tcp4", - virtioNetHdr{ - flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, - gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4, - gsoSize: 100, - hdrLen: 40, - csumStart: 20, - csumOffset: 16, - }, - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), - []int{140, 140}, - false, - }, - { - "tcp6", - virtioNetHdr{ - flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, - gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6, - gsoSize: 100, - hdrLen: 60, - csumStart: 40, - csumOffset: 16, - }, - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), - []int{160, 160}, - false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - out := make([][]byte, conn.IdealBatchSize) - sizes := make([]int, conn.IdealBatchSize) - for i := range out { - out[i] = make([]byte, 65535) - } - tt.hdr.encode(tt.pktIn) - n, err := handleVirtioRead(tt.pktIn, out, sizes, offset) - if err != nil { - if tt.wantErr { - return - } - t.Fatalf("got err: %v", err) - } - if n != len(tt.wantLens) { - t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens)) - } - for i := range tt.wantLens { - if tt.wantLens[i] != sizes[i] { - t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i]) - } - } - }) - } -} - -func flipTCP4Checksum(b []byte) []byte { - at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16 - b[at] ^= 0xFF - b[at+1] ^= 0xFF - return b -} - -func Fuzz_handleGRO(f *testing.F) { - pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1) - pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101) - pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201) - pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1) - pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101) - pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201) - f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, offset) - f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5 []byte, offset int) { - pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5} - toWrite := make([]int, 0, len(pkts)) - handleGRO(pkts, offset, newTCPGROTable(), newTCPGROTable(), &toWrite) - if len(toWrite) > len(pkts) { - t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) - } - seenWriteI := make(map[int]bool) - for _, writeI := range toWrite { - if writeI < 0 || writeI > len(pkts)-1 { - t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts)) - } - if seenWriteI[writeI] { - t.Errorf("duplicate toWrite value: %d", writeI) - } - seenWriteI[writeI] = true - } - }) -} - -func Test_handleGRO(t *testing.T) { - tests := []struct { - name string - pktsIn [][]byte - wantToWrite []int - wantLens []int - wantErr bool - }{ - { - "multiple flows", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // v4 flow 2 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // v6 flow 2 - }, - []int{0, 2, 3, 5}, - []int{240, 140, 260, 160}, - false, - }, - { - "PSH interleaved", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1 - }, - []int{0, 2, 4, 6}, - []int{240, 240, 260, 260}, - false, - }, - { - "coalesceItemInvalidCSum", - [][]byte{ - flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 - }, - []int{0, 1}, - []int{140, 240}, - false, - }, - { - "out of order", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 - }, - []int{0}, - []int{340}, - false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - toWrite := make([]int, 0, len(tt.pktsIn)) - err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newTCPGROTable(), &toWrite) - if err != nil { - if tt.wantErr { - return - } - t.Fatalf("got err: %v", err) - } - if len(toWrite) != len(tt.wantToWrite) { - t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite)) - } - for i, pktI := range tt.wantToWrite { - if tt.wantToWrite[i] != toWrite[i] { - t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i]) - } - if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) { - t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:])) - } - } - }) - } -} - -func Test_isTCP4NoIPOptions(t *testing.T) { - valid := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] - invalidLen := valid[:39] - invalidHeaderLen := make([]byte, len(valid)) - copy(invalidHeaderLen, valid) - invalidHeaderLen[0] = 0x46 - invalidProtocol := make([]byte, len(valid)) - copy(invalidProtocol, valid) - invalidProtocol[9] = unix.IPPROTO_TCP + 1 - - tests := []struct { - name string - b []byte - want bool - }{ - { - "valid", - valid, - true, - }, - { - "invalid length", - invalidLen, - false, - }, - { - "invalid version", - []byte{0x00}, - false, - }, - { - "invalid header len", - invalidHeaderLen, - false, - }, - { - "invalid protocol", - invalidProtocol, - false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := isTCP4NoIPOptions(tt.b); got != tt.want { - t.Errorf("isTCP4NoIPOptions() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 b/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 deleted file mode 100644 index 5461e79..0000000 --- a/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 +++ /dev/null @@ -1,8 +0,0 @@ -go test fuzz v1 -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -int(34) diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d b/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d deleted file mode 100644 index b441819..0000000 --- a/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d +++ /dev/null @@ -1,8 +0,0 @@ -go test fuzz v1 -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -int(-48) @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tun diff --git a/tun/tun_darwin.go b/tun/tun_darwin.go index c9a6c0b..341afe3 100644 --- a/tun/tun_darwin.go +++ b/tun/tun_darwin.go @@ -1,19 +1,17 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tun import ( - "errors" "fmt" "io" "net" "os" "sync" "syscall" - "time" "unsafe" "golang.org/x/sys/unix" @@ -30,18 +28,6 @@ type NativeTun struct { closeOnce sync.Once } -func retryInterfaceByIndex(index int) (iface *net.Interface, err error) { - for i := 0; i < 20; i++ { - iface, err = net.InterfaceByIndex(index) - if err != nil && errors.Is(err, unix.ENOMEM) { - time.Sleep(time.Duration(i) * time.Second / 3) - continue - } - return iface, err - } - return nil, err -} - func (tun *NativeTun) routineRouteListener(tunIfindex int) { var ( statusUp bool @@ -62,26 +48,22 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) { return } - if n < 14 { + if n < 28 { continue } - if data[3 /* type */] != unix.RTM_IFINFO { + if data[3 /* ifm_type */] != unix.RTM_IFINFO { continue } - ifindex := int(*(*uint16)(unsafe.Pointer(&data[12 /* ifindex */]))) + ifindex := int(*(*uint16)(unsafe.Pointer(&data[12 /* ifm_index */]))) if ifindex != tunIfindex { continue } - iface, err := retryInterfaceByIndex(ifindex) - if err != nil { - tun.errors <- err - return - } + flags := int(*(*uint32)(unsafe.Pointer(&data[8 /* ifm_flags */]))) // Up / Down event - up := (iface.Flags & net.FlagUp) != 0 + up := (flags & syscall.IFF_UP) != 0 if up != statusUp && up { tun.events <- EventUp } @@ -90,11 +72,13 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) { } statusUp = up + mtu := int(*(*uint32)(unsafe.Pointer(&data[24 /* ifm_data.ifi_mtu */]))) + // MTU changes - if iface.MTU != statusMTU { + if mtu != statusMTU { tun.events <- EventMTUUpdate } - statusMTU = iface.MTU + statusMTU = mtu } } diff --git a/tun/tun_freebsd.go b/tun/tun_freebsd.go index 7c65fd9..4adf3a1 100644 --- a/tun/tun_freebsd.go +++ b/tun/tun_freebsd.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tun diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 12cd49f..1461e06 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tun @@ -38,6 +38,7 @@ type NativeTun struct { statusListenersShutdown chan struct{} batchSize int vnetHdr bool + udpGSO bool closeOnce sync.Once @@ -48,9 +49,10 @@ type NativeTun struct { readOpMu sync.Mutex // readOpMu guards readBuff readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr - writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable - toWrite []int - tcp4GROTable, tcp6GROTable *tcpGROTable + writeOpMu sync.Mutex // writeOpMu guards toWrite, tcpGROTable + toWrite []int + tcpGROTable *tcpGROTable + udpGROTable *udpGROTable } func (tun *NativeTun) File() *os.File { @@ -333,8 +335,8 @@ func (tun *NativeTun) nameSlow() (string, error) { func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { tun.writeOpMu.Lock() defer func() { - tun.tcp4GROTable.reset() - tun.tcp6GROTable.reset() + tun.tcpGROTable.reset() + tun.udpGROTable.reset() tun.writeOpMu.Unlock() }() var ( @@ -343,7 +345,7 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { ) tun.toWrite = tun.toWrite[:0] if tun.vnetHdr { - err := handleGRO(bufs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite) + err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, &tun.toWrite) if err != nil { return 0, err } @@ -394,37 +396,42 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e sizes[0] = n return 1, nil } - if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType) } ipVersion := in[0] >> 4 switch ipVersion { case 4: - if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 { + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) } case 6: - if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) } default: return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) } - if len(in) <= int(hdr.csumStart+12) { - return 0, errors.New("packet is too short") - } // Don't trust hdr.hdrLen from the kernel as it can be equal to the length // of the entire first packet when the kernel is handling it as part of a - // FORWARD path. Instead, parse the TCP header length and add it onto + // FORWARD path. Instead, parse the transport header length and add it onto // csumStart, which is synonymous for IP header length. - tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) - if tcpHLen < 20 || tcpHLen > 60 { - // A TCP header must be between 20 and 60 bytes in length. - return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 { + hdr.hdrLen = hdr.csumStart + 8 + } else { + if len(in) <= int(hdr.csumStart+12) { + return 0, errors.New("packet is too short") + } + + tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) + if tcpHLen < 20 || tcpHLen > 60 { + // A TCP header must be between 20 and 60 bytes in length. + return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) + } + hdr.hdrLen = hdr.csumStart + tcpHLen } - hdr.hdrLen = hdr.csumStart + tcpHLen if len(in) < int(hdr.hdrLen) { return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen) @@ -438,7 +445,7 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) } - return tcpTSO(in, hdr, bufs, sizes, offset) + return gsoSplit(in, hdr, bufs, sizes, offset, ipVersion == 6) } func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { @@ -497,7 +504,8 @@ func (tun *NativeTun) BatchSize() int { const ( // TODO: support TSO with ECN bits - tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6 ) func (tun *NativeTun) initFromFlags(name string) error { @@ -519,12 +527,17 @@ func (tun *NativeTun) initFromFlags(name string) error { } got := ifr.Uint16() if got&unix.IFF_VNET_HDR != 0 { - err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads) + // tunTCPOffloads were added in Linux v2.6. We require their support + // if IFF_VNET_HDR is set. + err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads) if err != nil { return } tun.vnetHdr = true tun.batchSize = conn.IdealBatchSize + // tunUDPOffloads were added in Linux v6.2. We do not return an + // error if they are unsupported at runtime. + tun.udpGSO = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) == nil } else { tun.batchSize = 1 } @@ -575,8 +588,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { events: make(chan Event, 5), errors: make(chan error, 5), statusListenersShutdown: make(chan struct{}), - tcp4GROTable: newTCPGROTable(), - tcp6GROTable: newTCPGROTable(), + tcpGROTable: newTCPGROTable(), + udpGROTable: newUDPGROTable(), toWrite: make([]int, 0, conn.IdealBatchSize), } @@ -628,12 +641,12 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { } file := os.NewFile(uintptr(fd), "/dev/tun") tun := &NativeTun{ - tunFile: file, - events: make(chan Event, 5), - errors: make(chan error, 5), - tcp4GROTable: newTCPGROTable(), - tcp6GROTable: newTCPGROTable(), - toWrite: make([]int, 0, conn.IdealBatchSize), + tunFile: file, + events: make(chan Event, 5), + errors: make(chan error, 5), + tcpGROTable: newTCPGROTable(), + udpGROTable: newUDPGROTable(), + toWrite: make([]int, 0, conn.IdealBatchSize), } name, err := tun.Name() if err != nil { diff --git a/tun/tun_openbsd.go b/tun/tun_openbsd.go index ae571b9..5aa9070 100644 --- a/tun/tun_openbsd.go +++ b/tun/tun_openbsd.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tun diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 0cb4ce1..de65fb4 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tun @@ -127,6 +127,9 @@ func (tun *NativeTun) MTU() (int, error) { // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. func (tun *NativeTun) ForceMTU(mtu int) { + if tun.close.Load() { + return + } update := tun.forcedMTU != mtu tun.forcedMTU = mtu if update { @@ -157,11 +160,10 @@ retry: packet, err := tun.session.ReceivePacket() switch err { case nil: - packetSize := len(packet) - copy(bufs[0][offset:], packet) - sizes[0] = packetSize + n := copy(bufs[0][offset:], packet) + sizes[0] = n tun.session.ReleaseReceivePacket(packet) - tun.rate.update(uint64(packetSize)) + tun.rate.update(uint64(n)) return 1, nil case windows.ERROR_NO_MORE_ITEMS: if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go index d07e860..9c4564f 100644 --- a/tun/tuntest/tuntest.go +++ b/tun/tuntest/tuntest.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tuntest @@ -1,3 +1,3 @@ package main -const Version = "0.0.20230223" +const Version = "0.0.20250522" |