aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2019-08-05 10:38:04 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2019-08-05 20:12:19 +0200
commita1346c069998c66bdf27201937bed4ecd3c9ae9e (patch)
tree86dbfa6633a7ccdb508ba298f97a07a18b775f09
parentelevate: do not show UAC prompt for frictionless UX (diff)
downloadwireguard-windows-a1346c069998c66bdf27201937bed4ecd3c9ae9e.tar.xz
wireguard-windows-a1346c069998c66bdf27201937bed4ecd3c9ae9e.zip
elevate: do not rely on undocumented ldr function
Diffstat (limited to '')
-rw-r--r--elevate/shellexecute.go31
-rw-r--r--elevate/syscall_windows.go40
-rw-r--r--elevate/zsyscall_windows.go20
3 files changed, 72 insertions, 19 deletions
diff --git a/elevate/shellexecute.go b/elevate/shellexecute.go
index c3dc84eb..d784a2ed 100644
--- a/elevate/shellexecute.go
+++ b/elevate/shellexecute.go
@@ -22,6 +22,29 @@ const (
cSEE_MASK_DEFAULT = 0
)
+/* We could use the undocumented LdrFindEntryForAddress function instead, but that's undocumented, and we're trying
+ * to be as rock-solid as possible here. */
+func findCurrentDataTableEntry() (entry *cLDR_DATA_TABLE_ENTRY, err error) {
+ ourBase, err := getModuleHandle(nil) /* This is the same as peb->ImageBaseAddress, but that member is undocumented. */
+ if err != nil {
+ return
+ }
+ peb := rtlGetCurrentPeb()
+ if peb == nil || peb.Ldr == nil {
+ err = windows.ERROR_INVALID_ADDRESS
+ return
+ }
+ for cur := peb.Ldr.InMemoryOrderModuleList.Flink; cur != &peb.Ldr.InMemoryOrderModuleList; cur = cur.Flink {
+ entry = (*cLDR_DATA_TABLE_ENTRY)(unsafe.Pointer(uintptr(unsafe.Pointer(cur)) - unsafe.Offsetof(cLDR_DATA_TABLE_ENTRY{}.InMemoryOrderLinks)))
+ if entry.DllBase == ourBase {
+ return
+ }
+ }
+ entry = nil
+ err = windows.ERROR_OBJECT_NOT_FOUND
+ return
+}
+
func ShellExecute(program string, arguments string, directory string, show int32) (err error) {
var (
program16 *uint16
@@ -68,16 +91,10 @@ func ShellExecute(program string, arguments string, directory string, show int32
return
}
}
-
- moduleHandle, err := getModuleHandle(nil)
+ dataTableEntry, err := findCurrentDataTableEntry()
if err != nil {
return
}
- var dataTableEntry *cLDR_DATA_TABLE_ENTRY
- if ret := ldrFindEntryForAddress(moduleHandle, &dataTableEntry); ret != 0 {
- err = syscall.Errno(windows.ERROR_INTERNAL_ERROR)
- return
- }
var windowsDirectory [windows.MAX_PATH]uint16
if _, err = getWindowsDirectory(&windowsDirectory[0], windows.MAX_PATH); err != nil {
return
diff --git a/elevate/syscall_windows.go b/elevate/syscall_windows.go
index c73be812..c7def8fa 100644
--- a/elevate/syscall_windows.go
+++ b/elevate/syscall_windows.go
@@ -23,9 +23,17 @@ type cUNICODE_STRING struct {
Buffer *uint16
}
+type cLIST_ENTRY struct {
+ Flink *cLIST_ENTRY
+ Blink *cLIST_ENTRY
+}
+
+/* The below three structs have several "reserved" members. These are of course well-known and extensively reverse-
+ * engineered, but the below shows only the documented and therefore stable fields from Microsoft's winternl.h header */
+
type cLDR_DATA_TABLE_ENTRY struct {
Reserved1 [2]uintptr
- InMemoryOrderLinks [2]uintptr
+ InMemoryOrderLinks cLIST_ENTRY
Reserved2 [2]uintptr
DllBase uintptr
Reserved3 [2]uintptr
@@ -36,6 +44,34 @@ type cLDR_DATA_TABLE_ENTRY struct {
TimeDateStamp uint32
}
+type cPEB_LDR_DATA struct {
+ Reserved1 [8]byte
+ Reserved2 [3]uintptr
+ InMemoryOrderModuleList cLIST_ENTRY
+}
+
+type cPEB struct {
+ Reserved1 [2]byte
+ BeingDebugged byte
+ Reserved2 [1]byte
+ Reserved3 [2]uintptr
+ Ldr *cPEB_LDR_DATA
+ ProcessParameters uintptr
+ Reserved4 [3]uintptr
+ AtlThunkSListPtr uintptr
+ Reserved5 uintptr
+ Reserved6 uint32
+ Reserved7 uintptr
+ Reserved8 uint32
+ AtlThunkSListPtr32 uint32
+ Reserved9 [45]uintptr
+ Reserved10 [96]byte
+ PostProcessInitRoutine uintptr
+ Reserved11 [128]byte
+ Reserved12 [1]uintptr
+ SessionId uint32
+}
+
const (
cCLSCTX_LOCAL_SERVER = 4
cCOINIT_APARTMENTTHREADED = 2
@@ -45,7 +81,7 @@ const (
//sys getWindowsDirectory(windowsDirectory *uint16, inLen uint32) (outLen uint32, err error) [failretval==0] = kernel32.GetWindowsDirectoryW
//sys rtlInitUnicodeString(destinationString *cUNICODE_STRING, sourceString *uint16) = ntdll.RtlInitUnicodeString
-//sys ldrFindEntryForAddress(moduleHandle uintptr, entry **cLDR_DATA_TABLE_ENTRY) (ntstatus uint32) = ntdll.LdrFindEntryForAddress
+//sys rtlGetCurrentPeb() (peb *cPEB) = ntdll.RtlGetCurrentPeb
//sys coInitializeEx(reserved uintptr, coInit uint32) (ret error) = ole32.CoInitializeEx
//sys coUninitialize() = ole32.CoUninitialize
diff --git a/elevate/zsyscall_windows.go b/elevate/zsyscall_windows.go
index b16b5f5d..a3c5400d 100644
--- a/elevate/zsyscall_windows.go
+++ b/elevate/zsyscall_windows.go
@@ -41,13 +41,13 @@ var (
modntdll = windows.NewLazySystemDLL("ntdll.dll")
modole32 = windows.NewLazySystemDLL("ole32.dll")
- procGetModuleHandleW = modkernel32.NewProc("GetModuleHandleW")
- procGetWindowsDirectoryW = modkernel32.NewProc("GetWindowsDirectoryW")
- procRtlInitUnicodeString = modntdll.NewProc("RtlInitUnicodeString")
- procLdrFindEntryForAddress = modntdll.NewProc("LdrFindEntryForAddress")
- procCoInitializeEx = modole32.NewProc("CoInitializeEx")
- procCoUninitialize = modole32.NewProc("CoUninitialize")
- procCoGetObject = modole32.NewProc("CoGetObject")
+ procGetModuleHandleW = modkernel32.NewProc("GetModuleHandleW")
+ procGetWindowsDirectoryW = modkernel32.NewProc("GetWindowsDirectoryW")
+ procRtlInitUnicodeString = modntdll.NewProc("RtlInitUnicodeString")
+ procRtlGetCurrentPeb = modntdll.NewProc("RtlGetCurrentPeb")
+ procCoInitializeEx = modole32.NewProc("CoInitializeEx")
+ procCoUninitialize = modole32.NewProc("CoUninitialize")
+ procCoGetObject = modole32.NewProc("CoGetObject")
)
func getModuleHandle(moduleName *uint16) (moduleHandle uintptr, err error) {
@@ -81,9 +81,9 @@ func rtlInitUnicodeString(destinationString *cUNICODE_STRING, sourceString *uint
return
}
-func ldrFindEntryForAddress(moduleHandle uintptr, entry **cLDR_DATA_TABLE_ENTRY) (ntstatus uint32) {
- r0, _, _ := syscall.Syscall(procLdrFindEntryForAddress.Addr(), 2, uintptr(moduleHandle), uintptr(unsafe.Pointer(entry)), 0)
- ntstatus = uint32(r0)
+func rtlGetCurrentPeb() (peb *cPEB) {
+ r0, _, _ := syscall.Syscall(procRtlGetCurrentPeb.Addr(), 0, 0, 0, 0)
+ peb = (*cPEB)(unsafe.Pointer(r0))
return
}