aboutsummaryrefslogtreecommitdiffstats
path: root/conn
diff options
context:
space:
mode:
Diffstat (limited to 'conn')
-rw-r--r--conn/bind_std.go544
-rw-r--r--conn/bind_std_test.go250
-rw-r--r--conn/bind_windows.go601
-rw-r--r--conn/bindtest/bindtest.go136
-rw-r--r--conn/boundif_android.go10
-rw-r--r--conn/boundif_windows.go59
-rw-r--r--conn/conn.go132
-rw-r--r--conn/conn_default.go176
-rw-r--r--conn/conn_linux.go571
-rw-r--r--conn/conn_test.go24
-rw-r--r--conn/controlfns.go43
-rw-r--r--conn/controlfns_linux.go109
-rw-r--r--conn/controlfns_unix.go35
-rw-r--r--conn/controlfns_windows.go23
-rw-r--r--conn/default.go10
-rw-r--r--conn/errors_default.go12
-rw-r--r--conn/errors_linux.go26
-rw-r--r--conn/features_default.go15
-rw-r--r--conn/features_linux.go29
-rw-r--r--conn/gso_default.go21
-rw-r--r--conn/gso_linux.go65
-rw-r--r--conn/mark_default.go6
-rw-r--r--conn/mark_unix.go14
-rw-r--r--conn/sticky_default.go42
-rw-r--r--conn/sticky_linux.go112
-rw-r--r--conn/sticky_linux_test.go266
-rw-r--r--conn/winrio/rio_windows.go254
27 files changed, 2709 insertions, 876 deletions
diff --git a/conn/bind_std.go b/conn/bind_std.go
new file mode 100644
index 0000000..f5c8816
--- /dev/null
+++ b/conn/bind_std.go
@@ -0,0 +1,544 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "net/netip"
+ "runtime"
+ "strconv"
+ "sync"
+ "syscall"
+
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+)
+
+var (
+ _ Bind = (*StdNetBind)(nil)
+)
+
+// StdNetBind implements Bind for all platforms. While Windows has its own Bind
+// (see bind_windows.go), it may fall back to StdNetBind.
+// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
+// methods for sending and receiving multiple datagrams per-syscall. See the
+// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
+type StdNetBind struct {
+ mu sync.Mutex // protects all fields except as specified
+ ipv4 *net.UDPConn
+ ipv6 *net.UDPConn
+ ipv4PC *ipv4.PacketConn // will be nil on non-Linux
+ ipv6PC *ipv6.PacketConn // will be nil on non-Linux
+ ipv4TxOffload bool
+ ipv4RxOffload bool
+ ipv6TxOffload bool
+ ipv6RxOffload bool
+
+ // these two fields are not guarded by mu
+ udpAddrPool sync.Pool
+ msgsPool sync.Pool
+
+ blackhole4 bool
+ blackhole6 bool
+}
+
+func NewStdNetBind() Bind {
+ return &StdNetBind{
+ udpAddrPool: sync.Pool{
+ New: func() any {
+ return &net.UDPAddr{
+ IP: make([]byte, 16),
+ }
+ },
+ },
+
+ msgsPool: sync.Pool{
+ New: func() any {
+ // ipv6.Message and ipv4.Message are interchangeable as they are
+ // both aliases for x/net/internal/socket.Message.
+ msgs := make([]ipv6.Message, IdealBatchSize)
+ for i := range msgs {
+ msgs[i].Buffers = make(net.Buffers, 1)
+ msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
+ }
+ return &msgs
+ },
+ },
+ }
+}
+
+type StdNetEndpoint struct {
+ // AddrPort is the endpoint destination.
+ netip.AddrPort
+ // src is the current sticky source address and interface index, if
+ // supported. Typically this is a PKTINFO structure from/for control
+ // messages, see unix.PKTINFO for an example.
+ src []byte
+}
+
+var (
+ _ Bind = (*StdNetBind)(nil)
+ _ Endpoint = &StdNetEndpoint{}
+)
+
+func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
+ e, err := netip.ParseAddrPort(s)
+ if err != nil {
+ return nil, err
+ }
+ return &StdNetEndpoint{
+ AddrPort: e,
+ }, nil
+}
+
+func (e *StdNetEndpoint) ClearSrc() {
+ if e.src != nil {
+ // Truncate src, no need to reallocate.
+ e.src = e.src[:0]
+ }
+}
+
+func (e *StdNetEndpoint) DstIP() netip.Addr {
+ return e.AddrPort.Addr()
+}
+
+// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
+
+func (e *StdNetEndpoint) DstToBytes() []byte {
+ b, _ := e.AddrPort.MarshalBinary()
+ return b
+}
+
+func (e *StdNetEndpoint) DstToString() string {
+ return e.AddrPort.String()
+}
+
+func listenNet(network string, port int) (*net.UDPConn, int, error) {
+ conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
+ if err != nil {
+ return nil, 0, err
+ }
+
+ // Retrieve port.
+ laddr := conn.LocalAddr()
+ uaddr, err := net.ResolveUDPAddr(
+ laddr.Network(),
+ laddr.String(),
+ )
+ if err != nil {
+ return nil, 0, err
+ }
+ return conn.(*net.UDPConn), uaddr.Port, nil
+}
+
+func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ var err error
+ var tries int
+
+ if s.ipv4 != nil || s.ipv6 != nil {
+ return nil, 0, ErrBindAlreadyOpen
+ }
+
+ // Attempt to open ipv4 and ipv6 listeners on the same port.
+ // If uport is 0, we can retry on failure.
+again:
+ port := int(uport)
+ var v4conn, v6conn *net.UDPConn
+ var v4pc *ipv4.PacketConn
+ var v6pc *ipv6.PacketConn
+
+ v4conn, port, err = listenNet("udp4", port)
+ if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
+ return nil, 0, err
+ }
+
+ // Listen on the same port as we're using for ipv4.
+ v6conn, port, err = listenNet("udp6", port)
+ if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
+ v4conn.Close()
+ tries++
+ goto again
+ }
+ if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
+ v4conn.Close()
+ return nil, 0, err
+ }
+ var fns []ReceiveFunc
+ if v4conn != nil {
+ s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ v4pc = ipv4.NewPacketConn(v4conn)
+ s.ipv4PC = v4pc
+ }
+ fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
+ s.ipv4 = v4conn
+ }
+ if v6conn != nil {
+ s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ v6pc = ipv6.NewPacketConn(v6conn)
+ s.ipv6PC = v6pc
+ }
+ fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
+ s.ipv6 = v6conn
+ }
+ if len(fns) == 0 {
+ return nil, 0, syscall.EAFNOSUPPORT
+ }
+
+ return fns, uint16(port), nil
+}
+
+func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
+ for i := range *msgs {
+ (*msgs)[i].OOB = (*msgs)[i].OOB[:0]
+ (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
+ }
+ s.msgsPool.Put(msgs)
+}
+
+func (s *StdNetBind) getMessages() *[]ipv6.Message {
+ return s.msgsPool.Get().(*[]ipv6.Message)
+}
+
+var (
+ // If compilation fails here these are no longer the same underlying type.
+ _ ipv6.Message = ipv4.Message{}
+)
+
+type batchReader interface {
+ ReadBatch([]ipv6.Message, int) (int, error)
+}
+
+type batchWriter interface {
+ WriteBatch([]ipv6.Message, int) (int, error)
+}
+
+func (s *StdNetBind) receiveIP(
+ br batchReader,
+ conn *net.UDPConn,
+ rxOffload bool,
+ bufs [][]byte,
+ sizes []int,
+ eps []Endpoint,
+) (n int, err error) {
+ msgs := s.getMessages()
+ for i := range bufs {
+ (*msgs)[i].Buffers[0] = bufs[i]
+ (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
+ }
+ defer s.putMessages(msgs)
+ var numMsgs int
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ if rxOffload {
+ readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
+ numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
+ if err != nil {
+ return 0, err
+ }
+ numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
+ if err != nil {
+ return 0, err
+ }
+ } else {
+ numMsgs, err = br.ReadBatch(*msgs, 0)
+ if err != nil {
+ return 0, err
+ }
+ }
+ } else {
+ msg := &(*msgs)[0]
+ msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
+ if err != nil {
+ return 0, err
+ }
+ numMsgs = 1
+ }
+ for i := 0; i < numMsgs; i++ {
+ msg := &(*msgs)[i]
+ sizes[i] = msg.N
+ if sizes[i] == 0 {
+ continue
+ }
+ addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+ ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
+ getSrcFromControl(msg.OOB[:msg.NN], ep)
+ eps[i] = ep
+ }
+ return numMsgs, nil
+}
+
+func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
+ return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
+ }
+}
+
+func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
+ return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
+ }
+}
+
+// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
+// rename the IdealBatchSize constant to BatchSize.
+func (s *StdNetBind) BatchSize() int {
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ return IdealBatchSize
+ }
+ return 1
+}
+
+func (s *StdNetBind) Close() error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ var err1, err2 error
+ if s.ipv4 != nil {
+ err1 = s.ipv4.Close()
+ s.ipv4 = nil
+ s.ipv4PC = nil
+ }
+ if s.ipv6 != nil {
+ err2 = s.ipv6.Close()
+ s.ipv6 = nil
+ s.ipv6PC = nil
+ }
+ s.blackhole4 = false
+ s.blackhole6 = false
+ s.ipv4TxOffload = false
+ s.ipv4RxOffload = false
+ s.ipv6TxOffload = false
+ s.ipv6RxOffload = false
+ if err1 != nil {
+ return err1
+ }
+ return err2
+}
+
+type ErrUDPGSODisabled struct {
+ onLaddr string
+ RetryErr error
+}
+
+func (e ErrUDPGSODisabled) Error() string {
+ return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
+}
+
+func (e ErrUDPGSODisabled) Unwrap() error {
+ return e.RetryErr
+}
+
+func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
+ s.mu.Lock()
+ blackhole := s.blackhole4
+ conn := s.ipv4
+ offload := s.ipv4TxOffload
+ br := batchWriter(s.ipv4PC)
+ is6 := false
+ if endpoint.DstIP().Is6() {
+ blackhole = s.blackhole6
+ conn = s.ipv6
+ br = s.ipv6PC
+ is6 = true
+ offload = s.ipv6TxOffload
+ }
+ s.mu.Unlock()
+
+ if blackhole {
+ return nil
+ }
+ if conn == nil {
+ return syscall.EAFNOSUPPORT
+ }
+
+ msgs := s.getMessages()
+ defer s.putMessages(msgs)
+ ua := s.udpAddrPool.Get().(*net.UDPAddr)
+ defer s.udpAddrPool.Put(ua)
+ if is6 {
+ as16 := endpoint.DstIP().As16()
+ copy(ua.IP, as16[:])
+ ua.IP = ua.IP[:16]
+ } else {
+ as4 := endpoint.DstIP().As4()
+ copy(ua.IP, as4[:])
+ ua.IP = ua.IP[:4]
+ }
+ ua.Port = int(endpoint.(*StdNetEndpoint).Port())
+ var (
+ retried bool
+ err error
+ )
+retry:
+ if offload {
+ n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
+ err = s.send(conn, br, (*msgs)[:n])
+ if err != nil && offload && errShouldDisableUDPGSO(err) {
+ offload = false
+ s.mu.Lock()
+ if is6 {
+ s.ipv6TxOffload = false
+ } else {
+ s.ipv4TxOffload = false
+ }
+ s.mu.Unlock()
+ retried = true
+ goto retry
+ }
+ } else {
+ for i := range bufs {
+ (*msgs)[i].Addr = ua
+ (*msgs)[i].Buffers[0] = bufs[i]
+ setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
+ }
+ err = s.send(conn, br, (*msgs)[:len(bufs)])
+ }
+ if retried {
+ return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
+ }
+ return err
+}
+
+func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
+ var (
+ n int
+ err error
+ start int
+ )
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ for {
+ n, err = pc.WriteBatch(msgs[start:], 0)
+ if err != nil || n == len(msgs[start:]) {
+ break
+ }
+ start += n
+ }
+ } else {
+ for _, msg := range msgs {
+ _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
+ if err != nil {
+ break
+ }
+ }
+ }
+ return err
+}
+
+const (
+ // Exceeding these values results in EMSGSIZE. They account for layer3 and
+ // layer4 headers. IPv6 does not need to account for itself as the payload
+ // length field is self excluding.
+ maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
+ maxIPv6PayloadLen = 1<<16 - 1 - 8
+
+ // This is a hard limit imposed by the kernel.
+ udpSegmentMaxDatagrams = 64
+)
+
+type setGSOFunc func(control *[]byte, gsoSize uint16)
+
+func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
+ var (
+ base = -1 // index of msg we are currently coalescing into
+ gsoSize int // segmentation size of msgs[base]
+ dgramCnt int // number of dgrams coalesced into msgs[base]
+ endBatch bool // tracking flag to start a new batch on next iteration of bufs
+ )
+ maxPayloadLen := maxIPv4PayloadLen
+ if ep.DstIP().Is6() {
+ maxPayloadLen = maxIPv6PayloadLen
+ }
+ for i, buf := range bufs {
+ if i > 0 {
+ msgLen := len(buf)
+ baseLenBefore := len(msgs[base].Buffers[0])
+ freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
+ if msgLen+baseLenBefore <= maxPayloadLen &&
+ msgLen <= gsoSize &&
+ msgLen <= freeBaseCap &&
+ dgramCnt < udpSegmentMaxDatagrams &&
+ !endBatch {
+ msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
+ if i == len(bufs)-1 {
+ setGSO(&msgs[base].OOB, uint16(gsoSize))
+ }
+ dgramCnt++
+ if msgLen < gsoSize {
+ // A smaller than gsoSize packet on the tail is legal, but
+ // it must end the batch.
+ endBatch = true
+ }
+ continue
+ }
+ }
+ if dgramCnt > 1 {
+ setGSO(&msgs[base].OOB, uint16(gsoSize))
+ }
+ // Reset prior to incrementing base since we are preparing to start a
+ // new potential batch.
+ endBatch = false
+ base++
+ gsoSize = len(buf)
+ setSrcControl(&msgs[base].OOB, ep)
+ msgs[base].Buffers[0] = buf
+ msgs[base].Addr = addr
+ dgramCnt = 1
+ }
+ return base + 1
+}
+
+type getGSOFunc func(control []byte) (int, error)
+
+func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
+ for i := firstMsgAt; i < len(msgs); i++ {
+ msg := &msgs[i]
+ if msg.N == 0 {
+ return n, err
+ }
+ var (
+ gsoSize int
+ start int
+ end = msg.N
+ numToSplit = 1
+ )
+ gsoSize, err = getGSO(msg.OOB[:msg.NN])
+ if err != nil {
+ return n, err
+ }
+ if gsoSize > 0 {
+ numToSplit = (msg.N + gsoSize - 1) / gsoSize
+ end = gsoSize
+ }
+ for j := 0; j < numToSplit; j++ {
+ if n > i {
+ return n, errors.New("splitting coalesced packet resulted in overflow")
+ }
+ copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
+ msgs[n].N = copied
+ msgs[n].Addr = msg.Addr
+ start = end
+ end += gsoSize
+ if end > msg.N {
+ end = msg.N
+ }
+ n++
+ }
+ if i != n-1 {
+ // It is legal for bytes to move within msg.Buffers[0] as a result
+ // of splitting, so we only zero the source msg len when it is not
+ // the destination of the last split operation above.
+ msg.N = 0
+ }
+ }
+ return n, nil
+}
diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go
new file mode 100644
index 0000000..34a3c9a
--- /dev/null
+++ b/conn/bind_std_test.go
@@ -0,0 +1,250 @@
+package conn
+
+import (
+ "encoding/binary"
+ "net"
+ "testing"
+
+ "golang.org/x/net/ipv6"
+)
+
+func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
+ bind := NewStdNetBind().(*StdNetBind)
+ fns, _, err := bind.Open(0)
+ if err != nil {
+ t.Fatal(err)
+ }
+ bind.Close()
+ bufs := make([][]byte, 1)
+ bufs[0] = make([]byte, 1)
+ sizes := make([]int, 1)
+ eps := make([]Endpoint, 1)
+ for _, fn := range fns {
+ // The ReceiveFuncs must not access conn-related fields on StdNetBind
+ // unguarded. Close() nils the conn-related fields resulting in a panic
+ // if they violate the mutex.
+ fn(bufs, sizes, eps)
+ }
+}
+
+func mockSetGSOSize(control *[]byte, gsoSize uint16) {
+ *control = (*control)[:cap(*control)]
+ binary.LittleEndian.PutUint16(*control, gsoSize)
+}
+
+func Test_coalesceMessages(t *testing.T) {
+ cases := []struct {
+ name string
+ buffs [][]byte
+ wantLens []int
+ wantGSO []int
+ }{
+ {
+ name: "one message no coalesce",
+ buffs: [][]byte{
+ make([]byte, 1, 1),
+ },
+ wantLens: []int{1},
+ wantGSO: []int{0},
+ },
+ {
+ name: "two messages equal len coalesce",
+ buffs: [][]byte{
+ make([]byte, 1, 2),
+ make([]byte, 1, 1),
+ },
+ wantLens: []int{2},
+ wantGSO: []int{1},
+ },
+ {
+ name: "two messages unequal len coalesce",
+ buffs: [][]byte{
+ make([]byte, 2, 3),
+ make([]byte, 1, 1),
+ },
+ wantLens: []int{3},
+ wantGSO: []int{2},
+ },
+ {
+ name: "three messages second unequal len coalesce",
+ buffs: [][]byte{
+ make([]byte, 2, 3),
+ make([]byte, 1, 1),
+ make([]byte, 2, 2),
+ },
+ wantLens: []int{3, 2},
+ wantGSO: []int{2, 0},
+ },
+ {
+ name: "three messages limited cap coalesce",
+ buffs: [][]byte{
+ make([]byte, 2, 4),
+ make([]byte, 2, 2),
+ make([]byte, 2, 2),
+ },
+ wantLens: []int{4, 2},
+ wantGSO: []int{2, 0},
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ addr := &net.UDPAddr{
+ IP: net.ParseIP("127.0.0.1").To4(),
+ Port: 1,
+ }
+ msgs := make([]ipv6.Message, len(tt.buffs))
+ for i := range msgs {
+ msgs[i].Buffers = make([][]byte, 1)
+ msgs[i].OOB = make([]byte, 0, 2)
+ }
+ got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize)
+ if got != len(tt.wantLens) {
+ t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
+ }
+ for i := 0; i < got; i++ {
+ if msgs[i].Addr != addr {
+ t.Errorf("msgs[%d].Addr != passed addr", i)
+ }
+ gotLen := len(msgs[i].Buffers[0])
+ if gotLen != tt.wantLens[i] {
+ t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
+ }
+ gotGSO, err := mockGetGSOSize(msgs[i].OOB)
+ if err != nil {
+ t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
+ }
+ if gotGSO != tt.wantGSO[i] {
+ t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
+ }
+ }
+ })
+ }
+}
+
+func mockGetGSOSize(control []byte) (int, error) {
+ if len(control) < 2 {
+ return 0, nil
+ }
+ return int(binary.LittleEndian.Uint16(control)), nil
+}
+
+func Test_splitCoalescedMessages(t *testing.T) {
+ newMsg := func(n, gso int) ipv6.Message {
+ msg := ipv6.Message{
+ Buffers: [][]byte{make([]byte, 1<<16-1)},
+ N: n,
+ OOB: make([]byte, 2),
+ }
+ binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
+ if gso > 0 {
+ msg.NN = 2
+ }
+ return msg
+ }
+
+ cases := []struct {
+ name string
+ msgs []ipv6.Message
+ firstMsgAt int
+ wantNumEval int
+ wantMsgLens []int
+ wantErr bool
+ }{
+ {
+ name: "second last split last empty",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(3, 1),
+ newMsg(0, 0),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 3,
+ wantMsgLens: []int{1, 1, 1, 0},
+ wantErr: false,
+ },
+ {
+ name: "second last no split last empty",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(1, 0),
+ newMsg(0, 0),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 1,
+ wantMsgLens: []int{1, 0, 0, 0},
+ wantErr: false,
+ },
+ {
+ name: "second last no split last no split",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(1, 0),
+ newMsg(1, 0),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 2,
+ wantMsgLens: []int{1, 1, 0, 0},
+ wantErr: false,
+ },
+ {
+ name: "second last no split last split",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(1, 0),
+ newMsg(3, 1),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 4,
+ wantMsgLens: []int{1, 1, 1, 1},
+ wantErr: false,
+ },
+ {
+ name: "second last split last split",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(2, 1),
+ newMsg(2, 1),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 4,
+ wantMsgLens: []int{1, 1, 1, 1},
+ wantErr: false,
+ },
+ {
+ name: "second last no split last split overflow",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(1, 0),
+ newMsg(4, 1),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 4,
+ wantMsgLens: []int{1, 1, 1, 1},
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize)
+ if err != nil && !tt.wantErr {
+ t.Fatalf("err: %v", err)
+ }
+ if got != tt.wantNumEval {
+ t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
+ }
+ for i, msg := range tt.msgs {
+ if msg.N != tt.wantMsgLens[i] {
+ t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
+ }
+ }
+ })
+ }
+}
diff --git a/conn/bind_windows.go b/conn/bind_windows.go
new file mode 100644
index 0000000..a3b8460
--- /dev/null
+++ b/conn/bind_windows.go
@@ -0,0 +1,601 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "encoding/binary"
+ "io"
+ "net"
+ "net/netip"
+ "strconv"
+ "sync"
+ "sync/atomic"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+
+ "golang.zx2c4.com/wireguard/conn/winrio"
+)
+
+const (
+ packetsPerRing = 1024
+ bytesPerPacket = 2048 - 32
+ receiveSpins = 15
+)
+
+type ringPacket struct {
+ addr WinRingEndpoint
+ data [bytesPerPacket]byte
+}
+
+type ringBuffer struct {
+ packets uintptr
+ head, tail uint32
+ id winrio.BufferId
+ iocp windows.Handle
+ isFull bool
+ cq winrio.Cq
+ mu sync.Mutex
+ overlapped windows.Overlapped
+}
+
+func (rb *ringBuffer) Push() *ringPacket {
+ for rb.isFull {
+ panic("ring is full")
+ }
+ ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
+ rb.tail += 1
+ if rb.tail%packetsPerRing == rb.head%packetsPerRing {
+ rb.isFull = true
+ }
+ return ret
+}
+
+func (rb *ringBuffer) Return(count uint32) {
+ if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull {
+ return
+ }
+ rb.head += count
+ rb.isFull = false
+}
+
+type afWinRingBind struct {
+ sock windows.Handle
+ rx, tx ringBuffer
+ rq winrio.Rq
+ mu sync.Mutex
+ blackhole bool
+}
+
+// WinRingBind uses Windows registered I/O for fast ring buffered networking.
+type WinRingBind struct {
+ v4, v6 afWinRingBind
+ mu sync.RWMutex
+ isOpen atomic.Uint32 // 0, 1, or 2
+}
+
+func NewDefaultBind() Bind { return NewWinRingBind() }
+
+func NewWinRingBind() Bind {
+ if !winrio.Initialize() {
+ return NewStdNetBind()
+ }
+ return new(WinRingBind)
+}
+
+type WinRingEndpoint struct {
+ family uint16
+ data [30]byte
+}
+
+var (
+ _ Bind = (*WinRingBind)(nil)
+ _ Endpoint = (*WinRingEndpoint)(nil)
+)
+
+func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
+ host, port, err := net.SplitHostPort(s)
+ if err != nil {
+ return nil, err
+ }
+ host16, err := windows.UTF16PtrFromString(host)
+ if err != nil {
+ return nil, err
+ }
+ port16, err := windows.UTF16PtrFromString(port)
+ if err != nil {
+ return nil, err
+ }
+ hints := windows.AddrinfoW{
+ Flags: windows.AI_NUMERICHOST,
+ Family: windows.AF_UNSPEC,
+ Socktype: windows.SOCK_DGRAM,
+ Protocol: windows.IPPROTO_UDP,
+ }
+ var addrinfo *windows.AddrinfoW
+ err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo)
+ if err != nil {
+ return nil, err
+ }
+ defer windows.FreeAddrInfoW(addrinfo)
+ if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
+ return nil, windows.ERROR_INVALID_ADDRESS
+ }
+ var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
+ copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen))
+ return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
+}
+
+func (*WinRingEndpoint) ClearSrc() {}
+
+func (e *WinRingEndpoint) DstIP() netip.Addr {
+ switch e.family {
+ case windows.AF_INET:
+ return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
+ case windows.AF_INET6:
+ return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
+ }
+ return netip.Addr{}
+}
+
+func (e *WinRingEndpoint) SrcIP() netip.Addr {
+ return netip.Addr{} // not supported
+}
+
+func (e *WinRingEndpoint) DstToBytes() []byte {
+ switch e.family {
+ case windows.AF_INET:
+ b := make([]byte, 0, 6)
+ b = append(b, e.data[2:6]...)
+ b = append(b, e.data[1], e.data[0])
+ return b
+ case windows.AF_INET6:
+ b := make([]byte, 0, 18)
+ b = append(b, e.data[6:22]...)
+ b = append(b, e.data[1], e.data[0])
+ return b
+ }
+ return nil
+}
+
+func (e *WinRingEndpoint) DstToString() string {
+ switch e.family {
+ case windows.AF_INET:
+ return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
+ case windows.AF_INET6:
+ var zone string
+ if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
+ zone = strconv.FormatUint(uint64(scope), 10)
+ }
+ return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
+ }
+ return ""
+}
+
+func (e *WinRingEndpoint) SrcToString() string {
+ return ""
+}
+
+func (ring *ringBuffer) CloseAndZero() {
+ if ring.cq != 0 {
+ winrio.CloseCompletionQueue(ring.cq)
+ ring.cq = 0
+ }
+ if ring.iocp != 0 {
+ windows.CloseHandle(ring.iocp)
+ ring.iocp = 0
+ }
+ if ring.id != 0 {
+ winrio.DeregisterBuffer(ring.id)
+ ring.id = 0
+ }
+ if ring.packets != 0 {
+ windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
+ ring.packets = 0
+ }
+ ring.head = 0
+ ring.tail = 0
+ ring.isFull = false
+}
+
+func (bind *afWinRingBind) CloseAndZero() {
+ bind.rx.CloseAndZero()
+ bind.tx.CloseAndZero()
+ if bind.sock != 0 {
+ windows.CloseHandle(bind.sock)
+ bind.sock = 0
+ }
+ bind.blackhole = false
+}
+
+func (bind *WinRingBind) closeAndZero() {
+ bind.isOpen.Store(0)
+ bind.v4.CloseAndZero()
+ bind.v6.CloseAndZero()
+}
+
+func (ring *ringBuffer) Open() error {
+ var err error
+ packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
+ ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
+ if err != nil {
+ return err
+ }
+ ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
+ if err != nil {
+ return err
+ }
+ ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
+ if err != nil {
+ return err
+ }
+ ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) {
+ var err error
+ bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
+ if err != nil {
+ return nil, err
+ }
+ err = bind.rx.Open()
+ if err != nil {
+ return nil, err
+ }
+ err = bind.tx.Open()
+ if err != nil {
+ return nil, err
+ }
+ bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0)
+ if err != nil {
+ return nil, err
+ }
+ err = windows.Bind(bind.sock, sa)
+ if err != nil {
+ return nil, err
+ }
+ sa, err = windows.Getsockname(bind.sock)
+ if err != nil {
+ return nil, err
+ }
+ return sa, nil
+}
+
+func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) {
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+ defer func() {
+ if err != nil {
+ bind.closeAndZero()
+ }
+ }()
+ if bind.isOpen.Load() != 0 {
+ return nil, 0, ErrBindAlreadyOpen
+ }
+ var sa windows.Sockaddr
+ sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
+ if err != nil {
+ return nil, 0, err
+ }
+ sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
+ if err != nil {
+ return nil, 0, err
+ }
+ selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
+ for i := 0; i < packetsPerRing; i++ {
+ err = bind.v4.InsertReceiveRequest()
+ if err != nil {
+ return nil, 0, err
+ }
+ err = bind.v6.InsertReceiveRequest()
+ if err != nil {
+ return nil, 0, err
+ }
+ }
+ bind.isOpen.Store(1)
+ return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
+}
+
+func (bind *WinRingBind) Close() error {
+ bind.mu.RLock()
+ if bind.isOpen.Load() != 1 {
+ bind.mu.RUnlock()
+ return nil
+ }
+ bind.isOpen.Store(2)
+ windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
+ windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
+ windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
+ windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil)
+ bind.mu.RUnlock()
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+ bind.closeAndZero()
+ return nil
+}
+
+// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
+// rename the IdealBatchSize constant to BatchSize.
+func (bind *WinRingBind) BatchSize() int {
+ // TODO: implement batching in and out of the ring
+ return 1
+}
+
+func (bind *WinRingBind) SetMark(mark uint32) error {
+ return nil
+}
+
+func (bind *afWinRingBind) InsertReceiveRequest() error {
+ packet := bind.rx.Push()
+ dataBuffer := &winrio.Buffer{
+ Id: bind.rx.id,
+ Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets),
+ Length: uint32(len(packet.data)),
+ }
+ addressBuffer := &winrio.Buffer{
+ Id: bind.rx.id,
+ Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets),
+ Length: uint32(unsafe.Sizeof(packet.addr)),
+ }
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+ return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
+}
+
+//go:linkname procyield runtime.procyield
+func procyield(cycles uint32)
+
+func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
+ if isOpen.Load() != 1 {
+ return 0, nil, net.ErrClosed
+ }
+ bind.rx.mu.Lock()
+ defer bind.rx.mu.Unlock()
+
+ var err error
+ var count uint32
+ var results [1]winrio.Result
+retry:
+ count = 0
+ for tries := 0; count == 0 && tries < receiveSpins; tries++ {
+ if tries > 0 {
+ if isOpen.Load() != 1 {
+ return 0, nil, net.ErrClosed
+ }
+ procyield(1)
+ }
+ count = winrio.DequeueCompletion(bind.rx.cq, results[:])
+ }
+ if count == 0 {
+ err = winrio.Notify(bind.rx.cq)
+ if err != nil {
+ return 0, nil, err
+ }
+ var bytes uint32
+ var key uintptr
+ var overlapped *windows.Overlapped
+ err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
+ if err != nil {
+ return 0, nil, err
+ }
+ if isOpen.Load() != 1 {
+ return 0, nil, net.ErrClosed
+ }
+ count = winrio.DequeueCompletion(bind.rx.cq, results[:])
+ if count == 0 {
+ return 0, nil, io.ErrNoProgress
+ }
+ }
+ bind.rx.Return(1)
+ err = bind.InsertReceiveRequest()
+ if err != nil {
+ return 0, nil, err
+ }
+ // We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
+ // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
+ // attacker bandwidth, just like the rest of the receive path.
+ if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
+ if isOpen.Load() != 1 {
+ return 0, nil, net.ErrClosed
+ }
+ goto retry
+ }
+ if results[0].Status != 0 {
+ return 0, nil, windows.Errno(results[0].Status)
+ }
+ packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
+ ep := packet.addr
+ n := copy(buf, packet.data[:results[0].BytesTransferred])
+ return n, &ep, nil
+}
+
+func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
+ bind.mu.RLock()
+ defer bind.mu.RUnlock()
+ n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen)
+ sizes[0] = n
+ eps[0] = ep
+ return 1, err
+}
+
+func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
+ bind.mu.RLock()
+ defer bind.mu.RUnlock()
+ n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen)
+ sizes[0] = n
+ eps[0] = ep
+ return 1, err
+}
+
+func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
+ if isOpen.Load() != 1 {
+ return net.ErrClosed
+ }
+ if len(buf) > bytesPerPacket {
+ return io.ErrShortBuffer
+ }
+ bind.tx.mu.Lock()
+ defer bind.tx.mu.Unlock()
+ var results [packetsPerRing]winrio.Result
+ count := winrio.DequeueCompletion(bind.tx.cq, results[:])
+ if count == 0 && bind.tx.isFull {
+ err := winrio.Notify(bind.tx.cq)
+ if err != nil {
+ return err
+ }
+ var bytes uint32
+ var key uintptr
+ var overlapped *windows.Overlapped
+ err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
+ if err != nil {
+ return err
+ }
+ if isOpen.Load() != 1 {
+ return net.ErrClosed
+ }
+ count = winrio.DequeueCompletion(bind.tx.cq, results[:])
+ if count == 0 {
+ return io.ErrNoProgress
+ }
+ }
+ if count > 0 {
+ bind.tx.Return(count)
+ }
+ packet := bind.tx.Push()
+ packet.addr = *nend
+ copy(packet.data[:], buf)
+ dataBuffer := &winrio.Buffer{
+ Id: bind.tx.id,
+ Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets),
+ Length: uint32(len(buf)),
+ }
+ addressBuffer := &winrio.Buffer{
+ Id: bind.tx.id,
+ Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets),
+ Length: uint32(unsafe.Sizeof(packet.addr)),
+ }
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+ return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
+}
+
+func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error {
+ nend, ok := endpoint.(*WinRingEndpoint)
+ if !ok {
+ return ErrWrongEndpointType
+ }
+ bind.mu.RLock()
+ defer bind.mu.RUnlock()
+ for _, buf := range bufs {
+ switch nend.family {
+ case windows.AF_INET:
+ if bind.v4.blackhole {
+ continue
+ }
+ if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
+ return err
+ }
+ case windows.AF_INET6:
+ if bind.v6.blackhole {
+ continue
+ }
+ if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ sysconn, err := s.ipv4.SyscallConn()
+ if err != nil {
+ return err
+ }
+ err2 := sysconn.Control(func(fd uintptr) {
+ err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex)
+ })
+ if err2 != nil {
+ return err2
+ }
+ if err != nil {
+ return err
+ }
+ s.blackhole4 = blackhole
+ return nil
+}
+
+func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ sysconn, err := s.ipv6.SyscallConn()
+ if err != nil {
+ return err
+ }
+ err2 := sysconn.Control(func(fd uintptr) {
+ err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex)
+ })
+ if err2 != nil {
+ return err2
+ }
+ if err != nil {
+ return err
+ }
+ s.blackhole6 = blackhole
+ return nil
+}
+
+func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
+ bind.mu.RLock()
+ defer bind.mu.RUnlock()
+ if bind.isOpen.Load() != 1 {
+ return net.ErrClosed
+ }
+ err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
+ if err != nil {
+ return err
+ }
+ bind.v4.blackhole = blackhole
+ return nil
+}
+
+func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
+ bind.mu.RLock()
+ defer bind.mu.RUnlock()
+ if bind.isOpen.Load() != 1 {
+ return net.ErrClosed
+ }
+ err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
+ if err != nil {
+ return err
+ }
+ bind.v6.blackhole = blackhole
+ return nil
+}
+
+func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error {
+ const IP_UNICAST_IF = 31
+ /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
+ var bytes [4]byte
+ binary.BigEndian.PutUint32(bytes[:], interfaceIndex)
+ interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
+ err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex))
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error {
+ const IPV6_UNICAST_IF = 31
+ return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
+}
diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go
new file mode 100644
index 0000000..46e20e6
--- /dev/null
+++ b/conn/bindtest/bindtest.go
@@ -0,0 +1,136 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
+ */
+
+package bindtest
+
+import (
+ "fmt"
+ "math/rand"
+ "net"
+ "net/netip"
+ "os"
+
+ "golang.zx2c4.com/wireguard/conn"
+)
+
+type ChannelBind struct {
+ rx4, tx4 *chan []byte
+ rx6, tx6 *chan []byte
+ closeSignal chan bool
+ source4, source6 ChannelEndpoint
+ target4, target6 ChannelEndpoint
+}
+
+type ChannelEndpoint uint16
+
+var (
+ _ conn.Bind = (*ChannelBind)(nil)
+ _ conn.Endpoint = (*ChannelEndpoint)(nil)
+)
+
+func NewChannelBinds() [2]conn.Bind {
+ arx4 := make(chan []byte, 8192)
+ brx4 := make(chan []byte, 8192)
+ arx6 := make(chan []byte, 8192)
+ brx6 := make(chan []byte, 8192)
+ var binds [2]ChannelBind
+ binds[0].rx4 = &arx4
+ binds[0].tx4 = &brx4
+ binds[1].rx4 = &brx4
+ binds[1].tx4 = &arx4
+ binds[0].rx6 = &arx6
+ binds[0].tx6 = &brx6
+ binds[1].rx6 = &brx6
+ binds[1].tx6 = &arx6
+ binds[0].target4 = ChannelEndpoint(1)
+ binds[1].target4 = ChannelEndpoint(2)
+ binds[0].target6 = ChannelEndpoint(3)
+ binds[1].target6 = ChannelEndpoint(4)
+ binds[0].source4 = binds[1].target4
+ binds[0].source6 = binds[1].target6
+ binds[1].source4 = binds[0].target4
+ binds[1].source6 = binds[0].target6
+ return [2]conn.Bind{&binds[0], &binds[1]}
+}
+
+func (c ChannelEndpoint) ClearSrc() {}
+
+func (c ChannelEndpoint) SrcToString() string { return "" }
+
+func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) }
+
+func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
+
+func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
+
+func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
+
+func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
+ c.closeSignal = make(chan bool)
+ fns = append(fns, c.makeReceiveFunc(*c.rx4))
+ fns = append(fns, c.makeReceiveFunc(*c.rx6))
+ if rand.Uint32()&1 == 0 {
+ return fns, uint16(c.source4), nil
+ } else {
+ return fns, uint16(c.source6), nil
+ }
+}
+
+func (c *ChannelBind) Close() error {
+ if c.closeSignal != nil {
+ select {
+ case <-c.closeSignal:
+ default:
+ close(c.closeSignal)
+ }
+ }
+ return nil
+}
+
+func (c *ChannelBind) BatchSize() int { return 1 }
+
+func (c *ChannelBind) SetMark(mark uint32) error { return nil }
+
+func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
+ return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
+ select {
+ case <-c.closeSignal:
+ return 0, net.ErrClosed
+ case rx := <-ch:
+ copied := copy(bufs[0], rx)
+ sizes[0] = copied
+ eps[0] = c.target6
+ return 1, nil
+ }
+ }
+}
+
+func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error {
+ for _, b := range bufs {
+ select {
+ case <-c.closeSignal:
+ return net.ErrClosed
+ default:
+ bc := make([]byte, len(b))
+ copy(bc, b)
+ if ep.(ChannelEndpoint) == c.target4 {
+ *c.tx4 <- bc
+ } else if ep.(ChannelEndpoint) == c.target6 {
+ *c.tx6 <- bc
+ } else {
+ return os.ErrInvalid
+ }
+ }
+ }
+ return nil
+}
+
+func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
+ addr, err := netip.ParseAddrPort(s)
+ if err != nil {
+ return nil, err
+ }
+ return ChannelEndpoint(addr.Port()), nil
+}
diff --git a/conn/boundif_android.go b/conn/boundif_android.go
index 3e10607..be69b2a 100644
--- a/conn/boundif_android.go
+++ b/conn/boundif_android.go
@@ -1,12 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn
-func (bind *nativeBind) PeekLookAtSocketFd4() (fd int, err error) {
- sysconn, err := bind.ipv4.SyscallConn()
+func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
+ sysconn, err := s.ipv4.SyscallConn()
if err != nil {
return -1, err
}
@@ -19,8 +19,8 @@ func (bind *nativeBind) PeekLookAtSocketFd4() (fd int, err error) {
return
}
-func (bind *nativeBind) PeekLookAtSocketFd6() (fd int, err error) {
- sysconn, err := bind.ipv6.SyscallConn()
+func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
+ sysconn, err := s.ipv6.SyscallConn()
if err != nil {
return -1, err
}
diff --git a/conn/boundif_windows.go b/conn/boundif_windows.go
deleted file mode 100644
index 53a8f09..0000000
--- a/conn/boundif_windows.go
+++ /dev/null
@@ -1,59 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
- */
-
-package conn
-
-import (
- "encoding/binary"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-const (
- sockoptIP_UNICAST_IF = 31
- sockoptIPV6_UNICAST_IF = 31
-)
-
-func (bind *nativeBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
- /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
- bytes := make([]byte, 4)
- binary.BigEndian.PutUint32(bytes, interfaceIndex)
- interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
-
- sysconn, err := bind.ipv4.SyscallConn()
- if err != nil {
- return err
- }
- err2 := sysconn.Control(func(fd uintptr) {
- err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, sockoptIP_UNICAST_IF, int(interfaceIndex))
- })
- if err2 != nil {
- return err2
- }
- if err != nil {
- return err
- }
- bind.blackhole4 = blackhole
- return nil
-}
-
-func (bind *nativeBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
- sysconn, err := bind.ipv6.SyscallConn()
- if err != nil {
- return err
- }
- err2 := sysconn.Control(func(fd uintptr) {
- err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, sockoptIPV6_UNICAST_IF, int(interfaceIndex))
- })
- if err2 != nil {
- return err2
- }
- if err != nil {
- return err
- }
- bind.blackhole6 = blackhole
- return nil
-}
diff --git a/conn/conn.go b/conn/conn.go
index ad91d2d..1304657 100644
--- a/conn/conn.go
+++ b/conn/conn.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
// Package conn implements WireGuard's network connections.
@@ -8,49 +8,53 @@ package conn
import (
"errors"
- "net"
+ "fmt"
+ "net/netip"
+ "reflect"
+ "runtime"
"strings"
)
+const (
+ IdealBatchSize = 128 // maximum number of packets handled per read and write
+)
+
+// A ReceiveFunc receives at least one packet from the network and writes them
+// into packets. On a successful read it returns the number of elements of
+// sizes, packets, and endpoints that should be evaluated. Some elements of
+// sizes may be zero, and callers should ignore them. Callers must pass a sizes
+// and eps slice with a length greater than or equal to the length of packets.
+// These lengths must not exceed the length of the associated Bind.BatchSize().
+type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
+
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
//
// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
// depending on the platform-specific implementation.
type Bind interface {
- // LastMark reports the last mark set for this Bind.
- LastMark() uint32
+ // Open puts the Bind into a listening state on a given port and reports the actual
+ // port that it bound to. Passing zero results in a random selection.
+ // fns is the set of functions that will be called to receive packets.
+ Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
+
+ // Close closes the Bind listener.
+ // All fns returned by Open must return net.ErrClosed after a call to Close.
+ Close() error
// SetMark sets the mark for each packet sent through this Bind.
// This mark is passed to the kernel as the socket option SO_MARK.
SetMark(mark uint32) error
- // ReceiveIPv6 reads an IPv6 UDP packet into b.
- //
- // It reports the number of bytes read, n,
- // the packet source address ep,
- // and any error.
- ReceiveIPv6(buff []byte) (n int, ep Endpoint, err error)
-
- // ReceiveIPv4 reads an IPv4 UDP packet into b.
- //
- // It reports the number of bytes read, n,
- // the packet source address ep,
- // and any error.
- ReceiveIPv4(b []byte) (n int, ep Endpoint, err error)
-
- // Send writes a packet b to address ep.
- Send(b []byte, ep Endpoint) error
+ // Send writes one or more packets in bufs to address ep. The length of
+ // bufs must not exceed BatchSize().
+ Send(bufs [][]byte, ep Endpoint) error
- // Close closes the Bind connection.
- Close() error
-}
+ // ParseEndpoint creates a new endpoint from a string.
+ ParseEndpoint(s string) (Endpoint, error)
-// CreateBind creates a Bind bound to a port.
-//
-// The value actualPort reports the actual port number the Bind
-// object gets bound to.
-func CreateBind(port uint16) (b Bind, actualPort uint16, err error) {
- return createBind(port)
+ // BatchSize is the number of buffers expected to be passed to
+ // the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
+ BatchSize() int
}
// BindSocketToInterface is implemented by Bind objects that support being
@@ -69,43 +73,61 @@ type PeekLookAtSocketFd interface {
// An Endpoint maintains the source/destination caching for a peer.
//
-// dst : the remote address of a peer ("endpoint" in uapi terminology)
-// src : the local address from which datagrams originate going to the peer
+// dst: the remote address of a peer ("endpoint" in uapi terminology)
+// src: the local address from which datagrams originate going to the peer
type Endpoint interface {
ClearSrc() // clears the source address
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
- DstIP() net.IP
- SrcIP() net.IP
+ DstIP() netip.Addr
+ SrcIP() netip.Addr
}
-func parseEndpoint(s string) (*net.UDPAddr, error) {
- // ensure that the host is an IP address
+var (
+ ErrBindAlreadyOpen = errors.New("bind is already open")
+ ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
+)
- host, _, err := net.SplitHostPort(s)
- if err != nil {
- return nil, err
+func (fn ReceiveFunc) PrettyName() string {
+ name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
+ // 0. cheese/taco.beansIPv6.func12.func21218-fm
+ name = strings.TrimSuffix(name, "-fm")
+ // 1. cheese/taco.beansIPv6.func12.func21218
+ if idx := strings.LastIndexByte(name, '/'); idx != -1 {
+ name = name[idx+1:]
+ // 2. taco.beansIPv6.func12.func21218
}
- if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
- // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
- // trying to make sure with a small sanity test that this is a real IP address and
- // not something that's likely to incur DNS lookups.
- host = host[:i]
+ for {
+ var idx int
+ for idx = len(name) - 1; idx >= 0; idx-- {
+ if name[idx] < '0' || name[idx] > '9' {
+ break
+ }
+ }
+ if idx == len(name)-1 {
+ break
+ }
+ const dotFunc = ".func"
+ if !strings.HasSuffix(name[:idx+1], dotFunc) {
+ break
+ }
+ name = name[:idx+1-len(dotFunc)]
+ // 3. taco.beansIPv6.func12
+ // 4. taco.beansIPv6
}
- if ip := net.ParseIP(host); ip == nil {
- return nil, errors.New("Failed to parse IP address: " + host)
+ if idx := strings.LastIndexByte(name, '.'); idx != -1 {
+ name = name[idx+1:]
+ // 5. beansIPv6
}
-
- // parse address and port
-
- addr, err := net.ResolveUDPAddr("udp", s)
- if err != nil {
- return nil, err
+ if name == "" {
+ return fmt.Sprintf("%p", fn)
+ }
+ if strings.HasSuffix(name, "IPv4") {
+ return "v4"
}
- ip4 := addr.IP.To4()
- if ip4 != nil {
- addr.IP = ip4
+ if strings.HasSuffix(name, "IPv6") {
+ return "v6"
}
- return addr, err
+ return name
}
diff --git a/conn/conn_default.go b/conn/conn_default.go
deleted file mode 100644
index 8be3c9d..0000000
--- a/conn/conn_default.go
+++ /dev/null
@@ -1,176 +0,0 @@
-// +build !linux android
-
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
- */
-
-package conn
-
-import (
- "net"
- "os"
- "syscall"
-)
-
-/* This code is meant to be a temporary solution
- * on platforms for which the sticky socket / source caching behavior
- * has not yet been implemented.
- *
- * See conn_linux.go for an implementation on the linux platform.
- */
-
-type nativeBind struct {
- ipv4 *net.UDPConn
- ipv6 *net.UDPConn
- blackhole4 bool
- blackhole6 bool
-}
-
-type NativeEndpoint net.UDPAddr
-
-var _ Bind = (*nativeBind)(nil)
-var _ Endpoint = (*NativeEndpoint)(nil)
-
-func CreateEndpoint(s string) (Endpoint, error) {
- addr, err := parseEndpoint(s)
- return (*NativeEndpoint)(addr), err
-}
-
-func (_ *NativeEndpoint) ClearSrc() {}
-
-func (e *NativeEndpoint) DstIP() net.IP {
- return (*net.UDPAddr)(e).IP
-}
-
-func (e *NativeEndpoint) SrcIP() net.IP {
- return nil // not supported
-}
-
-func (e *NativeEndpoint) DstToBytes() []byte {
- addr := (*net.UDPAddr)(e)
- out := addr.IP.To4()
- if out == nil {
- out = addr.IP
- }
- out = append(out, byte(addr.Port&0xff))
- out = append(out, byte((addr.Port>>8)&0xff))
- return out
-}
-
-func (e *NativeEndpoint) DstToString() string {
- return (*net.UDPAddr)(e).String()
-}
-
-func (e *NativeEndpoint) SrcToString() string {
- return ""
-}
-
-func listenNet(network string, port int) (*net.UDPConn, int, error) {
- conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
- if err != nil {
- return nil, 0, err
- }
-
- // Retrieve port.
- laddr := conn.LocalAddr()
- uaddr, err := net.ResolveUDPAddr(
- laddr.Network(),
- laddr.String(),
- )
- if err != nil {
- return nil, 0, err
- }
- return conn, uaddr.Port, nil
-}
-
-func extractErrno(err error) error {
- opErr, ok := err.(*net.OpError)
- if !ok {
- return nil
- }
- syscallErr, ok := opErr.Err.(*os.SyscallError)
- if !ok {
- return nil
- }
- return syscallErr.Err
-}
-
-func createBind(uport uint16) (Bind, uint16, error) {
- var err error
- var bind nativeBind
-
- port := int(uport)
-
- bind.ipv4, port, err = listenNet("udp4", port)
- if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT {
- return nil, 0, err
- }
-
- bind.ipv6, port, err = listenNet("udp6", port)
- if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT {
- bind.ipv4.Close()
- bind.ipv4 = nil
- return nil, 0, err
- }
-
- return &bind, uint16(port), nil
-}
-
-func (bind *nativeBind) Close() error {
- var err1, err2 error
- if bind.ipv4 != nil {
- err1 = bind.ipv4.Close()
- }
- if bind.ipv6 != nil {
- err2 = bind.ipv6.Close()
- }
- if err1 != nil {
- return err1
- }
- return err2
-}
-
-func (bind *nativeBind) LastMark() uint32 { return 0 }
-
-func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
- if bind.ipv4 == nil {
- return 0, nil, syscall.EAFNOSUPPORT
- }
- n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
- if endpoint != nil {
- endpoint.IP = endpoint.IP.To4()
- }
- return n, (*NativeEndpoint)(endpoint), err
-}
-
-func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
- if bind.ipv6 == nil {
- return 0, nil, syscall.EAFNOSUPPORT
- }
- n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
- return n, (*NativeEndpoint)(endpoint), err
-}
-
-func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error {
- var err error
- nend := endpoint.(*NativeEndpoint)
- if nend.IP.To4() != nil {
- if bind.ipv4 == nil {
- return syscall.EAFNOSUPPORT
- }
- if bind.blackhole4 {
- return nil
- }
- _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
- } else {
- if bind.ipv6 == nil {
- return syscall.EAFNOSUPPORT
- }
- if bind.blackhole6 {
- return nil
- }
- _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
- }
- return err
-}
diff --git a/conn/conn_linux.go b/conn/conn_linux.go
deleted file mode 100644
index 08c8949..0000000
--- a/conn/conn_linux.go
+++ /dev/null
@@ -1,571 +0,0 @@
-// +build !android
-
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
- */
-
-package conn
-
-import (
- "errors"
- "net"
- "strconv"
- "sync"
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/unix"
-)
-
-const (
- FD_ERR = -1
-)
-
-type IPv4Source struct {
- Src [4]byte
- Ifindex int32
-}
-
-type IPv6Source struct {
- src [16]byte
- //ifindex belongs in dst.ZoneId
-}
-
-type NativeEndpoint struct {
- sync.Mutex
- dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
- src [unsafe.Sizeof(IPv6Source{})]byte
- isV6 bool
-}
-
-func (endpoint *NativeEndpoint) Src4() *IPv4Source { return endpoint.src4() }
-func (endpoint *NativeEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
-func (endpoint *NativeEndpoint) IsV6() bool { return endpoint.isV6 }
-
-func (endpoint *NativeEndpoint) src4() *IPv4Source {
- return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
-}
-
-func (endpoint *NativeEndpoint) src6() *IPv6Source {
- return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0]))
-}
-
-func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 {
- return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
-}
-
-func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
- return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
-}
-
-type nativeBind struct {
- sock4 int
- sock6 int
- lastMark uint32
-}
-
-var _ Endpoint = (*NativeEndpoint)(nil)
-var _ Bind = (*nativeBind)(nil)
-
-func CreateEndpoint(s string) (Endpoint, error) {
- var end NativeEndpoint
- addr, err := parseEndpoint(s)
- if err != nil {
- return nil, err
- }
-
- ipv4 := addr.IP.To4()
- if ipv4 != nil {
- dst := end.dst4()
- end.isV6 = false
- dst.Port = addr.Port
- copy(dst.Addr[:], ipv4)
- end.ClearSrc()
- return &end, nil
- }
-
- ipv6 := addr.IP.To16()
- if ipv6 != nil {
- zone, err := zoneToUint32(addr.Zone)
- if err != nil {
- return nil, err
- }
- dst := end.dst6()
- end.isV6 = true
- dst.Port = addr.Port
- dst.ZoneId = zone
- copy(dst.Addr[:], ipv6[:])
- end.ClearSrc()
- return &end, nil
- }
-
- return nil, errors.New("Invalid IP address")
-}
-
-func createBind(port uint16) (Bind, uint16, error) {
- var err error
- var bind nativeBind
- var newPort uint16
-
- // Attempt ipv6 bind, update port if successful.
- bind.sock6, newPort, err = create6(port)
- if err != nil {
- if err != syscall.EAFNOSUPPORT {
- return nil, 0, err
- }
- } else {
- port = newPort
- }
-
- // Attempt ipv4 bind, update port if successful.
- bind.sock4, newPort, err = create4(port)
- if err != nil {
- if err != syscall.EAFNOSUPPORT {
- unix.Close(bind.sock6)
- return nil, 0, err
- }
- } else {
- port = newPort
- }
-
- if bind.sock4 == FD_ERR && bind.sock6 == FD_ERR {
- return nil, 0, errors.New("ipv4 and ipv6 not supported")
- }
-
- return &bind, port, nil
-}
-
-func (bind *nativeBind) LastMark() uint32 {
- return bind.lastMark
-}
-
-func (bind *nativeBind) SetMark(value uint32) error {
- if bind.sock6 != -1 {
- err := unix.SetsockoptInt(
- bind.sock6,
- unix.SOL_SOCKET,
- unix.SO_MARK,
- int(value),
- )
-
- if err != nil {
- return err
- }
- }
-
- if bind.sock4 != -1 {
- err := unix.SetsockoptInt(
- bind.sock4,
- unix.SOL_SOCKET,
- unix.SO_MARK,
- int(value),
- )
-
- if err != nil {
- return err
- }
- }
-
- bind.lastMark = value
- return nil
-}
-
-func closeUnblock(fd int) error {
- // shutdown to unblock readers and writers
- unix.Shutdown(fd, unix.SHUT_RDWR)
- return unix.Close(fd)
-}
-
-func (bind *nativeBind) Close() error {
- var err1, err2 error
- if bind.sock6 != -1 {
- err1 = closeUnblock(bind.sock6)
- }
- if bind.sock4 != -1 {
- err2 = closeUnblock(bind.sock4)
- }
-
- if err1 != nil {
- return err1
- }
- return err2
-}
-
-func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
- var end NativeEndpoint
- if bind.sock6 == -1 {
- return 0, nil, syscall.EAFNOSUPPORT
- }
- n, err := receive6(
- bind.sock6,
- buff,
- &end,
- )
- return n, &end, err
-}
-
-func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
- var end NativeEndpoint
- if bind.sock4 == -1 {
- return 0, nil, syscall.EAFNOSUPPORT
- }
- n, err := receive4(
- bind.sock4,
- buff,
- &end,
- )
- return n, &end, err
-}
-
-func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
- nend := end.(*NativeEndpoint)
- if !nend.isV6 {
- if bind.sock4 == -1 {
- return syscall.EAFNOSUPPORT
- }
- return send4(bind.sock4, nend, buff)
- } else {
- if bind.sock6 == -1 {
- return syscall.EAFNOSUPPORT
- }
- return send6(bind.sock6, nend, buff)
- }
-}
-
-func (end *NativeEndpoint) SrcIP() net.IP {
- if !end.isV6 {
- return net.IPv4(
- end.src4().Src[0],
- end.src4().Src[1],
- end.src4().Src[2],
- end.src4().Src[3],
- )
- } else {
- return end.src6().src[:]
- }
-}
-
-func (end *NativeEndpoint) DstIP() net.IP {
- if !end.isV6 {
- return net.IPv4(
- end.dst4().Addr[0],
- end.dst4().Addr[1],
- end.dst4().Addr[2],
- end.dst4().Addr[3],
- )
- } else {
- return end.dst6().Addr[:]
- }
-}
-
-func (end *NativeEndpoint) DstToBytes() []byte {
- if !end.isV6 {
- return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
- } else {
- return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
- }
-}
-
-func (end *NativeEndpoint) SrcToString() string {
- return end.SrcIP().String()
-}
-
-func (end *NativeEndpoint) DstToString() string {
- var udpAddr net.UDPAddr
- udpAddr.IP = end.DstIP()
- if !end.isV6 {
- udpAddr.Port = end.dst4().Port
- } else {
- udpAddr.Port = end.dst6().Port
- }
- return udpAddr.String()
-}
-
-func (end *NativeEndpoint) ClearDst() {
- for i := range end.dst {
- end.dst[i] = 0
- }
-}
-
-func (end *NativeEndpoint) ClearSrc() {
- for i := range end.src {
- end.src[i] = 0
- }
-}
-
-func zoneToUint32(zone string) (uint32, error) {
- if zone == "" {
- return 0, nil
- }
- if intr, err := net.InterfaceByName(zone); err == nil {
- return uint32(intr.Index), nil
- }
- n, err := strconv.ParseUint(zone, 10, 32)
- return uint32(n), err
-}
-
-func create4(port uint16) (int, uint16, error) {
-
- // create socket
-
- fd, err := unix.Socket(
- unix.AF_INET,
- unix.SOCK_DGRAM,
- 0,
- )
-
- if err != nil {
- return FD_ERR, 0, err
- }
-
- addr := unix.SockaddrInet4{
- Port: int(port),
- }
-
- // set sockopts and bind
-
- if err := func() error {
- if err := unix.SetsockoptInt(
- fd,
- unix.SOL_SOCKET,
- unix.SO_REUSEADDR,
- 1,
- ); err != nil {
- return err
- }
-
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IP,
- unix.IP_PKTINFO,
- 1,
- ); err != nil {
- return err
- }
-
- return unix.Bind(fd, &addr)
- }(); err != nil {
- unix.Close(fd)
- return FD_ERR, 0, err
- }
-
- sa, err := unix.Getsockname(fd)
- if err == nil {
- addr.Port = sa.(*unix.SockaddrInet4).Port
- }
-
- return fd, uint16(addr.Port), err
-}
-
-func create6(port uint16) (int, uint16, error) {
-
- // create socket
-
- fd, err := unix.Socket(
- unix.AF_INET6,
- unix.SOCK_DGRAM,
- 0,
- )
-
- if err != nil {
- return FD_ERR, 0, err
- }
-
- // set sockopts and bind
-
- addr := unix.SockaddrInet6{
- Port: int(port),
- }
-
- if err := func() error {
-
- if err := unix.SetsockoptInt(
- fd,
- unix.SOL_SOCKET,
- unix.SO_REUSEADDR,
- 1,
- ); err != nil {
- return err
- }
-
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IPV6,
- unix.IPV6_RECVPKTINFO,
- 1,
- ); err != nil {
- return err
- }
-
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IPV6,
- unix.IPV6_V6ONLY,
- 1,
- ); err != nil {
- return err
- }
-
- return unix.Bind(fd, &addr)
-
- }(); err != nil {
- unix.Close(fd)
- return FD_ERR, 0, err
- }
-
- sa, err := unix.Getsockname(fd)
- if err == nil {
- addr.Port = sa.(*unix.SockaddrInet6).Port
- }
-
- return fd, uint16(addr.Port), err
-}
-
-func send4(sock int, end *NativeEndpoint, buff []byte) error {
-
- // construct message header
-
- cmsg := struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet4Pktinfo
- }{
- unix.Cmsghdr{
- Level: unix.IPPROTO_IP,
- Type: unix.IP_PKTINFO,
- Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
- },
- unix.Inet4Pktinfo{
- Spec_dst: end.src4().Src,
- Ifindex: end.src4().Ifindex,
- },
- }
-
- end.Lock()
- _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
- end.Unlock()
-
- if err == nil {
- return nil
- }
-
- // clear src and retry
-
- if err == unix.EINVAL {
- end.ClearSrc()
- cmsg.pktinfo = unix.Inet4Pktinfo{}
- end.Lock()
- _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
- end.Unlock()
- }
-
- return err
-}
-
-func send6(sock int, end *NativeEndpoint, buff []byte) error {
-
- // construct message header
-
- cmsg := struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet6Pktinfo
- }{
- unix.Cmsghdr{
- Level: unix.IPPROTO_IPV6,
- Type: unix.IPV6_PKTINFO,
- Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
- },
- unix.Inet6Pktinfo{
- Addr: end.src6().src,
- Ifindex: end.dst6().ZoneId,
- },
- }
-
- if cmsg.pktinfo.Addr == [16]byte{} {
- cmsg.pktinfo.Ifindex = 0
- }
-
- end.Lock()
- _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
- end.Unlock()
-
- if err == nil {
- return nil
- }
-
- // clear src and retry
-
- if err == unix.EINVAL {
- end.ClearSrc()
- cmsg.pktinfo = unix.Inet6Pktinfo{}
- end.Lock()
- _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
- end.Unlock()
- }
-
- return err
-}
-
-func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
-
- // construct message header
-
- var cmsg struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet4Pktinfo
- }
-
- size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
-
- if err != nil {
- return 0, err
- }
- end.isV6 = false
-
- if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
- *end.dst4() = *newDst4
- }
-
- // update source cache
-
- if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
- cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
- cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
- end.src4().Src = cmsg.pktinfo.Spec_dst
- end.src4().Ifindex = cmsg.pktinfo.Ifindex
- }
-
- return size, nil
-}
-
-func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
-
- // construct message header
-
- var cmsg struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet6Pktinfo
- }
-
- size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
-
- if err != nil {
- return 0, err
- }
- end.isV6 = true
-
- if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
- *end.dst6() = *newDst6
- }
-
- // update source cache
-
- if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
- cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
- cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
- end.src6().src = cmsg.pktinfo.Addr
- end.dst6().ZoneId = cmsg.pktinfo.Ifindex
- }
-
- return size, nil
-}
diff --git a/conn/conn_test.go b/conn/conn_test.go
new file mode 100644
index 0000000..618d02b
--- /dev/null
+++ b/conn/conn_test.go
@@ -0,0 +1,24 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "testing"
+)
+
+func TestPrettyName(t *testing.T) {
+ var (
+ recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return }
+ )
+
+ const want = "TestPrettyName"
+
+ t.Run("ReceiveFunc.PrettyName", func(t *testing.T) {
+ if got := recvFunc.PrettyName(); got != want {
+ t.Errorf("PrettyName() = %v, want %v", got, want)
+ }
+ })
+}
diff --git a/conn/controlfns.go b/conn/controlfns.go
new file mode 100644
index 0000000..27421bd
--- /dev/null
+++ b/conn/controlfns.go
@@ -0,0 +1,43 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "net"
+ "syscall"
+)
+
+// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
+// the max supported by a default configuration of macOS. Some platforms will
+// silently clamp the value to other maximums, such as linux clamping to
+// net.core.{r,w}mem_max (see _linux.go for additional implementation that works
+// around this limitation)
+const socketBufferSize = 7 << 20
+
+// controlFn is the callback function signature from net.ListenConfig.Control.
+// It is used to apply platform specific configuration to the socket prior to
+// bind.
+type controlFn func(network, address string, c syscall.RawConn) error
+
+// controlFns is a list of functions that are called from the listen config
+// that can apply socket options.
+var controlFns = []controlFn{}
+
+// listenConfig returns a net.ListenConfig that applies the controlFns to the
+// socket prior to bind. This is used to apply socket buffer sizing and packet
+// information OOB configuration for sticky sockets.
+func listenConfig() *net.ListenConfig {
+ return &net.ListenConfig{
+ Control: func(network, address string, c syscall.RawConn) error {
+ for _, fn := range controlFns {
+ if err := fn(network, address, c); err != nil {
+ return err
+ }
+ }
+ return nil
+ },
+ }
+}
diff --git a/conn/controlfns_linux.go b/conn/controlfns_linux.go
new file mode 100644
index 0000000..f0deefa
--- /dev/null
+++ b/conn/controlfns_linux.go
@@ -0,0 +1,109 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "fmt"
+ "runtime"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+// Taken from go/src/internal/syscall/unix/kernel_version_linux.go
+func kernelVersion() (major, minor int) {
+ var uname unix.Utsname
+ if err := unix.Uname(&uname); err != nil {
+ return
+ }
+
+ var (
+ values [2]int
+ value, vi int
+ )
+ for _, c := range uname.Release {
+ if '0' <= c && c <= '9' {
+ value = (value * 10) + int(c-'0')
+ } else {
+ // Note that we're assuming N.N.N here.
+ // If we see anything else, we are likely to mis-parse it.
+ values[vi] = value
+ vi++
+ if vi >= len(values) {
+ break
+ }
+ value = 0
+ }
+ }
+
+ return values[0], values[1]
+}
+
+func init() {
+ controlFns = append(controlFns,
+
+ // Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by
+ // using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to
+ // fail silently - the result of failure is lower performance on very fast
+ // links or high latency links.
+ func(network, address string, c syscall.RawConn) error {
+ return c.Control(func(fd uintptr) {
+ // Set up to *mem_max
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
+ // Set beyond *mem_max if CAP_NET_ADMIN
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
+ })
+ },
+
+ // Enable receiving of the packet information (IP_PKTINFO for IPv4,
+ // IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
+ func(network, address string, c syscall.RawConn) error {
+ var err error
+ switch network {
+ case "udp4":
+ if runtime.GOOS != "android" {
+ c.Control(func(fd uintptr) {
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
+ })
+ }
+ case "udp6":
+ c.Control(func(fd uintptr) {
+ if runtime.GOOS != "android" {
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
+ if err != nil {
+ return
+ }
+ }
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
+ })
+ default:
+ err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
+ }
+ return err
+ },
+
+ // Attempt to enable UDP_GRO
+ func(network, address string, c syscall.RawConn) error {
+ // Kernels below 5.12 are missing 98184612aca0 ("net:
+ // udp: Add support for getsockopt(..., ..., UDP_GRO,
+ // ..., ...);"), which means we can't read this back
+ // later. We could pipe the return value through to
+ // the rest of the code, but UDP_GRO is kind of buggy
+ // anyway, so just gate this here.
+ major, minor := kernelVersion()
+ if major < 5 || (major == 5 && minor < 12) {
+ return nil
+ }
+
+ c.Control(func(fd uintptr) {
+ _ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
+ })
+ return nil
+ },
+ )
+}
diff --git a/conn/controlfns_unix.go b/conn/controlfns_unix.go
new file mode 100644
index 0000000..b2e7570
--- /dev/null
+++ b/conn/controlfns_unix.go
@@ -0,0 +1,35 @@
+//go:build !windows && !linux && !wasm
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+func init() {
+ controlFns = append(controlFns,
+ func(network, address string, c syscall.RawConn) error {
+ return c.Control(func(fd uintptr) {
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
+ _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
+ })
+ },
+
+ func(network, address string, c syscall.RawConn) error {
+ var err error
+ if network == "udp6" {
+ c.Control(func(fd uintptr) {
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
+ })
+ }
+ return err
+ },
+ )
+}
diff --git a/conn/controlfns_windows.go b/conn/controlfns_windows.go
new file mode 100644
index 0000000..5e38305
--- /dev/null
+++ b/conn/controlfns_windows.go
@@ -0,0 +1,23 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/windows"
+)
+
+func init() {
+ controlFns = append(controlFns,
+ func(network, address string, c syscall.RawConn) error {
+ return c.Control(func(fd uintptr) {
+ _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF, socketBufferSize)
+ _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF, socketBufferSize)
+ })
+ },
+ )
+}
diff --git a/conn/default.go b/conn/default.go
new file mode 100644
index 0000000..2ce1579
--- /dev/null
+++ b/conn/default.go
@@ -0,0 +1,10 @@
+//go:build !windows
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+func NewDefaultBind() Bind { return NewStdNetBind() }
diff --git a/conn/errors_default.go b/conn/errors_default.go
new file mode 100644
index 0000000..3c9b223
--- /dev/null
+++ b/conn/errors_default.go
@@ -0,0 +1,12 @@
+//go:build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+func errShouldDisableUDPGSO(_ error) bool {
+ return false
+}
diff --git a/conn/errors_linux.go b/conn/errors_linux.go
new file mode 100644
index 0000000..037d820
--- /dev/null
+++ b/conn/errors_linux.go
@@ -0,0 +1,26 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "errors"
+ "os"
+
+ "golang.org/x/sys/unix"
+)
+
+func errShouldDisableUDPGSO(err error) bool {
+ var serr *os.SyscallError
+ if errors.As(err, &serr) {
+ // EIO is returned by udp_send_skb() if the device driver does not have
+ // tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
+ // See:
+ // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
+ // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
+ return serr.Err == unix.EIO
+ }
+ return false
+}
diff --git a/conn/features_default.go b/conn/features_default.go
new file mode 100644
index 0000000..9fc5088
--- /dev/null
+++ b/conn/features_default.go
@@ -0,0 +1,15 @@
+//go:build !linux
+// +build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import "net"
+
+func supportsUDPOffload(_ *net.UDPConn) (txOffload, rxOffload bool) {
+ return
+}
diff --git a/conn/features_linux.go b/conn/features_linux.go
new file mode 100644
index 0000000..6386023
--- /dev/null
+++ b/conn/features_linux.go
@@ -0,0 +1,29 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "net"
+
+ "golang.org/x/sys/unix"
+)
+
+func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
+ rc, err := conn.SyscallConn()
+ if err != nil {
+ return
+ }
+ err = rc.Control(func(fd uintptr) {
+ _, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
+ txOffload = errSyscall == nil
+ opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO)
+ rxOffload = errSyscall == nil && opt == 1
+ })
+ if err != nil {
+ return false, false
+ }
+ return txOffload, rxOffload
+}
diff --git a/conn/gso_default.go b/conn/gso_default.go
new file mode 100644
index 0000000..a9a3e80
--- /dev/null
+++ b/conn/gso_default.go
@@ -0,0 +1,21 @@
+//go:build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
+func getGSOSize(control []byte) (int, error) {
+ return 0, nil
+}
+
+// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
+func setGSOSize(control *[]byte, gsoSize uint16) {
+}
+
+// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
+// offloading control data.
+const gsoControlSize = 0
diff --git a/conn/gso_linux.go b/conn/gso_linux.go
new file mode 100644
index 0000000..4ee31fa
--- /dev/null
+++ b/conn/gso_linux.go
@@ -0,0 +1,65 @@
+//go:build linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "fmt"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+const (
+ sizeOfGSOData = 2
+)
+
+// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
+func getGSOSize(control []byte) (int, error) {
+ var (
+ hdr unix.Cmsghdr
+ data []byte
+ rem = control
+ err error
+ )
+
+ for len(rem) > unix.SizeofCmsghdr {
+ hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
+ if err != nil {
+ return 0, fmt.Errorf("error parsing socket control message: %w", err)
+ }
+ if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
+ var gso uint16
+ copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
+ return int(gso), nil
+ }
+ }
+ return 0, nil
+}
+
+// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
+// data in control untouched.
+func setGSOSize(control *[]byte, gsoSize uint16) {
+ existingLen := len(*control)
+ avail := cap(*control) - existingLen
+ space := unix.CmsgSpace(sizeOfGSOData)
+ if avail < space {
+ return
+ }
+ *control = (*control)[:cap(*control)]
+ gsoControl := (*control)[existingLen:]
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
+ hdr.Level = unix.SOL_UDP
+ hdr.Type = unix.UDP_SEGMENT
+ hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
+ copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
+ *control = (*control)[:existingLen+space]
+}
+
+// gsoControlSize returns the recommended buffer size for pooling UDP
+// offloading control data.
+var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)
diff --git a/conn/mark_default.go b/conn/mark_default.go
index f57215a..72b266e 100644
--- a/conn/mark_default.go
+++ b/conn/mark_default.go
@@ -1,12 +1,12 @@
-// +build !linux,!openbsd,!freebsd
+//go:build !linux && !openbsd && !freebsd
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn
-func (bind *nativeBind) SetMark(mark uint32) error {
+func (s *StdNetBind) SetMark(mark uint32) error {
return nil
}
diff --git a/conn/mark_unix.go b/conn/mark_unix.go
index 19ec2af..d0580d5 100644
--- a/conn/mark_unix.go
+++ b/conn/mark_unix.go
@@ -1,8 +1,8 @@
-// +build android openbsd freebsd
+//go:build linux || openbsd || freebsd
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package conn
@@ -26,13 +26,13 @@ func init() {
}
}
-func (bind *nativeBind) SetMark(mark uint32) error {
+func (s *StdNetBind) SetMark(mark uint32) error {
var operr error
if fwmarkIoctl == 0 {
return nil
}
- if bind.ipv4 != nil {
- fd, err := bind.ipv4.SyscallConn()
+ if s.ipv4 != nil {
+ fd, err := s.ipv4.SyscallConn()
if err != nil {
return err
}
@@ -46,8 +46,8 @@ func (bind *nativeBind) SetMark(mark uint32) error {
return err
}
}
- if bind.ipv6 != nil {
- fd, err := bind.ipv6.SyscallConn()
+ if s.ipv6 != nil {
+ fd, err := s.ipv6.SyscallConn()
if err != nil {
return err
}
diff --git a/conn/sticky_default.go b/conn/sticky_default.go
new file mode 100644
index 0000000..15b65af
--- /dev/null
+++ b/conn/sticky_default.go
@@ -0,0 +1,42 @@
+//go:build !linux || android
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import "net/netip"
+
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+ return netip.Addr{}
+}
+
+func (e *StdNetEndpoint) SrcIfidx() int32 {
+ return 0
+}
+
+func (e *StdNetEndpoint) SrcToString() string {
+ return ""
+}
+
+// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
+// {get,set}srcControl feature set, but use alternatively named flags and need
+// ports and require testing.
+
+// getSrcFromControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
+}
+
+// setSrcControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
+}
+
+// stickyControlSize returns the recommended buffer size for pooling sticky
+// offloading control data.
+const stickyControlSize = 0
+
+const StdNetSupportsStickySockets = false
diff --git a/conn/sticky_linux.go b/conn/sticky_linux.go
new file mode 100644
index 0000000..adfedc1
--- /dev/null
+++ b/conn/sticky_linux.go
@@ -0,0 +1,112 @@
+//go:build linux && !android
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "net/netip"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+ switch len(e.src) {
+ case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+ return netip.AddrFrom4(info.Spec_dst)
+ case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+ // TODO: set zone. in order to do so we need to check if the address is
+ // link local, and if it is perform a syscall to turn the ifindex into a
+ // zone string because netip uses string zones.
+ return netip.AddrFrom16(info.Addr)
+ }
+ return netip.Addr{}
+}
+
+func (e *StdNetEndpoint) SrcIfidx() int32 {
+ switch len(e.src) {
+ case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+ return info.Ifindex
+ case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+ return int32(info.Ifindex)
+ }
+ return 0
+}
+
+func (e *StdNetEndpoint) SrcToString() string {
+ return e.SrcIP().String()
+}
+
+// getSrcFromControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
+ ep.ClearSrc()
+
+ var (
+ hdr unix.Cmsghdr
+ data []byte
+ rem []byte = control
+ err error
+ )
+
+ for len(rem) > unix.SizeofCmsghdr {
+ hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
+ if err != nil {
+ return
+ }
+
+ if hdr.Level == unix.IPPROTO_IP &&
+ hdr.Type == unix.IP_PKTINFO {
+
+ if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
+ ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+ }
+ ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
+
+ hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
+ copy(ep.src, hdrBuf)
+ copy(ep.src[unix.CmsgLen(0):], data)
+ return
+ }
+
+ if hdr.Level == unix.IPPROTO_IPV6 &&
+ hdr.Type == unix.IPV6_PKTINFO {
+
+ if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
+ ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
+ }
+
+ ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
+
+ hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
+ copy(ep.src, hdrBuf)
+ copy(ep.src[unix.CmsgLen(0):], data)
+ return
+ }
+ }
+}
+
+// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
+// and source ifindex found in ep. control's len will be set to 0 in the event
+// that ep is a default value.
+func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
+ if cap(*control) < len(ep.src) {
+ return
+ }
+ *control = (*control)[:0]
+ *control = append(*control, ep.src...)
+}
+
+// stickyControlSize returns the recommended buffer size for pooling sticky
+// offloading control data.
+var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
+
+const StdNetSupportsStickySockets = true
diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go
new file mode 100644
index 0000000..1b1ee68
--- /dev/null
+++ b/conn/sticky_linux_test.go
@@ -0,0 +1,266 @@
+//go:build linux && !android
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "context"
+ "net"
+ "net/netip"
+ "runtime"
+ "testing"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) {
+ var buf []byte
+ if addr.Is4() {
+ buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+ hdr := unix.Cmsghdr{
+ Level: unix.IPPROTO_IP,
+ Type: unix.IP_PKTINFO,
+ }
+ hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
+ copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
+
+ info := unix.Inet4Pktinfo{
+ Ifindex: ifidx,
+ Spec_dst: addr.As4(),
+ }
+ copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo))
+ } else {
+ buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
+ hdr := unix.Cmsghdr{
+ Level: unix.IPPROTO_IPV6,
+ Type: unix.IPV6_PKTINFO,
+ }
+ hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
+ copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
+
+ info := unix.Inet6Pktinfo{
+ Ifindex: uint32(ifidx),
+ Addr: addr.As16(),
+ }
+ copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo))
+ }
+
+ ep.src = buf
+}
+
+func Test_setSrcControl(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ ep := &StdNetEndpoint{
+ AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
+ }
+ setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
+
+ control := make([]byte, stickyControlSize)
+
+ setSrcControl(&control, ep)
+
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ if hdr.Level != unix.IPPROTO_IP {
+ t.Errorf("unexpected level: %d", hdr.Level)
+ }
+ if hdr.Type != unix.IP_PKTINFO {
+ t.Errorf("unexpected type: %d", hdr.Type)
+ }
+ if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
+ t.Errorf("unexpected length: %d", hdr.Len)
+ }
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
+ t.Errorf("unexpected address: %v", info.Spec_dst)
+ }
+ if info.Ifindex != 5 {
+ t.Errorf("unexpected ifindex: %d", info.Ifindex)
+ }
+ })
+
+ t.Run("IPv6", func(t *testing.T) {
+ ep := &StdNetEndpoint{
+ AddrPort: netip.MustParseAddrPort("[::1]:1234"),
+ }
+ setSrc(ep, netip.MustParseAddr("::1"), 5)
+
+ control := make([]byte, stickyControlSize)
+
+ setSrcControl(&control, ep)
+
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ if hdr.Level != unix.IPPROTO_IPV6 {
+ t.Errorf("unexpected level: %d", hdr.Level)
+ }
+ if hdr.Type != unix.IPV6_PKTINFO {
+ t.Errorf("unexpected type: %d", hdr.Type)
+ }
+ if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
+ t.Errorf("unexpected length: %d", hdr.Len)
+ }
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ if info.Addr != ep.SrcIP().As16() {
+ t.Errorf("unexpected address: %v", info.Addr)
+ }
+ if info.Ifindex != 5 {
+ t.Errorf("unexpected ifindex: %d", info.Ifindex)
+ }
+ })
+
+ t.Run("ClearOnNoSrc", func(t *testing.T) {
+ control := make([]byte, stickyControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = 1
+ hdr.Type = 2
+ hdr.Len = 3
+
+ setSrcControl(&control, &StdNetEndpoint{})
+
+ if len(control) != 0 {
+ t.Errorf("unexpected control: %v", control)
+ }
+ })
+}
+
+func Test_getSrcFromControl(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ control := make([]byte, stickyControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = unix.IPPROTO_IP
+ hdr.Type = unix.IP_PKTINFO
+ hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ info.Spec_dst = [4]byte{127, 0, 0, 1}
+ info.Ifindex = 5
+
+ ep := &StdNetEndpoint{}
+ getSrcFromControl(control, ep)
+
+ if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.SrcIfidx() != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
+ }
+ })
+ t.Run("IPv6", func(t *testing.T) {
+ control := make([]byte, stickyControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = unix.IPPROTO_IPV6
+ hdr.Type = unix.IPV6_PKTINFO
+ hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
+ info.Ifindex = 5
+
+ ep := &StdNetEndpoint{}
+ getSrcFromControl(control, ep)
+
+ if ep.SrcIP() != netip.MustParseAddr("::1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.SrcIfidx() != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
+ }
+ })
+ t.Run("ClearOnEmpty", func(t *testing.T) {
+ var control []byte
+ ep := &StdNetEndpoint{}
+ setSrc(ep, netip.MustParseAddr("::1"), 5)
+
+ getSrcFromControl(control, ep)
+ if ep.SrcIP().IsValid() {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.SrcIfidx() != 0 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
+ }
+ })
+ t.Run("Multiple", func(t *testing.T) {
+ zeroControl := make([]byte, unix.CmsgSpace(0))
+ zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0]))
+ zeroHdr.SetLen(unix.CmsgLen(0))
+
+ control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = unix.IPPROTO_IP
+ hdr.Type = unix.IP_PKTINFO
+ hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ info.Spec_dst = [4]byte{127, 0, 0, 1}
+ info.Ifindex = 5
+
+ combined := make([]byte, 0)
+ combined = append(combined, zeroControl...)
+ combined = append(combined, control...)
+
+ ep := &StdNetEndpoint{}
+ getSrcFromControl(combined, ep)
+
+ if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.SrcIfidx() != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
+ }
+ })
+}
+
+func Test_listenConfig(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ sc, err := conn.(*net.UDPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if runtime.GOOS == "linux" {
+ var i int
+ sc.Control(func(fd uintptr) {
+ i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if i != 1 {
+ t.Error("IP_PKTINFO not set!")
+ }
+ } else {
+ t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
+ }
+ })
+ t.Run("IPv6", func(t *testing.T) {
+ conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ sc, err := conn.(*net.UDPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if runtime.GOOS == "linux" {
+ var i int
+ sc.Control(func(fd uintptr) {
+ i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if i != 1 {
+ t.Error("IPV6_PKTINFO not set!")
+ }
+ } else {
+ t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
+ }
+ })
+}
diff --git a/conn/winrio/rio_windows.go b/conn/winrio/rio_windows.go
new file mode 100644
index 0000000..c396658
--- /dev/null
+++ b/conn/winrio/rio_windows.go
@@ -0,0 +1,254 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
+ */
+
+package winrio
+
+import (
+ "log"
+ "sync"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+const (
+ MsgDontNotify = 1
+ MsgDefer = 2
+ MsgWaitAll = 4
+ MsgCommitOnly = 8
+
+ MaxCqSize = 0x8000000
+
+ invalidBufferId = 0xFFFFFFFF
+ invalidCq = 0
+ invalidRq = 0
+ corruptCq = 0xFFFFFFFF
+)
+
+var extensionFunctionTable struct {
+ cbSize uint32
+ rioReceive uintptr
+ rioReceiveEx uintptr
+ rioSend uintptr
+ rioSendEx uintptr
+ rioCloseCompletionQueue uintptr
+ rioCreateCompletionQueue uintptr
+ rioCreateRequestQueue uintptr
+ rioDequeueCompletion uintptr
+ rioDeregisterBuffer uintptr
+ rioNotify uintptr
+ rioRegisterBuffer uintptr
+ rioResizeCompletionQueue uintptr
+ rioResizeRequestQueue uintptr
+}
+
+type Cq uintptr
+
+type Rq uintptr
+
+type BufferId uintptr
+
+type Buffer struct {
+ Id BufferId
+ Offset uint32
+ Length uint32
+}
+
+type Result struct {
+ Status int32
+ BytesTransferred uint32
+ SocketContext uint64
+ RequestContext uint64
+}
+
+type notificationCompletionType uint32
+
+const (
+ eventCompletion notificationCompletionType = 1
+ iocpCompletion notificationCompletionType = 2
+)
+
+type eventNotificationCompletion struct {
+ completionType notificationCompletionType
+ event windows.Handle
+ notifyReset uint32
+}
+
+type iocpNotificationCompletion struct {
+ completionType notificationCompletionType
+ iocp windows.Handle
+ key uintptr
+ overlapped *windows.Overlapped
+}
+
+var (
+ initialized sync.Once
+ available bool
+)
+
+func Initialize() bool {
+ initialized.Do(func() {
+ var (
+ err error
+ socket windows.Handle
+ cq Cq
+ )
+ defer func() {
+ if err == nil {
+ return
+ }
+ if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 {
+ return
+ }
+ log.Printf("Registered I/O is unavailable: %v", err)
+ }()
+ socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
+ if err != nil {
+ return
+ }
+ defer windows.CloseHandle(socket)
+ WSAID_MULTIPLE_RIO := &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}}
+ const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024
+ ob := uint32(0)
+ err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,
+ (*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)),
+ (*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)),
+ &ob, nil, 0)
+ if err != nil {
+ return
+ }
+
+ // While we should be able to stop here, after getting the function pointers, some anti-virus actually causes
+ // failures in RIOCreateRequestQueue, so keep going to be certain this is supported.
+ var iocp windows.Handle
+ iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
+ if err != nil {
+ return
+ }
+ defer windows.CloseHandle(iocp)
+ var overlapped windows.Overlapped
+ cq, err = CreateIOCPCompletionQueue(2, iocp, 0, &overlapped)
+ if err != nil {
+ return
+ }
+ defer CloseCompletionQueue(cq)
+ _, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0)
+ if err != nil {
+ return
+ }
+ available = true
+ })
+ return available
+}
+
+func Socket(af, typ, proto int32) (windows.Handle, error) {
+ return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO)
+}
+
+func CloseCompletionQueue(cq Cq) {
+ _, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0)
+}
+
+func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) {
+ notificationCompletion := &eventNotificationCompletion{
+ completionType: eventCompletion,
+ event: event,
+ }
+ if notifyReset {
+ notificationCompletion.notifyReset = 1
+ }
+ ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
+ if ret == invalidCq {
+ return 0, err
+ }
+ return Cq(ret), nil
+}
+
+func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) {
+ notificationCompletion := &iocpNotificationCompletion{
+ completionType: iocpCompletion,
+ iocp: iocp,
+ key: key,
+ overlapped: overlapped,
+ }
+ ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
+ if ret == invalidCq {
+ return 0, err
+ }
+ return Cq(ret), nil
+}
+
+func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) {
+ ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0)
+ if ret == invalidCq {
+ return 0, err
+ }
+ return Cq(ret), nil
+}
+
+func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) {
+ ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0)
+ if ret == invalidRq {
+ return 0, err
+ }
+ return Rq(ret), nil
+}
+
+func DequeueCompletion(cq Cq, results []Result) uint32 {
+ var array uintptr
+ if len(results) > 0 {
+ array = uintptr(unsafe.Pointer(&results[0]))
+ }
+ ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results)))
+ if ret == corruptCq {
+ panic("cq is corrupt")
+ }
+ return uint32(ret)
+}
+
+func DeregisterBuffer(id BufferId) {
+ _, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0)
+}
+
+func RegisterBuffer(buffer []byte) (BufferId, error) {
+ var buf unsafe.Pointer
+ if len(buffer) > 0 {
+ buf = unsafe.Pointer(&buffer[0])
+ }
+ return RegisterPointer(buf, uint32(len(buffer)))
+}
+
+func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) {
+ ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0)
+ if ret == invalidBufferId {
+ return 0, err
+ }
+ return BufferId(ret), nil
+}
+
+func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
+ ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
+ if ret == 0 {
+ return err
+ }
+ return nil
+}
+
+func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
+ ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
+ if ret == 0 {
+ return err
+ }
+ return nil
+}
+
+func Notify(cq Cq) error {
+ ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0)
+ if ret != 0 {
+ return windows.Errno(ret)
+ }
+ return nil
+}