aboutsummaryrefslogtreecommitdiffstats
path: root/ipc/namedpipe
diff options
context:
space:
mode:
Diffstat (limited to 'ipc/namedpipe')
-rw-r--r--ipc/namedpipe/file.go287
-rw-r--r--ipc/namedpipe/namedpipe.go485
-rw-r--r--ipc/namedpipe/namedpipe_test.go674
3 files changed, 1446 insertions, 0 deletions
diff --git a/ipc/namedpipe/file.go b/ipc/namedpipe/file.go
new file mode 100644
index 0000000..ab1e13d
--- /dev/null
+++ b/ipc/namedpipe/file.go
@@ -0,0 +1,287 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Copyright 2015 Microsoft
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build windows
+
+package namedpipe
+
+import (
+ "io"
+ "os"
+ "runtime"
+ "sync"
+ "sync/atomic"
+ "time"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+type timeoutChan chan struct{}
+
+var (
+ ioInitOnce sync.Once
+ ioCompletionPort windows.Handle
+)
+
+// ioResult contains the result of an asynchronous IO operation
+type ioResult struct {
+ bytes uint32
+ err error
+}
+
+// ioOperation represents an outstanding asynchronous Win32 IO
+type ioOperation struct {
+ o windows.Overlapped
+ ch chan ioResult
+}
+
+func initIo() {
+ h, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
+ if err != nil {
+ panic(err)
+ }
+ ioCompletionPort = h
+ go ioCompletionProcessor(h)
+}
+
+// file implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
+// It takes ownership of this handle and will close it if it is garbage collected.
+type file struct {
+ handle windows.Handle
+ wg sync.WaitGroup
+ wgLock sync.RWMutex
+ closing atomic.Bool
+ socket bool
+ readDeadline deadlineHandler
+ writeDeadline deadlineHandler
+}
+
+type deadlineHandler struct {
+ setLock sync.Mutex
+ channel timeoutChan
+ channelLock sync.RWMutex
+ timer *time.Timer
+ timedout atomic.Bool
+}
+
+// makeFile makes a new file from an existing file handle
+func makeFile(h windows.Handle) (*file, error) {
+ f := &file{handle: h}
+ ioInitOnce.Do(initIo)
+ _, err := windows.CreateIoCompletionPort(h, ioCompletionPort, 0, 0)
+ if err != nil {
+ return nil, err
+ }
+ err = windows.SetFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE)
+ if err != nil {
+ return nil, err
+ }
+ f.readDeadline.channel = make(timeoutChan)
+ f.writeDeadline.channel = make(timeoutChan)
+ return f, nil
+}
+
+// closeHandle closes the resources associated with a Win32 handle
+func (f *file) closeHandle() {
+ f.wgLock.Lock()
+ // Atomically set that we are closing, releasing the resources only once.
+ if f.closing.Swap(true) == false {
+ f.wgLock.Unlock()
+ // cancel all IO and wait for it to complete
+ windows.CancelIoEx(f.handle, nil)
+ f.wg.Wait()
+ // at this point, no new IO can start
+ windows.Close(f.handle)
+ f.handle = 0
+ } else {
+ f.wgLock.Unlock()
+ }
+}
+
+// Close closes a file.
+func (f *file) Close() error {
+ f.closeHandle()
+ return nil
+}
+
+// prepareIo prepares for a new IO operation.
+// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
+func (f *file) prepareIo() (*ioOperation, error) {
+ f.wgLock.RLock()
+ if f.closing.Load() {
+ f.wgLock.RUnlock()
+ return nil, os.ErrClosed
+ }
+ f.wg.Add(1)
+ f.wgLock.RUnlock()
+ c := &ioOperation{}
+ c.ch = make(chan ioResult)
+ return c, nil
+}
+
+// ioCompletionProcessor processes completed async IOs forever
+func ioCompletionProcessor(h windows.Handle) {
+ for {
+ var bytes uint32
+ var key uintptr
+ var op *ioOperation
+ err := windows.GetQueuedCompletionStatus(h, &bytes, &key, (**windows.Overlapped)(unsafe.Pointer(&op)), windows.INFINITE)
+ if op == nil {
+ panic(err)
+ }
+ op.ch <- ioResult{bytes, err}
+ }
+}
+
+// asyncIo processes the return value from ReadFile or WriteFile, blocking until
+// the operation has actually completed.
+func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
+ if err != windows.ERROR_IO_PENDING {
+ return int(bytes), err
+ }
+
+ if f.closing.Load() {
+ windows.CancelIoEx(f.handle, &c.o)
+ }
+
+ var timeout timeoutChan
+ if d != nil {
+ d.channelLock.Lock()
+ timeout = d.channel
+ d.channelLock.Unlock()
+ }
+
+ var r ioResult
+ select {
+ case r = <-c.ch:
+ err = r.err
+ if err == windows.ERROR_OPERATION_ABORTED {
+ if f.closing.Load() {
+ err = os.ErrClosed
+ }
+ } else if err != nil && f.socket {
+ // err is from Win32. Query the overlapped structure to get the winsock error.
+ var bytes, flags uint32
+ err = windows.WSAGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
+ }
+ case <-timeout:
+ windows.CancelIoEx(f.handle, &c.o)
+ r = <-c.ch
+ err = r.err
+ if err == windows.ERROR_OPERATION_ABORTED {
+ err = os.ErrDeadlineExceeded
+ }
+ }
+
+ // runtime.KeepAlive is needed, as c is passed via native
+ // code to ioCompletionProcessor, c must remain alive
+ // until the channel read is complete.
+ runtime.KeepAlive(c)
+ return int(r.bytes), err
+}
+
+// Read reads from a file handle.
+func (f *file) Read(b []byte) (int, error) {
+ c, err := f.prepareIo()
+ if err != nil {
+ return 0, err
+ }
+ defer f.wg.Done()
+
+ if f.readDeadline.timedout.Load() {
+ return 0, os.ErrDeadlineExceeded
+ }
+
+ var bytes uint32
+ err = windows.ReadFile(f.handle, b, &bytes, &c.o)
+ n, err := f.asyncIo(c, &f.readDeadline, bytes, err)
+ runtime.KeepAlive(b)
+
+ // Handle EOF conditions.
+ if err == nil && n == 0 && len(b) != 0 {
+ return 0, io.EOF
+ } else if err == windows.ERROR_BROKEN_PIPE {
+ return 0, io.EOF
+ } else {
+ return n, err
+ }
+}
+
+// Write writes to a file handle.
+func (f *file) Write(b []byte) (int, error) {
+ c, err := f.prepareIo()
+ if err != nil {
+ return 0, err
+ }
+ defer f.wg.Done()
+
+ if f.writeDeadline.timedout.Load() {
+ return 0, os.ErrDeadlineExceeded
+ }
+
+ var bytes uint32
+ err = windows.WriteFile(f.handle, b, &bytes, &c.o)
+ n, err := f.asyncIo(c, &f.writeDeadline, bytes, err)
+ runtime.KeepAlive(b)
+ return n, err
+}
+
+func (f *file) SetReadDeadline(deadline time.Time) error {
+ return f.readDeadline.set(deadline)
+}
+
+func (f *file) SetWriteDeadline(deadline time.Time) error {
+ return f.writeDeadline.set(deadline)
+}
+
+func (f *file) Flush() error {
+ return windows.FlushFileBuffers(f.handle)
+}
+
+func (f *file) Fd() uintptr {
+ return uintptr(f.handle)
+}
+
+func (d *deadlineHandler) set(deadline time.Time) error {
+ d.setLock.Lock()
+ defer d.setLock.Unlock()
+
+ if d.timer != nil {
+ if !d.timer.Stop() {
+ <-d.channel
+ }
+ d.timer = nil
+ }
+ d.timedout.Store(false)
+
+ select {
+ case <-d.channel:
+ d.channelLock.Lock()
+ d.channel = make(chan struct{})
+ d.channelLock.Unlock()
+ default:
+ }
+
+ if deadline.IsZero() {
+ return nil
+ }
+
+ timeoutIO := func() {
+ d.timedout.Store(true)
+ close(d.channel)
+ }
+
+ now := time.Now()
+ duration := deadline.Sub(now)
+ if deadline.After(now) {
+ // Deadline is in the future, set a timer to wait
+ d.timer = time.AfterFunc(duration, timeoutIO)
+ } else {
+ // Deadline is in the past. Cancel all pending IO now.
+ timeoutIO()
+ }
+ return nil
+}
diff --git a/ipc/namedpipe/namedpipe.go b/ipc/namedpipe/namedpipe.go
new file mode 100644
index 0000000..ef3dea1
--- /dev/null
+++ b/ipc/namedpipe/namedpipe.go
@@ -0,0 +1,485 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Copyright 2015 Microsoft
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build windows
+
+// Package namedpipe implements a net.Conn and net.Listener around Windows named pipes.
+package namedpipe
+
+import (
+ "context"
+ "io"
+ "net"
+ "os"
+ "runtime"
+ "sync/atomic"
+ "time"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+type pipe struct {
+ *file
+ path string
+}
+
+type messageBytePipe struct {
+ pipe
+ writeClosed atomic.Bool
+ readEOF bool
+}
+
+type pipeAddress string
+
+func (f *pipe) LocalAddr() net.Addr {
+ return pipeAddress(f.path)
+}
+
+func (f *pipe) RemoteAddr() net.Addr {
+ return pipeAddress(f.path)
+}
+
+func (f *pipe) SetDeadline(t time.Time) error {
+ f.SetReadDeadline(t)
+ f.SetWriteDeadline(t)
+ return nil
+}
+
+// CloseWrite closes the write side of a message pipe in byte mode.
+func (f *messageBytePipe) CloseWrite() error {
+ if !f.writeClosed.CompareAndSwap(false, true) {
+ return io.ErrClosedPipe
+ }
+ err := f.file.Flush()
+ if err != nil {
+ f.writeClosed.Store(false)
+ return err
+ }
+ _, err = f.file.Write(nil)
+ if err != nil {
+ f.writeClosed.Store(false)
+ return err
+ }
+ return nil
+}
+
+// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
+// they are used to implement CloseWrite.
+func (f *messageBytePipe) Write(b []byte) (int, error) {
+ if f.writeClosed.Load() {
+ return 0, io.ErrClosedPipe
+ }
+ if len(b) == 0 {
+ return 0, nil
+ }
+ return f.file.Write(b)
+}
+
+// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message
+// mode pipe will return io.EOF, as will all subsequent reads.
+func (f *messageBytePipe) Read(b []byte) (int, error) {
+ if f.readEOF {
+ return 0, io.EOF
+ }
+ n, err := f.file.Read(b)
+ if err == io.EOF {
+ // If this was the result of a zero-byte read, then
+ // it is possible that the read was due to a zero-size
+ // message. Since we are simulating CloseWrite with a
+ // zero-byte message, ensure that all future Read calls
+ // also return EOF.
+ f.readEOF = true
+ } else if err == windows.ERROR_MORE_DATA {
+ // ERROR_MORE_DATA indicates that the pipe's read mode is message mode
+ // and the message still has more bytes. Treat this as a success, since
+ // this package presents all named pipes as byte streams.
+ err = nil
+ }
+ return n, err
+}
+
+func (f *pipe) Handle() windows.Handle {
+ return f.handle
+}
+
+func (s pipeAddress) Network() string {
+ return "pipe"
+}
+
+func (s pipeAddress) String() string {
+ return string(s)
+}
+
+// tryDialPipe attempts to dial the specified pipe until cancellation or timeout.
+func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) {
+ for {
+ select {
+ case <-ctx.Done():
+ return 0, ctx.Err()
+ default:
+ path16, err := windows.UTF16PtrFromString(*path)
+ if err != nil {
+ return 0, err
+ }
+ h, err := windows.CreateFile(path16, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
+ if err == nil {
+ return h, nil
+ }
+ if err != windows.ERROR_PIPE_BUSY {
+ return h, &os.PathError{Err: err, Op: "open", Path: *path}
+ }
+ // Wait 10 msec and try again. This is a rather simplistic
+ // view, as we always try each 10 milliseconds.
+ time.Sleep(10 * time.Millisecond)
+ }
+ }
+}
+
+// DialConfig exposes various options for use in Dial and DialContext.
+type DialConfig struct {
+ ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID.
+}
+
+// DialTimeout connects to the specified named pipe by path, timing out if the
+// connection takes longer than the specified duration. If timeout is zero, then
+// we use a default timeout of 2 seconds.
+func (config *DialConfig) DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
+ if timeout == 0 {
+ timeout = time.Second * 2
+ }
+ absTimeout := time.Now().Add(timeout)
+ ctx, _ := context.WithDeadline(context.Background(), absTimeout)
+ conn, err := config.DialContext(ctx, path)
+ if err == context.DeadlineExceeded {
+ return nil, os.ErrDeadlineExceeded
+ }
+ return conn, err
+}
+
+// DialContext attempts to connect to the specified named pipe by path.
+func (config *DialConfig) DialContext(ctx context.Context, path string) (net.Conn, error) {
+ var err error
+ var h windows.Handle
+ h, err = tryDialPipe(ctx, &path)
+ if err != nil {
+ return nil, err
+ }
+
+ if config.ExpectedOwner != nil {
+ sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION)
+ if err != nil {
+ windows.Close(h)
+ return nil, err
+ }
+ realOwner, _, err := sd.Owner()
+ if err != nil {
+ windows.Close(h)
+ return nil, err
+ }
+ if !realOwner.Equals(config.ExpectedOwner) {
+ windows.Close(h)
+ return nil, windows.ERROR_ACCESS_DENIED
+ }
+ }
+
+ var flags uint32
+ err = windows.GetNamedPipeInfo(h, &flags, nil, nil, nil)
+ if err != nil {
+ windows.Close(h)
+ return nil, err
+ }
+
+ f, err := makeFile(h)
+ if err != nil {
+ windows.Close(h)
+ return nil, err
+ }
+
+ // If the pipe is in message mode, return a message byte pipe, which
+ // supports CloseWrite.
+ if flags&windows.PIPE_TYPE_MESSAGE != 0 {
+ return &messageBytePipe{
+ pipe: pipe{file: f, path: path},
+ }, nil
+ }
+ return &pipe{file: f, path: path}, nil
+}
+
+var defaultDialer DialConfig
+
+// DialTimeout calls DialConfig.DialTimeout using an empty configuration.
+func DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
+ return defaultDialer.DialTimeout(path, timeout)
+}
+
+// DialContext calls DialConfig.DialContext using an empty configuration.
+func DialContext(ctx context.Context, path string) (net.Conn, error) {
+ return defaultDialer.DialContext(ctx, path)
+}
+
+type acceptResponse struct {
+ f *file
+ err error
+}
+
+type pipeListener struct {
+ firstHandle windows.Handle
+ path string
+ config ListenConfig
+ acceptCh chan chan acceptResponse
+ closeCh chan int
+ doneCh chan int
+}
+
+func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, isFirstPipe bool) (windows.Handle, error) {
+ path16, err := windows.UTF16PtrFromString(path)
+ if err != nil {
+ return 0, &os.PathError{Op: "open", Path: path, Err: err}
+ }
+
+ var oa windows.OBJECT_ATTRIBUTES
+ oa.Length = uint32(unsafe.Sizeof(oa))
+
+ var ntPath windows.NTUnicodeString
+ if err := windows.RtlDosPathNameToNtPathName(path16, &ntPath, nil, nil); err != nil {
+ if ntstatus, ok := err.(windows.NTStatus); ok {
+ err = ntstatus.Errno()
+ }
+ return 0, &os.PathError{Op: "open", Path: path, Err: err}
+ }
+ defer windows.LocalFree(windows.Handle(unsafe.Pointer(ntPath.Buffer)))
+ oa.ObjectName = &ntPath
+
+ // The security descriptor is only needed for the first pipe.
+ if isFirstPipe {
+ if sd != nil {
+ oa.SecurityDescriptor = sd
+ } else {
+ // Construct the default named pipe security descriptor.
+ var acl *windows.ACL
+ if err := windows.RtlDefaultNpAcl(&acl); err != nil {
+ return 0, err
+ }
+ defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl)))
+ sd, err = windows.NewSecurityDescriptor()
+ if err != nil {
+ return 0, err
+ }
+ if err = sd.SetDACL(acl, true, false); err != nil {
+ return 0, err
+ }
+ oa.SecurityDescriptor = sd
+ }
+ }
+
+ typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS)
+ if c.MessageMode {
+ typ |= windows.FILE_PIPE_MESSAGE_TYPE
+ }
+
+ disposition := uint32(windows.FILE_OPEN)
+ access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
+ if isFirstPipe {
+ disposition = windows.FILE_CREATE
+ // By not asking for read or write access, the named pipe file system
+ // will put this pipe into an initially disconnected state, blocking
+ // client connections until the next call with isFirstPipe == false.
+ access = windows.SYNCHRONIZE
+ }
+
+ timeout := int64(-50 * 10000) // 50ms
+
+ var (
+ h windows.Handle
+ iosb windows.IO_STATUS_BLOCK
+ )
+ err = windows.NtCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout)
+ if err != nil {
+ if ntstatus, ok := err.(windows.NTStatus); ok {
+ err = ntstatus.Errno()
+ }
+ return 0, &os.PathError{Op: "open", Path: path, Err: err}
+ }
+
+ runtime.KeepAlive(ntPath)
+ return h, nil
+}
+
+func (l *pipeListener) makeServerPipe() (*file, error) {
+ h, err := makeServerPipeHandle(l.path, nil, &l.config, false)
+ if err != nil {
+ return nil, err
+ }
+ f, err := makeFile(h)
+ if err != nil {
+ windows.Close(h)
+ return nil, err
+ }
+ return f, nil
+}
+
+func (l *pipeListener) makeConnectedServerPipe() (*file, error) {
+ p, err := l.makeServerPipe()
+ if err != nil {
+ return nil, err
+ }
+
+ // Wait for the client to connect.
+ ch := make(chan error)
+ go func(p *file) {
+ ch <- connectPipe(p)
+ }(p)
+
+ select {
+ case err = <-ch:
+ if err != nil {
+ p.Close()
+ p = nil
+ }
+ case <-l.closeCh:
+ // Abort the connect request by closing the handle.
+ p.Close()
+ p = nil
+ err = <-ch
+ if err == nil || err == os.ErrClosed {
+ err = net.ErrClosed
+ }
+ }
+ return p, err
+}
+
+func (l *pipeListener) listenerRoutine() {
+ closed := false
+ for !closed {
+ select {
+ case <-l.closeCh:
+ closed = true
+ case responseCh := <-l.acceptCh:
+ var (
+ p *file
+ err error
+ )
+ for {
+ p, err = l.makeConnectedServerPipe()
+ // If the connection was immediately closed by the client, try
+ // again.
+ if err != windows.ERROR_NO_DATA {
+ break
+ }
+ }
+ responseCh <- acceptResponse{p, err}
+ closed = err == net.ErrClosed
+ }
+ }
+ windows.Close(l.firstHandle)
+ l.firstHandle = 0
+ // Notify Close and Accept callers that the handle has been closed.
+ close(l.doneCh)
+}
+
+// ListenConfig contains configuration for the pipe listener.
+type ListenConfig struct {
+ // SecurityDescriptor contains a Windows security descriptor. If nil, the default from RtlDefaultNpAcl is used.
+ SecurityDescriptor *windows.SECURITY_DESCRIPTOR
+
+ // MessageMode determines whether the pipe is in byte or message mode. In either
+ // case the pipe is read in byte mode by default. The only practical difference in
+ // this implementation is that CloseWrite is only supported for message mode pipes;
+ // CloseWrite is implemented as a zero-byte write, but zero-byte writes are only
+ // transferred to the reader (and returned as io.EOF in this implementation)
+ // when the pipe is in message mode.
+ MessageMode bool
+
+ // InputBufferSize specifies the initial size of the input buffer, in bytes, which the OS will grow as needed.
+ InputBufferSize int32
+
+ // OutputBufferSize specifies the initial size of the output buffer, in bytes, which the OS will grow as needed.
+ OutputBufferSize int32
+}
+
+// Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe.
+// The pipe must not already exist.
+func (c *ListenConfig) Listen(path string) (net.Listener, error) {
+ h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
+ if err != nil {
+ return nil, err
+ }
+ l := &pipeListener{
+ firstHandle: h,
+ path: path,
+ config: *c,
+ acceptCh: make(chan chan acceptResponse),
+ closeCh: make(chan int),
+ doneCh: make(chan int),
+ }
+ // The first connection is swallowed on Windows 7 & 8, so synthesize it.
+ if maj, min, _ := windows.RtlGetNtVersionNumbers(); maj < 6 || (maj == 6 && min < 4) {
+ path16, err := windows.UTF16PtrFromString(path)
+ if err == nil {
+ h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
+ if err == nil {
+ windows.CloseHandle(h)
+ }
+ }
+ }
+ go l.listenerRoutine()
+ return l, nil
+}
+
+var defaultListener ListenConfig
+
+// Listen calls ListenConfig.Listen using an empty configuration.
+func Listen(path string) (net.Listener, error) {
+ return defaultListener.Listen(path)
+}
+
+func connectPipe(p *file) error {
+ c, err := p.prepareIo()
+ if err != nil {
+ return err
+ }
+ defer p.wg.Done()
+
+ err = windows.ConnectNamedPipe(p.handle, &c.o)
+ _, err = p.asyncIo(c, nil, 0, err)
+ if err != nil && err != windows.ERROR_PIPE_CONNECTED {
+ return err
+ }
+ return nil
+}
+
+func (l *pipeListener) Accept() (net.Conn, error) {
+ ch := make(chan acceptResponse)
+ select {
+ case l.acceptCh <- ch:
+ response := <-ch
+ err := response.err
+ if err != nil {
+ return nil, err
+ }
+ if l.config.MessageMode {
+ return &messageBytePipe{
+ pipe: pipe{file: response.f, path: l.path},
+ }, nil
+ }
+ return &pipe{file: response.f, path: l.path}, nil
+ case <-l.doneCh:
+ return nil, net.ErrClosed
+ }
+}
+
+func (l *pipeListener) Close() error {
+ select {
+ case l.closeCh <- 1:
+ <-l.doneCh
+ case <-l.doneCh:
+ }
+ return nil
+}
+
+func (l *pipeListener) Addr() net.Addr {
+ return pipeAddress(l.path)
+}
diff --git a/ipc/namedpipe/namedpipe_test.go b/ipc/namedpipe/namedpipe_test.go
new file mode 100644
index 0000000..998453b
--- /dev/null
+++ b/ipc/namedpipe/namedpipe_test.go
@@ -0,0 +1,674 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Copyright 2015 Microsoft
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build windows
+
+package namedpipe_test
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "errors"
+ "io"
+ "net"
+ "os"
+ "sync"
+ "syscall"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/windows"
+ "golang.zx2c4.com/wireguard/ipc/namedpipe"
+)
+
+func randomPipePath() string {
+ guid, err := windows.GenerateGUID()
+ if err != nil {
+ panic(err)
+ }
+ return `\\.\PIPE\go-namedpipe-test-` + guid.String()
+}
+
+func TestPingPong(t *testing.T) {
+ const (
+ ping = 42
+ pong = 24
+ )
+ pipePath := randomPipePath()
+ listener, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatalf("unable to listen on pipe: %v", err)
+ }
+ defer listener.Close()
+ go func() {
+ incoming, err := listener.Accept()
+ if err != nil {
+ t.Fatalf("unable to accept pipe connection: %v", err)
+ }
+ defer incoming.Close()
+ var data [1]byte
+ _, err = incoming.Read(data[:])
+ if err != nil {
+ t.Fatalf("unable to read ping from pipe: %v", err)
+ }
+ if data[0] != ping {
+ t.Fatalf("expected ping, got %d", data[0])
+ }
+ data[0] = pong
+ _, err = incoming.Write(data[:])
+ if err != nil {
+ t.Fatalf("unable to write pong to pipe: %v", err)
+ }
+ }()
+ client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
+ if err != nil {
+ t.Fatalf("unable to dial pipe: %v", err)
+ }
+ defer client.Close()
+ client.SetDeadline(time.Now().Add(time.Second * 5))
+ var data [1]byte
+ data[0] = ping
+ _, err = client.Write(data[:])
+ if err != nil {
+ t.Fatalf("unable to write ping to pipe: %v", err)
+ }
+ _, err = client.Read(data[:])
+ if err != nil {
+ t.Fatalf("unable to read pong from pipe: %v", err)
+ }
+ if data[0] != pong {
+ t.Fatalf("expected pong, got %d", data[0])
+ }
+}
+
+func TestDialUnknownFailsImmediately(t *testing.T) {
+ _, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0))
+ if !errors.Is(err, syscall.ENOENT) {
+ t.Fatalf("expected ENOENT got %v", err)
+ }
+}
+
+func TestDialListenerTimesOut(t *testing.T) {
+ pipePath := randomPipePath()
+ l, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+ pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond)
+ if err == nil {
+ pipe.Close()
+ }
+ if err != os.ErrDeadlineExceeded {
+ t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
+ }
+}
+
+func TestDialContextListenerTimesOut(t *testing.T) {
+ pipePath := randomPipePath()
+ l, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+ d := 10 * time.Millisecond
+ ctx, _ := context.WithTimeout(context.Background(), d)
+ pipe, err := namedpipe.DialContext(ctx, pipePath)
+ if err == nil {
+ pipe.Close()
+ }
+ if err != context.DeadlineExceeded {
+ t.Fatalf("expected context.DeadlineExceeded, got %v", err)
+ }
+}
+
+func TestDialListenerGetsCancelled(t *testing.T) {
+ pipePath := randomPipePath()
+ ctx, cancel := context.WithCancel(context.Background())
+ l, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+ ch := make(chan error)
+ go func(ctx context.Context, ch chan error) {
+ _, err := namedpipe.DialContext(ctx, pipePath)
+ ch <- err
+ }(ctx, ch)
+ time.Sleep(time.Millisecond * 30)
+ cancel()
+ err = <-ch
+ if err != context.Canceled {
+ t.Fatalf("expected context.Canceled, got %v", err)
+ }
+}
+
+func TestDialAccessDeniedWithRestrictedSD(t *testing.T) {
+ if windows.NewLazySystemDLL("ntdll.dll").NewProc("wine_get_version").Find() == nil {
+ t.Skip("dacls on named pipes are broken on wine")
+ }
+ pipePath := randomPipePath()
+ sd, _ := windows.SecurityDescriptorFromString("D:")
+ l, err := (&namedpipe.ListenConfig{
+ SecurityDescriptor: sd,
+ }).Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+ pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
+ if err == nil {
+ pipe.Close()
+ }
+ if !errors.Is(err, windows.ERROR_ACCESS_DENIED) {
+ t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
+ }
+}
+
+func getConnection(cfg *namedpipe.ListenConfig) (client, server net.Conn, err error) {
+ pipePath := randomPipePath()
+ if cfg == nil {
+ cfg = &namedpipe.ListenConfig{}
+ }
+ l, err := cfg.Listen(pipePath)
+ if err != nil {
+ return
+ }
+ defer l.Close()
+
+ type response struct {
+ c net.Conn
+ err error
+ }
+ ch := make(chan response)
+ go func() {
+ c, err := l.Accept()
+ ch <- response{c, err}
+ }()
+
+ c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
+ if err != nil {
+ return
+ }
+
+ r := <-ch
+ if err = r.err; err != nil {
+ c.Close()
+ return
+ }
+
+ client = c
+ server = r.c
+ return
+}
+
+func TestReadTimeout(t *testing.T) {
+ c, s, err := getConnection(nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ defer s.Close()
+
+ c.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
+
+ buf := make([]byte, 10)
+ _, err = c.Read(buf)
+ if err != os.ErrDeadlineExceeded {
+ t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
+ }
+}
+
+func server(l net.Listener, ch chan int) {
+ c, err := l.Accept()
+ if err != nil {
+ panic(err)
+ }
+ rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
+ s, err := rw.ReadString('\n')
+ if err != nil {
+ panic(err)
+ }
+ _, err = rw.WriteString("got " + s)
+ if err != nil {
+ panic(err)
+ }
+ err = rw.Flush()
+ if err != nil {
+ panic(err)
+ }
+ c.Close()
+ ch <- 1
+}
+
+func TestFullListenDialReadWrite(t *testing.T) {
+ pipePath := randomPipePath()
+ l, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+
+ ch := make(chan int)
+ go server(l, ch)
+
+ c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
+ _, err = rw.WriteString("hello world\n")
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = rw.Flush()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ s, err := rw.ReadString('\n')
+ if err != nil {
+ t.Fatal(err)
+ }
+ ms := "got hello world\n"
+ if s != ms {
+ t.Errorf("expected '%s', got '%s'", ms, s)
+ }
+
+ <-ch
+}
+
+func TestCloseAbortsListen(t *testing.T) {
+ pipePath := randomPipePath()
+ l, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ch := make(chan error)
+ go func() {
+ _, err := l.Accept()
+ ch <- err
+ }()
+
+ time.Sleep(30 * time.Millisecond)
+ l.Close()
+
+ err = <-ch
+ if err != net.ErrClosed {
+ t.Fatalf("expected net.ErrClosed, got %v", err)
+ }
+}
+
+func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) {
+ b := make([]byte, 10)
+ w.Close()
+ n, err := r.Read(b)
+ if n > 0 {
+ t.Errorf("unexpected byte count %d", n)
+ }
+ if err != io.EOF {
+ t.Errorf("expected EOF: %v", err)
+ }
+}
+
+func TestCloseClientEOFServer(t *testing.T) {
+ c, s, err := getConnection(nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ defer s.Close()
+ ensureEOFOnClose(t, c, s)
+}
+
+func TestCloseServerEOFClient(t *testing.T) {
+ c, s, err := getConnection(nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ defer s.Close()
+ ensureEOFOnClose(t, s, c)
+}
+
+func TestCloseWriteEOF(t *testing.T) {
+ cfg := &namedpipe.ListenConfig{
+ MessageMode: true,
+ }
+ c, s, err := getConnection(cfg)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ defer s.Close()
+
+ type closeWriter interface {
+ CloseWrite() error
+ }
+
+ err = c.(closeWriter).CloseWrite()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ b := make([]byte, 10)
+ _, err = s.Read(b)
+ if err != io.EOF {
+ t.Fatal(err)
+ }
+}
+
+func TestAcceptAfterCloseFails(t *testing.T) {
+ pipePath := randomPipePath()
+ l, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ l.Close()
+ _, err = l.Accept()
+ if err != net.ErrClosed {
+ t.Fatalf("expected net.ErrClosed, got %v", err)
+ }
+}
+
+func TestDialTimesOutByDefault(t *testing.T) {
+ pipePath := randomPipePath()
+ l, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+ pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds.
+ if err == nil {
+ pipe.Close()
+ }
+ if err != os.ErrDeadlineExceeded {
+ t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
+ }
+}
+
+func TestTimeoutPendingRead(t *testing.T) {
+ pipePath := randomPipePath()
+ l, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+
+ serverDone := make(chan struct{})
+
+ go func() {
+ s, err := l.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ time.Sleep(1 * time.Second)
+ s.Close()
+ close(serverDone)
+ }()
+
+ client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer client.Close()
+
+ clientErr := make(chan error)
+ go func() {
+ buf := make([]byte, 10)
+ _, err = client.Read(buf)
+ clientErr <- err
+ }()
+
+ time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline
+ client.SetReadDeadline(time.Unix(1, 0))
+
+ select {
+ case err = <-clientErr:
+ if err != os.ErrDeadlineExceeded {
+ t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
+ }
+ case <-time.After(100 * time.Millisecond):
+ t.Fatalf("timed out while waiting for read to cancel")
+ <-clientErr
+ }
+ <-serverDone
+}
+
+func TestTimeoutPendingWrite(t *testing.T) {
+ pipePath := randomPipePath()
+ l, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+
+ serverDone := make(chan struct{})
+
+ go func() {
+ s, err := l.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ time.Sleep(1 * time.Second)
+ s.Close()
+ close(serverDone)
+ }()
+
+ client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer client.Close()
+
+ clientErr := make(chan error)
+ go func() {
+ _, err = client.Write([]byte("this should timeout"))
+ clientErr <- err
+ }()
+
+ time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline
+ client.SetWriteDeadline(time.Unix(1, 0))
+
+ select {
+ case err = <-clientErr:
+ if err != os.ErrDeadlineExceeded {
+ t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
+ }
+ case <-time.After(100 * time.Millisecond):
+ t.Fatalf("timed out while waiting for write to cancel")
+ <-clientErr
+ }
+ <-serverDone
+}
+
+type CloseWriter interface {
+ CloseWrite() error
+}
+
+func TestEchoWithMessaging(t *testing.T) {
+ pipePath := randomPipePath()
+ l, err := (&namedpipe.ListenConfig{
+ MessageMode: true, // Use message mode so that CloseWrite() is supported
+ InputBufferSize: 65536, // Use 64KB buffers to improve performance
+ OutputBufferSize: 65536,
+ }).Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+
+ listenerDone := make(chan bool)
+ clientDone := make(chan bool)
+ go func() {
+ // server echo
+ conn, err := l.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+
+ time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent
+ _, err = io.Copy(conn, conn)
+ if err != nil {
+ t.Fatal(err)
+ }
+ conn.(CloseWriter).CloseWrite()
+ close(listenerDone)
+ }()
+ client, err := namedpipe.DialTimeout(pipePath, time.Second)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer client.Close()
+
+ go func() {
+ // client read back
+ bytes := make([]byte, 2)
+ n, e := client.Read(bytes)
+ if e != nil {
+ t.Fatal(e)
+ }
+ if n != 2 || bytes[0] != 0 || bytes[1] != 1 {
+ t.Fatalf("expected 2 bytes, got %v", n)
+ }
+ close(clientDone)
+ }()
+
+ payload := make([]byte, 2)
+ payload[0] = 0
+ payload[1] = 1
+
+ n, err := client.Write(payload)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != 2 {
+ t.Fatalf("expected 2 bytes, got %v", n)
+ }
+ client.(CloseWriter).CloseWrite()
+ <-listenerDone
+ <-clientDone
+}
+
+func TestConnectRace(t *testing.T) {
+ pipePath := randomPipePath()
+ l, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+ go func() {
+ for {
+ s, err := l.Accept()
+ if err == net.ErrClosed {
+ return
+ }
+
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.Close()
+ }
+ }()
+
+ for i := 0; i < 1000; i++ {
+ c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+ }
+}
+
+func TestMessageReadMode(t *testing.T) {
+ if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 {
+ t.Skipf("Skipping on Windows %d", maj)
+ }
+ var wg sync.WaitGroup
+ defer wg.Wait()
+ pipePath := randomPipePath()
+ l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+
+ msg := ([]byte)("hello world")
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ s, err := l.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = s.Write(msg)
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.Close()
+ }()
+
+ c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ mode := uint32(windows.PIPE_READMODE_MESSAGE)
+ err = windows.SetNamedPipeHandleState(c.(interface{ Handle() windows.Handle }).Handle(), &mode, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ch := make([]byte, 1)
+ var vmsg []byte
+ for {
+ n, err := c.Read(ch)
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != 1 {
+ t.Fatalf("expected 1, got %d", n)
+ }
+ vmsg = append(vmsg, ch[0])
+ }
+ if !bytes.Equal(msg, vmsg) {
+ t.Fatalf("expected %s, got %s", msg, vmsg)
+ }
+}
+
+func TestListenConnectRace(t *testing.T) {
+ if testing.Short() {
+ t.Skip("Skipping long race test")
+ }
+ pipePath := randomPipePath()
+ for i := 0; i < 50 && !t.Failed(); i++ {
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
+ if err == nil {
+ c.Close()
+ }
+ wg.Done()
+ }()
+ s, err := namedpipe.Listen(pipePath)
+ if err != nil {
+ t.Error(i, err)
+ } else {
+ s.Close()
+ }
+ wg.Wait()
+ }
+}