aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2019-08-03 09:25:35 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2019-08-05 20:12:19 +0200
commitcea56d3a253408c4915ed3db5a15c499288cb89b (patch)
tree094f6a07f7619ad83c7679388a32835418fb25bc
parentui: remove SetFocus hack from EditDialog (diff)
downloadwireguard-windows-cea56d3a253408c4915ed3db5a15c499288cb89b.tar.xz
wireguard-windows-cea56d3a253408c4915ed3db5a15c499288cb89b.zip
elevate: do not show UAC prompt for frictionless UX
-rw-r--r--elevate/mksyscall.go8
-rw-r--r--elevate/shellexecute.go126
-rw-r--r--elevate/syscall_windows.go52
-rw-r--r--elevate/zsyscall_windows.go109
-rw-r--r--main.go3
5 files changed, 297 insertions, 1 deletions
diff --git a/elevate/mksyscall.go b/elevate/mksyscall.go
new file mode 100644
index 00000000..fda63f27
--- /dev/null
+++ b/elevate/mksyscall.go
@@ -0,0 +1,8 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package elevate
+
+//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go syscall_windows.go
diff --git a/elevate/shellexecute.go b/elevate/shellexecute.go
new file mode 100644
index 00000000..c3dc84eb
--- /dev/null
+++ b/elevate/shellexecute.go
@@ -0,0 +1,126 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package elevate
+
+import (
+ "path/filepath"
+ "runtime"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+ "golang.org/x/sys/windows/registry"
+)
+
+const (
+ releaseOffset = 2
+ shellExecuteOffset = 9
+
+ cSEE_MASK_DEFAULT = 0
+)
+
+func ShellExecute(program string, arguments string, directory string, show int32) (err error) {
+ var (
+ program16 *uint16
+ arguments16 *uint16
+ directory16 *uint16
+ )
+
+ if len(program) > 0 {
+ program16, _ = windows.UTF16PtrFromString(program)
+ }
+ if len(arguments) > 0 {
+ arguments16, _ = windows.UTF16PtrFromString(arguments)
+ }
+ if len(directory) > 0 {
+ directory16, _ = windows.UTF16PtrFromString(directory)
+ }
+
+ defer func() {
+ if err != nil {
+ err = windows.ShellExecute(0, windows.StringToUTF16Ptr("runas"), program16, arguments16, directory16, show)
+ }
+ }()
+
+ processToken, err := windows.OpenCurrentProcessToken()
+ if err != nil {
+ return
+ }
+ defer processToken.Close()
+ if processToken.IsElevated() {
+ err = windows.ERROR_SUCCESS
+ return
+ }
+
+ key, err := registry.OpenKey(registry.LOCAL_MACHINE, "SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\UAC\\COMAutoApprovalList", registry.QUERY_VALUE)
+ if err == nil {
+ var autoApproved uint64
+ autoApproved, _, err = key.GetIntegerValue("{3E5FC7F9-9A51-4367-9063-A120244FBEC7}")
+ key.Close()
+ if err != nil {
+ return
+ }
+ if uint32(autoApproved) == 0 {
+ err = windows.ERROR_ACCESS_DENIED
+ return
+ }
+ }
+
+ moduleHandle, err := getModuleHandle(nil)
+ if err != nil {
+ return
+ }
+ var dataTableEntry *cLDR_DATA_TABLE_ENTRY
+ if ret := ldrFindEntryForAddress(moduleHandle, &dataTableEntry); ret != 0 {
+ err = syscall.Errno(windows.ERROR_INTERNAL_ERROR)
+ return
+ }
+ var windowsDirectory [windows.MAX_PATH]uint16
+ if _, err = getWindowsDirectory(&windowsDirectory[0], windows.MAX_PATH); err != nil {
+ return
+ }
+ originalPath := dataTableEntry.FullDllName.Buffer
+ explorerPath := windows.StringToUTF16Ptr(filepath.Join(windows.UTF16ToString(windowsDirectory[:]), "explorer.exe"))
+ rtlInitUnicodeString(&dataTableEntry.FullDllName, explorerPath)
+ defer func() {
+ rtlInitUnicodeString(&dataTableEntry.FullDllName, originalPath)
+ runtime.KeepAlive(explorerPath)
+ }()
+
+ if err = coInitializeEx(0, cCOINIT_APARTMENTTHREADED); err == nil {
+ defer coUninitialize()
+ }
+
+ var interfacePointer **[0xffff]uintptr
+ if err = coGetObject(
+ windows.StringToUTF16Ptr("Elevation:Administrator!new:{3E5FC7F9-9A51-4367-9063-A120244FBEC7}"),
+ &cBIND_OPTS3{
+ cbStruct: uint32(unsafe.Sizeof(cBIND_OPTS3{})),
+ dwClassContext: cCLSCTX_LOCAL_SERVER,
+ },
+ &windows.GUID{0x6EDD6D74, 0xC007, 0x4E75, [8]byte{0xB7, 0x6A, 0xE5, 0x74, 0x09, 0x95, 0xE2, 0x4C}},
+ &interfacePointer,
+ ); err != nil {
+ return
+ }
+
+ defer syscall.Syscall((*interfacePointer)[releaseOffset], 1, uintptr(unsafe.Pointer(interfacePointer)), 0, 0)
+
+ if ret, _, _ := syscall.Syscall6((*interfacePointer)[shellExecuteOffset], 6,
+ uintptr(unsafe.Pointer(interfacePointer)),
+ uintptr(unsafe.Pointer(program16)),
+ uintptr(unsafe.Pointer(arguments16)),
+ uintptr(unsafe.Pointer(directory16)),
+ cSEE_MASK_DEFAULT,
+ uintptr(show),
+ ); ret != uintptr(windows.ERROR_SUCCESS) {
+ err = syscall.Errno(ret)
+ return
+ }
+
+ err = nil
+ return
+}
diff --git a/elevate/syscall_windows.go b/elevate/syscall_windows.go
new file mode 100644
index 00000000..c73be812
--- /dev/null
+++ b/elevate/syscall_windows.go
@@ -0,0 +1,52 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package elevate
+
+type cBIND_OPTS3 struct {
+ cbStruct uint32
+ grfFlags uint32
+ grfMode uint32
+ dwTickCountDeadline uint32
+ dwTrackFlags uint32
+ dwClassContext uint32
+ locale uint32
+ pServerInfo *uintptr
+ hwnd *uintptr
+}
+
+type cUNICODE_STRING struct {
+ Length uint16
+ MaximumLength uint16
+ Buffer *uint16
+}
+
+type cLDR_DATA_TABLE_ENTRY struct {
+ Reserved1 [2]uintptr
+ InMemoryOrderLinks [2]uintptr
+ Reserved2 [2]uintptr
+ DllBase uintptr
+ Reserved3 [2]uintptr
+ FullDllName cUNICODE_STRING
+ Reserved4 [8]byte
+ Reserved5 [3]uintptr
+ Reserved6 uintptr
+ TimeDateStamp uint32
+}
+
+const (
+ cCLSCTX_LOCAL_SERVER = 4
+ cCOINIT_APARTMENTTHREADED = 2
+)
+
+//sys getModuleHandle(moduleName *uint16) (moduleHandle uintptr, err error) [failretval==0] = kernel32.GetModuleHandleW
+//sys getWindowsDirectory(windowsDirectory *uint16, inLen uint32) (outLen uint32, err error) [failretval==0] = kernel32.GetWindowsDirectoryW
+
+//sys rtlInitUnicodeString(destinationString *cUNICODE_STRING, sourceString *uint16) = ntdll.RtlInitUnicodeString
+//sys ldrFindEntryForAddress(moduleHandle uintptr, entry **cLDR_DATA_TABLE_ENTRY) (ntstatus uint32) = ntdll.LdrFindEntryForAddress
+
+//sys coInitializeEx(reserved uintptr, coInit uint32) (ret error) = ole32.CoInitializeEx
+//sys coUninitialize() = ole32.CoUninitialize
+//sys coGetObject(name *uint16, bindOpts *cBIND_OPTS3, guid *windows.GUID, functionTable ***[0xffff]uintptr) (ret error) = ole32.CoGetObject
diff --git a/elevate/zsyscall_windows.go b/elevate/zsyscall_windows.go
new file mode 100644
index 00000000..b16b5f5d
--- /dev/null
+++ b/elevate/zsyscall_windows.go
@@ -0,0 +1,109 @@
+// Code generated by 'go generate'; DO NOT EDIT.
+
+package elevate
+
+import (
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+var _ unsafe.Pointer
+
+// Do the interface allocations only once for common
+// Errno values.
+const (
+ errnoERROR_IO_PENDING = 997
+)
+
+var (
+ errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
+)
+
+// errnoErr returns common boxed Errno values, to prevent
+// allocations at runtime.
+func errnoErr(e syscall.Errno) error {
+ switch e {
+ case 0:
+ return nil
+ case errnoERROR_IO_PENDING:
+ return errERROR_IO_PENDING
+ }
+ // TODO: add more here, after collecting data on the common
+ // error values see on Windows. (perhaps when running
+ // all.bat?)
+ return e
+}
+
+var (
+ modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
+ modntdll = windows.NewLazySystemDLL("ntdll.dll")
+ modole32 = windows.NewLazySystemDLL("ole32.dll")
+
+ procGetModuleHandleW = modkernel32.NewProc("GetModuleHandleW")
+ procGetWindowsDirectoryW = modkernel32.NewProc("GetWindowsDirectoryW")
+ procRtlInitUnicodeString = modntdll.NewProc("RtlInitUnicodeString")
+ procLdrFindEntryForAddress = modntdll.NewProc("LdrFindEntryForAddress")
+ procCoInitializeEx = modole32.NewProc("CoInitializeEx")
+ procCoUninitialize = modole32.NewProc("CoUninitialize")
+ procCoGetObject = modole32.NewProc("CoGetObject")
+)
+
+func getModuleHandle(moduleName *uint16) (moduleHandle uintptr, err error) {
+ r0, _, e1 := syscall.Syscall(procGetModuleHandleW.Addr(), 1, uintptr(unsafe.Pointer(moduleName)), 0, 0)
+ moduleHandle = uintptr(r0)
+ if moduleHandle == 0 {
+ if e1 != 0 {
+ err = errnoErr(e1)
+ } else {
+ err = syscall.EINVAL
+ }
+ }
+ return
+}
+
+func getWindowsDirectory(windowsDirectory *uint16, inLen uint32) (outLen uint32, err error) {
+ r0, _, e1 := syscall.Syscall(procGetWindowsDirectoryW.Addr(), 2, uintptr(unsafe.Pointer(windowsDirectory)), uintptr(inLen), 0)
+ outLen = uint32(r0)
+ if outLen == 0 {
+ if e1 != 0 {
+ err = errnoErr(e1)
+ } else {
+ err = syscall.EINVAL
+ }
+ }
+ return
+}
+
+func rtlInitUnicodeString(destinationString *cUNICODE_STRING, sourceString *uint16) {
+ syscall.Syscall(procRtlInitUnicodeString.Addr(), 2, uintptr(unsafe.Pointer(destinationString)), uintptr(unsafe.Pointer(sourceString)), 0)
+ return
+}
+
+func ldrFindEntryForAddress(moduleHandle uintptr, entry **cLDR_DATA_TABLE_ENTRY) (ntstatus uint32) {
+ r0, _, _ := syscall.Syscall(procLdrFindEntryForAddress.Addr(), 2, uintptr(moduleHandle), uintptr(unsafe.Pointer(entry)), 0)
+ ntstatus = uint32(r0)
+ return
+}
+
+func coInitializeEx(reserved uintptr, coInit uint32) (ret error) {
+ r0, _, _ := syscall.Syscall(procCoInitializeEx.Addr(), 2, uintptr(reserved), uintptr(coInit), 0)
+ if r0 != 0 {
+ ret = syscall.Errno(r0)
+ }
+ return
+}
+
+func coUninitialize() {
+ syscall.Syscall(procCoUninitialize.Addr(), 0, 0, 0, 0)
+ return
+}
+
+func coGetObject(name *uint16, bindOpts *cBIND_OPTS3, guid *windows.GUID, functionTable ***[0xffff]uintptr) (ret error) {
+ r0, _, _ := syscall.Syscall6(procCoGetObject.Addr(), 4, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(bindOpts)), uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(functionTable)), 0, 0)
+ if r0 != 0 {
+ ret = syscall.Errno(r0)
+ }
+ return
+}
diff --git a/main.go b/main.go
index 348f1df4..813663a5 100644
--- a/main.go
+++ b/main.go
@@ -15,6 +15,7 @@ import (
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/tun/wintun"
+ "golang.zx2c4.com/wireguard/windows/elevate"
"golang.zx2c4.com/wireguard/windows/manager"
"golang.zx2c4.com/wireguard/windows/ringlogger"
"golang.zx2c4.com/wireguard/windows/services"
@@ -84,7 +85,7 @@ func execElevatedManagerServiceInstaller() error {
if err != nil {
return err
}
- err = windows.ShellExecute(0, windows.StringToUTF16Ptr("runas"), windows.StringToUTF16Ptr(path), windows.StringToUTF16Ptr("/installmanagerservice"), nil, windows.SW_SHOW)
+ err = elevate.ShellExecute(path, "/installmanagerservice", "", windows.SW_SHOW)
if err != nil {
return err
}