aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/driver/dll_windows.go
diff options
context:
space:
mode:
Diffstat (limited to 'driver/dll_windows.go')
-rw-r--r--driver/dll_windows.go91
1 files changed, 91 insertions, 0 deletions
diff --git a/driver/dll_windows.go b/driver/dll_windows.go
new file mode 100644
index 00000000..5dcb849e
--- /dev/null
+++ b/driver/dll_windows.go
@@ -0,0 +1,91 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved.
+ */
+
+package driver
+
+import (
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+func newLazyDLL(name string, onLoad func(d *lazyDLL)) *lazyDLL {
+ return &lazyDLL{Name: name, onLoad: onLoad}
+}
+
+func (d *lazyDLL) NewProc(name string) *lazyProc {
+ return &lazyProc{dll: d, Name: name}
+}
+
+type lazyProc struct {
+ Name string
+ mu sync.Mutex
+ dll *lazyDLL
+ addr uintptr
+}
+
+func (p *lazyProc) Find() error {
+ if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr))) != nil {
+ return nil
+ }
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.addr != 0 {
+ return nil
+ }
+
+ err := p.dll.Load()
+ if err != nil {
+ return fmt.Errorf("Error loading %v DLL: %w", p.dll.Name, err)
+ }
+ addr, err := p.nameToAddr()
+ if err != nil {
+ return fmt.Errorf("Error getting %v address: %w", p.Name, err)
+ }
+
+ atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr)), unsafe.Pointer(addr))
+ return nil
+}
+
+func (p *lazyProc) Addr() uintptr {
+ err := p.Find()
+ if err != nil {
+ panic(err)
+ }
+ return p.addr
+}
+
+// Version returns the version of the driver DLL.
+func Version() string {
+ if modwireguard.Load() != nil {
+ return "unknown"
+ }
+ resInfo, err := windows.FindResource(modwireguard.Base, windows.ResourceID(1), windows.RT_VERSION)
+ if err != nil {
+ return "unknown"
+ }
+ data, err := windows.LoadResourceData(modwireguard.Base, resInfo)
+ if err != nil {
+ return "unknown"
+ }
+
+ var fixedInfo *windows.VS_FIXEDFILEINFO
+ fixedInfoLen := uint32(unsafe.Sizeof(*fixedInfo))
+ err = windows.VerQueryValue(unsafe.Pointer(&data[0]), `\`, unsafe.Pointer(&fixedInfo), &fixedInfoLen)
+ if err != nil {
+ return "unknown"
+ }
+ version := fmt.Sprintf("%d.%d", (fixedInfo.FileVersionMS>>16)&0xff, (fixedInfo.FileVersionMS>>0)&0xff)
+ if nextNibble := (fixedInfo.FileVersionLS >> 16) & 0xff; nextNibble != 0 {
+ version += fmt.Sprintf(".%d", nextNibble)
+ }
+ if nextNibble := (fixedInfo.FileVersionLS >> 0) & 0xff; nextNibble != 0 {
+ version += fmt.Sprintf(".%d", nextNibble)
+ }
+ return version
+}