diff options
Diffstat (limited to 'driver/driver_windows.go')
-rw-r--r-- | driver/driver_windows.go | 170 |
1 files changed, 170 insertions, 0 deletions
diff --git a/driver/driver_windows.go b/driver/driver_windows.go new file mode 100644 index 00000000..462c3a30 --- /dev/null +++ b/driver/driver_windows.go @@ -0,0 +1,170 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved. + */ + +package driver + +import ( + "log" + "runtime" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" +) + +type loggerLevel int + +const ( + logInfo loggerLevel = iota + logWarn + logErr +) + +const AdapterNameMax = 128 + +type Adapter struct { + handle uintptr + lastGetGuessSize uint32 +} + +var ( + modwireguard = newLazyDLL("wireguard.dll", setupLogger) + procWireGuardCreateAdapter = modwireguard.NewProc("WireGuardCreateAdapter") + procWireGuardOpenAdapter = modwireguard.NewProc("WireGuardOpenAdapter") + procWireGuardCloseAdapter = modwireguard.NewProc("WireGuardCloseAdapter") + procWireGuardDeleteDriver = modwireguard.NewProc("WireGuardDeleteDriver") + procWireGuardGetAdapterLUID = modwireguard.NewProc("WireGuardGetAdapterLUID") + procWireGuardGetRunningDriverVersion = modwireguard.NewProc("WireGuardGetRunningDriverVersion") + procWireGuardSetAdapterLogging = modwireguard.NewProc("WireGuardSetAdapterLogging") +) + +type TimestampedWriter interface { + WriteWithTimestamp(p []byte, ts int64) (n int, err error) +} + +func logMessage(level loggerLevel, timestamp uint64, msg *uint16) int { + if tw, ok := log.Default().Writer().(TimestampedWriter); ok { + tw.WriteWithTimestamp([]byte(log.Default().Prefix()+windows.UTF16PtrToString(msg)), (int64(timestamp)-116444736000000000)*100) + } else { + log.Println(windows.UTF16PtrToString(msg)) + } + return 0 +} + +func setupLogger(dll *lazyDLL) { + var callback uintptr + if runtime.GOARCH == "386" { + callback = windows.NewCallback(func(level loggerLevel, timestampLow, timestampHigh uint32, msg *uint16) int { + return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg) + }) + } else if runtime.GOARCH == "arm" { + callback = windows.NewCallback(func(level loggerLevel, _, timestampLow, timestampHigh uint32, msg *uint16) int { + return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg) + }) + } else if runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" { + callback = windows.NewCallback(logMessage) + } + syscall.SyscallN(dll.NewProc("WireGuardSetLogger").Addr(), callback) +} + +func closeAdapter(wireguard *Adapter) { + syscall.SyscallN(procWireGuardCloseAdapter.Addr(), wireguard.handle) +} + +// CreateAdapter creates a WireGuard adapter. name is the cosmetic name of the adapter. +// tunnelType represents the type of adapter and should be "WireGuard". requestedGUID is +// the GUID of the created network adapter, which then influences NLA generation +// deterministically. If it is set to nil, the GUID is chosen by the system at random, +// and hence a new NLA entry is created for each new adapter. +func CreateAdapter(name, tunnelType string, requestedGUID *windows.GUID) (wireguard *Adapter, err error) { + var name16 *uint16 + name16, err = windows.UTF16PtrFromString(name) + if err != nil { + return + } + var tunnelType16 *uint16 + tunnelType16, err = windows.UTF16PtrFromString(tunnelType) + if err != nil { + return + } + r0, _, e1 := syscall.SyscallN(procWireGuardCreateAdapter.Addr(), uintptr(unsafe.Pointer(name16)), uintptr(unsafe.Pointer(tunnelType16)), uintptr(unsafe.Pointer(requestedGUID))) + if r0 == 0 { + err = e1 + return + } + wireguard = &Adapter{handle: r0} + runtime.SetFinalizer(wireguard, closeAdapter) + return +} + +// OpenAdapter opens an existing WireGuard adapter by name. +func OpenAdapter(name string) (wireguard *Adapter, err error) { + var name16 *uint16 + name16, err = windows.UTF16PtrFromString(name) + if err != nil { + return + } + r0, _, e1 := syscall.SyscallN(procWireGuardOpenAdapter.Addr(), uintptr(unsafe.Pointer(name16))) + if r0 == 0 { + err = e1 + return + } + wireguard = &Adapter{handle: r0} + runtime.SetFinalizer(wireguard, closeAdapter) + return +} + +// Close closes a WireGuard adapter. +func (wireguard *Adapter) Close() (err error) { + runtime.SetFinalizer(wireguard, nil) + r1, _, e1 := syscall.SyscallN(procWireGuardCloseAdapter.Addr(), wireguard.handle) + if r1 == 0 { + err = e1 + } + return +} + +// Uninstall removes the driver from the system if no drivers are currently in use. +func Uninstall() (err error) { + r1, _, e1 := syscall.SyscallN(procWireGuardDeleteDriver.Addr()) + if r1 == 0 { + err = e1 + } + return +} + +type AdapterLogState uint32 + +const ( + AdapterLogOff AdapterLogState = 0 + AdapterLogOn AdapterLogState = 1 + AdapterLogOnWithPrefix AdapterLogState = 2 +) + +// SetLogging enables or disables logging on the WireGuard adapter. +func (wireguard *Adapter) SetLogging(logState AdapterLogState) (err error) { + r1, _, e1 := syscall.SyscallN(procWireGuardSetAdapterLogging.Addr(), wireguard.handle, uintptr(logState)) + if r1 == 0 { + err = e1 + } + return +} + +// RunningVersion returns the version of the loaded driver. +func RunningVersion() (version uint32, err error) { + r0, _, e1 := syscall.SyscallN(procWireGuardGetRunningDriverVersion.Addr()) + version = uint32(r0) + if version == 0 { + err = e1 + } + return +} + +// LUID returns the LUID of the adapter. +func (wireguard *Adapter) LUID() (luid winipcfg.LUID) { + syscall.SyscallN(procWireGuardGetAdapterLUID.Addr(), wireguard.handle, uintptr(unsafe.Pointer(&luid))) + return +} |