summaryrefslogtreecommitdiffstats
path: root/conn
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-02-22 18:47:41 +0100
committerJason A. Donenfeld <Jason@zx2c4.com>2021-02-25 15:08:08 +0100
commit3c11c0308e4e9fae76e1531f4f49a39f1ae24253 (patch)
tree6fbc8fba7161b0835e9e1bfba6d68bc82ad78301 /conn
parentglobal: remove TODO name graffiti (diff)
downloadwireguard-go-3c11c0308e4e9fae76e1531f4f49a39f1ae24253.tar.xz
wireguard-go-3c11c0308e4e9fae76e1531f4f49a39f1ae24253.zip
conn: implement RIO for fast Windows UDP sockets
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'conn')
-rw-r--r--conn/bind_std.go2
-rw-r--r--conn/bind_windows.go581
-rw-r--r--conn/boundif_windows.go59
-rw-r--r--conn/default.go2
-rw-r--r--conn/winrio/rio_windows.go243
5 files changed, 827 insertions, 60 deletions
diff --git a/conn/bind_std.go b/conn/bind_std.go
index 193c4fe..28d1464 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -128,6 +128,8 @@ func (bind *StdNetBind) Close() error {
err2 = bind.ipv6.Close()
bind.ipv6 = nil
}
+ bind.blackhole4 = false
+ bind.blackhole6 = false
if err1 != nil {
return err1
}
diff --git a/conn/bind_windows.go b/conn/bind_windows.go
new file mode 100644
index 0000000..1e2712e
--- /dev/null
+++ b/conn/bind_windows.go
@@ -0,0 +1,581 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "encoding/binary"
+ "io"
+ "net"
+ "strconv"
+ "sync"
+ "sync/atomic"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+
+ "golang.zx2c4.com/wireguard/conn/winrio"
+)
+
+const (
+ packetsPerRing = 1024
+ bytesPerPacket = 2048 - 32
+ receiveSpins = 15
+)
+
+type ringPacket struct {
+ addr WinRingEndpoint
+ data [bytesPerPacket]byte
+}
+
+type ringBuffer struct {
+ packets uintptr
+ head, tail uint32
+ id winrio.BufferId
+ iocp windows.Handle
+ isFull bool
+ cq winrio.Cq
+ mu sync.Mutex
+ overlapped windows.Overlapped
+}
+
+func (rb *ringBuffer) Push() *ringPacket {
+ for rb.isFull {
+ panic("ring is full")
+ }
+ ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
+ rb.tail += 1
+ if rb.tail == rb.head {
+ rb.isFull = true
+ }
+ return ret
+}
+
+func (rb *ringBuffer) Return(count uint32) {
+ if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull {
+ return
+ }
+ rb.head += count
+ rb.isFull = false
+}
+
+type afWinRingBind struct {
+ sock windows.Handle
+ rx, tx ringBuffer
+ rq winrio.Rq
+ mu sync.Mutex
+ blackhole bool
+}
+
+// WinRingBind uses Windows registered I/O for fast ring buffered networking.
+type WinRingBind struct {
+ v4, v6 afWinRingBind
+ mu sync.RWMutex
+ isOpen uint32
+}
+
+func NewDefaultBind() Bind { return NewWinRingBind() }
+
+func NewWinRingBind() Bind {
+ if !winrio.Initialize() {
+ return NewStdNetBind()
+ }
+ return new(WinRingBind)
+}
+
+type WinRingEndpoint struct {
+ family uint16
+ data [30]byte
+}
+
+var _ Bind = (*WinRingBind)(nil)
+var _ Endpoint = (*WinRingEndpoint)(nil)
+
+func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
+ host, port, err := net.SplitHostPort(s)
+ if err != nil {
+ return nil, err
+ }
+ host16, err := windows.UTF16PtrFromString(host)
+ if err != nil {
+ return nil, err
+ }
+ port16, err := windows.UTF16PtrFromString(port)
+ if err != nil {
+ return nil, err
+ }
+ hints := windows.AddrinfoW{
+ Flags: windows.AI_NUMERICHOST,
+ Family: windows.AF_UNSPEC,
+ Socktype: windows.SOCK_DGRAM,
+ Protocol: windows.IPPROTO_UDP,
+ }
+ var addrinfo *windows.AddrinfoW
+ err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo)
+ if err != nil {
+ return nil, err
+ }
+ defer windows.FreeAddrInfoW(addrinfo)
+ if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
+ return nil, windows.ERROR_INVALID_ADDRESS
+ }
+ var src []byte
+ var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
+ unsafeSlice(unsafe.Pointer(&src), unsafe.Pointer(addrinfo.Addr), int(addrinfo.Addrlen))
+ copy(dst[:], src)
+ return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
+}
+
+func (*WinRingEndpoint) ClearSrc() {}
+
+func (e *WinRingEndpoint) DstIP() net.IP {
+ switch e.family {
+ case windows.AF_INET:
+ return append([]byte{}, e.data[2:6]...)
+ case windows.AF_INET6:
+ return append([]byte{}, e.data[6:22]...)
+ }
+ return nil
+}
+
+func (e *WinRingEndpoint) SrcIP() net.IP {
+ return nil // not supported
+}
+
+func (e *WinRingEndpoint) DstToBytes() []byte {
+ switch e.family {
+ case windows.AF_INET:
+ b := make([]byte, 0, 6)
+ b = append(b, e.data[2:6]...)
+ b = append(b, e.data[1], e.data[0])
+ return b
+ case windows.AF_INET6:
+ b := make([]byte, 0, 18)
+ b = append(b, e.data[6:22]...)
+ b = append(b, e.data[1], e.data[0])
+ return b
+ }
+ return nil
+}
+
+func (e *WinRingEndpoint) DstToString() string {
+ switch e.family {
+ case windows.AF_INET:
+ addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
+ return addr.String()
+ case windows.AF_INET6:
+ var zone string
+ if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
+ zone = strconv.FormatUint(uint64(scope), 10)
+ }
+ addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
+ return addr.String()
+ }
+ return ""
+}
+
+func (e *WinRingEndpoint) SrcToString() string {
+ return ""
+}
+
+func (ring *ringBuffer) CloseAndZero() {
+ if ring.cq != 0 {
+ winrio.CloseCompletionQueue(ring.cq)
+ ring.cq = 0
+ }
+ if ring.iocp != 0 {
+ windows.CloseHandle(ring.iocp)
+ ring.iocp = 0
+ }
+ if ring.id != 0 {
+ winrio.DeregisterBuffer(ring.id)
+ ring.id = 0
+ }
+ if ring.packets != 0 {
+ windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
+ ring.packets = 0
+ }
+}
+
+func (bind *afWinRingBind) CloseAndZero() {
+ bind.rx.CloseAndZero()
+ bind.tx.CloseAndZero()
+ if bind.sock != 0 {
+ windows.CloseHandle(bind.sock)
+ bind.sock = 0
+ }
+ bind.blackhole = false
+}
+
+func (bind *WinRingBind) closeAndZero() {
+ atomic.StoreUint32(&bind.isOpen, 0)
+ bind.v4.CloseAndZero()
+ bind.v6.CloseAndZero()
+}
+
+func (ring *ringBuffer) Open() error {
+ var err error
+ packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
+ ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
+ if err != nil {
+ return err
+ }
+ ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
+ if err != nil {
+ return err
+ }
+ ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
+ if err != nil {
+ return err
+ }
+ ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) {
+ var err error
+ bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
+ if err != nil {
+ return nil, err
+ }
+ err = bind.rx.Open()
+ if err != nil {
+ return nil, err
+ }
+ err = bind.tx.Open()
+ if err != nil {
+ return nil, err
+ }
+ bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0)
+ if err != nil {
+ return nil, err
+ }
+ err = windows.Bind(bind.sock, sa)
+ if err != nil {
+ return nil, err
+ }
+ sa, err = windows.Getsockname(bind.sock)
+ if err != nil {
+ return nil, err
+ }
+ return sa, nil
+}
+
+func (bind *WinRingBind) Open(port uint16) (selectedPort uint16, err error) {
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+ defer func() {
+ if err != nil {
+ bind.closeAndZero()
+ }
+ }()
+ if atomic.LoadUint32(&bind.isOpen) != 0 {
+ return 0, ErrBindAlreadyOpen
+ }
+ var sa windows.Sockaddr
+ sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
+ if err != nil {
+ return 0, err
+ }
+ sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
+ if err != nil {
+ return 0, err
+ }
+ selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
+ for i := 0; i < packetsPerRing; i++ {
+ err = bind.v4.InsertReceiveRequest()
+ if err != nil {
+ return 0, err
+ }
+ err = bind.v6.InsertReceiveRequest()
+ if err != nil {
+ return 0, err
+ }
+ }
+ atomic.StoreUint32(&bind.isOpen, 1)
+ return
+}
+
+func (bind *WinRingBind) Close() error {
+ bind.mu.RLock()
+ if atomic.LoadUint32(&bind.isOpen) != 1 {
+ bind.mu.RUnlock()
+ return nil
+ }
+ atomic.StoreUint32(&bind.isOpen, 2)
+ windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
+ windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
+ windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
+ windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil)
+ bind.mu.RUnlock()
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+ bind.closeAndZero()
+ return nil
+}
+
+func (bind *WinRingBind) SetMark(mark uint32) error {
+ return nil
+}
+
+func (bind *afWinRingBind) InsertReceiveRequest() error {
+ packet := bind.rx.Push()
+ dataBuffer := &winrio.Buffer{
+ Id: bind.rx.id,
+ Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets),
+ Length: uint32(len(packet.data)),
+ }
+ addressBuffer := &winrio.Buffer{
+ Id: bind.rx.id,
+ Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets),
+ Length: uint32(unsafe.Sizeof(packet.addr)),
+ }
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+ return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
+}
+
+//go:linkname procyield runtime.procyield
+func procyield(cycles uint32)
+
+func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) {
+ if atomic.LoadUint32(isOpen) != 1 {
+ return 0, nil, net.ErrClosed
+ }
+ bind.rx.mu.Lock()
+ defer bind.rx.mu.Unlock()
+ var count uint32
+ var results [1]winrio.Result
+ for tries := 0; count == 0 && tries < receiveSpins; tries++ {
+ if tries > 0 {
+ if atomic.LoadUint32(isOpen) != 1 {
+ return 0, nil, net.ErrClosed
+ }
+ procyield(1)
+ }
+ count = winrio.DequeueCompletion(bind.rx.cq, results[:])
+ }
+ if count == 0 {
+ err := winrio.Notify(bind.rx.cq)
+ if err != nil {
+ return 0, nil, err
+ }
+ var bytes uint32
+ var key uintptr
+ var overlapped *windows.Overlapped
+ err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
+ if err != nil {
+ return 0, nil, err
+ }
+ if atomic.LoadUint32(isOpen) != 1 {
+ return 0, nil, net.ErrClosed
+ }
+ count = winrio.DequeueCompletion(bind.rx.cq, results[:])
+ if count == 0 {
+ return 0, nil, io.ErrNoProgress
+
+ }
+ }
+ bind.rx.Return(1)
+ err := bind.InsertReceiveRequest()
+ if err != nil {
+ return 0, nil, err
+ }
+ if results[0].Status != 0 {
+ return 0, nil, windows.Errno(results[0].Status)
+ }
+ packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
+ ep := packet.addr
+ n := copy(buf, packet.data[:results[0].BytesTransferred])
+ return n, &ep, nil
+}
+
+func (bind *WinRingBind) ReceiveIPv4(buf []byte) (int, Endpoint, error) {
+ bind.mu.RLock()
+ defer bind.mu.RUnlock()
+ return bind.v4.Receive(buf, &bind.isOpen)
+}
+
+func (bind *WinRingBind) ReceiveIPv6(buf []byte) (int, Endpoint, error) {
+ bind.mu.RLock()
+ defer bind.mu.RUnlock()
+ return bind.v6.Receive(buf, &bind.isOpen)
+}
+
+func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error {
+ if atomic.LoadUint32(isOpen) != 1 {
+ return net.ErrClosed
+ }
+ if len(buf) > bytesPerPacket {
+ return io.ErrShortBuffer
+ }
+ bind.tx.mu.Lock()
+ defer bind.tx.mu.Unlock()
+ var results [packetsPerRing]winrio.Result
+ count := winrio.DequeueCompletion(bind.tx.cq, results[:])
+ if count == 0 && bind.tx.isFull {
+ err := winrio.Notify(bind.tx.cq)
+ if err != nil {
+ return err
+ }
+ var bytes uint32
+ var key uintptr
+ var overlapped *windows.Overlapped
+ err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
+ if err != nil {
+ return err
+ }
+ if atomic.LoadUint32(isOpen) != 1 {
+ return net.ErrClosed
+ }
+ count = winrio.DequeueCompletion(bind.tx.cq, results[:])
+ if count == 0 {
+ return io.ErrNoProgress
+ }
+ }
+ if count > 0 {
+ bind.tx.Return(count)
+ }
+ packet := bind.tx.Push()
+ packet.addr = *nend
+ copy(packet.data[:], buf)
+ dataBuffer := &winrio.Buffer{
+ Id: bind.tx.id,
+ Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets),
+ Length: uint32(len(buf)),
+ }
+ addressBuffer := &winrio.Buffer{
+ Id: bind.tx.id,
+ Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets),
+ Length: uint32(unsafe.Sizeof(packet.addr)),
+ }
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+ return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
+}
+
+func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error {
+ nend, ok := endpoint.(*WinRingEndpoint)
+ if !ok {
+ return ErrWrongEndpointType
+ }
+ bind.mu.RLock()
+ defer bind.mu.RUnlock()
+ switch nend.family {
+ case windows.AF_INET:
+ if bind.v4.blackhole {
+ return nil
+ }
+ return bind.v4.Send(buf, nend, &bind.isOpen)
+ case windows.AF_INET6:
+ if bind.v6.blackhole {
+ return nil
+ }
+ return bind.v6.Send(buf, nend, &bind.isOpen)
+ }
+ return nil
+}
+
+func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
+ sysconn, err := bind.ipv4.SyscallConn()
+ if err != nil {
+ return err
+ }
+ err2 := sysconn.Control(func(fd uintptr) {
+ err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex)
+ })
+ if err2 != nil {
+ return err2
+ }
+ if err != nil {
+ return err
+ }
+ bind.blackhole4 = blackhole
+ return nil
+}
+
+func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
+ sysconn, err := bind.ipv6.SyscallConn()
+ if err != nil {
+ return err
+ }
+ err2 := sysconn.Control(func(fd uintptr) {
+ err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex)
+ })
+ if err2 != nil {
+ return err2
+ }
+ if err != nil {
+ return err
+ }
+ bind.blackhole6 = blackhole
+ return nil
+}
+func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
+ bind.mu.RLock()
+ defer bind.mu.RUnlock()
+ if atomic.LoadUint32(&bind.isOpen) != 1 {
+ return net.ErrClosed
+ }
+ err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
+ if err != nil {
+ return err
+ }
+ bind.v4.blackhole = blackhole
+ return nil
+}
+
+func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
+ bind.mu.RLock()
+ defer bind.mu.RUnlock()
+ if atomic.LoadUint32(&bind.isOpen) != 1 {
+ return net.ErrClosed
+ }
+ err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
+ if err != nil {
+ return err
+ }
+ bind.v6.blackhole = blackhole
+ return nil
+}
+
+func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error {
+ const IP_UNICAST_IF = 31
+ /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
+ var bytes [4]byte
+ binary.BigEndian.PutUint32(bytes[:], interfaceIndex)
+ interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
+ err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex))
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error {
+ const IPV6_UNICAST_IF = 31
+ return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
+}
+
+// unsafeSlice updates the slice slicePtr to be a slice
+// referencing the provided data with its length & capacity set to
+// lenCap.
+//
+// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
+// update callers to use unsafe.Slice instead of this.
+func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
+ type sliceHeader struct {
+ Data unsafe.Pointer
+ Len int
+ Cap int
+ }
+ h := (*sliceHeader)(slicePtr)
+ h.Data = data
+ h.Len = lenCap
+ h.Cap = lenCap
+}
diff --git a/conn/boundif_windows.go b/conn/boundif_windows.go
deleted file mode 100644
index 6f6fdd8..0000000
--- a/conn/boundif_windows.go
+++ /dev/null
@@ -1,59 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package conn
-
-import (
- "encoding/binary"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-const (
- sockoptIP_UNICAST_IF = 31
- sockoptIPV6_UNICAST_IF = 31
-)
-
-func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
- /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
- bytes := make([]byte, 4)
- binary.BigEndian.PutUint32(bytes, interfaceIndex)
- interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
-
- sysconn, err := bind.ipv4.SyscallConn()
- if err != nil {
- return err
- }
- err2 := sysconn.Control(func(fd uintptr) {
- err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, sockoptIP_UNICAST_IF, int(interfaceIndex))
- })
- if err2 != nil {
- return err2
- }
- if err != nil {
- return err
- }
- bind.blackhole4 = blackhole
- return nil
-}
-
-func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
- sysconn, err := bind.ipv6.SyscallConn()
- if err != nil {
- return err
- }
- err2 := sysconn.Control(func(fd uintptr) {
- err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, sockoptIPV6_UNICAST_IF, int(interfaceIndex))
- })
- if err2 != nil {
- return err2
- }
- if err != nil {
- return err
- }
- bind.blackhole6 = blackhole
- return nil
-}
diff --git a/conn/default.go b/conn/default.go
index cd9bfb0..161454a 100644
--- a/conn/default.go
+++ b/conn/default.go
@@ -1,4 +1,4 @@
-// +build !linux
+// +build !linux,!windows
/* SPDX-License-Identifier: MIT
*
diff --git a/conn/winrio/rio_windows.go b/conn/winrio/rio_windows.go
new file mode 100644
index 0000000..1785a02
--- /dev/null
+++ b/conn/winrio/rio_windows.go
@@ -0,0 +1,243 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package winrio
+
+import (
+ "log"
+ "sync"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+const (
+ MsgDontNotify = 1
+ MsgDefer = 2
+ MsgWaitAll = 4
+ MsgCommitOnly = 8
+
+ MaxCqSize = 0x8000000
+
+ invalidBufferId = 0xFFFFFFFF
+ invalidCq = 0
+ invalidRq = 0
+ corruptCq = 0xFFFFFFFF
+)
+
+var extensionFunctionTable struct {
+ cbSize uint32
+ rioReceive uintptr
+ rioReceiveEx uintptr
+ rioSend uintptr
+ rioSendEx uintptr
+ rioCloseCompletionQueue uintptr
+ rioCreateCompletionQueue uintptr
+ rioCreateRequestQueue uintptr
+ rioDequeueCompletion uintptr
+ rioDeregisterBuffer uintptr
+ rioNotify uintptr
+ rioRegisterBuffer uintptr
+ rioResizeCompletionQueue uintptr
+ rioResizeRequestQueue uintptr
+}
+
+type Cq uintptr
+
+type Rq uintptr
+
+type BufferId uintptr
+
+type Buffer struct {
+ Id BufferId
+ Offset uint32
+ Length uint32
+}
+
+type Result struct {
+ Status int32
+ BytesTransferred uint32
+ SocketContext uint64
+ RequestContext uint64
+}
+
+type notificationCompletionType uint32
+
+const (
+ eventCompletion notificationCompletionType = 1
+ iocpCompletion notificationCompletionType = 2
+)
+
+type eventNotificationCompletion struct {
+ completionType notificationCompletionType
+ event windows.Handle
+ notifyReset uint32
+}
+
+type iocpNotificationCompletion struct {
+ completionType notificationCompletionType
+ iocp windows.Handle
+ key uintptr
+ overlapped *windows.Overlapped
+}
+
+var initialized sync.Once
+var available bool
+
+func Initialize() bool {
+ initialized.Do(func() {
+ var (
+ err error
+ socket windows.Handle
+ cq Cq
+ )
+ defer func() {
+ if err == nil {
+ return
+ }
+ if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 {
+ return
+ }
+ log.Printf("Registered I/O is unavailable: %v", err)
+ }()
+ socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
+ if err != nil {
+ return
+ }
+ defer windows.CloseHandle(socket)
+ var WSAID_MULTIPLE_RIO = &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}}
+ const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024
+ ob := uint32(0)
+ err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,
+ (*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)),
+ (*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)),
+ &ob, nil, 0)
+ if err != nil {
+ return
+ }
+ // While we should be able to stop here, after getting the function pointers, some anti-virus actually causes
+ // failures in RIOCreateRequestQueue, so keep going to be certain this is supported.
+ cq, err = CreatePolledCompletionQueue(2)
+ if err != nil {
+ return
+ }
+ defer CloseCompletionQueue(cq)
+ _, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0)
+ if err != nil {
+ return
+ }
+ available = true
+ })
+ return available
+}
+
+func Socket(af, typ, proto int32) (windows.Handle, error) {
+ return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO)
+}
+
+func CloseCompletionQueue(cq Cq) {
+ _, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0)
+}
+
+func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) {
+ notificationCompletion := &eventNotificationCompletion{
+ completionType: eventCompletion,
+ event: event,
+ }
+ if notifyReset {
+ notificationCompletion.notifyReset = 1
+ }
+ ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
+ if ret == invalidCq {
+ return 0, err
+ }
+ return Cq(ret), nil
+}
+
+func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) {
+ notificationCompletion := &iocpNotificationCompletion{
+ completionType: iocpCompletion,
+ iocp: iocp,
+ overlapped: overlapped,
+ }
+ ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
+ if ret == invalidCq {
+ return 0, err
+ }
+ return Cq(ret), nil
+}
+
+func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) {
+ ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0)
+ if ret == invalidCq {
+ return 0, err
+ }
+ return Cq(ret), nil
+}
+
+func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) {
+ ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0)
+ if ret == invalidRq {
+ return 0, err
+ }
+ return Rq(ret), nil
+}
+
+func DequeueCompletion(cq Cq, results []Result) uint32 {
+ var array uintptr
+ if len(results) > 0 {
+ array = uintptr(unsafe.Pointer(&results[0]))
+ }
+ ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results)))
+ if ret == corruptCq {
+ panic("cq is corrupt")
+ }
+ return uint32(ret)
+}
+
+func DeregisterBuffer(id BufferId) {
+ _, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0)
+}
+
+func RegisterBuffer(buffer []byte) (BufferId, error) {
+ var buf unsafe.Pointer
+ if len(buffer) > 0 {
+ buf = unsafe.Pointer(&buffer[0])
+ }
+ return RegisterPointer(buf, uint32(len(buffer)))
+}
+
+func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) {
+ ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0)
+ if ret == invalidBufferId {
+ return 0, err
+ }
+ return BufferId(ret), nil
+}
+
+func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
+ ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
+ if ret == 0 {
+ return err
+ }
+ return nil
+}
+
+func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
+ ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
+ if ret == 0 {
+ return err
+ }
+ return nil
+}
+
+func Notify(cq Cq) error {
+ ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0)
+ if ret != 0 {
+ return windows.Errno(ret)
+ }
+ return nil
+}