aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-10-08 20:19:57 -0600
committerJason A. Donenfeld <Jason@zx2c4.com>2021-10-10 23:30:06 -0600
commitafe859468635c5c6a11266d266a4c210aeb67eac (patch)
tree4f1e1ee01f31cffafcc9e4804dcb431c0b7522c9
parentbuild: use better PnP enumeration in wg(8) (diff)
downloadwireguard-windows-afe859468635c5c6a11266d266a4c210aeb67eac.tar.xz
wireguard-windows-afe859468635c5c6a11266d266a4c210aeb67eac.zip
memmod: hook RtlPcToFileHeader's invocation from GetModuleHandleEx
When GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS) is called by cfgmgr32.dll's SwCreateDevice on the DLL's callback, it expects to get the module of the DLL. But of course memory loaded modules means there is none. This causes SwCreateDevice to fail. GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS) internally uses RtlPcToFileHeader. In turn, RtlPcToFileHeader looks things up in the inverted function table, which has no stable interface across OS releases. That means adding a proper module isn't going to work. So instead we hook the IAT, so that we can intercept all calls to RtlPcToFileHeader that come from GetModuleHandleEx's kernelbase.dll. If the value to look up is within the range of a module we've memory loaded, then we change the value to lookup to the hook function itself, so that it winds up returning the main module. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r--driver/memmod/memmod_windows.go84
-rw-r--r--go.mod2
2 files changed, 85 insertions, 1 deletions
diff --git a/driver/memmod/memmod_windows.go b/driver/memmod/memmod_windows.go
index 075c03a0..6ced43fe 100644
--- a/driver/memmod/memmod_windows.go
+++ b/driver/memmod/memmod_windows.go
@@ -8,6 +8,8 @@ package memmod
import (
"errors"
"fmt"
+ "strings"
+ "sync"
"syscall"
"unsafe"
@@ -382,6 +384,76 @@ func (module *Module) buildNameExports() error {
return nil
}
+type addressRange struct {
+ start uintptr
+ end uintptr
+}
+
+var loadedAddressRanges []addressRange
+var loadedAddressRangesMu sync.RWMutex
+var haveHookedRtlPcToFileHeader sync.Once
+var hookRtlPcToFileHeaderResult error
+
+func hookRtlPcToFileHeader() error {
+ var kernelBase windows.Handle
+ err := windows.GetModuleHandleEx(windows.GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, windows.StringToUTF16Ptr("kernelbase.dll"), &kernelBase)
+ if err != nil {
+ return err
+ }
+ imageBase := unsafe.Pointer(kernelBase)
+ dosHeader := (*IMAGE_DOS_HEADER)(imageBase)
+ ntHeaders := (*IMAGE_NT_HEADERS)(unsafe.Add(imageBase, dosHeader.E_lfanew))
+ importsDirectory := ntHeaders.OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT]
+ importDescriptor := (*IMAGE_IMPORT_DESCRIPTOR)(unsafe.Add(imageBase, importsDirectory.VirtualAddress))
+ for ; importDescriptor.Name != 0; importDescriptor = (*IMAGE_IMPORT_DESCRIPTOR)(unsafe.Add(unsafe.Pointer(importDescriptor), unsafe.Sizeof(*importDescriptor))) {
+ libraryName := windows.BytePtrToString((*byte)(unsafe.Add(imageBase, importDescriptor.Name)))
+ if strings.EqualFold(libraryName, "ntdll.dll") {
+ break
+ }
+ }
+ if importDescriptor.Name == 0 {
+ return errors.New("ntdll.dll not found")
+ }
+ originalThunk := (*uintptr)(unsafe.Add(imageBase, importDescriptor.OriginalFirstThunk()))
+ thunk := (*uintptr)(unsafe.Add(imageBase, importDescriptor.FirstThunk))
+ for ; *originalThunk != 0; originalThunk = (*uintptr)(unsafe.Add(unsafe.Pointer(originalThunk), unsafe.Sizeof(*originalThunk))) {
+ if *originalThunk&IMAGE_ORDINAL_FLAG == 0 {
+ function := (*IMAGE_IMPORT_BY_NAME)(unsafe.Add(imageBase, *originalThunk))
+ name := windows.BytePtrToString(&function.Name[0])
+ if name == "RtlPcToFileHeader" {
+ break
+ }
+ }
+ thunk = (*uintptr)(unsafe.Add(unsafe.Pointer(thunk), unsafe.Sizeof(*thunk)))
+ }
+ if *originalThunk == 0 {
+ return errors.New("RtlPcToFileHeader not found")
+ }
+ var oldProtect uint32
+ err = windows.VirtualProtect(uintptr(unsafe.Pointer(thunk)), unsafe.Sizeof(*thunk), windows.PAGE_READWRITE, &oldProtect)
+ if err != nil {
+ return err
+ }
+ originalRtlPcToFileHeader := *thunk
+ *thunk = windows.NewCallback(func(pcValue uintptr, baseOfImage *uintptr) uintptr {
+ loadedAddressRangesMu.RLock()
+ for i := range loadedAddressRanges {
+ if pcValue >= loadedAddressRanges[i].start && pcValue < loadedAddressRanges[i].end {
+ pcValue = *thunk
+ break
+ }
+ }
+ loadedAddressRangesMu.RUnlock()
+ ret, _, _ := syscall.Syscall(originalRtlPcToFileHeader, 2, pcValue, uintptr(unsafe.Pointer(baseOfImage)), 0)
+ return ret
+ })
+ err = windows.VirtualProtect(uintptr(unsafe.Pointer(thunk)), unsafe.Sizeof(*thunk), oldProtect, &oldProtect)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
// LoadLibrary loads module image to memory.
func LoadLibrary(data []byte) (module *Module, err error) {
addr := uintptr(unsafe.Pointer(&data[0]))
@@ -513,6 +585,18 @@ func LoadLibrary(data []byte) (module *Module, err error) {
// Register exception tables, if they exist.
module.registerExceptionHandlers()
+ // Register function PCs.
+ loadedAddressRangesMu.Lock()
+ loadedAddressRanges = append(loadedAddressRanges, addressRange{module.codeBase, module.codeBase + alignedImageSize})
+ loadedAddressRangesMu.Unlock()
+ haveHookedRtlPcToFileHeader.Do(func() {
+ hookRtlPcToFileHeaderResult = hookRtlPcToFileHeader()
+ })
+ err = hookRtlPcToFileHeaderResult
+ if err != nil {
+ return
+ }
+
// TLS callbacks are executed BEFORE the main loading.
module.executeTLS()
diff --git a/go.mod b/go.mod
index db521359..84d14667 100644
--- a/go.mod
+++ b/go.mod
@@ -1,6 +1,6 @@
module golang.zx2c4.com/wireguard/windows
-go 1.16
+go 1.17
require (
github.com/lxn/walk v0.0.0-20210112085537-c389da54e794