From cea56d3a253408c4915ed3db5a15c499288cb89b Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Sat, 3 Aug 2019 09:25:35 +0200 Subject: elevate: do not show UAC prompt for frictionless UX --- elevate/mksyscall.go | 8 +++ elevate/shellexecute.go | 126 ++++++++++++++++++++++++++++++++++++++++++++ elevate/syscall_windows.go | 52 ++++++++++++++++++ elevate/zsyscall_windows.go | 109 ++++++++++++++++++++++++++++++++++++++ main.go | 3 +- 5 files changed, 297 insertions(+), 1 deletion(-) create mode 100644 elevate/mksyscall.go create mode 100644 elevate/shellexecute.go create mode 100644 elevate/syscall_windows.go create mode 100644 elevate/zsyscall_windows.go diff --git a/elevate/mksyscall.go b/elevate/mksyscall.go new file mode 100644 index 00000000..fda63f27 --- /dev/null +++ b/elevate/mksyscall.go @@ -0,0 +1,8 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package elevate + +//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go syscall_windows.go diff --git a/elevate/shellexecute.go b/elevate/shellexecute.go new file mode 100644 index 00000000..c3dc84eb --- /dev/null +++ b/elevate/shellexecute.go @@ -0,0 +1,126 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package elevate + +import ( + "path/filepath" + "runtime" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" +) + +const ( + releaseOffset = 2 + shellExecuteOffset = 9 + + cSEE_MASK_DEFAULT = 0 +) + +func ShellExecute(program string, arguments string, directory string, show int32) (err error) { + var ( + program16 *uint16 + arguments16 *uint16 + directory16 *uint16 + ) + + if len(program) > 0 { + program16, _ = windows.UTF16PtrFromString(program) + } + if len(arguments) > 0 { + arguments16, _ = windows.UTF16PtrFromString(arguments) + } + if len(directory) > 0 { + directory16, _ = windows.UTF16PtrFromString(directory) + } + + defer func() { + if err != nil { + err = windows.ShellExecute(0, windows.StringToUTF16Ptr("runas"), program16, arguments16, directory16, show) + } + }() + + processToken, err := windows.OpenCurrentProcessToken() + if err != nil { + return + } + defer processToken.Close() + if processToken.IsElevated() { + err = windows.ERROR_SUCCESS + return + } + + key, err := registry.OpenKey(registry.LOCAL_MACHINE, "SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\UAC\\COMAutoApprovalList", registry.QUERY_VALUE) + if err == nil { + var autoApproved uint64 + autoApproved, _, err = key.GetIntegerValue("{3E5FC7F9-9A51-4367-9063-A120244FBEC7}") + key.Close() + if err != nil { + return + } + if uint32(autoApproved) == 0 { + err = windows.ERROR_ACCESS_DENIED + return + } + } + + moduleHandle, err := getModuleHandle(nil) + if err != nil { + return + } + var dataTableEntry *cLDR_DATA_TABLE_ENTRY + if ret := ldrFindEntryForAddress(moduleHandle, &dataTableEntry); ret != 0 { + err = syscall.Errno(windows.ERROR_INTERNAL_ERROR) + return + } + var windowsDirectory [windows.MAX_PATH]uint16 + if _, err = getWindowsDirectory(&windowsDirectory[0], windows.MAX_PATH); err != nil { + return + } + originalPath := dataTableEntry.FullDllName.Buffer + explorerPath := windows.StringToUTF16Ptr(filepath.Join(windows.UTF16ToString(windowsDirectory[:]), "explorer.exe")) + rtlInitUnicodeString(&dataTableEntry.FullDllName, explorerPath) + defer func() { + rtlInitUnicodeString(&dataTableEntry.FullDllName, originalPath) + runtime.KeepAlive(explorerPath) + }() + + if err = coInitializeEx(0, cCOINIT_APARTMENTTHREADED); err == nil { + defer coUninitialize() + } + + var interfacePointer **[0xffff]uintptr + if err = coGetObject( + windows.StringToUTF16Ptr("Elevation:Administrator!new:{3E5FC7F9-9A51-4367-9063-A120244FBEC7}"), + &cBIND_OPTS3{ + cbStruct: uint32(unsafe.Sizeof(cBIND_OPTS3{})), + dwClassContext: cCLSCTX_LOCAL_SERVER, + }, + &windows.GUID{0x6EDD6D74, 0xC007, 0x4E75, [8]byte{0xB7, 0x6A, 0xE5, 0x74, 0x09, 0x95, 0xE2, 0x4C}}, + &interfacePointer, + ); err != nil { + return + } + + defer syscall.Syscall((*interfacePointer)[releaseOffset], 1, uintptr(unsafe.Pointer(interfacePointer)), 0, 0) + + if ret, _, _ := syscall.Syscall6((*interfacePointer)[shellExecuteOffset], 6, + uintptr(unsafe.Pointer(interfacePointer)), + uintptr(unsafe.Pointer(program16)), + uintptr(unsafe.Pointer(arguments16)), + uintptr(unsafe.Pointer(directory16)), + cSEE_MASK_DEFAULT, + uintptr(show), + ); ret != uintptr(windows.ERROR_SUCCESS) { + err = syscall.Errno(ret) + return + } + + err = nil + return +} diff --git a/elevate/syscall_windows.go b/elevate/syscall_windows.go new file mode 100644 index 00000000..c73be812 --- /dev/null +++ b/elevate/syscall_windows.go @@ -0,0 +1,52 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package elevate + +type cBIND_OPTS3 struct { + cbStruct uint32 + grfFlags uint32 + grfMode uint32 + dwTickCountDeadline uint32 + dwTrackFlags uint32 + dwClassContext uint32 + locale uint32 + pServerInfo *uintptr + hwnd *uintptr +} + +type cUNICODE_STRING struct { + Length uint16 + MaximumLength uint16 + Buffer *uint16 +} + +type cLDR_DATA_TABLE_ENTRY struct { + Reserved1 [2]uintptr + InMemoryOrderLinks [2]uintptr + Reserved2 [2]uintptr + DllBase uintptr + Reserved3 [2]uintptr + FullDllName cUNICODE_STRING + Reserved4 [8]byte + Reserved5 [3]uintptr + Reserved6 uintptr + TimeDateStamp uint32 +} + +const ( + cCLSCTX_LOCAL_SERVER = 4 + cCOINIT_APARTMENTTHREADED = 2 +) + +//sys getModuleHandle(moduleName *uint16) (moduleHandle uintptr, err error) [failretval==0] = kernel32.GetModuleHandleW +//sys getWindowsDirectory(windowsDirectory *uint16, inLen uint32) (outLen uint32, err error) [failretval==0] = kernel32.GetWindowsDirectoryW + +//sys rtlInitUnicodeString(destinationString *cUNICODE_STRING, sourceString *uint16) = ntdll.RtlInitUnicodeString +//sys ldrFindEntryForAddress(moduleHandle uintptr, entry **cLDR_DATA_TABLE_ENTRY) (ntstatus uint32) = ntdll.LdrFindEntryForAddress + +//sys coInitializeEx(reserved uintptr, coInit uint32) (ret error) = ole32.CoInitializeEx +//sys coUninitialize() = ole32.CoUninitialize +//sys coGetObject(name *uint16, bindOpts *cBIND_OPTS3, guid *windows.GUID, functionTable ***[0xffff]uintptr) (ret error) = ole32.CoGetObject diff --git a/elevate/zsyscall_windows.go b/elevate/zsyscall_windows.go new file mode 100644 index 00000000..b16b5f5d --- /dev/null +++ b/elevate/zsyscall_windows.go @@ -0,0 +1,109 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package elevate + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return nil + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + modntdll = windows.NewLazySystemDLL("ntdll.dll") + modole32 = windows.NewLazySystemDLL("ole32.dll") + + procGetModuleHandleW = modkernel32.NewProc("GetModuleHandleW") + procGetWindowsDirectoryW = modkernel32.NewProc("GetWindowsDirectoryW") + procRtlInitUnicodeString = modntdll.NewProc("RtlInitUnicodeString") + procLdrFindEntryForAddress = modntdll.NewProc("LdrFindEntryForAddress") + procCoInitializeEx = modole32.NewProc("CoInitializeEx") + procCoUninitialize = modole32.NewProc("CoUninitialize") + procCoGetObject = modole32.NewProc("CoGetObject") +) + +func getModuleHandle(moduleName *uint16) (moduleHandle uintptr, err error) { + r0, _, e1 := syscall.Syscall(procGetModuleHandleW.Addr(), 1, uintptr(unsafe.Pointer(moduleName)), 0, 0) + moduleHandle = uintptr(r0) + if moduleHandle == 0 { + if e1 != 0 { + err = errnoErr(e1) + } else { + err = syscall.EINVAL + } + } + return +} + +func getWindowsDirectory(windowsDirectory *uint16, inLen uint32) (outLen uint32, err error) { + r0, _, e1 := syscall.Syscall(procGetWindowsDirectoryW.Addr(), 2, uintptr(unsafe.Pointer(windowsDirectory)), uintptr(inLen), 0) + outLen = uint32(r0) + if outLen == 0 { + if e1 != 0 { + err = errnoErr(e1) + } else { + err = syscall.EINVAL + } + } + return +} + +func rtlInitUnicodeString(destinationString *cUNICODE_STRING, sourceString *uint16) { + syscall.Syscall(procRtlInitUnicodeString.Addr(), 2, uintptr(unsafe.Pointer(destinationString)), uintptr(unsafe.Pointer(sourceString)), 0) + return +} + +func ldrFindEntryForAddress(moduleHandle uintptr, entry **cLDR_DATA_TABLE_ENTRY) (ntstatus uint32) { + r0, _, _ := syscall.Syscall(procLdrFindEntryForAddress.Addr(), 2, uintptr(moduleHandle), uintptr(unsafe.Pointer(entry)), 0) + ntstatus = uint32(r0) + return +} + +func coInitializeEx(reserved uintptr, coInit uint32) (ret error) { + r0, _, _ := syscall.Syscall(procCoInitializeEx.Addr(), 2, uintptr(reserved), uintptr(coInit), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func coUninitialize() { + syscall.Syscall(procCoUninitialize.Addr(), 0, 0, 0, 0) + return +} + +func coGetObject(name *uint16, bindOpts *cBIND_OPTS3, guid *windows.GUID, functionTable ***[0xffff]uintptr) (ret error) { + r0, _, _ := syscall.Syscall6(procCoGetObject.Addr(), 4, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(bindOpts)), uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(functionTable)), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} diff --git a/main.go b/main.go index 348f1df4..813663a5 100644 --- a/main.go +++ b/main.go @@ -15,6 +15,7 @@ import ( "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/tun/wintun" + "golang.zx2c4.com/wireguard/windows/elevate" "golang.zx2c4.com/wireguard/windows/manager" "golang.zx2c4.com/wireguard/windows/ringlogger" "golang.zx2c4.com/wireguard/windows/services" @@ -84,7 +85,7 @@ func execElevatedManagerServiceInstaller() error { if err != nil { return err } - err = windows.ShellExecute(0, windows.StringToUTF16Ptr("runas"), windows.StringToUTF16Ptr(path), windows.StringToUTF16Ptr("/installmanagerservice"), nil, windows.SW_SHOW) + err = elevate.ShellExecute(path, "/installmanagerservice", "", windows.SW_SHOW) if err != nil { return err } -- cgit v1.2.3-59-g8ed1b