diff options
Diffstat (limited to 'driver/dll_windows.go')
-rw-r--r-- | driver/dll_windows.go | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/driver/dll_windows.go b/driver/dll_windows.go new file mode 100644 index 00000000..5dcb849e --- /dev/null +++ b/driver/dll_windows.go @@ -0,0 +1,91 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved. + */ + +package driver + +import ( + "fmt" + "sync" + "sync/atomic" + "unsafe" + + "golang.org/x/sys/windows" +) + +func newLazyDLL(name string, onLoad func(d *lazyDLL)) *lazyDLL { + return &lazyDLL{Name: name, onLoad: onLoad} +} + +func (d *lazyDLL) NewProc(name string) *lazyProc { + return &lazyProc{dll: d, Name: name} +} + +type lazyProc struct { + Name string + mu sync.Mutex + dll *lazyDLL + addr uintptr +} + +func (p *lazyProc) Find() error { + if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr))) != nil { + return nil + } + p.mu.Lock() + defer p.mu.Unlock() + if p.addr != 0 { + return nil + } + + err := p.dll.Load() + if err != nil { + return fmt.Errorf("Error loading %v DLL: %w", p.dll.Name, err) + } + addr, err := p.nameToAddr() + if err != nil { + return fmt.Errorf("Error getting %v address: %w", p.Name, err) + } + + atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr)), unsafe.Pointer(addr)) + return nil +} + +func (p *lazyProc) Addr() uintptr { + err := p.Find() + if err != nil { + panic(err) + } + return p.addr +} + +// Version returns the version of the driver DLL. +func Version() string { + if modwireguard.Load() != nil { + return "unknown" + } + resInfo, err := windows.FindResource(modwireguard.Base, windows.ResourceID(1), windows.RT_VERSION) + if err != nil { + return "unknown" + } + data, err := windows.LoadResourceData(modwireguard.Base, resInfo) + if err != nil { + return "unknown" + } + + var fixedInfo *windows.VS_FIXEDFILEINFO + fixedInfoLen := uint32(unsafe.Sizeof(*fixedInfo)) + err = windows.VerQueryValue(unsafe.Pointer(&data[0]), `\`, unsafe.Pointer(&fixedInfo), &fixedInfoLen) + if err != nil { + return "unknown" + } + version := fmt.Sprintf("%d.%d", (fixedInfo.FileVersionMS>>16)&0xff, (fixedInfo.FileVersionMS>>0)&0xff) + if nextNibble := (fixedInfo.FileVersionLS >> 16) & 0xff; nextNibble != 0 { + version += fmt.Sprintf(".%d", nextNibble) + } + if nextNibble := (fixedInfo.FileVersionLS >> 0) & 0xff; nextNibble != 0 { + version += fmt.Sprintf(".%d", nextNibble) + } + return version +} |