aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/driver/driver_windows.go
diff options
context:
space:
mode:
Diffstat (limited to 'driver/driver_windows.go')
-rw-r--r--driver/driver_windows.go233
1 files changed, 233 insertions, 0 deletions
diff --git a/driver/driver_windows.go b/driver/driver_windows.go
new file mode 100644
index 00000000..68839a71
--- /dev/null
+++ b/driver/driver_windows.go
@@ -0,0 +1,233 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package driver
+
+import (
+ "errors"
+ "log"
+ "runtime"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+ "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
+)
+
+type loggerLevel int
+
+const (
+ logInfo loggerLevel = iota
+ logWarn
+ logErr
+)
+
+const (
+ PoolNameMax = 256
+ AdapterNameMax = 128
+)
+
+type Pool [PoolNameMax]uint16
+type Adapter struct {
+ handle uintptr
+ lastGetGuessSize uint32
+}
+
+var (
+ modwireguard = newLazyDLL("wireguard.dll", setupLogger)
+
+ procWireGuardCreateAdapter = modwireguard.NewProc("WireGuardCreateAdapter")
+ procWireGuardDeleteAdapter = modwireguard.NewProc("WireGuardDeleteAdapter")
+ procWireGuardDeletePoolDriver = modwireguard.NewProc("WireGuardDeletePoolDriver")
+ procWireGuardEnumAdapters = modwireguard.NewProc("WireGuardEnumAdapters")
+ procWireGuardFreeAdapter = modwireguard.NewProc("WireGuardFreeAdapter")
+ procWireGuardOpenAdapter = modwireguard.NewProc("WireGuardOpenAdapter")
+ procWireGuardGetAdapterLUID = modwireguard.NewProc("WireGuardGetAdapterLUID")
+ procWireGuardGetAdapterName = modwireguard.NewProc("WireGuardGetAdapterName")
+ procWireGuardGetRunningDriverVersion = modwireguard.NewProc("WireGuardGetRunningDriverVersion")
+ procWireGuardSetAdapterName = modwireguard.NewProc("WireGuardSetAdapterName")
+ procWireGuardSetAdapterLogging = modwireguard.NewProc("WireGuardSetAdapterLogging")
+)
+
+func setupLogger(dll *lazyDLL) {
+ syscall.Syscall(dll.NewProc("WireGuardSetLogger").Addr(), 1, windows.NewCallback(func(level loggerLevel, msg *uint16) int {
+ log.Println(windows.UTF16PtrToString(msg))
+ return 0
+ }), 0, 0)
+}
+
+var DefaultPool, _ = MakePool("WireGuard")
+
+func MakePool(poolName string) (pool *Pool, err error) {
+ poolName16, err := windows.UTF16FromString(poolName)
+ if err != nil {
+ return
+ }
+ if len(poolName16) > PoolNameMax {
+ err = errors.New("Pool name too long")
+ return
+ }
+ pool = &Pool{}
+ copy(pool[:], poolName16)
+ return
+}
+
+func (pool *Pool) String() string {
+ return windows.UTF16ToString(pool[:])
+}
+
+func freeAdapter(wireguard *Adapter) {
+ syscall.Syscall(procWireGuardFreeAdapter.Addr(), 1, wireguard.handle, 0, 0)
+}
+
+// OpenAdapter finds a WireGuard adapter by its name. This function returns the adapter if found, or
+// windows.ERROR_FILE_NOT_FOUND otherwise. If the adapter is found but not a WireGuard-class or a
+// member of the pool, this function returns windows.ERROR_ALREADY_EXISTS. The adapter must be
+// released after use.
+func (pool *Pool) OpenAdapter(ifname string) (wireguard *Adapter, err error) {
+ ifname16, err := windows.UTF16PtrFromString(ifname)
+ if err != nil {
+ return nil, err
+ }
+ r0, _, e1 := syscall.Syscall(procWireGuardOpenAdapter.Addr(), 2, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(ifname16)), 0)
+ if r0 == 0 {
+ err = e1
+ return
+ }
+ wireguard = &Adapter{handle: r0}
+ runtime.SetFinalizer(wireguard, freeAdapter)
+ return
+}
+
+// CreateAdapter creates a WireGuard adapter. ifname is the requested name of the adapter, while
+// requestedGUID is the GUID of the created network adapter, which then influences NLA generation
+// deterministically. If it is set to nil, the GUID is chosen by the system at random, and hence a
+// new NLA entry is created for each new adapter. It is called "requested" GUID because the API it
+// uses is completely undocumented, and so there could be minor interesting complications with its
+// usage. This function returns the network adapter ID and a flag if reboot is required.
+func (pool *Pool) CreateAdapter(ifname string, requestedGUID *windows.GUID) (wireguard *Adapter, rebootRequired bool, err error) {
+ var ifname16 *uint16
+ ifname16, err = windows.UTF16PtrFromString(ifname)
+ if err != nil {
+ return
+ }
+ var _p0 uint32
+ r0, _, e1 := syscall.Syscall6(procWireGuardCreateAdapter.Addr(), 4, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(ifname16)), uintptr(unsafe.Pointer(requestedGUID)), uintptr(unsafe.Pointer(&_p0)), 0, 0)
+ rebootRequired = _p0 != 0
+ if r0 == 0 {
+ err = e1
+ return
+ }
+ wireguard = &Adapter{handle: r0}
+ runtime.SetFinalizer(wireguard, freeAdapter)
+ return
+}
+
+// Delete deletes a WireGuard adapter. This function succeeds if the adapter was not found. It returns
+// a bool indicating whether a reboot is required.
+func (wireguard *Adapter) Delete() (rebootRequired bool, err error) {
+ var _p0 uint32
+ r1, _, e1 := syscall.Syscall(procWireGuardDeleteAdapter.Addr(), 2, wireguard.handle, uintptr(unsafe.Pointer(&_p0)), 0)
+ rebootRequired = _p0 != 0
+ if r1 == 0 {
+ err = e1
+ }
+ return
+}
+
+// DeleteMatchingAdapters deletes all WireGuard adapters, which match
+// given criteria, and returns which ones it deleted, whether a reboot
+// is required after, and which errors occurred during the process.
+func (pool *Pool) DeleteMatchingAdapters(matches func(adapter *Adapter) bool) (rebootRequired bool, errors []error) {
+ cb := func(handle uintptr, _ uintptr) int {
+ adapter := &Adapter{handle: handle}
+ if !matches(adapter) {
+ return 1
+ }
+ rebootRequired2, err := adapter.Delete()
+ if err != nil {
+ errors = append(errors, err)
+ return 1
+ }
+ rebootRequired = rebootRequired || rebootRequired2
+ return 1
+ }
+ r1, _, e1 := syscall.Syscall(procWireGuardEnumAdapters.Addr(), 3, uintptr(unsafe.Pointer(pool)), uintptr(windows.NewCallback(cb)), 0)
+ if r1 == 0 {
+ errors = append(errors, e1)
+ }
+ return
+}
+
+// Name returns the name of the WireGuard adapter.
+func (wireguard *Adapter) Name() (ifname string, err error) {
+ var ifname16 [AdapterNameMax]uint16
+ r1, _, e1 := syscall.Syscall(procWireGuardGetAdapterName.Addr(), 2, wireguard.handle, uintptr(unsafe.Pointer(&ifname16[0])), 0)
+ if r1 == 0 {
+ err = e1
+ return
+ }
+ ifname = windows.UTF16ToString(ifname16[:])
+ return
+}
+
+// DeleteDriver deletes all WireGuard adapters in a pool and if there are no more adapters in any other
+// pools, also removes WireGuard from the driver store, usually called by uninstallers.
+func (pool *Pool) DeleteDriver() (rebootRequired bool, err error) {
+ var _p0 uint32
+ r1, _, e1 := syscall.Syscall(procWireGuardDeletePoolDriver.Addr(), 2, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(&_p0)), 0)
+ rebootRequired = _p0 != 0
+ if r1 == 0 {
+ err = e1
+ }
+ return
+
+}
+
+// SetName sets name of the WireGuard adapter.
+func (wireguard *Adapter) SetName(ifname string) (err error) {
+ ifname16, err := windows.UTF16FromString(ifname)
+ if err != nil {
+ return err
+ }
+ r1, _, e1 := syscall.Syscall(procWireGuardSetAdapterName.Addr(), 2, wireguard.handle, uintptr(unsafe.Pointer(&ifname16[0])), 0)
+ if r1 == 0 {
+ err = e1
+ }
+ return
+}
+
+type AdapterLogState uint32
+
+const (
+ AdapterLogOff AdapterLogState = 0
+ AdapterLogOn AdapterLogState = 1
+ AdapterLogOnWithPrefix AdapterLogState = 2
+)
+
+// SetLogging enables or disables logging on the WireGuard adapter.
+func (wireguard *Adapter) SetLogging(logState AdapterLogState) (err error) {
+ r1, _, e1 := syscall.Syscall(procWireGuardSetAdapterLogging.Addr(), 2, wireguard.handle, uintptr(logState), 0)
+ if r1 == 0 {
+ err = e1
+ }
+ return
+}
+
+// RunningVersion returns the version of the running WireGuard driver.
+func RunningVersion() (version uint32, err error) {
+ r0, _, e1 := syscall.Syscall(procWireGuardGetRunningDriverVersion.Addr(), 0, 0, 0, 0)
+ version = uint32(r0)
+ if version == 0 {
+ err = e1
+ }
+ return
+}
+
+// LUID returns the LUID of the adapter.
+func (wireguard *Adapter) LUID() (luid winipcfg.LUID) {
+ syscall.Syscall(procWireGuardGetAdapterLUID.Addr(), 2, wireguard.handle, uintptr(unsafe.Pointer(&luid)), 0)
+ return
+}