diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-10-08 20:19:57 -0600 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-10-10 23:30:06 -0600 |
commit | afe859468635c5c6a11266d266a4c210aeb67eac (patch) | |
tree | 4f1e1ee01f31cffafcc9e4804dcb431c0b7522c9 | |
parent | build: use better PnP enumeration in wg(8) (diff) | |
download | wireguard-windows-afe859468635c5c6a11266d266a4c210aeb67eac.tar.xz wireguard-windows-afe859468635c5c6a11266d266a4c210aeb67eac.zip |
memmod: hook RtlPcToFileHeader's invocation from GetModuleHandleEx
When GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS) is called
by cfgmgr32.dll's SwCreateDevice on the DLL's callback, it expects to
get the module of the DLL. But of course memory loaded modules means
there is none. This causes SwCreateDevice to fail.
GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS) internally
uses RtlPcToFileHeader. In turn, RtlPcToFileHeader looks things up in
the inverted function table, which has no stable interface across OS
releases. That means adding a proper module isn't going to work.
So instead we hook the IAT, so that we can intercept all calls to
RtlPcToFileHeader that come from GetModuleHandleEx's kernelbase.dll. If
the value to look up is within the range of a module we've memory
loaded, then we change the value to lookup to the hook function itself,
so that it winds up returning the main module.
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r-- | driver/memmod/memmod_windows.go | 84 | ||||
-rw-r--r-- | go.mod | 2 |
2 files changed, 85 insertions, 1 deletions
diff --git a/driver/memmod/memmod_windows.go b/driver/memmod/memmod_windows.go index 075c03a0..6ced43fe 100644 --- a/driver/memmod/memmod_windows.go +++ b/driver/memmod/memmod_windows.go @@ -8,6 +8,8 @@ package memmod import ( "errors" "fmt" + "strings" + "sync" "syscall" "unsafe" @@ -382,6 +384,76 @@ func (module *Module) buildNameExports() error { return nil } +type addressRange struct { + start uintptr + end uintptr +} + +var loadedAddressRanges []addressRange +var loadedAddressRangesMu sync.RWMutex +var haveHookedRtlPcToFileHeader sync.Once +var hookRtlPcToFileHeaderResult error + +func hookRtlPcToFileHeader() error { + var kernelBase windows.Handle + err := windows.GetModuleHandleEx(windows.GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, windows.StringToUTF16Ptr("kernelbase.dll"), &kernelBase) + if err != nil { + return err + } + imageBase := unsafe.Pointer(kernelBase) + dosHeader := (*IMAGE_DOS_HEADER)(imageBase) + ntHeaders := (*IMAGE_NT_HEADERS)(unsafe.Add(imageBase, dosHeader.E_lfanew)) + importsDirectory := ntHeaders.OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT] + importDescriptor := (*IMAGE_IMPORT_DESCRIPTOR)(unsafe.Add(imageBase, importsDirectory.VirtualAddress)) + for ; importDescriptor.Name != 0; importDescriptor = (*IMAGE_IMPORT_DESCRIPTOR)(unsafe.Add(unsafe.Pointer(importDescriptor), unsafe.Sizeof(*importDescriptor))) { + libraryName := windows.BytePtrToString((*byte)(unsafe.Add(imageBase, importDescriptor.Name))) + if strings.EqualFold(libraryName, "ntdll.dll") { + break + } + } + if importDescriptor.Name == 0 { + return errors.New("ntdll.dll not found") + } + originalThunk := (*uintptr)(unsafe.Add(imageBase, importDescriptor.OriginalFirstThunk())) + thunk := (*uintptr)(unsafe.Add(imageBase, importDescriptor.FirstThunk)) + for ; *originalThunk != 0; originalThunk = (*uintptr)(unsafe.Add(unsafe.Pointer(originalThunk), unsafe.Sizeof(*originalThunk))) { + if *originalThunk&IMAGE_ORDINAL_FLAG == 0 { + function := (*IMAGE_IMPORT_BY_NAME)(unsafe.Add(imageBase, *originalThunk)) + name := windows.BytePtrToString(&function.Name[0]) + if name == "RtlPcToFileHeader" { + break + } + } + thunk = (*uintptr)(unsafe.Add(unsafe.Pointer(thunk), unsafe.Sizeof(*thunk))) + } + if *originalThunk == 0 { + return errors.New("RtlPcToFileHeader not found") + } + var oldProtect uint32 + err = windows.VirtualProtect(uintptr(unsafe.Pointer(thunk)), unsafe.Sizeof(*thunk), windows.PAGE_READWRITE, &oldProtect) + if err != nil { + return err + } + originalRtlPcToFileHeader := *thunk + *thunk = windows.NewCallback(func(pcValue uintptr, baseOfImage *uintptr) uintptr { + loadedAddressRangesMu.RLock() + for i := range loadedAddressRanges { + if pcValue >= loadedAddressRanges[i].start && pcValue < loadedAddressRanges[i].end { + pcValue = *thunk + break + } + } + loadedAddressRangesMu.RUnlock() + ret, _, _ := syscall.Syscall(originalRtlPcToFileHeader, 2, pcValue, uintptr(unsafe.Pointer(baseOfImage)), 0) + return ret + }) + err = windows.VirtualProtect(uintptr(unsafe.Pointer(thunk)), unsafe.Sizeof(*thunk), oldProtect, &oldProtect) + if err != nil { + return err + } + return nil +} + // LoadLibrary loads module image to memory. func LoadLibrary(data []byte) (module *Module, err error) { addr := uintptr(unsafe.Pointer(&data[0])) @@ -513,6 +585,18 @@ func LoadLibrary(data []byte) (module *Module, err error) { // Register exception tables, if they exist. module.registerExceptionHandlers() + // Register function PCs. + loadedAddressRangesMu.Lock() + loadedAddressRanges = append(loadedAddressRanges, addressRange{module.codeBase, module.codeBase + alignedImageSize}) + loadedAddressRangesMu.Unlock() + haveHookedRtlPcToFileHeader.Do(func() { + hookRtlPcToFileHeaderResult = hookRtlPcToFileHeader() + }) + err = hookRtlPcToFileHeaderResult + if err != nil { + return + } + // TLS callbacks are executed BEFORE the main loading. module.executeTLS() @@ -1,6 +1,6 @@ module golang.zx2c4.com/wireguard/windows -go 1.16 +go 1.17 require ( github.com/lxn/walk v0.0.0-20210112085537-c389da54e794 |