aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2019-08-30 13:21:47 -0600
committerJason A. Donenfeld <Jason@zx2c4.com>2019-08-30 13:21:47 -0600
commite4b957183c4a330f020f5188f3b30b59355efb80 (patch)
treed6006fcdd00f381eefdcddac8b21235668f55d65
parentwintun: put mutex into private namespace (diff)
downloadwireguard-go-e4b957183c4a330f020f5188f3b30b59355efb80.tar.xz
wireguard-go-e4b957183c4a330f020f5188f3b30b59355efb80.zip
winpipe: enforce ownership of client connection
-rw-r--r--ipc/winpipe/pipe.go22
-rw-r--r--ipc/winpipe/sd.go15
-rw-r--r--ipc/winpipe/zsyscall_windows.go16
3 files changed, 46 insertions, 7 deletions
diff --git a/ipc/winpipe/pipe.go b/ipc/winpipe/pipe.go
index 1e99a93..39ccfa4 100644
--- a/ipc/winpipe/pipe.go
+++ b/ipc/winpipe/pipe.go
@@ -211,7 +211,7 @@ func tryDialPipe(ctx context.Context, path *string) (syscall.Handle, error) {
// DialPipe connects to a named pipe by path, timing out if the connection
// takes longer than the specified duration. If timeout is nil, then we use
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
-func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
+func DialPipe(path string, timeout *time.Duration, expectedOwner *syscall.SID) (net.Conn, error) {
var absTimeout time.Time
if timeout != nil {
absTimeout = time.Now().Add(*timeout)
@@ -219,7 +219,7 @@ func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
absTimeout = time.Now().Add(time.Second * 2)
}
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
- conn, err := DialPipeContext(ctx, path)
+ conn, err := DialPipeContext(ctx, path, expectedOwner)
if err == context.DeadlineExceeded {
return nil, ErrTimeout
}
@@ -228,7 +228,7 @@ func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
// cancellation or timeout.
-func DialPipeContext(ctx context.Context, path string) (net.Conn, error) {
+func DialPipeContext(ctx context.Context, path string, expectedOwner *syscall.SID) (net.Conn, error) {
var err error
var h syscall.Handle
h, err = tryDialPipe(ctx, &path)
@@ -236,9 +236,25 @@ func DialPipeContext(ctx context.Context, path string) (net.Conn, error) {
return nil, err
}
+ if expectedOwner != nil {
+ var realOwner *syscall.SID
+ var realSd uintptr
+ err = getSecurityInfo(h, SE_FILE_OBJECT, OWNER_SECURITY_INFORMATION, &realOwner, nil, nil, nil, &realSd)
+ if err != nil {
+ syscall.Close(h)
+ return nil, err
+ }
+ defer localFree(realSd)
+ if !equalSid(realOwner, expectedOwner) {
+ syscall.Close(h)
+ return nil, syscall.ERROR_ACCESS_DENIED
+ }
+ }
+
var flags uint32
err = getNamedPipeInfo(h, &flags, nil, nil, nil)
if err != nil {
+ syscall.Close(h)
return nil, err
}
diff --git a/ipc/winpipe/sd.go b/ipc/winpipe/sd.go
index 75686b2..4456917 100644
--- a/ipc/winpipe/sd.go
+++ b/ipc/winpipe/sd.go
@@ -12,9 +12,16 @@ import (
"unsafe"
)
-//sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW
-//sys localFree(mem uintptr) = LocalFree
-//sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
+//sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW
+//sys localFree(mem uintptr) = LocalFree
+//sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
+//sys getSecurityInfo(handle syscall.Handle, objectType uint32, securityInformation uint32, owner **syscall.SID, group **syscall.SID, dacl *uintptr, sacl *uintptr, sd *uintptr) (ret error) = advapi32.GetSecurityInfo
+//sys equalSid(sid1 *syscall.SID, sid2 *syscall.SID) (isEqual bool) = advapi32.EqualSid
+
+const (
+ SE_FILE_OBJECT = 1
+ OWNER_SECURITY_INFORMATION = 1
+)
func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
var sdBuffer uintptr
@@ -26,4 +33,4 @@ func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
sd := make([]byte, getSecurityDescriptorLength(sdBuffer))
copy(sd, (*[0xffff]byte)(unsafe.Pointer(sdBuffer))[:len(sd)])
return sd, nil
-}
+} \ No newline at end of file
diff --git a/ipc/winpipe/zsyscall_windows.go b/ipc/winpipe/zsyscall_windows.go
index b8eedb4..ecf3e84 100644
--- a/ipc/winpipe/zsyscall_windows.go
+++ b/ipc/winpipe/zsyscall_windows.go
@@ -55,6 +55,8 @@ var (
procConvertStringSecurityDescriptorToSecurityDescriptorW = modadvapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
procLocalFree = modkernel32.NewProc("LocalFree")
procGetSecurityDescriptorLength = modadvapi32.NewProc("GetSecurityDescriptorLength")
+ procGetSecurityInfo = modadvapi32.NewProc("GetSecurityInfo")
+ procEqualSid = modadvapi32.NewProc("EqualSid")
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
@@ -206,6 +208,20 @@ func getSecurityDescriptorLength(sd uintptr) (len uint32) {
return
}
+func getSecurityInfo(handle syscall.Handle, objectType uint32, securityInformation uint32, owner **syscall.SID, group **syscall.SID, dacl *uintptr, sacl *uintptr, sd *uintptr) (ret error) {
+ r0, _, _ := syscall.Syscall9(procGetSecurityInfo.Addr(), 8, uintptr(handle), uintptr(objectType), uintptr(securityInformation), uintptr(unsafe.Pointer(owner)), uintptr(unsafe.Pointer(group)), uintptr(unsafe.Pointer(dacl)), uintptr(unsafe.Pointer(sacl)), uintptr(unsafe.Pointer(sd)), 0)
+ if r0 != 0 {
+ ret = syscall.Errno(r0)
+ }
+ return
+}
+
+func equalSid(sid1 *syscall.SID, sid2 *syscall.SID) (isEqual bool) {
+ r0, _, _ := syscall.Syscall(procEqualSid.Addr(), 2, uintptr(unsafe.Pointer(sid1)), uintptr(unsafe.Pointer(sid2)), 0)
+ isEqual = r0 != 0
+ return
+}
+
func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
if r1 == 0 {