summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-10-11 14:53:36 -0600
committerJason A. Donenfeld <Jason@zx2c4.com>2021-10-11 14:53:36 -0600
commit642a56e165e74a518fe986c2cf93dea62d6029b5 (patch)
tree12293a06953758f5660bb539869b05fb5d77d657
parentrwcancel: use unix.Poll again but bump x/sys so it uses ppoll under the hood (diff)
downloadwireguard-go-642a56e165e74a518fe986c2cf93dea62d6029b5.tar.xz
wireguard-go-642a56e165e74a518fe986c2cf93dea62d6029b5.zip
memmod: import from wireguard-windows
We'll eventually be getting rid of it here, but keep it sync'd up for now. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r--go.mod2
-rw-r--r--tun/wintun/memmod/memmod_windows.go124
-rw-r--r--tun/wintun/memmod/memmod_windows_32.go1
-rw-r--r--tun/wintun/memmod/memmod_windows_64.go1
-rw-r--r--tun/wintun/memmod/syscall_windows_32.go1
-rw-r--r--tun/wintun/memmod/syscall_windows_64.go1
6 files changed, 96 insertions, 34 deletions
diff --git a/go.mod b/go.mod
index 5d8388b..e543167 100644
--- a/go.mod
+++ b/go.mod
@@ -1,6 +1,6 @@
module golang.zx2c4.com/wireguard
-go 1.16
+go 1.17
require (
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519
diff --git a/tun/wintun/memmod/memmod_windows.go b/tun/wintun/memmod/memmod_windows.go
index 075c03a..da6ff9a 100644
--- a/tun/wintun/memmod/memmod_windows.go
+++ b/tun/wintun/memmod/memmod_windows.go
@@ -8,6 +8,8 @@ package memmod
import (
"errors"
"fmt"
+ "strings"
+ "sync"
"syscall"
"unsafe"
@@ -62,8 +64,7 @@ func (module *Module) copySections(address uintptr, size uintptr, oldHeaders *IM
dest = module.codeBase + uintptr(sections[i].VirtualAddress)
// NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used.
sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff))
- var dst []byte
- unsafeSlice(unsafe.Pointer(&dst), a2p(dest), int(sectionSize))
+ dst := unsafe.Slice((*byte)(a2p(dest)), sectionSize)
for j := range dst {
dst[j] = 0
}
@@ -245,11 +246,9 @@ func (module *Module) performBaseRelocation(delta uintptr) (relocated bool, err
for relocationHdr.VirtualAddress > 0 {
dest := module.codeBase + uintptr(relocationHdr.VirtualAddress)
- var relInfos []uint16
- unsafeSlice(
- unsafe.Pointer(&relInfos),
- a2p(uintptr(unsafe.Pointer(relocationHdr))+unsafe.Sizeof(*relocationHdr)),
- int((uintptr(relocationHdr.SizeOfBlock)-unsafe.Sizeof(*relocationHdr))/unsafe.Sizeof(relInfos[0])))
+ relInfos := unsafe.Slice(
+ (*uint16)(a2p(uintptr(unsafe.Pointer(relocationHdr))+unsafe.Sizeof(*relocationHdr))),
+ (uintptr(relocationHdr.SizeOfBlock)-unsafe.Sizeof(*relocationHdr))/unsafe.Sizeof(uint16(0)))
for _, relInfo := range relInfos {
// The upper 4 bits define the type of relocation.
relType := relInfo >> 12
@@ -370,10 +369,8 @@ func (module *Module) buildNameExports() error {
if exports.NumberOfNames == 0 {
return errors.New("No functions exported by name")
}
- var nameRefs []uint32
- unsafeSlice(unsafe.Pointer(&nameRefs), a2p(module.codeBase+uintptr(exports.AddressOfNames)), int(exports.NumberOfNames))
- var ordinals []uint16
- unsafeSlice(unsafe.Pointer(&ordinals), a2p(module.codeBase+uintptr(exports.AddressOfNameOrdinals)), int(exports.NumberOfNames))
+ nameRefs := unsafe.Slice((*uint32)(a2p(module.codeBase+uintptr(exports.AddressOfNames))), exports.NumberOfNames)
+ ordinals := unsafe.Slice((*uint16)(a2p(module.codeBase+uintptr(exports.AddressOfNameOrdinals))), exports.NumberOfNames)
module.nameExports = make(map[string]uint16)
for i := range nameRefs {
nameArray := windows.BytePtrToString((*byte)(a2p(module.codeBase + uintptr(nameRefs[i]))))
@@ -382,6 +379,76 @@ func (module *Module) buildNameExports() error {
return nil
}
+type addressRange struct {
+ start uintptr
+ end uintptr
+}
+
+var loadedAddressRanges []addressRange
+var loadedAddressRangesMu sync.RWMutex
+var haveHookedRtlPcToFileHeader sync.Once
+var hookRtlPcToFileHeaderResult error
+
+func hookRtlPcToFileHeader() error {
+ var kernelBase windows.Handle
+ err := windows.GetModuleHandleEx(windows.GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, windows.StringToUTF16Ptr("kernelbase.dll"), &kernelBase)
+ if err != nil {
+ return err
+ }
+ imageBase := unsafe.Pointer(kernelBase)
+ dosHeader := (*IMAGE_DOS_HEADER)(imageBase)
+ ntHeaders := (*IMAGE_NT_HEADERS)(unsafe.Add(imageBase, dosHeader.E_lfanew))
+ importsDirectory := ntHeaders.OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT]
+ importDescriptor := (*IMAGE_IMPORT_DESCRIPTOR)(unsafe.Add(imageBase, importsDirectory.VirtualAddress))
+ for ; importDescriptor.Name != 0; importDescriptor = (*IMAGE_IMPORT_DESCRIPTOR)(unsafe.Add(unsafe.Pointer(importDescriptor), unsafe.Sizeof(*importDescriptor))) {
+ libraryName := windows.BytePtrToString((*byte)(unsafe.Add(imageBase, importDescriptor.Name)))
+ if strings.EqualFold(libraryName, "ntdll.dll") {
+ break
+ }
+ }
+ if importDescriptor.Name == 0 {
+ return errors.New("ntdll.dll not found")
+ }
+ originalThunk := (*uintptr)(unsafe.Add(imageBase, importDescriptor.OriginalFirstThunk()))
+ thunk := (*uintptr)(unsafe.Add(imageBase, importDescriptor.FirstThunk))
+ for ; *originalThunk != 0; originalThunk = (*uintptr)(unsafe.Add(unsafe.Pointer(originalThunk), unsafe.Sizeof(*originalThunk))) {
+ if *originalThunk&IMAGE_ORDINAL_FLAG == 0 {
+ function := (*IMAGE_IMPORT_BY_NAME)(unsafe.Add(imageBase, *originalThunk))
+ name := windows.BytePtrToString(&function.Name[0])
+ if name == "RtlPcToFileHeader" {
+ break
+ }
+ }
+ thunk = (*uintptr)(unsafe.Add(unsafe.Pointer(thunk), unsafe.Sizeof(*thunk)))
+ }
+ if *originalThunk == 0 {
+ return errors.New("RtlPcToFileHeader not found")
+ }
+ var oldProtect uint32
+ err = windows.VirtualProtect(uintptr(unsafe.Pointer(thunk)), unsafe.Sizeof(*thunk), windows.PAGE_READWRITE, &oldProtect)
+ if err != nil {
+ return err
+ }
+ originalRtlPcToFileHeader := *thunk
+ *thunk = windows.NewCallback(func(pcValue uintptr, baseOfImage *uintptr) uintptr {
+ loadedAddressRangesMu.RLock()
+ for i := range loadedAddressRanges {
+ if pcValue >= loadedAddressRanges[i].start && pcValue < loadedAddressRanges[i].end {
+ pcValue = *thunk
+ break
+ }
+ }
+ loadedAddressRangesMu.RUnlock()
+ ret, _, _ := syscall.Syscall(originalRtlPcToFileHeader, 2, pcValue, uintptr(unsafe.Pointer(baseOfImage)), 0)
+ return ret
+ })
+ err = windows.VirtualProtect(uintptr(unsafe.Pointer(thunk)), unsafe.Sizeof(*thunk), oldProtect, &oldProtect)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
// LoadLibrary loads module image to memory.
func LoadLibrary(data []byte) (module *Module, err error) {
addr := uintptr(unsafe.Pointer(&data[0]))
@@ -513,6 +580,18 @@ func LoadLibrary(data []byte) (module *Module, err error) {
// Register exception tables, if they exist.
module.registerExceptionHandlers()
+ // Register function PCs.
+ loadedAddressRangesMu.Lock()
+ loadedAddressRanges = append(loadedAddressRanges, addressRange{module.codeBase, module.codeBase + alignedImageSize})
+ loadedAddressRangesMu.Unlock()
+ haveHookedRtlPcToFileHeader.Do(func() {
+ hookRtlPcToFileHeaderResult = hookRtlPcToFileHeader()
+ })
+ err = hookRtlPcToFileHeaderResult
+ if err != nil {
+ return
+ }
+
// TLS callbacks are executed BEFORE the main loading.
module.executeTLS()
@@ -610,26 +689,5 @@ func a2p(addr uintptr) unsafe.Pointer {
}
func memcpy(dst, src, size uintptr) {
- var d, s []byte
- unsafeSlice(unsafe.Pointer(&d), a2p(dst), int(size))
- unsafeSlice(unsafe.Pointer(&s), a2p(src), int(size))
- copy(d, s)
-}
-
-// unsafeSlice updates the slice slicePtr to be a slice
-// referencing the provided data with its length & capacity set to
-// lenCap.
-//
-// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
-// update callers to use unsafe.Slice instead of this.
-func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
- type sliceHeader struct {
- Data unsafe.Pointer
- Len int
- Cap int
- }
- h := (*sliceHeader)(slicePtr)
- h.Data = data
- h.Len = lenCap
- h.Cap = lenCap
+ copy(unsafe.Slice((*byte)(a2p(dst)), size), unsafe.Slice((*byte)(a2p(src)), size))
}
diff --git a/tun/wintun/memmod/memmod_windows_32.go b/tun/wintun/memmod/memmod_windows_32.go
index ac76bdc..75d7ca1 100644
--- a/tun/wintun/memmod/memmod_windows_32.go
+++ b/tun/wintun/memmod/memmod_windows_32.go
@@ -1,3 +1,4 @@
+//go:build (windows && 386) || (windows && arm)
// +build windows,386 windows,arm
/* SPDX-License-Identifier: MIT
diff --git a/tun/wintun/memmod/memmod_windows_64.go b/tun/wintun/memmod/memmod_windows_64.go
index a620368..09e6e73 100644
--- a/tun/wintun/memmod/memmod_windows_64.go
+++ b/tun/wintun/memmod/memmod_windows_64.go
@@ -1,3 +1,4 @@
+//go:build (windows && amd64) || (windows && arm64)
// +build windows,amd64 windows,arm64
/* SPDX-License-Identifier: MIT
diff --git a/tun/wintun/memmod/syscall_windows_32.go b/tun/wintun/memmod/syscall_windows_32.go
index 7abbac9..0072710 100644
--- a/tun/wintun/memmod/syscall_windows_32.go
+++ b/tun/wintun/memmod/syscall_windows_32.go
@@ -1,3 +1,4 @@
+//go:build (windows && 386) || (windows && arm)
// +build windows,386 windows,arm
/* SPDX-License-Identifier: MIT
diff --git a/tun/wintun/memmod/syscall_windows_64.go b/tun/wintun/memmod/syscall_windows_64.go
index 10c6533..b475202 100644
--- a/tun/wintun/memmod/syscall_windows_64.go
+++ b/tun/wintun/memmod/syscall_windows_64.go
@@ -1,3 +1,4 @@
+//go:build (windows && amd64) || (windows && arm64)
// +build windows,amd64 windows,arm64
/* SPDX-License-Identifier: MIT