diff options
Diffstat (limited to 'conn')
-rw-r--r-- | conn/bind_std.go | 544 | ||||
-rw-r--r-- | conn/bind_std_test.go | 250 | ||||
-rw-r--r-- | conn/bind_windows.go | 601 | ||||
-rw-r--r-- | conn/bindtest/bindtest.go | 136 | ||||
-rw-r--r-- | conn/boundif_android.go | 10 | ||||
-rw-r--r-- | conn/boundif_windows.go | 59 | ||||
-rw-r--r-- | conn/conn.go | 132 | ||||
-rw-r--r-- | conn/conn_default.go | 176 | ||||
-rw-r--r-- | conn/conn_linux.go | 571 | ||||
-rw-r--r-- | conn/conn_test.go | 24 | ||||
-rw-r--r-- | conn/controlfns.go | 43 | ||||
-rw-r--r-- | conn/controlfns_linux.go | 109 | ||||
-rw-r--r-- | conn/controlfns_unix.go | 35 | ||||
-rw-r--r-- | conn/controlfns_windows.go | 23 | ||||
-rw-r--r-- | conn/default.go | 10 | ||||
-rw-r--r-- | conn/errors_default.go | 12 | ||||
-rw-r--r-- | conn/errors_linux.go | 26 | ||||
-rw-r--r-- | conn/features_default.go | 15 | ||||
-rw-r--r-- | conn/features_linux.go | 29 | ||||
-rw-r--r-- | conn/gso_default.go | 21 | ||||
-rw-r--r-- | conn/gso_linux.go | 65 | ||||
-rw-r--r-- | conn/mark_default.go | 6 | ||||
-rw-r--r-- | conn/mark_unix.go | 14 | ||||
-rw-r--r-- | conn/sticky_default.go | 42 | ||||
-rw-r--r-- | conn/sticky_linux.go | 112 | ||||
-rw-r--r-- | conn/sticky_linux_test.go | 266 | ||||
-rw-r--r-- | conn/winrio/rio_windows.go | 254 |
27 files changed, 2709 insertions, 876 deletions
diff --git a/conn/bind_std.go b/conn/bind_std.go new file mode 100644 index 0000000..f5c8816 --- /dev/null +++ b/conn/bind_std.go @@ -0,0 +1,544 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "runtime" + "strconv" + "sync" + "syscall" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +var ( + _ Bind = (*StdNetBind)(nil) +) + +// StdNetBind implements Bind for all platforms. While Windows has its own Bind +// (see bind_windows.go), it may fall back to StdNetBind. +// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable +// methods for sending and receiving multiple datagrams per-syscall. See the +// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. +type StdNetBind struct { + mu sync.Mutex // protects all fields except as specified + ipv4 *net.UDPConn + ipv6 *net.UDPConn + ipv4PC *ipv4.PacketConn // will be nil on non-Linux + ipv6PC *ipv6.PacketConn // will be nil on non-Linux + ipv4TxOffload bool + ipv4RxOffload bool + ipv6TxOffload bool + ipv6RxOffload bool + + // these two fields are not guarded by mu + udpAddrPool sync.Pool + msgsPool sync.Pool + + blackhole4 bool + blackhole6 bool +} + +func NewStdNetBind() Bind { + return &StdNetBind{ + udpAddrPool: sync.Pool{ + New: func() any { + return &net.UDPAddr{ + IP: make([]byte, 16), + } + }, + }, + + msgsPool: sync.Pool{ + New: func() any { + // ipv6.Message and ipv4.Message are interchangeable as they are + // both aliases for x/net/internal/socket.Message. + msgs := make([]ipv6.Message, IdealBatchSize) + for i := range msgs { + msgs[i].Buffers = make(net.Buffers, 1) + msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize) + } + return &msgs + }, + }, + } +} + +type StdNetEndpoint struct { + // AddrPort is the endpoint destination. + netip.AddrPort + // src is the current sticky source address and interface index, if + // supported. Typically this is a PKTINFO structure from/for control + // messages, see unix.PKTINFO for an example. + src []byte +} + +var ( + _ Bind = (*StdNetBind)(nil) + _ Endpoint = &StdNetEndpoint{} +) + +func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { + e, err := netip.ParseAddrPort(s) + if err != nil { + return nil, err + } + return &StdNetEndpoint{ + AddrPort: e, + }, nil +} + +func (e *StdNetEndpoint) ClearSrc() { + if e.src != nil { + // Truncate src, no need to reallocate. + e.src = e.src[:0] + } +} + +func (e *StdNetEndpoint) DstIP() netip.Addr { + return e.AddrPort.Addr() +} + +// See control_default,linux, etc for implementations of SrcIP and SrcIfidx. + +func (e *StdNetEndpoint) DstToBytes() []byte { + b, _ := e.AddrPort.MarshalBinary() + return b +} + +func (e *StdNetEndpoint) DstToString() string { + return e.AddrPort.String() +} + +func listenNet(network string, port int) (*net.UDPConn, int, error) { + conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) + if err != nil { + return nil, 0, err + } + + // Retrieve port. + laddr := conn.LocalAddr() + uaddr, err := net.ResolveUDPAddr( + laddr.Network(), + laddr.String(), + ) + if err != nil { + return nil, 0, err + } + return conn.(*net.UDPConn), uaddr.Port, nil +} + +func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { + s.mu.Lock() + defer s.mu.Unlock() + + var err error + var tries int + + if s.ipv4 != nil || s.ipv6 != nil { + return nil, 0, ErrBindAlreadyOpen + } + + // Attempt to open ipv4 and ipv6 listeners on the same port. + // If uport is 0, we can retry on failure. +again: + port := int(uport) + var v4conn, v6conn *net.UDPConn + var v4pc *ipv4.PacketConn + var v6pc *ipv6.PacketConn + + v4conn, port, err = listenNet("udp4", port) + if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + return nil, 0, err + } + + // Listen on the same port as we're using for ipv4. + v6conn, port, err = listenNet("udp6", port) + if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { + v4conn.Close() + tries++ + goto again + } + if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + v4conn.Close() + return nil, 0, err + } + var fns []ReceiveFunc + if v4conn != nil { + s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn) + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + v4pc = ipv4.NewPacketConn(v4conn) + s.ipv4PC = v4pc + } + fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload)) + s.ipv4 = v4conn + } + if v6conn != nil { + s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn) + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + v6pc = ipv6.NewPacketConn(v6conn) + s.ipv6PC = v6pc + } + fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload)) + s.ipv6 = v6conn + } + if len(fns) == 0 { + return nil, 0, syscall.EAFNOSUPPORT + } + + return fns, uint16(port), nil +} + +func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) { + for i := range *msgs { + (*msgs)[i].OOB = (*msgs)[i].OOB[:0] + (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB} + } + s.msgsPool.Put(msgs) +} + +func (s *StdNetBind) getMessages() *[]ipv6.Message { + return s.msgsPool.Get().(*[]ipv6.Message) +} + +var ( + // If compilation fails here these are no longer the same underlying type. + _ ipv6.Message = ipv4.Message{} +) + +type batchReader interface { + ReadBatch([]ipv6.Message, int) (int, error) +} + +type batchWriter interface { + WriteBatch([]ipv6.Message, int) (int, error) +} + +func (s *StdNetBind) receiveIP( + br batchReader, + conn *net.UDPConn, + rxOffload bool, + bufs [][]byte, + sizes []int, + eps []Endpoint, +) (n int, err error) { + msgs := s.getMessages() + for i := range bufs { + (*msgs)[i].Buffers[0] = bufs[i] + (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] + } + defer s.putMessages(msgs) + var numMsgs int + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + if rxOffload { + readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams) + numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) + if err != nil { + return 0, err + } + numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize) + if err != nil { + return 0, err + } + } else { + numMsgs, err = br.ReadBatch(*msgs, 0) + if err != nil { + return 0, err + } + } + } else { + msg := &(*msgs)[0] + msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + if err != nil { + return 0, err + } + numMsgs = 1 + } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + if sizes[i] == 0 { + continue + } + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation + getSrcFromControl(msg.OOB[:msg.NN], ep) + eps[i] = ep + } + return numMsgs, nil +} + +func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + } +} + +func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + } +} + +// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and +// rename the IdealBatchSize constant to BatchSize. +func (s *StdNetBind) BatchSize() int { + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + return IdealBatchSize + } + return 1 +} + +func (s *StdNetBind) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + var err1, err2 error + if s.ipv4 != nil { + err1 = s.ipv4.Close() + s.ipv4 = nil + s.ipv4PC = nil + } + if s.ipv6 != nil { + err2 = s.ipv6.Close() + s.ipv6 = nil + s.ipv6PC = nil + } + s.blackhole4 = false + s.blackhole6 = false + s.ipv4TxOffload = false + s.ipv4RxOffload = false + s.ipv6TxOffload = false + s.ipv6RxOffload = false + if err1 != nil { + return err1 + } + return err2 +} + +type ErrUDPGSODisabled struct { + onLaddr string + RetryErr error +} + +func (e ErrUDPGSODisabled) Error() string { + return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr) +} + +func (e ErrUDPGSODisabled) Unwrap() error { + return e.RetryErr +} + +func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { + s.mu.Lock() + blackhole := s.blackhole4 + conn := s.ipv4 + offload := s.ipv4TxOffload + br := batchWriter(s.ipv4PC) + is6 := false + if endpoint.DstIP().Is6() { + blackhole = s.blackhole6 + conn = s.ipv6 + br = s.ipv6PC + is6 = true + offload = s.ipv6TxOffload + } + s.mu.Unlock() + + if blackhole { + return nil + } + if conn == nil { + return syscall.EAFNOSUPPORT + } + + msgs := s.getMessages() + defer s.putMessages(msgs) + ua := s.udpAddrPool.Get().(*net.UDPAddr) + defer s.udpAddrPool.Put(ua) + if is6 { + as16 := endpoint.DstIP().As16() + copy(ua.IP, as16[:]) + ua.IP = ua.IP[:16] + } else { + as4 := endpoint.DstIP().As4() + copy(ua.IP, as4[:]) + ua.IP = ua.IP[:4] + } + ua.Port = int(endpoint.(*StdNetEndpoint).Port()) + var ( + retried bool + err error + ) +retry: + if offload { + n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize) + err = s.send(conn, br, (*msgs)[:n]) + if err != nil && offload && errShouldDisableUDPGSO(err) { + offload = false + s.mu.Lock() + if is6 { + s.ipv6TxOffload = false + } else { + s.ipv4TxOffload = false + } + s.mu.Unlock() + retried = true + goto retry + } + } else { + for i := range bufs { + (*msgs)[i].Addr = ua + (*msgs)[i].Buffers[0] = bufs[i] + setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint)) + } + err = s.send(conn, br, (*msgs)[:len(bufs)]) + } + if retried { + return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err} + } + return err +} + +func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error { + var ( + n int + err error + start int + ) + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + for { + n, err = pc.WriteBatch(msgs[start:], 0) + if err != nil || n == len(msgs[start:]) { + break + } + start += n + } + } else { + for _, msg := range msgs { + _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr)) + if err != nil { + break + } + } + } + return err +} + +const ( + // Exceeding these values results in EMSGSIZE. They account for layer3 and + // layer4 headers. IPv6 does not need to account for itself as the payload + // length field is self excluding. + maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 + maxIPv6PayloadLen = 1<<16 - 1 - 8 + + // This is a hard limit imposed by the kernel. + udpSegmentMaxDatagrams = 64 +) + +type setGSOFunc func(control *[]byte, gsoSize uint16) + +func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int { + var ( + base = -1 // index of msg we are currently coalescing into + gsoSize int // segmentation size of msgs[base] + dgramCnt int // number of dgrams coalesced into msgs[base] + endBatch bool // tracking flag to start a new batch on next iteration of bufs + ) + maxPayloadLen := maxIPv4PayloadLen + if ep.DstIP().Is6() { + maxPayloadLen = maxIPv6PayloadLen + } + for i, buf := range bufs { + if i > 0 { + msgLen := len(buf) + baseLenBefore := len(msgs[base].Buffers[0]) + freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore + if msgLen+baseLenBefore <= maxPayloadLen && + msgLen <= gsoSize && + msgLen <= freeBaseCap && + dgramCnt < udpSegmentMaxDatagrams && + !endBatch { + msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...) + if i == len(bufs)-1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) + } + dgramCnt++ + if msgLen < gsoSize { + // A smaller than gsoSize packet on the tail is legal, but + // it must end the batch. + endBatch = true + } + continue + } + } + if dgramCnt > 1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) + } + // Reset prior to incrementing base since we are preparing to start a + // new potential batch. + endBatch = false + base++ + gsoSize = len(buf) + setSrcControl(&msgs[base].OOB, ep) + msgs[base].Buffers[0] = buf + msgs[base].Addr = addr + dgramCnt = 1 + } + return base + 1 +} + +type getGSOFunc func(control []byte) (int, error) + +func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) { + for i := firstMsgAt; i < len(msgs); i++ { + msg := &msgs[i] + if msg.N == 0 { + return n, err + } + var ( + gsoSize int + start int + end = msg.N + numToSplit = 1 + ) + gsoSize, err = getGSO(msg.OOB[:msg.NN]) + if err != nil { + return n, err + } + if gsoSize > 0 { + numToSplit = (msg.N + gsoSize - 1) / gsoSize + end = gsoSize + } + for j := 0; j < numToSplit; j++ { + if n > i { + return n, errors.New("splitting coalesced packet resulted in overflow") + } + copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) + msgs[n].N = copied + msgs[n].Addr = msg.Addr + start = end + end += gsoSize + if end > msg.N { + end = msg.N + } + n++ + } + if i != n-1 { + // It is legal for bytes to move within msg.Buffers[0] as a result + // of splitting, so we only zero the source msg len when it is not + // the destination of the last split operation above. + msg.N = 0 + } + } + return n, nil +} diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go new file mode 100644 index 0000000..34a3c9a --- /dev/null +++ b/conn/bind_std_test.go @@ -0,0 +1,250 @@ +package conn + +import ( + "encoding/binary" + "net" + "testing" + + "golang.org/x/net/ipv6" +) + +func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { + bind := NewStdNetBind().(*StdNetBind) + fns, _, err := bind.Open(0) + if err != nil { + t.Fatal(err) + } + bind.Close() + bufs := make([][]byte, 1) + bufs[0] = make([]byte, 1) + sizes := make([]int, 1) + eps := make([]Endpoint, 1) + for _, fn := range fns { + // The ReceiveFuncs must not access conn-related fields on StdNetBind + // unguarded. Close() nils the conn-related fields resulting in a panic + // if they violate the mutex. + fn(bufs, sizes, eps) + } +} + +func mockSetGSOSize(control *[]byte, gsoSize uint16) { + *control = (*control)[:cap(*control)] + binary.LittleEndian.PutUint16(*control, gsoSize) +} + +func Test_coalesceMessages(t *testing.T) { + cases := []struct { + name string + buffs [][]byte + wantLens []int + wantGSO []int + }{ + { + name: "one message no coalesce", + buffs: [][]byte{ + make([]byte, 1, 1), + }, + wantLens: []int{1}, + wantGSO: []int{0}, + }, + { + name: "two messages equal len coalesce", + buffs: [][]byte{ + make([]byte, 1, 2), + make([]byte, 1, 1), + }, + wantLens: []int{2}, + wantGSO: []int{1}, + }, + { + name: "two messages unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + }, + wantLens: []int{3}, + wantGSO: []int{2}, + }, + { + name: "three messages second unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + make([]byte, 2, 2), + }, + wantLens: []int{3, 2}, + wantGSO: []int{2, 0}, + }, + { + name: "three messages limited cap coalesce", + buffs: [][]byte{ + make([]byte, 2, 4), + make([]byte, 2, 2), + make([]byte, 2, 2), + }, + wantLens: []int{4, 2}, + wantGSO: []int{2, 0}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + addr := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1").To4(), + Port: 1, + } + msgs := make([]ipv6.Message, len(tt.buffs)) + for i := range msgs { + msgs[i].Buffers = make([][]byte, 1) + msgs[i].OOB = make([]byte, 0, 2) + } + got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize) + if got != len(tt.wantLens) { + t.Fatalf("got len %d want: %d", got, len(tt.wantLens)) + } + for i := 0; i < got; i++ { + if msgs[i].Addr != addr { + t.Errorf("msgs[%d].Addr != passed addr", i) + } + gotLen := len(msgs[i].Buffers[0]) + if gotLen != tt.wantLens[i] { + t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i]) + } + gotGSO, err := mockGetGSOSize(msgs[i].OOB) + if err != nil { + t.Fatalf("msgs[%d] getGSOSize err: %v", i, err) + } + if gotGSO != tt.wantGSO[i] { + t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i]) + } + } + }) + } +} + +func mockGetGSOSize(control []byte) (int, error) { + if len(control) < 2 { + return 0, nil + } + return int(binary.LittleEndian.Uint16(control)), nil +} + +func Test_splitCoalescedMessages(t *testing.T) { + newMsg := func(n, gso int) ipv6.Message { + msg := ipv6.Message{ + Buffers: [][]byte{make([]byte, 1<<16-1)}, + N: n, + OOB: make([]byte, 2), + } + binary.LittleEndian.PutUint16(msg.OOB, uint16(gso)) + if gso > 0 { + msg.NN = 2 + } + return msg + } + + cases := []struct { + name string + msgs []ipv6.Message + firstMsgAt int + wantNumEval int + wantMsgLens []int + wantErr bool + }{ + { + name: "second last split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(3, 1), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 3, + wantMsgLens: []int{1, 1, 1, 0}, + wantErr: false, + }, + { + name: "second last no split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 1, + wantMsgLens: []int{1, 0, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last no split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(1, 0), + }, + firstMsgAt: 2, + wantNumEval: 2, + wantMsgLens: []int{1, 1, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(3, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(2, 1), + newMsg(2, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last no split last split overflow", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(4, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: true, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize) + if err != nil && !tt.wantErr { + t.Fatalf("err: %v", err) + } + if got != tt.wantNumEval { + t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval) + } + for i, msg := range tt.msgs { + if msg.N != tt.wantMsgLens[i] { + t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i]) + } + } + }) + } +} diff --git a/conn/bind_windows.go b/conn/bind_windows.go new file mode 100644 index 0000000..a3b8460 --- /dev/null +++ b/conn/bind_windows.go @@ -0,0 +1,601 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "encoding/binary" + "io" + "net" + "net/netip" + "strconv" + "sync" + "sync/atomic" + "unsafe" + + "golang.org/x/sys/windows" + + "golang.zx2c4.com/wireguard/conn/winrio" +) + +const ( + packetsPerRing = 1024 + bytesPerPacket = 2048 - 32 + receiveSpins = 15 +) + +type ringPacket struct { + addr WinRingEndpoint + data [bytesPerPacket]byte +} + +type ringBuffer struct { + packets uintptr + head, tail uint32 + id winrio.BufferId + iocp windows.Handle + isFull bool + cq winrio.Cq + mu sync.Mutex + overlapped windows.Overlapped +} + +func (rb *ringBuffer) Push() *ringPacket { + for rb.isFull { + panic("ring is full") + } + ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{})))) + rb.tail += 1 + if rb.tail%packetsPerRing == rb.head%packetsPerRing { + rb.isFull = true + } + return ret +} + +func (rb *ringBuffer) Return(count uint32) { + if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull { + return + } + rb.head += count + rb.isFull = false +} + +type afWinRingBind struct { + sock windows.Handle + rx, tx ringBuffer + rq winrio.Rq + mu sync.Mutex + blackhole bool +} + +// WinRingBind uses Windows registered I/O for fast ring buffered networking. +type WinRingBind struct { + v4, v6 afWinRingBind + mu sync.RWMutex + isOpen atomic.Uint32 // 0, 1, or 2 +} + +func NewDefaultBind() Bind { return NewWinRingBind() } + +func NewWinRingBind() Bind { + if !winrio.Initialize() { + return NewStdNetBind() + } + return new(WinRingBind) +} + +type WinRingEndpoint struct { + family uint16 + data [30]byte +} + +var ( + _ Bind = (*WinRingBind)(nil) + _ Endpoint = (*WinRingEndpoint)(nil) +) + +func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) { + host, port, err := net.SplitHostPort(s) + if err != nil { + return nil, err + } + host16, err := windows.UTF16PtrFromString(host) + if err != nil { + return nil, err + } + port16, err := windows.UTF16PtrFromString(port) + if err != nil { + return nil, err + } + hints := windows.AddrinfoW{ + Flags: windows.AI_NUMERICHOST, + Family: windows.AF_UNSPEC, + Socktype: windows.SOCK_DGRAM, + Protocol: windows.IPPROTO_UDP, + } + var addrinfo *windows.AddrinfoW + err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo) + if err != nil { + return nil, err + } + defer windows.FreeAddrInfoW(addrinfo) + if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) { + return nil, windows.ERROR_INVALID_ADDRESS + } + var dst [unsafe.Sizeof(WinRingEndpoint{})]byte + copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen)) + return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil +} + +func (*WinRingEndpoint) ClearSrc() {} + +func (e *WinRingEndpoint) DstIP() netip.Addr { + switch e.family { + case windows.AF_INET: + return netip.AddrFrom4(*(*[4]byte)(e.data[2:6])) + case windows.AF_INET6: + return netip.AddrFrom16(*(*[16]byte)(e.data[6:22])) + } + return netip.Addr{} +} + +func (e *WinRingEndpoint) SrcIP() netip.Addr { + return netip.Addr{} // not supported +} + +func (e *WinRingEndpoint) DstToBytes() []byte { + switch e.family { + case windows.AF_INET: + b := make([]byte, 0, 6) + b = append(b, e.data[2:6]...) + b = append(b, e.data[1], e.data[0]) + return b + case windows.AF_INET6: + b := make([]byte, 0, 18) + b = append(b, e.data[6:22]...) + b = append(b, e.data[1], e.data[0]) + return b + } + return nil +} + +func (e *WinRingEndpoint) DstToString() string { + switch e.family { + case windows.AF_INET: + return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String() + case windows.AF_INET6: + var zone string + if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 { + zone = strconv.FormatUint(uint64(scope), 10) + } + return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String() + } + return "" +} + +func (e *WinRingEndpoint) SrcToString() string { + return "" +} + +func (ring *ringBuffer) CloseAndZero() { + if ring.cq != 0 { + winrio.CloseCompletionQueue(ring.cq) + ring.cq = 0 + } + if ring.iocp != 0 { + windows.CloseHandle(ring.iocp) + ring.iocp = 0 + } + if ring.id != 0 { + winrio.DeregisterBuffer(ring.id) + ring.id = 0 + } + if ring.packets != 0 { + windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE) + ring.packets = 0 + } + ring.head = 0 + ring.tail = 0 + ring.isFull = false +} + +func (bind *afWinRingBind) CloseAndZero() { + bind.rx.CloseAndZero() + bind.tx.CloseAndZero() + if bind.sock != 0 { + windows.CloseHandle(bind.sock) + bind.sock = 0 + } + bind.blackhole = false +} + +func (bind *WinRingBind) closeAndZero() { + bind.isOpen.Store(0) + bind.v4.CloseAndZero() + bind.v6.CloseAndZero() +} + +func (ring *ringBuffer) Open() error { + var err error + packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing + ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) + if err != nil { + return err + } + ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen)) + if err != nil { + return err + } + ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) + if err != nil { + return err + } + ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped) + if err != nil { + return err + } + return nil +} + +func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) { + var err error + bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP) + if err != nil { + return nil, err + } + err = bind.rx.Open() + if err != nil { + return nil, err + } + err = bind.tx.Open() + if err != nil { + return nil, err + } + bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0) + if err != nil { + return nil, err + } + err = windows.Bind(bind.sock, sa) + if err != nil { + return nil, err + } + sa, err = windows.Getsockname(bind.sock) + if err != nil { + return nil, err + } + return sa, nil +} + +func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) { + bind.mu.Lock() + defer bind.mu.Unlock() + defer func() { + if err != nil { + bind.closeAndZero() + } + }() + if bind.isOpen.Load() != 0 { + return nil, 0, ErrBindAlreadyOpen + } + var sa windows.Sockaddr + sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)}) + if err != nil { + return nil, 0, err + } + sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port}) + if err != nil { + return nil, 0, err + } + selectedPort = uint16(sa.(*windows.SockaddrInet6).Port) + for i := 0; i < packetsPerRing; i++ { + err = bind.v4.InsertReceiveRequest() + if err != nil { + return nil, 0, err + } + err = bind.v6.InsertReceiveRequest() + if err != nil { + return nil, 0, err + } + } + bind.isOpen.Store(1) + return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err +} + +func (bind *WinRingBind) Close() error { + bind.mu.RLock() + if bind.isOpen.Load() != 1 { + bind.mu.RUnlock() + return nil + } + bind.isOpen.Store(2) + windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil) + windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil) + windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil) + windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil) + bind.mu.RUnlock() + bind.mu.Lock() + defer bind.mu.Unlock() + bind.closeAndZero() + return nil +} + +// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and +// rename the IdealBatchSize constant to BatchSize. +func (bind *WinRingBind) BatchSize() int { + // TODO: implement batching in and out of the ring + return 1 +} + +func (bind *WinRingBind) SetMark(mark uint32) error { + return nil +} + +func (bind *afWinRingBind) InsertReceiveRequest() error { + packet := bind.rx.Push() + dataBuffer := &winrio.Buffer{ + Id: bind.rx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets), + Length: uint32(len(packet.data)), + } + addressBuffer := &winrio.Buffer{ + Id: bind.rx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets), + Length: uint32(unsafe.Sizeof(packet.addr)), + } + bind.mu.Lock() + defer bind.mu.Unlock() + return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet))) +} + +//go:linkname procyield runtime.procyield +func procyield(cycles uint32) + +func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) { + if isOpen.Load() != 1 { + return 0, nil, net.ErrClosed + } + bind.rx.mu.Lock() + defer bind.rx.mu.Unlock() + + var err error + var count uint32 + var results [1]winrio.Result +retry: + count = 0 + for tries := 0; count == 0 && tries < receiveSpins; tries++ { + if tries > 0 { + if isOpen.Load() != 1 { + return 0, nil, net.ErrClosed + } + procyield(1) + } + count = winrio.DequeueCompletion(bind.rx.cq, results[:]) + } + if count == 0 { + err = winrio.Notify(bind.rx.cq) + if err != nil { + return 0, nil, err + } + var bytes uint32 + var key uintptr + var overlapped *windows.Overlapped + err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE) + if err != nil { + return 0, nil, err + } + if isOpen.Load() != 1 { + return 0, nil, net.ErrClosed + } + count = winrio.DequeueCompletion(bind.rx.cq, results[:]) + if count == 0 { + return 0, nil, io.ErrNoProgress + } + } + bind.rx.Return(1) + err = bind.InsertReceiveRequest() + if err != nil { + return 0, nil, err + } + // We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us + // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to + // attacker bandwidth, just like the rest of the receive path. + if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE { + if isOpen.Load() != 1 { + return 0, nil, net.ErrClosed + } + goto retry + } + if results[0].Status != 0 { + return 0, nil, windows.Errno(results[0].Status) + } + packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext))) + ep := packet.addr + n := copy(buf, packet.data[:results[0].BytesTransferred]) + return n, &ep, nil +} + +func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { + bind.mu.RLock() + defer bind.mu.RUnlock() + n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen) + sizes[0] = n + eps[0] = ep + return 1, err +} + +func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { + bind.mu.RLock() + defer bind.mu.RUnlock() + n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen) + sizes[0] = n + eps[0] = ep + return 1, err +} + +func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error { + if isOpen.Load() != 1 { + return net.ErrClosed + } + if len(buf) > bytesPerPacket { + return io.ErrShortBuffer + } + bind.tx.mu.Lock() + defer bind.tx.mu.Unlock() + var results [packetsPerRing]winrio.Result + count := winrio.DequeueCompletion(bind.tx.cq, results[:]) + if count == 0 && bind.tx.isFull { + err := winrio.Notify(bind.tx.cq) + if err != nil { + return err + } + var bytes uint32 + var key uintptr + var overlapped *windows.Overlapped + err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE) + if err != nil { + return err + } + if isOpen.Load() != 1 { + return net.ErrClosed + } + count = winrio.DequeueCompletion(bind.tx.cq, results[:]) + if count == 0 { + return io.ErrNoProgress + } + } + if count > 0 { + bind.tx.Return(count) + } + packet := bind.tx.Push() + packet.addr = *nend + copy(packet.data[:], buf) + dataBuffer := &winrio.Buffer{ + Id: bind.tx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets), + Length: uint32(len(buf)), + } + addressBuffer := &winrio.Buffer{ + Id: bind.tx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets), + Length: uint32(unsafe.Sizeof(packet.addr)), + } + bind.mu.Lock() + defer bind.mu.Unlock() + return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) +} + +func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error { + nend, ok := endpoint.(*WinRingEndpoint) + if !ok { + return ErrWrongEndpointType + } + bind.mu.RLock() + defer bind.mu.RUnlock() + for _, buf := range bufs { + switch nend.family { + case windows.AF_INET: + if bind.v4.blackhole { + continue + } + if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil { + return err + } + case windows.AF_INET6: + if bind.v6.blackhole { + continue + } + if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil { + return err + } + } + } + return nil +} + +func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { + s.mu.Lock() + defer s.mu.Unlock() + sysconn, err := s.ipv4.SyscallConn() + if err != nil { + return err + } + err2 := sysconn.Control(func(fd uintptr) { + err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex) + }) + if err2 != nil { + return err2 + } + if err != nil { + return err + } + s.blackhole4 = blackhole + return nil +} + +func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { + s.mu.Lock() + defer s.mu.Unlock() + sysconn, err := s.ipv6.SyscallConn() + if err != nil { + return err + } + err2 := sysconn.Control(func(fd uintptr) { + err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex) + }) + if err2 != nil { + return err2 + } + if err != nil { + return err + } + s.blackhole6 = blackhole + return nil +} + +func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { + bind.mu.RLock() + defer bind.mu.RUnlock() + if bind.isOpen.Load() != 1 { + return net.ErrClosed + } + err := bindSocketToInterface4(bind.v4.sock, interfaceIndex) + if err != nil { + return err + } + bind.v4.blackhole = blackhole + return nil +} + +func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { + bind.mu.RLock() + defer bind.mu.RUnlock() + if bind.isOpen.Load() != 1 { + return net.ErrClosed + } + err := bindSocketToInterface6(bind.v6.sock, interfaceIndex) + if err != nil { + return err + } + bind.v6.blackhole = blackhole + return nil +} + +func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error { + const IP_UNICAST_IF = 31 + /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ + var bytes [4]byte + binary.BigEndian.PutUint32(bytes[:], interfaceIndex) + interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0])) + err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex)) + if err != nil { + return err + } + return nil +} + +func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error { + const IPV6_UNICAST_IF = 31 + return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex)) +} diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go new file mode 100644 index 0000000..46e20e6 --- /dev/null +++ b/conn/bindtest/bindtest.go @@ -0,0 +1,136 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package bindtest + +import ( + "fmt" + "math/rand" + "net" + "net/netip" + "os" + + "golang.zx2c4.com/wireguard/conn" +) + +type ChannelBind struct { + rx4, tx4 *chan []byte + rx6, tx6 *chan []byte + closeSignal chan bool + source4, source6 ChannelEndpoint + target4, target6 ChannelEndpoint +} + +type ChannelEndpoint uint16 + +var ( + _ conn.Bind = (*ChannelBind)(nil) + _ conn.Endpoint = (*ChannelEndpoint)(nil) +) + +func NewChannelBinds() [2]conn.Bind { + arx4 := make(chan []byte, 8192) + brx4 := make(chan []byte, 8192) + arx6 := make(chan []byte, 8192) + brx6 := make(chan []byte, 8192) + var binds [2]ChannelBind + binds[0].rx4 = &arx4 + binds[0].tx4 = &brx4 + binds[1].rx4 = &brx4 + binds[1].tx4 = &arx4 + binds[0].rx6 = &arx6 + binds[0].tx6 = &brx6 + binds[1].rx6 = &brx6 + binds[1].tx6 = &arx6 + binds[0].target4 = ChannelEndpoint(1) + binds[1].target4 = ChannelEndpoint(2) + binds[0].target6 = ChannelEndpoint(3) + binds[1].target6 = ChannelEndpoint(4) + binds[0].source4 = binds[1].target4 + binds[0].source6 = binds[1].target6 + binds[1].source4 = binds[0].target4 + binds[1].source6 = binds[0].target6 + return [2]conn.Bind{&binds[0], &binds[1]} +} + +func (c ChannelEndpoint) ClearSrc() {} + +func (c ChannelEndpoint) SrcToString() string { return "" } + +func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) } + +func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} } + +func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } + +func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} } + +func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { + c.closeSignal = make(chan bool) + fns = append(fns, c.makeReceiveFunc(*c.rx4)) + fns = append(fns, c.makeReceiveFunc(*c.rx6)) + if rand.Uint32()&1 == 0 { + return fns, uint16(c.source4), nil + } else { + return fns, uint16(c.source6), nil + } +} + +func (c *ChannelBind) Close() error { + if c.closeSignal != nil { + select { + case <-c.closeSignal: + default: + close(c.closeSignal) + } + } + return nil +} + +func (c *ChannelBind) BatchSize() int { return 1 } + +func (c *ChannelBind) SetMark(mark uint32) error { return nil } + +func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { + select { + case <-c.closeSignal: + return 0, net.ErrClosed + case rx := <-ch: + copied := copy(bufs[0], rx) + sizes[0] = copied + eps[0] = c.target6 + return 1, nil + } + } +} + +func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error { + for _, b := range bufs { + select { + case <-c.closeSignal: + return net.ErrClosed + default: + bc := make([]byte, len(b)) + copy(bc, b) + if ep.(ChannelEndpoint) == c.target4 { + *c.tx4 <- bc + } else if ep.(ChannelEndpoint) == c.target6 { + *c.tx6 <- bc + } else { + return os.ErrInvalid + } + } + } + return nil +} + +func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) { + addr, err := netip.ParseAddrPort(s) + if err != nil { + return nil, err + } + return ChannelEndpoint(addr.Port()), nil +} diff --git a/conn/boundif_android.go b/conn/boundif_android.go index 3e10607..be69b2a 100644 --- a/conn/boundif_android.go +++ b/conn/boundif_android.go @@ -1,12 +1,12 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn -func (bind *nativeBind) PeekLookAtSocketFd4() (fd int, err error) { - sysconn, err := bind.ipv4.SyscallConn() +func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) { + sysconn, err := s.ipv4.SyscallConn() if err != nil { return -1, err } @@ -19,8 +19,8 @@ func (bind *nativeBind) PeekLookAtSocketFd4() (fd int, err error) { return } -func (bind *nativeBind) PeekLookAtSocketFd6() (fd int, err error) { - sysconn, err := bind.ipv6.SyscallConn() +func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) { + sysconn, err := s.ipv6.SyscallConn() if err != nil { return -1, err } diff --git a/conn/boundif_windows.go b/conn/boundif_windows.go deleted file mode 100644 index 53a8f09..0000000 --- a/conn/boundif_windows.go +++ /dev/null @@ -1,59 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. - */ - -package conn - -import ( - "encoding/binary" - "unsafe" - - "golang.org/x/sys/windows" -) - -const ( - sockoptIP_UNICAST_IF = 31 - sockoptIPV6_UNICAST_IF = 31 -) - -func (bind *nativeBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { - /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ - bytes := make([]byte, 4) - binary.BigEndian.PutUint32(bytes, interfaceIndex) - interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0])) - - sysconn, err := bind.ipv4.SyscallConn() - if err != nil { - return err - } - err2 := sysconn.Control(func(fd uintptr) { - err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, sockoptIP_UNICAST_IF, int(interfaceIndex)) - }) - if err2 != nil { - return err2 - } - if err != nil { - return err - } - bind.blackhole4 = blackhole - return nil -} - -func (bind *nativeBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { - sysconn, err := bind.ipv6.SyscallConn() - if err != nil { - return err - } - err2 := sysconn.Control(func(fd uintptr) { - err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, sockoptIPV6_UNICAST_IF, int(interfaceIndex)) - }) - if err2 != nil { - return err2 - } - if err != nil { - return err - } - bind.blackhole6 = blackhole - return nil -} diff --git a/conn/conn.go b/conn/conn.go index ad91d2d..1304657 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ // Package conn implements WireGuard's network connections. @@ -8,49 +8,53 @@ package conn import ( "errors" - "net" + "fmt" + "net/netip" + "reflect" + "runtime" "strings" ) +const ( + IdealBatchSize = 128 // maximum number of packets handled per read and write +) + +// A ReceiveFunc receives at least one packet from the network and writes them +// into packets. On a successful read it returns the number of elements of +// sizes, packets, and endpoints that should be evaluated. Some elements of +// sizes may be zero, and callers should ignore them. Callers must pass a sizes +// and eps slice with a length greater than or equal to the length of packets. +// These lengths must not exceed the length of the associated Bind.BatchSize(). +type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error) + // A Bind listens on a port for both IPv6 and IPv4 UDP traffic. // // A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface, // depending on the platform-specific implementation. type Bind interface { - // LastMark reports the last mark set for this Bind. - LastMark() uint32 + // Open puts the Bind into a listening state on a given port and reports the actual + // port that it bound to. Passing zero results in a random selection. + // fns is the set of functions that will be called to receive packets. + Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error) + + // Close closes the Bind listener. + // All fns returned by Open must return net.ErrClosed after a call to Close. + Close() error // SetMark sets the mark for each packet sent through this Bind. // This mark is passed to the kernel as the socket option SO_MARK. SetMark(mark uint32) error - // ReceiveIPv6 reads an IPv6 UDP packet into b. - // - // It reports the number of bytes read, n, - // the packet source address ep, - // and any error. - ReceiveIPv6(buff []byte) (n int, ep Endpoint, err error) - - // ReceiveIPv4 reads an IPv4 UDP packet into b. - // - // It reports the number of bytes read, n, - // the packet source address ep, - // and any error. - ReceiveIPv4(b []byte) (n int, ep Endpoint, err error) - - // Send writes a packet b to address ep. - Send(b []byte, ep Endpoint) error + // Send writes one or more packets in bufs to address ep. The length of + // bufs must not exceed BatchSize(). + Send(bufs [][]byte, ep Endpoint) error - // Close closes the Bind connection. - Close() error -} + // ParseEndpoint creates a new endpoint from a string. + ParseEndpoint(s string) (Endpoint, error) -// CreateBind creates a Bind bound to a port. -// -// The value actualPort reports the actual port number the Bind -// object gets bound to. -func CreateBind(port uint16) (b Bind, actualPort uint16, err error) { - return createBind(port) + // BatchSize is the number of buffers expected to be passed to + // the ReceiveFuncs, and the maximum expected to be passed to SendBatch. + BatchSize() int } // BindSocketToInterface is implemented by Bind objects that support being @@ -69,43 +73,61 @@ type PeekLookAtSocketFd interface { // An Endpoint maintains the source/destination caching for a peer. // -// dst : the remote address of a peer ("endpoint" in uapi terminology) -// src : the local address from which datagrams originate going to the peer +// dst: the remote address of a peer ("endpoint" in uapi terminology) +// src: the local address from which datagrams originate going to the peer type Endpoint interface { ClearSrc() // clears the source address SrcToString() string // returns the local source address (ip:port) DstToString() string // returns the destination address (ip:port) DstToBytes() []byte // used for mac2 cookie calculations - DstIP() net.IP - SrcIP() net.IP + DstIP() netip.Addr + SrcIP() netip.Addr } -func parseEndpoint(s string) (*net.UDPAddr, error) { - // ensure that the host is an IP address +var ( + ErrBindAlreadyOpen = errors.New("bind is already open") + ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type") +) - host, _, err := net.SplitHostPort(s) - if err != nil { - return nil, err +func (fn ReceiveFunc) PrettyName() string { + name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() + // 0. cheese/taco.beansIPv6.func12.func21218-fm + name = strings.TrimSuffix(name, "-fm") + // 1. cheese/taco.beansIPv6.func12.func21218 + if idx := strings.LastIndexByte(name, '/'); idx != -1 { + name = name[idx+1:] + // 2. taco.beansIPv6.func12.func21218 } - if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 { - // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just - // trying to make sure with a small sanity test that this is a real IP address and - // not something that's likely to incur DNS lookups. - host = host[:i] + for { + var idx int + for idx = len(name) - 1; idx >= 0; idx-- { + if name[idx] < '0' || name[idx] > '9' { + break + } + } + if idx == len(name)-1 { + break + } + const dotFunc = ".func" + if !strings.HasSuffix(name[:idx+1], dotFunc) { + break + } + name = name[:idx+1-len(dotFunc)] + // 3. taco.beansIPv6.func12 + // 4. taco.beansIPv6 } - if ip := net.ParseIP(host); ip == nil { - return nil, errors.New("Failed to parse IP address: " + host) + if idx := strings.LastIndexByte(name, '.'); idx != -1 { + name = name[idx+1:] + // 5. beansIPv6 } - - // parse address and port - - addr, err := net.ResolveUDPAddr("udp", s) - if err != nil { - return nil, err + if name == "" { + return fmt.Sprintf("%p", fn) + } + if strings.HasSuffix(name, "IPv4") { + return "v4" } - ip4 := addr.IP.To4() - if ip4 != nil { - addr.IP = ip4 + if strings.HasSuffix(name, "IPv6") { + return "v6" } - return addr, err + return name } diff --git a/conn/conn_default.go b/conn/conn_default.go deleted file mode 100644 index 8be3c9d..0000000 --- a/conn/conn_default.go +++ /dev/null @@ -1,176 +0,0 @@ -// +build !linux android - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. - */ - -package conn - -import ( - "net" - "os" - "syscall" -) - -/* This code is meant to be a temporary solution - * on platforms for which the sticky socket / source caching behavior - * has not yet been implemented. - * - * See conn_linux.go for an implementation on the linux platform. - */ - -type nativeBind struct { - ipv4 *net.UDPConn - ipv6 *net.UDPConn - blackhole4 bool - blackhole6 bool -} - -type NativeEndpoint net.UDPAddr - -var _ Bind = (*nativeBind)(nil) -var _ Endpoint = (*NativeEndpoint)(nil) - -func CreateEndpoint(s string) (Endpoint, error) { - addr, err := parseEndpoint(s) - return (*NativeEndpoint)(addr), err -} - -func (_ *NativeEndpoint) ClearSrc() {} - -func (e *NativeEndpoint) DstIP() net.IP { - return (*net.UDPAddr)(e).IP -} - -func (e *NativeEndpoint) SrcIP() net.IP { - return nil // not supported -} - -func (e *NativeEndpoint) DstToBytes() []byte { - addr := (*net.UDPAddr)(e) - out := addr.IP.To4() - if out == nil { - out = addr.IP - } - out = append(out, byte(addr.Port&0xff)) - out = append(out, byte((addr.Port>>8)&0xff)) - return out -} - -func (e *NativeEndpoint) DstToString() string { - return (*net.UDPAddr)(e).String() -} - -func (e *NativeEndpoint) SrcToString() string { - return "" -} - -func listenNet(network string, port int) (*net.UDPConn, int, error) { - conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) - if err != nil { - return nil, 0, err - } - - // Retrieve port. - laddr := conn.LocalAddr() - uaddr, err := net.ResolveUDPAddr( - laddr.Network(), - laddr.String(), - ) - if err != nil { - return nil, 0, err - } - return conn, uaddr.Port, nil -} - -func extractErrno(err error) error { - opErr, ok := err.(*net.OpError) - if !ok { - return nil - } - syscallErr, ok := opErr.Err.(*os.SyscallError) - if !ok { - return nil - } - return syscallErr.Err -} - -func createBind(uport uint16) (Bind, uint16, error) { - var err error - var bind nativeBind - - port := int(uport) - - bind.ipv4, port, err = listenNet("udp4", port) - if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT { - return nil, 0, err - } - - bind.ipv6, port, err = listenNet("udp6", port) - if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT { - bind.ipv4.Close() - bind.ipv4 = nil - return nil, 0, err - } - - return &bind, uint16(port), nil -} - -func (bind *nativeBind) Close() error { - var err1, err2 error - if bind.ipv4 != nil { - err1 = bind.ipv4.Close() - } - if bind.ipv6 != nil { - err2 = bind.ipv6.Close() - } - if err1 != nil { - return err1 - } - return err2 -} - -func (bind *nativeBind) LastMark() uint32 { return 0 } - -func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { - if bind.ipv4 == nil { - return 0, nil, syscall.EAFNOSUPPORT - } - n, endpoint, err := bind.ipv4.ReadFromUDP(buff) - if endpoint != nil { - endpoint.IP = endpoint.IP.To4() - } - return n, (*NativeEndpoint)(endpoint), err -} - -func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { - if bind.ipv6 == nil { - return 0, nil, syscall.EAFNOSUPPORT - } - n, endpoint, err := bind.ipv6.ReadFromUDP(buff) - return n, (*NativeEndpoint)(endpoint), err -} - -func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error { - var err error - nend := endpoint.(*NativeEndpoint) - if nend.IP.To4() != nil { - if bind.ipv4 == nil { - return syscall.EAFNOSUPPORT - } - if bind.blackhole4 { - return nil - } - _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend)) - } else { - if bind.ipv6 == nil { - return syscall.EAFNOSUPPORT - } - if bind.blackhole6 { - return nil - } - _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend)) - } - return err -} diff --git a/conn/conn_linux.go b/conn/conn_linux.go deleted file mode 100644 index 08c8949..0000000 --- a/conn/conn_linux.go +++ /dev/null @@ -1,571 +0,0 @@ -// +build !android - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. - */ - -package conn - -import ( - "errors" - "net" - "strconv" - "sync" - "syscall" - "unsafe" - - "golang.org/x/sys/unix" -) - -const ( - FD_ERR = -1 -) - -type IPv4Source struct { - Src [4]byte - Ifindex int32 -} - -type IPv6Source struct { - src [16]byte - //ifindex belongs in dst.ZoneId -} - -type NativeEndpoint struct { - sync.Mutex - dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte - src [unsafe.Sizeof(IPv6Source{})]byte - isV6 bool -} - -func (endpoint *NativeEndpoint) Src4() *IPv4Source { return endpoint.src4() } -func (endpoint *NativeEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() } -func (endpoint *NativeEndpoint) IsV6() bool { return endpoint.isV6 } - -func (endpoint *NativeEndpoint) src4() *IPv4Source { - return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0])) -} - -func (endpoint *NativeEndpoint) src6() *IPv6Source { - return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0])) -} - -func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 { - return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0])) -} - -func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 { - return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0])) -} - -type nativeBind struct { - sock4 int - sock6 int - lastMark uint32 -} - -var _ Endpoint = (*NativeEndpoint)(nil) -var _ Bind = (*nativeBind)(nil) - -func CreateEndpoint(s string) (Endpoint, error) { - var end NativeEndpoint - addr, err := parseEndpoint(s) - if err != nil { - return nil, err - } - - ipv4 := addr.IP.To4() - if ipv4 != nil { - dst := end.dst4() - end.isV6 = false - dst.Port = addr.Port - copy(dst.Addr[:], ipv4) - end.ClearSrc() - return &end, nil - } - - ipv6 := addr.IP.To16() - if ipv6 != nil { - zone, err := zoneToUint32(addr.Zone) - if err != nil { - return nil, err - } - dst := end.dst6() - end.isV6 = true - dst.Port = addr.Port - dst.ZoneId = zone - copy(dst.Addr[:], ipv6[:]) - end.ClearSrc() - return &end, nil - } - - return nil, errors.New("Invalid IP address") -} - -func createBind(port uint16) (Bind, uint16, error) { - var err error - var bind nativeBind - var newPort uint16 - - // Attempt ipv6 bind, update port if successful. - bind.sock6, newPort, err = create6(port) - if err != nil { - if err != syscall.EAFNOSUPPORT { - return nil, 0, err - } - } else { - port = newPort - } - - // Attempt ipv4 bind, update port if successful. - bind.sock4, newPort, err = create4(port) - if err != nil { - if err != syscall.EAFNOSUPPORT { - unix.Close(bind.sock6) - return nil, 0, err - } - } else { - port = newPort - } - - if bind.sock4 == FD_ERR && bind.sock6 == FD_ERR { - return nil, 0, errors.New("ipv4 and ipv6 not supported") - } - - return &bind, port, nil -} - -func (bind *nativeBind) LastMark() uint32 { - return bind.lastMark -} - -func (bind *nativeBind) SetMark(value uint32) error { - if bind.sock6 != -1 { - err := unix.SetsockoptInt( - bind.sock6, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) - - if err != nil { - return err - } - } - - if bind.sock4 != -1 { - err := unix.SetsockoptInt( - bind.sock4, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) - - if err != nil { - return err - } - } - - bind.lastMark = value - return nil -} - -func closeUnblock(fd int) error { - // shutdown to unblock readers and writers - unix.Shutdown(fd, unix.SHUT_RDWR) - return unix.Close(fd) -} - -func (bind *nativeBind) Close() error { - var err1, err2 error - if bind.sock6 != -1 { - err1 = closeUnblock(bind.sock6) - } - if bind.sock4 != -1 { - err2 = closeUnblock(bind.sock4) - } - - if err1 != nil { - return err1 - } - return err2 -} - -func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { - var end NativeEndpoint - if bind.sock6 == -1 { - return 0, nil, syscall.EAFNOSUPPORT - } - n, err := receive6( - bind.sock6, - buff, - &end, - ) - return n, &end, err -} - -func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { - var end NativeEndpoint - if bind.sock4 == -1 { - return 0, nil, syscall.EAFNOSUPPORT - } - n, err := receive4( - bind.sock4, - buff, - &end, - ) - return n, &end, err -} - -func (bind *nativeBind) Send(buff []byte, end Endpoint) error { - nend := end.(*NativeEndpoint) - if !nend.isV6 { - if bind.sock4 == -1 { - return syscall.EAFNOSUPPORT - } - return send4(bind.sock4, nend, buff) - } else { - if bind.sock6 == -1 { - return syscall.EAFNOSUPPORT - } - return send6(bind.sock6, nend, buff) - } -} - -func (end *NativeEndpoint) SrcIP() net.IP { - if !end.isV6 { - return net.IPv4( - end.src4().Src[0], - end.src4().Src[1], - end.src4().Src[2], - end.src4().Src[3], - ) - } else { - return end.src6().src[:] - } -} - -func (end *NativeEndpoint) DstIP() net.IP { - if !end.isV6 { - return net.IPv4( - end.dst4().Addr[0], - end.dst4().Addr[1], - end.dst4().Addr[2], - end.dst4().Addr[3], - ) - } else { - return end.dst6().Addr[:] - } -} - -func (end *NativeEndpoint) DstToBytes() []byte { - if !end.isV6 { - return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:] - } else { - return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:] - } -} - -func (end *NativeEndpoint) SrcToString() string { - return end.SrcIP().String() -} - -func (end *NativeEndpoint) DstToString() string { - var udpAddr net.UDPAddr - udpAddr.IP = end.DstIP() - if !end.isV6 { - udpAddr.Port = end.dst4().Port - } else { - udpAddr.Port = end.dst6().Port - } - return udpAddr.String() -} - -func (end *NativeEndpoint) ClearDst() { - for i := range end.dst { - end.dst[i] = 0 - } -} - -func (end *NativeEndpoint) ClearSrc() { - for i := range end.src { - end.src[i] = 0 - } -} - -func zoneToUint32(zone string) (uint32, error) { - if zone == "" { - return 0, nil - } - if intr, err := net.InterfaceByName(zone); err == nil { - return uint32(intr.Index), nil - } - n, err := strconv.ParseUint(zone, 10, 32) - return uint32(n), err -} - -func create4(port uint16) (int, uint16, error) { - - // create socket - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return FD_ERR, 0, err - } - - addr := unix.SockaddrInet4{ - Port: int(port), - } - - // set sockopts and bind - - if err := func() error { - if err := unix.SetsockoptInt( - fd, - unix.SOL_SOCKET, - unix.SO_REUSEADDR, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IP, - unix.IP_PKTINFO, - 1, - ); err != nil { - return err - } - - return unix.Bind(fd, &addr) - }(); err != nil { - unix.Close(fd) - return FD_ERR, 0, err - } - - sa, err := unix.Getsockname(fd) - if err == nil { - addr.Port = sa.(*unix.SockaddrInet4).Port - } - - return fd, uint16(addr.Port), err -} - -func create6(port uint16) (int, uint16, error) { - - // create socket - - fd, err := unix.Socket( - unix.AF_INET6, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return FD_ERR, 0, err - } - - // set sockopts and bind - - addr := unix.SockaddrInet6{ - Port: int(port), - } - - if err := func() error { - - if err := unix.SetsockoptInt( - fd, - unix.SOL_SOCKET, - unix.SO_REUSEADDR, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IPV6, - unix.IPV6_RECVPKTINFO, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IPV6, - unix.IPV6_V6ONLY, - 1, - ); err != nil { - return err - } - - return unix.Bind(fd, &addr) - - }(); err != nil { - unix.Close(fd) - return FD_ERR, 0, err - } - - sa, err := unix.Getsockname(fd) - if err == nil { - addr.Port = sa.(*unix.SockaddrInet6).Port - } - - return fd, uint16(addr.Port), err -} - -func send4(sock int, end *NativeEndpoint, buff []byte) error { - - // construct message header - - cmsg := struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo - }{ - unix.Cmsghdr{ - Level: unix.IPPROTO_IP, - Type: unix.IP_PKTINFO, - Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, - }, - unix.Inet4Pktinfo{ - Spec_dst: end.src4().Src, - Ifindex: end.src4().Ifindex, - }, - } - - end.Lock() - _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) - end.Unlock() - - if err == nil { - return nil - } - - // clear src and retry - - if err == unix.EINVAL { - end.ClearSrc() - cmsg.pktinfo = unix.Inet4Pktinfo{} - end.Lock() - _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) - end.Unlock() - } - - return err -} - -func send6(sock int, end *NativeEndpoint, buff []byte) error { - - // construct message header - - cmsg := struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo - }{ - unix.Cmsghdr{ - Level: unix.IPPROTO_IPV6, - Type: unix.IPV6_PKTINFO, - Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr, - }, - unix.Inet6Pktinfo{ - Addr: end.src6().src, - Ifindex: end.dst6().ZoneId, - }, - } - - if cmsg.pktinfo.Addr == [16]byte{} { - cmsg.pktinfo.Ifindex = 0 - } - - end.Lock() - _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) - end.Unlock() - - if err == nil { - return nil - } - - // clear src and retry - - if err == unix.EINVAL { - end.ClearSrc() - cmsg.pktinfo = unix.Inet6Pktinfo{} - end.Lock() - _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) - end.Unlock() - } - - return err -} - -func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { - - // construct message header - - var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo - } - - size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) - - if err != nil { - return 0, err - } - end.isV6 = false - - if newDst4, ok := newDst.(*unix.SockaddrInet4); ok { - *end.dst4() = *newDst4 - } - - // update source cache - - if cmsg.cmsghdr.Level == unix.IPPROTO_IP && - cmsg.cmsghdr.Type == unix.IP_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { - end.src4().Src = cmsg.pktinfo.Spec_dst - end.src4().Ifindex = cmsg.pktinfo.Ifindex - } - - return size, nil -} - -func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { - - // construct message header - - var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo - } - - size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) - - if err != nil { - return 0, err - } - end.isV6 = true - - if newDst6, ok := newDst.(*unix.SockaddrInet6); ok { - *end.dst6() = *newDst6 - } - - // update source cache - - if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && - cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { - end.src6().src = cmsg.pktinfo.Addr - end.dst6().ZoneId = cmsg.pktinfo.Ifindex - } - - return size, nil -} diff --git a/conn/conn_test.go b/conn/conn_test.go new file mode 100644 index 0000000..618d02b --- /dev/null +++ b/conn/conn_test.go @@ -0,0 +1,24 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "testing" +) + +func TestPrettyName(t *testing.T) { + var ( + recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return } + ) + + const want = "TestPrettyName" + + t.Run("ReceiveFunc.PrettyName", func(t *testing.T) { + if got := recvFunc.PrettyName(); got != want { + t.Errorf("PrettyName() = %v, want %v", got, want) + } + }) +} diff --git a/conn/controlfns.go b/conn/controlfns.go new file mode 100644 index 0000000..27421bd --- /dev/null +++ b/conn/controlfns.go @@ -0,0 +1,43 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "net" + "syscall" +) + +// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is +// the max supported by a default configuration of macOS. Some platforms will +// silently clamp the value to other maximums, such as linux clamping to +// net.core.{r,w}mem_max (see _linux.go for additional implementation that works +// around this limitation) +const socketBufferSize = 7 << 20 + +// controlFn is the callback function signature from net.ListenConfig.Control. +// It is used to apply platform specific configuration to the socket prior to +// bind. +type controlFn func(network, address string, c syscall.RawConn) error + +// controlFns is a list of functions that are called from the listen config +// that can apply socket options. +var controlFns = []controlFn{} + +// listenConfig returns a net.ListenConfig that applies the controlFns to the +// socket prior to bind. This is used to apply socket buffer sizing and packet +// information OOB configuration for sticky sockets. +func listenConfig() *net.ListenConfig { + return &net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + for _, fn := range controlFns { + if err := fn(network, address, c); err != nil { + return err + } + } + return nil + }, + } +} diff --git a/conn/controlfns_linux.go b/conn/controlfns_linux.go new file mode 100644 index 0000000..f0deefa --- /dev/null +++ b/conn/controlfns_linux.go @@ -0,0 +1,109 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "fmt" + "runtime" + "syscall" + + "golang.org/x/sys/unix" +) + +// Taken from go/src/internal/syscall/unix/kernel_version_linux.go +func kernelVersion() (major, minor int) { + var uname unix.Utsname + if err := unix.Uname(&uname); err != nil { + return + } + + var ( + values [2]int + value, vi int + ) + for _, c := range uname.Release { + if '0' <= c && c <= '9' { + value = (value * 10) + int(c-'0') + } else { + // Note that we're assuming N.N.N here. + // If we see anything else, we are likely to mis-parse it. + values[vi] = value + vi++ + if vi >= len(values) { + break + } + value = 0 + } + } + + return values[0], values[1] +} + +func init() { + controlFns = append(controlFns, + + // Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by + // using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to + // fail silently - the result of failure is lower performance on very fast + // links or high latency links. + func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + // Set up to *mem_max + _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize) + _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize) + // Set beyond *mem_max if CAP_NET_ADMIN + _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize) + _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize) + }) + }, + + // Enable receiving of the packet information (IP_PKTINFO for IPv4, + // IPV6_PKTINFO for IPv6) that is used to implement sticky socket support. + func(network, address string, c syscall.RawConn) error { + var err error + switch network { + case "udp4": + if runtime.GOOS != "android" { + c.Control(func(fd uintptr) { + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1) + }) + } + case "udp6": + c.Control(func(fd uintptr) { + if runtime.GOOS != "android" { + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1) + if err != nil { + return + } + } + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1) + }) + default: + err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL) + } + return err + }, + + // Attempt to enable UDP_GRO + func(network, address string, c syscall.RawConn) error { + // Kernels below 5.12 are missing 98184612aca0 ("net: + // udp: Add support for getsockopt(..., ..., UDP_GRO, + // ..., ...);"), which means we can't read this back + // later. We could pipe the return value through to + // the rest of the code, but UDP_GRO is kind of buggy + // anyway, so just gate this here. + major, minor := kernelVersion() + if major < 5 || (major == 5 && minor < 12) { + return nil + } + + c.Control(func(fd uintptr) { + _ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1) + }) + return nil + }, + ) +} diff --git a/conn/controlfns_unix.go b/conn/controlfns_unix.go new file mode 100644 index 0000000..b2e7570 --- /dev/null +++ b/conn/controlfns_unix.go @@ -0,0 +1,35 @@ +//go:build !windows && !linux && !wasm + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +func init() { + controlFns = append(controlFns, + func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize) + _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize) + }) + }, + + func(network, address string, c syscall.RawConn) error { + var err error + if network == "udp6" { + c.Control(func(fd uintptr) { + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1) + }) + } + return err + }, + ) +} diff --git a/conn/controlfns_windows.go b/conn/controlfns_windows.go new file mode 100644 index 0000000..5e38305 --- /dev/null +++ b/conn/controlfns_windows.go @@ -0,0 +1,23 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "syscall" + + "golang.org/x/sys/windows" +) + +func init() { + controlFns = append(controlFns, + func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF, socketBufferSize) + _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF, socketBufferSize) + }) + }, + ) +} diff --git a/conn/default.go b/conn/default.go new file mode 100644 index 0000000..2ce1579 --- /dev/null +++ b/conn/default.go @@ -0,0 +1,10 @@ +//go:build !windows + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +func NewDefaultBind() Bind { return NewStdNetBind() } diff --git a/conn/errors_default.go b/conn/errors_default.go new file mode 100644 index 0000000..3c9b223 --- /dev/null +++ b/conn/errors_default.go @@ -0,0 +1,12 @@ +//go:build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +func errShouldDisableUDPGSO(_ error) bool { + return false +} diff --git a/conn/errors_linux.go b/conn/errors_linux.go new file mode 100644 index 0000000..037d820 --- /dev/null +++ b/conn/errors_linux.go @@ -0,0 +1,26 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "errors" + "os" + + "golang.org/x/sys/unix" +) + +func errShouldDisableUDPGSO(err error) bool { + var serr *os.SyscallError + if errors.As(err, &serr) { + // EIO is returned by udp_send_skb() if the device driver does not have + // tx checksumming enabled, which is a hard requirement of UDP_SEGMENT. + // See: + // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 + // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 + return serr.Err == unix.EIO + } + return false +} diff --git a/conn/features_default.go b/conn/features_default.go new file mode 100644 index 0000000..9fc5088 --- /dev/null +++ b/conn/features_default.go @@ -0,0 +1,15 @@ +//go:build !linux +// +build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import "net" + +func supportsUDPOffload(_ *net.UDPConn) (txOffload, rxOffload bool) { + return +} diff --git a/conn/features_linux.go b/conn/features_linux.go new file mode 100644 index 0000000..6386023 --- /dev/null +++ b/conn/features_linux.go @@ -0,0 +1,29 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "net" + + "golang.org/x/sys/unix" +) + +func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { + rc, err := conn.SyscallConn() + if err != nil { + return + } + err = rc.Control(func(fd uintptr) { + _, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT) + txOffload = errSyscall == nil + opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO) + rxOffload = errSyscall == nil && opt == 1 + }) + if err != nil { + return false, false + } + return txOffload, rxOffload +} diff --git a/conn/gso_default.go b/conn/gso_default.go new file mode 100644 index 0000000..a9a3e80 --- /dev/null +++ b/conn/gso_default.go @@ -0,0 +1,21 @@ +//go:build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. +func setGSOSize(control *[]byte, gsoSize uint16) { +} + +// gsoControlSize returns the recommended buffer size for pooling sticky and UDP +// offloading control data. +const gsoControlSize = 0 diff --git a/conn/gso_linux.go b/conn/gso_linux.go new file mode 100644 index 0000000..4ee31fa --- /dev/null +++ b/conn/gso_linux.go @@ -0,0 +1,65 @@ +//go:build linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "fmt" + "unsafe" + + "golang.org/x/sys/unix" +) + +const ( + sizeOfGSOData = 2 +) + +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + var ( + hdr unix.Cmsghdr + data []byte + rem = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) + if err != nil { + return 0, fmt.Errorf("error parsing socket control message: %w", err) + } + if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData { + var gso uint16 + copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData]) + return int(gso), nil + } + } + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing +// data in control untouched. +func setGSOSize(control *[]byte, gsoSize uint16) { + existingLen := len(*control) + avail := cap(*control) - existingLen + space := unix.CmsgSpace(sizeOfGSOData) + if avail < space { + return + } + *control = (*control)[:cap(*control)] + gsoControl := (*control)[existingLen:] + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0])) + hdr.Level = unix.SOL_UDP + hdr.Type = unix.UDP_SEGMENT + hdr.SetLen(unix.CmsgLen(sizeOfGSOData)) + copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData)) + *control = (*control)[:existingLen+space] +} + +// gsoControlSize returns the recommended buffer size for pooling UDP +// offloading control data. +var gsoControlSize = unix.CmsgSpace(sizeOfGSOData) diff --git a/conn/mark_default.go b/conn/mark_default.go index f57215a..72b266e 100644 --- a/conn/mark_default.go +++ b/conn/mark_default.go @@ -1,12 +1,12 @@ -// +build !linux,!openbsd,!freebsd +//go:build !linux && !openbsd && !freebsd /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn -func (bind *nativeBind) SetMark(mark uint32) error { +func (s *StdNetBind) SetMark(mark uint32) error { return nil } diff --git a/conn/mark_unix.go b/conn/mark_unix.go index 19ec2af..d0580d5 100644 --- a/conn/mark_unix.go +++ b/conn/mark_unix.go @@ -1,8 +1,8 @@ -// +build android openbsd freebsd +//go:build linux || openbsd || freebsd /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package conn @@ -26,13 +26,13 @@ func init() { } } -func (bind *nativeBind) SetMark(mark uint32) error { +func (s *StdNetBind) SetMark(mark uint32) error { var operr error if fwmarkIoctl == 0 { return nil } - if bind.ipv4 != nil { - fd, err := bind.ipv4.SyscallConn() + if s.ipv4 != nil { + fd, err := s.ipv4.SyscallConn() if err != nil { return err } @@ -46,8 +46,8 @@ func (bind *nativeBind) SetMark(mark uint32) error { return err } } - if bind.ipv6 != nil { - fd, err := bind.ipv6.SyscallConn() + if s.ipv6 != nil { + fd, err := s.ipv6.SyscallConn() if err != nil { return err } diff --git a/conn/sticky_default.go b/conn/sticky_default.go new file mode 100644 index 0000000..15b65af --- /dev/null +++ b/conn/sticky_default.go @@ -0,0 +1,42 @@ +//go:build !linux || android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import "net/netip" + +func (e *StdNetEndpoint) SrcIP() netip.Addr { + return netip.Addr{} +} + +func (e *StdNetEndpoint) SrcIfidx() int32 { + return 0 +} + +func (e *StdNetEndpoint) SrcToString() string { + return "" +} + +// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets +// {get,set}srcControl feature set, but use alternatively named flags and need +// ports and require testing. + +// getSrcFromControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func getSrcFromControl(control []byte, ep *StdNetEndpoint) { +} + +// setSrcControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func setSrcControl(control *[]byte, ep *StdNetEndpoint) { +} + +// stickyControlSize returns the recommended buffer size for pooling sticky +// offloading control data. +const stickyControlSize = 0 + +const StdNetSupportsStickySockets = false diff --git a/conn/sticky_linux.go b/conn/sticky_linux.go new file mode 100644 index 0000000..adfedc1 --- /dev/null +++ b/conn/sticky_linux.go @@ -0,0 +1,112 @@ +//go:build linux && !android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "net/netip" + "unsafe" + + "golang.org/x/sys/unix" +) + +func (e *StdNetEndpoint) SrcIP() netip.Addr { + switch len(e.src) { + case unix.CmsgSpace(unix.SizeofInet4Pktinfo): + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return netip.AddrFrom4(info.Spec_dst) + case unix.CmsgSpace(unix.SizeofInet6Pktinfo): + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + // TODO: set zone. in order to do so we need to check if the address is + // link local, and if it is perform a syscall to turn the ifindex into a + // zone string because netip uses string zones. + return netip.AddrFrom16(info.Addr) + } + return netip.Addr{} +} + +func (e *StdNetEndpoint) SrcIfidx() int32 { + switch len(e.src) { + case unix.CmsgSpace(unix.SizeofInet4Pktinfo): + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return info.Ifindex + case unix.CmsgSpace(unix.SizeofInet6Pktinfo): + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return int32(info.Ifindex) + } + return 0 +} + +func (e *StdNetEndpoint) SrcToString() string { + return e.SrcIP().String() +} + +// getSrcFromControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func getSrcFromControl(control []byte, ep *StdNetEndpoint) { + ep.ClearSrc() + + var ( + hdr unix.Cmsghdr + data []byte + rem []byte = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) + if err != nil { + return + } + + if hdr.Level == unix.IPPROTO_IP && + hdr.Type == unix.IP_PKTINFO { + + if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) { + ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + } + ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)] + + hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) + copy(ep.src, hdrBuf) + copy(ep.src[unix.CmsgLen(0):], data) + return + } + + if hdr.Level == unix.IPPROTO_IPV6 && + hdr.Type == unix.IPV6_PKTINFO { + + if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) { + ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) + } + + ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)] + + hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) + copy(ep.src, hdrBuf) + copy(ep.src[unix.CmsgLen(0):], data) + return + } + } +} + +// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address +// and source ifindex found in ep. control's len will be set to 0 in the event +// that ep is a default value. +func setSrcControl(control *[]byte, ep *StdNetEndpoint) { + if cap(*control) < len(ep.src) { + return + } + *control = (*control)[:0] + *control = append(*control, ep.src...) +} + +// stickyControlSize returns the recommended buffer size for pooling sticky +// offloading control data. +var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) + +const StdNetSupportsStickySockets = true diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go new file mode 100644 index 0000000..1b1ee68 --- /dev/null +++ b/conn/sticky_linux_test.go @@ -0,0 +1,266 @@ +//go:build linux && !android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "context" + "net" + "net/netip" + "runtime" + "testing" + "unsafe" + + "golang.org/x/sys/unix" +) + +func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) { + var buf []byte + if addr.Is4() { + buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + hdr := unix.Cmsghdr{ + Level: unix.IPPROTO_IP, + Type: unix.IP_PKTINFO, + } + hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo)) + copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) + + info := unix.Inet4Pktinfo{ + Ifindex: ifidx, + Spec_dst: addr.As4(), + } + copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo)) + } else { + buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) + hdr := unix.Cmsghdr{ + Level: unix.IPPROTO_IPV6, + Type: unix.IPV6_PKTINFO, + } + hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo)) + copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) + + info := unix.Inet6Pktinfo{ + Ifindex: uint32(ifidx), + Addr: addr.As16(), + } + copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo)) + } + + ep.src = buf +} + +func Test_setSrcControl(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + ep := &StdNetEndpoint{ + AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"), + } + setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5) + + control := make([]byte, stickyControlSize) + + setSrcControl(&control, ep) + + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + if hdr.Level != unix.IPPROTO_IP { + t.Errorf("unexpected level: %d", hdr.Level) + } + if hdr.Type != unix.IP_PKTINFO { + t.Errorf("unexpected type: %d", hdr.Type) + } + if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) { + t.Errorf("unexpected length: %d", hdr.Len) + } + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 { + t.Errorf("unexpected address: %v", info.Spec_dst) + } + if info.Ifindex != 5 { + t.Errorf("unexpected ifindex: %d", info.Ifindex) + } + }) + + t.Run("IPv6", func(t *testing.T) { + ep := &StdNetEndpoint{ + AddrPort: netip.MustParseAddrPort("[::1]:1234"), + } + setSrc(ep, netip.MustParseAddr("::1"), 5) + + control := make([]byte, stickyControlSize) + + setSrcControl(&control, ep) + + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + if hdr.Level != unix.IPPROTO_IPV6 { + t.Errorf("unexpected level: %d", hdr.Level) + } + if hdr.Type != unix.IPV6_PKTINFO { + t.Errorf("unexpected type: %d", hdr.Type) + } + if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) { + t.Errorf("unexpected length: %d", hdr.Len) + } + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + if info.Addr != ep.SrcIP().As16() { + t.Errorf("unexpected address: %v", info.Addr) + } + if info.Ifindex != 5 { + t.Errorf("unexpected ifindex: %d", info.Ifindex) + } + }) + + t.Run("ClearOnNoSrc", func(t *testing.T) { + control := make([]byte, stickyControlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = 1 + hdr.Type = 2 + hdr.Len = 3 + + setSrcControl(&control, &StdNetEndpoint{}) + + if len(control) != 0 { + t.Errorf("unexpected control: %v", control) + } + }) +} + +func Test_getSrcFromControl(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + control := make([]byte, stickyControlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = unix.IPPROTO_IP + hdr.Type = unix.IP_PKTINFO + hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + info.Spec_dst = [4]byte{127, 0, 0, 1} + info.Ifindex = 5 + + ep := &StdNetEndpoint{} + getSrcFromControl(control, ep) + + if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) + } + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) + } + }) + t.Run("IPv6", func(t *testing.T) { + control := make([]byte, stickyControlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = unix.IPPROTO_IPV6 + hdr.Type = unix.IPV6_PKTINFO + hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + info.Ifindex = 5 + + ep := &StdNetEndpoint{} + getSrcFromControl(control, ep) + + if ep.SrcIP() != netip.MustParseAddr("::1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) + } + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) + } + }) + t.Run("ClearOnEmpty", func(t *testing.T) { + var control []byte + ep := &StdNetEndpoint{} + setSrc(ep, netip.MustParseAddr("::1"), 5) + + getSrcFromControl(control, ep) + if ep.SrcIP().IsValid() { + t.Errorf("unexpected address: %v", ep.SrcIP()) + } + if ep.SrcIfidx() != 0 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) + } + }) + t.Run("Multiple", func(t *testing.T) { + zeroControl := make([]byte, unix.CmsgSpace(0)) + zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0])) + zeroHdr.SetLen(unix.CmsgLen(0)) + + control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = unix.IPPROTO_IP + hdr.Type = unix.IP_PKTINFO + hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + info.Spec_dst = [4]byte{127, 0, 0, 1} + info.Ifindex = 5 + + combined := make([]byte, 0) + combined = append(combined, zeroControl...) + combined = append(combined, control...) + + ep := &StdNetEndpoint{} + getSrcFromControl(combined, ep) + + if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) + } + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) + } + }) +} + +func Test_listenConfig(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0") + if err != nil { + t.Fatal(err) + } + defer conn.Close() + sc, err := conn.(*net.UDPConn).SyscallConn() + if err != nil { + t.Fatal(err) + } + + if runtime.GOOS == "linux" { + var i int + sc.Control(func(fd uintptr) { + i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO) + }) + if err != nil { + t.Fatal(err) + } + if i != 1 { + t.Error("IP_PKTINFO not set!") + } + } else { + t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) + } + }) + t.Run("IPv6", func(t *testing.T) { + conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0") + if err != nil { + t.Fatal(err) + } + sc, err := conn.(*net.UDPConn).SyscallConn() + if err != nil { + t.Fatal(err) + } + + if runtime.GOOS == "linux" { + var i int + sc.Control(func(fd uintptr) { + i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO) + }) + if err != nil { + t.Fatal(err) + } + if i != 1 { + t.Error("IPV6_PKTINFO not set!") + } + } else { + t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) + } + }) +} diff --git a/conn/winrio/rio_windows.go b/conn/winrio/rio_windows.go new file mode 100644 index 0000000..c396658 --- /dev/null +++ b/conn/winrio/rio_windows.go @@ -0,0 +1,254 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package winrio + +import ( + "log" + "sync" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +const ( + MsgDontNotify = 1 + MsgDefer = 2 + MsgWaitAll = 4 + MsgCommitOnly = 8 + + MaxCqSize = 0x8000000 + + invalidBufferId = 0xFFFFFFFF + invalidCq = 0 + invalidRq = 0 + corruptCq = 0xFFFFFFFF +) + +var extensionFunctionTable struct { + cbSize uint32 + rioReceive uintptr + rioReceiveEx uintptr + rioSend uintptr + rioSendEx uintptr + rioCloseCompletionQueue uintptr + rioCreateCompletionQueue uintptr + rioCreateRequestQueue uintptr + rioDequeueCompletion uintptr + rioDeregisterBuffer uintptr + rioNotify uintptr + rioRegisterBuffer uintptr + rioResizeCompletionQueue uintptr + rioResizeRequestQueue uintptr +} + +type Cq uintptr + +type Rq uintptr + +type BufferId uintptr + +type Buffer struct { + Id BufferId + Offset uint32 + Length uint32 +} + +type Result struct { + Status int32 + BytesTransferred uint32 + SocketContext uint64 + RequestContext uint64 +} + +type notificationCompletionType uint32 + +const ( + eventCompletion notificationCompletionType = 1 + iocpCompletion notificationCompletionType = 2 +) + +type eventNotificationCompletion struct { + completionType notificationCompletionType + event windows.Handle + notifyReset uint32 +} + +type iocpNotificationCompletion struct { + completionType notificationCompletionType + iocp windows.Handle + key uintptr + overlapped *windows.Overlapped +} + +var ( + initialized sync.Once + available bool +) + +func Initialize() bool { + initialized.Do(func() { + var ( + err error + socket windows.Handle + cq Cq + ) + defer func() { + if err == nil { + return + } + if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 { + return + } + log.Printf("Registered I/O is unavailable: %v", err) + }() + socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP) + if err != nil { + return + } + defer windows.CloseHandle(socket) + WSAID_MULTIPLE_RIO := &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}} + const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024 + ob := uint32(0) + err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER, + (*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)), + (*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)), + &ob, nil, 0) + if err != nil { + return + } + + // While we should be able to stop here, after getting the function pointers, some anti-virus actually causes + // failures in RIOCreateRequestQueue, so keep going to be certain this is supported. + var iocp windows.Handle + iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) + if err != nil { + return + } + defer windows.CloseHandle(iocp) + var overlapped windows.Overlapped + cq, err = CreateIOCPCompletionQueue(2, iocp, 0, &overlapped) + if err != nil { + return + } + defer CloseCompletionQueue(cq) + _, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0) + if err != nil { + return + } + available = true + }) + return available +} + +func Socket(af, typ, proto int32) (windows.Handle, error) { + return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO) +} + +func CloseCompletionQueue(cq Cq) { + _, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0) +} + +func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) { + notificationCompletion := &eventNotificationCompletion{ + completionType: eventCompletion, + event: event, + } + if notifyReset { + notificationCompletion.notifyReset = 1 + } + ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0) + if ret == invalidCq { + return 0, err + } + return Cq(ret), nil +} + +func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) { + notificationCompletion := &iocpNotificationCompletion{ + completionType: iocpCompletion, + iocp: iocp, + key: key, + overlapped: overlapped, + } + ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0) + if ret == invalidCq { + return 0, err + } + return Cq(ret), nil +} + +func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) { + ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0) + if ret == invalidCq { + return 0, err + } + return Cq(ret), nil +} + +func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) { + ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0) + if ret == invalidRq { + return 0, err + } + return Rq(ret), nil +} + +func DequeueCompletion(cq Cq, results []Result) uint32 { + var array uintptr + if len(results) > 0 { + array = uintptr(unsafe.Pointer(&results[0])) + } + ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results))) + if ret == corruptCq { + panic("cq is corrupt") + } + return uint32(ret) +} + +func DeregisterBuffer(id BufferId) { + _, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0) +} + +func RegisterBuffer(buffer []byte) (BufferId, error) { + var buf unsafe.Pointer + if len(buffer) > 0 { + buf = unsafe.Pointer(&buffer[0]) + } + return RegisterPointer(buf, uint32(len(buffer))) +} + +func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) { + ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0) + if ret == invalidBufferId { + return 0, err + } + return BufferId(ret), nil +} + +func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error { + ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext) + if ret == 0 { + return err + } + return nil +} + +func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error { + ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext) + if ret == 0 { + return err + } + return nil +} + +func Notify(cq Cq) error { + ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0) + if ret != 0 { + return windows.Errno(ret) + } + return nil +} |