aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--Makefile19
-rw-r--r--README.md2
-rw-r--r--build.bat10
-rw-r--r--conf/dnsresolver_windows.go11
-rw-r--r--conf/parser.go72
-rw-r--r--conf/writer.go67
-rw-r--r--docs/adminregistry.md12
-rw-r--r--docs/attacksurface.md20
-rw-r--r--docs/buildrun.md2
-rw-r--r--docs/enterprise.md4
-rw-r--r--docs/netquirk.md2
-rw-r--r--driver/configuration_windows.go191
-rw-r--r--driver/dll_fromfile_windows.go54
-rw-r--r--driver/dll_fromrsrc_windows.go60
-rw-r--r--driver/dll_windows.go59
-rw-r--r--driver/driver_windows.go233
-rw-r--r--driver/memmod/memmod_windows.go622
-rw-r--r--driver/memmod/memmod_windows_32.go16
-rw-r--r--driver/memmod/memmod_windows_386.go8
-rw-r--r--driver/memmod/memmod_windows_64.go36
-rw-r--r--driver/memmod/memmod_windows_amd64.go8
-rw-r--r--driver/memmod/memmod_windows_arm.go8
-rw-r--r--driver/memmod/memmod_windows_arm64.go8
-rw-r--r--driver/memmod/syscall_windows.go392
-rw-r--r--driver/memmod/syscall_windows_32.go96
-rw-r--r--driver/memmod/syscall_windows_64.go95
-rw-r--r--embeddable-dll-service/csharp/README.md2
-rw-r--r--main.go14
-rw-r--r--manager/interfacecleanup.go51
-rw-r--r--manager/ipc_driver.go59
-rw-r--r--manager/ipc_server.go83
-rw-r--r--manager/service.go2
-rw-r--r--resources.rc1
-rw-r--r--services/errors.go12
-rw-r--r--tunnel/addressconfig.go14
-rw-r--r--tunnel/defaultroutemonitor.go15
-rw-r--r--tunnel/interfacewatcher.go45
-rw-r--r--tunnel/mtumonitor.go113
-rw-r--r--tunnel/service.go175
-rw-r--r--tunnel/winipcfg/types.go31
40 files changed, 2561 insertions, 163 deletions
diff --git a/Makefile b/Makefile
index bcb0919c..f7fb9545 100644
--- a/Makefile
+++ b/Makefile
@@ -1,4 +1,4 @@
-GOFLAGS := -tags load_wintun_from_rsrc -ldflags="-H windowsgui -s -w" -v -trimpath
+GOFLAGS := -tags load_wintun_from_rsrc,load_wgnt_from_rsrc -ldflags="-H windowsgui -s -w" -v -trimpath
export GOOS := windows
export PATH := $(CURDIR)/.deps/go/bin:$(PATH)
@@ -10,7 +10,7 @@ RCFLAGS := -DWIREGUARD_VERSION_ARRAY=$(subst $(space),$(comma),$(wordlist 1,4,$(
rwildcard=$(foreach d,$(filter-out .deps,$(wildcard $1*)),$(call rwildcard,$d/,$2) $(filter $(subst *,%,$2),$d))
SOURCE_FILES := $(call rwildcard,,*.go) .deps/go/prepared go.mod go.sum
-RESOURCE_FILES := resources.rc version/version.go manifest.xml $(patsubst %.svg,%.ico,$(wildcard ui/icon/*.svg)) .deps/wintun/prepared
+RESOURCE_FILES := resources.rc version/version.go manifest.xml $(patsubst %.svg,%.ico,$(wildcard ui/icon/*.svg)) .deps/wintun/prepared .deps/wireguard-nt/prepared
DEPLOYMENT_HOST ?= winvm
DEPLOYMENT_PATH ?= Desktop
@@ -27,6 +27,7 @@ endef
$(eval $(call download,go.tar.gz,https://golang.org/dl/go1.17rc1.linux-amd64.tar.gz,bfbd3881a01ca3826777b1c40f241acacd45b14730d373259cd673d74e15e534))
$(eval $(call download,wintun.zip,https://www.wintun.net/builds/wintun-0.13.zip,34afe7d0de1fdb781af3defc0a75fd8c97daa756279b42dd6be6a1bd8ccdc7f0))
+$(eval $(call download,wireguard-nt.zip,https://download.wireguard.com/wireguard-nt/wireguard-nt-0.1.zip,00478a0a2e24d3c0638193b063cb273c956014cb3ddd81307cbe61b07fdeb692))
.deps/go/prepared: .distfiles/go.tar.gz $(wildcard go-patches/*.patch)
mkdir -p .deps
@@ -42,20 +43,26 @@ $(eval $(call download,wintun.zip,https://www.wintun.net/builds/wintun-0.13.zip,
bsdtar -C .deps -xf .distfiles/wintun.zip
touch $@
+.deps/wireguard-nt/prepared: .distfiles/wireguard-nt.zip
+ mkdir -p .deps
+ rm -rf .deps/wireguard-nt
+ bsdtar -C .deps -xf .distfiles/wireguard-nt.zip
+ touch $@
+
%.ico: %.svg
convert -background none $< -define icon:auto-resize="256,192,128,96,64,48,40,32,24,20,16" -compress zip $@
resources_amd64.syso: $(RESOURCE_FILES)
- x86_64-w64-mingw32-windres $(RCFLAGS) -I .deps/wintun/bin/amd64 -i $< -o $@
+ x86_64-w64-mingw32-windres $(RCFLAGS) -I .deps/wintun/bin/amd64 -I .deps/wireguard-nt/bin/amd64 -i $< -o $@
resources_386.syso: $(RESOURCE_FILES)
- i686-w64-mingw32-windres $(RCFLAGS) -I .deps/wintun/bin/x86 -i $< -o $@
+ i686-w64-mingw32-windres $(RCFLAGS) -I .deps/wintun/bin/x86 -I .deps/wireguard-nt/bin/x86 -i $< -o $@
resources_arm.syso: $(RESOURCE_FILES)
- armv7-w64-mingw32-windres $(RCFLAGS) -I .deps/wintun/bin/arm -i $< -o $@
+ armv7-w64-mingw32-windres $(RCFLAGS) -I .deps/wintun/bin/arm -I .deps/wireguard-nt/bin/arm -i $< -o $@
resources_arm64.syso: $(RESOURCE_FILES)
- aarch64-w64-mingw32-windres $(RCFLAGS) -I .deps/wintun/bin/arm64 -i $< -o $@
+ aarch64-w64-mingw32-windres $(RCFLAGS) -I .deps/wintun/bin/arm64 -I .deps/wireguard-nt/bin/arm64 -i $< -o $@
amd64/wireguard.exe: export GOARCH := amd64
amd64/wireguard.exe: resources_amd64.syso $(SOURCE_FILES)
diff --git a/README.md b/README.md
index 76abf9dd..ba1dd41b 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
# [WireGuard](https://www.wireguard.com/) for Windows
-This is a fully-featured WireGuard client for Windows that uses [Wintun](https://www.wintun.net/). It is the only official and recommended way of using WireGuard on Windows.
+This is a fully-featured WireGuard client for Windows that uses [Wintun](https://www.wintun.net/) or [WireGuardNT](https://git.zx2c4.com/wireguard-nt/about/). It is the only official and recommended way of using WireGuard on Windows.
## Download &amp; Install
diff --git a/build.bat b/build.bat
index 9b468850..f1d3a9a3 100644
--- a/build.bat
+++ b/build.bat
@@ -20,10 +20,11 @@ if exist .deps\prepared goto :render
call :download imagemagick.zip https://download.wireguard.com/windows-toolchain/distfiles/ImageMagick-7.0.8-42-portable-Q16-x64.zip 584e069f56456ce7dde40220948ff9568ac810688c892c5dfb7f6db902aa05aa "convert.exe colors.xml delegates.xml" || goto :error
rem Mirror of https://sourceforge.net/projects/ezwinports/files/make-4.2.1-without-guile-w32-bin.zip
call :download make.zip https://download.wireguard.com/windows-toolchain/distfiles/make-4.2.1-without-guile-w32-bin.zip 30641be9602712be76212b99df7209f4f8f518ba764cf564262bc9d6e4047cc7 "--strip-components 1 bin" || goto :error
- call :download wireguard-tools.zip https://git.zx2c4.com/wireguard-tools/snapshot/wireguard-tools-542b7c0f2474fca14e18c23148790f76728dd46a.zip 10ac31af150220850c82b7ed7b65a0e39b60557ed366fcb351da0e0ff2700aea "--exclude wg-quick --strip-components 1" || goto :error
+ call :download wireguard-tools.zip https://git.zx2c4.com/wireguard-tools/snapshot/wireguard-tools-52597c351554ec7f39c0817d45771d1f30572f4b.zip c869f539dee0f5d4028d2d33f931c548c3aa5a22fe01ee4b28ffa01ffb33b8c9 "--exclude wg-quick --strip-components 1" || goto :error
rem Mirror of https://sourceforge.net/projects/gnuwin32/files/patch/2.5.9-7/patch-2.5.9-7-bin.zip with fixed manifest
call :download patch.zip https://download.wireguard.com/windows-toolchain/distfiles/patch-2.5.9-7-bin-fixed-manifest.zip 25977006ca9713f2662a5d0a2ed3a5a138225b8be3757035bd7da9dcf985d0a1 "--strip-components 1 bin" || goto :error
call :download wintun.zip https://www.wintun.net/builds/wintun-0.13.zip 34afe7d0de1fdb781af3defc0a75fd8c97daa756279b42dd6be6a1bd8ccdc7f0 || goto :error
+ call :download wireguard-nt.zip https://download.wireguard.com/wireguard-nt/wireguard-nt-0.1.zip 00478a0a2e24d3c0638193b063cb273c956014cb3ddd81307cbe61b07fdeb692 || goto :error
echo [+] Patching go
for %%a in ("..\go-patches\*.patch") do .\patch.exe -f -N -r- -d go -p1 --binary < "%%a" || goto :error
copy /y NUL prepared > NUL || goto :error
@@ -76,13 +77,14 @@ if exist .deps\prepared goto :render
set GOARCH=%~3
mkdir %1 >NUL 2>&1
echo [+] Assembling resources %1
- %~2-w64-mingw32-windres -I ".deps\wintun\bin\%~1" -DWIREGUARD_VERSION_ARRAY=%WIREGUARD_VERSION_ARRAY% -DWIREGUARD_VERSION_STR=%WIREGUARD_VERSION% -i resources.rc -o "resources_%~3.syso" -O coff -c 65001 || exit /b %errorlevel%
+ %~2-w64-mingw32-windres -I ".deps\wintun\bin\%~1" -I ".deps\wireguard-nt\bin\%~1" -DWIREGUARD_VERSION_ARRAY=%WIREGUARD_VERSION_ARRAY% -DWIREGUARD_VERSION_STR=%WIREGUARD_VERSION% -i resources.rc -o "resources_%~3.syso" -O coff -c 65001 || exit /b %errorlevel%
echo [+] Building program %1
- go build -tags load_wintun_from_rsrc -ldflags="-H windowsgui -s -w" -trimpath -v -o "%~1\wireguard.exe" || exit /b 1
+ go build -tags load_wintun_from_rsrc,load_wgnt_from_rsrc -ldflags="-H windowsgui -s -w" -trimpath -v -o "%~1\wireguard.exe" || exit /b 1
if not exist "%~1\wg.exe" (
echo [+] Building command line tools %1
del .deps\src\*.exe .deps\src\*.o .deps\src\wincompat\*.o .deps\src\wincompat\*.lib 2> NUL
- make --no-print-directory -C .deps\src PLATFORM=windows CC=%~2-w64-mingw32-gcc WINDRES=%~2-w64-mingw32-windres V=1 LDFLAGS=-s RUNSTATEDIR= SYSTEMDUNITDIR= -j%NUMBER_OF_PROCESSORS% || exit /b 1
+ set LDFLAGS=-s
+ make --no-print-directory -C .deps\src PLATFORM=windows CC=%~2-w64-mingw32-gcc WINDRES=%~2-w64-mingw32-windres V=1 RUNSTATEDIR= SYSTEMDUNITDIR= -j%NUMBER_OF_PROCESSORS% || exit /b 1
move /Y .deps\src\wg.exe "%~1\wg.exe" > NUL || exit /b 1
)
goto :eof
diff --git a/conf/dnsresolver_windows.go b/conf/dnsresolver_windows.go
index b17be849..b8e2dc57 100644
--- a/conf/dnsresolver_windows.go
+++ b/conf/dnsresolver_windows.go
@@ -88,3 +88,14 @@ func resolveHostnameOnce(name string) (resolvedIPString string, err error) {
err = windows.WSAHOST_NOT_FOUND
return
}
+
+func (config *Config) ResolveEndpoints() error {
+ for i := range config.Peers {
+ var err error
+ config.Peers[i].Endpoint.Host, err = resolveHostname(config.Peers[i].Endpoint.Host)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/conf/parser.go b/conf/parser.go
index 9c1295bf..2a1c0894 100644
--- a/conf/parser.go
+++ b/conf/parser.go
@@ -15,7 +15,9 @@ import (
"strings"
"time"
+ "golang.org/x/sys/windows"
"golang.org/x/text/encoding/unicode"
+ "golang.zx2c4.com/wireguard/windows/driver"
"golang.zx2c4.com/wireguard/windows/l18n"
)
@@ -521,3 +523,73 @@ func FromUAPI(reader io.Reader, existingConfig *Config) (*Config, error) {
return &conf, nil
}
+
+func FromDriverConfiguration(interfaze *driver.Interface, existingConfig *Config) *Config {
+ conf := Config{
+ Name: existingConfig.Name,
+ Interface: Interface{
+ Addresses: existingConfig.Interface.Addresses,
+ DNS: existingConfig.Interface.DNS,
+ DNSSearch: existingConfig.Interface.DNSSearch,
+ MTU: existingConfig.Interface.MTU,
+ PreUp: existingConfig.Interface.PreUp,
+ PostUp: existingConfig.Interface.PostUp,
+ PreDown: existingConfig.Interface.PreDown,
+ PostDown: existingConfig.Interface.PostDown,
+ TableOff: existingConfig.Interface.TableOff,
+ },
+ }
+ if interfaze.Flags&driver.InterfaceHasPrivateKey != 0 {
+ conf.Interface.PrivateKey = interfaze.PrivateKey
+ }
+ if interfaze.Flags&driver.InterfaceHasListenPort != 0 {
+ conf.Interface.ListenPort = interfaze.ListenPort
+ }
+ var p *driver.Peer
+ for i := uint32(0); i < interfaze.PeerCount; i++ {
+ if p == nil {
+ p = interfaze.FirstPeer()
+ } else {
+ p = p.NextPeer()
+ }
+ peer := Peer{}
+ if p.Flags&driver.PeerHasPublicKey != 0 {
+ peer.PublicKey = p.PublicKey
+ }
+ if p.Flags&driver.PeerHasPresharedKey != 0 {
+ peer.PresharedKey = p.PresharedKey
+ }
+ if p.Flags&driver.PeerHasEndpoint != 0 {
+ peer.Endpoint.Port = p.Endpoint.Port()
+ peer.Endpoint.Host = p.Endpoint.IP().String()
+ }
+ if p.Flags&driver.PeerHasPersistentKeepalive != 0 {
+ peer.PersistentKeepalive = p.PersistentKeepalive
+ }
+ peer.TxBytes = Bytes(p.TxBytes)
+ peer.RxBytes = Bytes(p.RxBytes)
+ if p.LastHandshake != 0 {
+ peer.LastHandshakeTime = HandshakeTime((p.LastHandshake - 116444736000000000) * 100)
+ }
+ var a *driver.AllowedIP
+ for j := uint32(0); j < p.AllowedIPsCount; j++ {
+ if a == nil {
+ a = p.FirstAllowedIP()
+ } else {
+ a = a.NextAllowedIP()
+ }
+ var ip net.IP
+ if a.AddressFamily == windows.AF_INET {
+ ip = a.Address[:4]
+ } else if a.AddressFamily == windows.AF_INET6 {
+ ip = a.Address[:16]
+ }
+ peer.AllowedIPs = append(peer.AllowedIPs, IPCidr{
+ IP: ip,
+ Cidr: a.Cidr,
+ })
+ }
+ conf.Peers = append(conf.Peers, peer)
+ }
+ return &conf
+}
diff --git a/conf/writer.go b/conf/writer.go
index ddf54aa5..1b2d82a4 100644
--- a/conf/writer.go
+++ b/conf/writer.go
@@ -7,7 +7,13 @@ package conf
import (
"fmt"
+ "net"
"strings"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+ "golang.zx2c4.com/wireguard/windows/driver"
+ "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
func (conf *Config) ToWgQuick() string {
@@ -85,7 +91,7 @@ func (conf *Config) ToWgQuick() string {
return output.String()
}
-func (conf *Config) ToUAPI() (uapi string, dnsErr error) {
+func (conf *Config) ToUAPI() string {
var output strings.Builder
output.WriteString(fmt.Sprintf("private_key=%s\n", conf.Interface.PrivateKey.HexString()))
@@ -105,13 +111,7 @@ func (conf *Config) ToUAPI() (uapi string, dnsErr error) {
}
if !peer.Endpoint.IsEmpty() {
- var resolvedIP string
- resolvedIP, dnsErr = resolveHostname(peer.Endpoint.Host)
- if dnsErr != nil {
- return
- }
- resolvedEndpoint := Endpoint{resolvedIP, peer.Endpoint.Port}
- output.WriteString(fmt.Sprintf("endpoint=%s\n", resolvedEndpoint.String()))
+ output.WriteString(fmt.Sprintf("endpoint=%s\n", peer.Endpoint.String()))
}
output.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.PersistentKeepalive))
@@ -123,5 +123,54 @@ func (conf *Config) ToUAPI() (uapi string, dnsErr error) {
}
}
}
- return output.String(), nil
+ return output.String()
+}
+
+func (config *Config) ToDriverConfiguration() (*driver.Interface, uint32) {
+ preallocation := unsafe.Sizeof(driver.Interface{}) + uintptr(len(config.Peers))*unsafe.Sizeof(driver.Peer{})
+ for i := range config.Peers {
+ preallocation += uintptr(len(config.Peers[i].AllowedIPs)) * unsafe.Sizeof(driver.AllowedIP{})
+ }
+ var c driver.ConfigBuilder
+ c.Preallocate(uint32(preallocation))
+ c.AppendInterface(&driver.Interface{
+ Flags: driver.InterfaceHasPrivateKey | driver.InterfaceHasListenPort,
+ ListenPort: config.Interface.ListenPort,
+ PrivateKey: config.Interface.PrivateKey,
+ PeerCount: uint32(len(config.Peers)),
+ })
+ for i := range config.Peers {
+ flags := driver.PeerHasPublicKey | driver.PeerHasPersistentKeepalive
+ if !config.Peers[i].PresharedKey.IsZero() {
+ flags |= driver.PeerHasPresharedKey
+ }
+ var endpoint winipcfg.RawSockaddrInet
+ if !config.Peers[i].Endpoint.IsEmpty() {
+ flags |= driver.PeerHasEndpoint
+ endpoint.SetIP(net.ParseIP(config.Peers[i].Endpoint.Host), config.Peers[i].Endpoint.Port)
+ }
+ c.AppendPeer(&driver.Peer{
+ Flags: flags,
+ PublicKey: config.Peers[i].PublicKey,
+ PresharedKey: config.Peers[i].PresharedKey,
+ PersistentKeepalive: config.Peers[i].PersistentKeepalive,
+ Endpoint: endpoint,
+ AllowedIPsCount: uint32(len(config.Peers[i].AllowedIPs)),
+ })
+ for j := range config.Peers[i].AllowedIPs {
+ var family winipcfg.AddressFamily
+ if config.Peers[i].AllowedIPs[j].IP.To4() != nil {
+ family = windows.AF_INET
+ } else {
+ family = windows.AF_INET6
+ }
+ a := &driver.AllowedIP{
+ AddressFamily: family,
+ Cidr: config.Peers[i].AllowedIPs[j].Cidr,
+ }
+ copy(a.Address[:], config.Peers[i].AllowedIPs[j].IP)
+ c.AppendAllowedIP(a)
+ }
+ }
+ return c.Interface()
}
diff --git a/docs/adminregistry.md b/docs/adminregistry.md
index 9196a93f..34033446 100644
--- a/docs/adminregistry.md
+++ b/docs/adminregistry.md
@@ -51,3 +51,15 @@ overlapping routes, but for now, this key provides a manual override.
```
> reg add HKLM\Software\WireGuard /v MultipleSimultaneousTunnels /t REG_DWORD /d 1 /f
```
+
+#### `HKLM\Software\WireGuard\ExperimentalKernelDriver`
+
+When this key is set to `DWORD(1)`, an experimental kernel driver from the
+[WireGuardNT](https://git.zx2c4.com/wireguard-nt/about/) project is used instead
+of the much slower wireguard-go/Wintun implementation. There are significant
+performance gains, but do note that this is _currently_ considered experimental,
+and hence is not recommended.
+
+```
+> reg add HKLM\Software\WireGuard /v ExperimentalKernelDriver /t REG_DWORD /d 1 /f
+```
diff --git a/docs/attacksurface.md b/docs/attacksurface.md
index 1700c1f2..53bcd7c6 100644
--- a/docs/attacksurface.md
+++ b/docs/attacksurface.md
@@ -12,14 +12,24 @@ Wintun is a kernel driver. It exposes:
- There are also various ndis OID calls, accessible to certain users, which hit further code.
- IOCTLs are added to the NDIS device file, and those IOCTLs are restricted to `O:SYD:P(A;;FA;;;SY)(A;;FA;;;BA)S:(ML;;NWNRNX;;;HI)`. The IOCTL allows userspace to register a pair of rings and event objects, which Wintun then locks the pages of with a double mapping and takes a reference to the event object. It parses the contents of the ring to send and receive layer 3 packets, each of which it minimally parses to determine IP family.
+### WireGuardNT
+
+WireGuardNT is a kernel driver. It exposes:
+
+ - A miniport driver to the ndis stack, meaning any process on the system that can access the network stack in a reasonable way can send and receive packets, hitting those related ndis handlers.
+ - A UDP port parsing WireGuard packets.
+ - There are also various ndis OID calls, accessible to certain users, which hit further code.
+ - A PNP and Close notifier added to the NDIS device file.
+ - IOCTLs are added to the NDIS device file, and those IOCTLs are restricted to `O:SYD:P(A;;FA;;;SY)(A;;FA;;;BA)S:(ML;;NWNRNX;;;HI)`. The IOCTL allows userspace to get and set configuration, adapter state, and read log messages from a ring buffer.
+
### Tunnel Service
-The tunnel service is a userspace service running as Local System, responsible for creating UDP sockets, creating Wintun adapters, and speaking the WireGuard protocol between the two. It exposes:
+The tunnel service is a userspace service running as Local System, responsible for either A) creating UDP sockets, creating Wintun adapters, and speaking the WireGuard protocol between the two, or B) creating WireGuardNT adapters and configuring them. It exposes:
- - A listening pipe in `\\.\pipe\ProtectedPrefix\Administrators\WireGuard\%s`, where `%s` is some basename of an already valid filename. Its DACL is set to `O:SYD:P(A;;GA;;;SY)(A;;GA;;;BA)S:(ML;;NWNRNX;;;HI)`. If the config file used by the tunnel service is not DPAPI-encrypted and it is owned by a SID other than "Local System" then an additional ACE is added giving that file owner SID access to the named pipe. This pipe gives access to private keys and allows for reconfiguration of the interface, as well as rebinding to different ports (below 1024, even). Clients who connect to the pipe run `GetSecurityInfo` to verify that it is owned by "Local System".
- - A global mutex is used for Wintun interface creation, with the same DACL as the pipe, but first CreatePrivateNamespace is called with a "Local System" SID.
- - It handles data from its two UDP sockets, accessible to the public Internet.
- - It handles data from Wintun, accessible to all users who can do anything with the network stack.
+ - In case A) a listening pipe in `\\.\pipe\ProtectedPrefix\Administrators\WireGuard\%s`, where `%s` is some basename of an already valid filename. Its DACL is set to `O:SYD:P(A;;GA;;;SY)(A;;GA;;;BA)S:(ML;;NWNRNX;;;HI)`. If the config file used by the tunnel service is not DPAPI-encrypted and it is owned by a SID other than "Local System" then an additional ACE is added giving that file owner SID access to the named pipe. This pipe gives access to private keys and allows for reconfiguration of the interface, as well as rebinding to different ports (below 1024, even). Clients who connect to the pipe run `GetSecurityInfo` to verify that it is owned by "Local System".
+ - A global mutex is used for Wintun/WireGuardNT interface creation, with the same DACL as the pipe, but first CreatePrivateNamespace is called with a "Local System" SID.
+ - In case A) it handles data from its two UDP sockets, accessible to the public Internet.
+ - In case A) it handles data from Wintun, accessible to all users who can do anything with the network stack.
- After some initial setup, it uses `AdjustTokenPrivileges` to remove all privileges, except for `SeLoadDriverPrivilege`, so that it can remove the interface when shutting down. This latter point is rather unfortunate, as `SeLoadDriverPrivilege` can be used for all sorts of interesting escalation. Future work includes forking an additional process or the like so that we can drop this from the main tunnel process.
### Manager Service
diff --git a/docs/buildrun.md b/docs/buildrun.md
index 265c4d68..209c57ac 100644
--- a/docs/buildrun.md
+++ b/docs/buildrun.md
@@ -18,7 +18,7 @@ After you've built the application, run `amd64\wireguard.exe` or `x86\wireguard.
C:\Projects\wireguard-windows> amd64\wireguard.exe
```
-Since WireGuard requires the Wintun driver to be installed, and this generally requires a valid Microsoft signature, you may benefit from first installing a release of WireGuard for Windows from the official [wireguard.com](https://www.wireguard.com/install/) builds, which bundles a Microsoft-signed Wintun, and then subsequently run your own wireguard.exe. Alternatively, you can craft your own installer using the `quickinstall.bat` script.
+Since WireGuard requires a driver to be installed, and this generally requires a valid Microsoft signature, you may benefit from first installing a release of WireGuard for Windows from the official [wireguard.com](https://www.wireguard.com/install/) builds, which bundles a Microsoft-signed driver, and then subsequently run your own wireguard.exe. Alternatively, you can craft your own installer using the `quickinstall.bat` script.
### Optional: Localizing
diff --git a/docs/enterprise.md b/docs/enterprise.md
index 37a0ca44..8e4ca59e 100644
--- a/docs/enterprise.md
+++ b/docs/enterprise.md
@@ -83,9 +83,9 @@ Or, to log the status of that command:
> wireguard /update 2> C:\path\to\update\log.txt
```
-### Wintun Adapters
+### Network Adapters
-The tunnel service creates a Wintun adapter at startup and destroys it at shutdown. It may be desirable, however, to remove all Wintun adapters created in WireGuard's pool and uninstall the driver if no other applications are using Wintun. This can be accomplished using the command:
+The tunnel service creates a network adapter at startup and destroys it at shutdown. It may be desirable, however, to remove all network adapters created in WireGuard's pool and uninstall the driver if no other applications are using our network adapters. This can be accomplished using the command:
```text
> wireguard /removealladapters
diff --git a/docs/netquirk.md b/docs/netquirk.md
index 1fe2ad72..c0aa7bf3 100644
--- a/docs/netquirk.md
+++ b/docs/netquirk.md
@@ -30,4 +30,4 @@ Windows assigns a unique GUID to each new WireGuard adapter. The application tak
### Adapter Lifetime
-WireGuard's Wintun adapter is created dynamically when a tunnel is started and destroyed when a tunnel is stopped. This means that additional filters, address families, or protocols should be bound to the adapter programmatically, possibly through use of dangerous script execution in thet configuration file or by way of automatic NDIS layer binding.
+WireGuard's network adapter is created dynamically when a tunnel is started and destroyed when a tunnel is stopped. This means that additional filters, address families, or protocols should be bound to the adapter programmatically, possibly through use of dangerous script execution in thet configuration file or by way of automatic NDIS layer binding.
diff --git a/driver/configuration_windows.go b/driver/configuration_windows.go
new file mode 100644
index 00000000..6ff67edc
--- /dev/null
+++ b/driver/configuration_windows.go
@@ -0,0 +1,191 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package driver
+
+import (
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+ "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
+)
+
+type AdapterState uint32
+
+const (
+ AdapterStateDown AdapterState = 0
+ AdapterStateUp AdapterState = 1
+)
+
+type AllowedIP struct {
+ Address [16]byte
+ AddressFamily winipcfg.AddressFamily
+ Cidr uint8
+ _ [4]byte
+}
+
+type PeerFlag uint32
+
+const (
+ PeerHasPublicKey PeerFlag = 1 << 0
+ PeerHasPresharedKey PeerFlag = 1 << 1
+ PeerHasPersistentKeepalive PeerFlag = 1 << 2
+ PeerHasEndpoint PeerFlag = 1 << 3
+ PeerReplaceAllowedIPs PeerFlag = 1 << 5
+ PeerRemove PeerFlag = 1 << 6
+ PeerUpdate PeerFlag = 1 << 7
+)
+
+type Peer struct {
+ Flags PeerFlag
+ _ uint32
+ PublicKey [32]byte
+ PresharedKey [32]byte
+ PersistentKeepalive uint16
+ _ uint16
+ Endpoint winipcfg.RawSockaddrInet
+ TxBytes uint64
+ RxBytes uint64
+ LastHandshake uint64
+ AllowedIPsCount uint32
+ _ [4]byte
+}
+
+type InterfaceFlag uint32
+
+const (
+ InterfaceHasPublicKey InterfaceFlag = 1 << 0
+ InterfaceHasPrivateKey InterfaceFlag = 1 << 1
+ InterfaceHasListenPort InterfaceFlag = 1 << 2
+ InterfaceReplacePeers InterfaceFlag = 1 << 3
+)
+
+type Interface struct {
+ Flags InterfaceFlag
+ ListenPort uint16
+ PrivateKey [32]byte
+ PublicKey [32]byte
+ PeerCount uint32
+ _ [4]byte
+}
+
+var (
+ procWireGuardSetAdapterState = modwireguard.NewProc("WireGuardSetAdapterState")
+ procWireGuardGetAdapterState = modwireguard.NewProc("WireGuardGetAdapterState")
+ procWireGuardSetConfiguration = modwireguard.NewProc("WireGuardSetConfiguration")
+ procWireGuardGetConfiguration = modwireguard.NewProc("WireGuardGetConfiguration")
+)
+
+func (wireguard *Adapter) SetAdapterState(adapterState AdapterState) (err error) {
+ r0, _, e1 := syscall.Syscall(procWireGuardSetAdapterState.Addr(), 2, wireguard.handle, uintptr(adapterState), 0)
+ if r0 == 0 {
+ err = e1
+ }
+ return
+}
+
+func (wireguard *Adapter) AdapterState() (adapterState AdapterState, err error) {
+ r0, _, e1 := syscall.Syscall(procWireGuardGetAdapterState.Addr(), 2, wireguard.handle, uintptr(unsafe.Pointer(&adapterState)), 0)
+ if r0 == 0 {
+ err = e1
+ }
+ return
+}
+
+func (wireguard *Adapter) SetConfiguration(interfaze *Interface, size uint32) (err error) {
+ r0, _, e1 := syscall.Syscall(procWireGuardSetConfiguration.Addr(), 3, wireguard.handle, uintptr(unsafe.Pointer(interfaze)), uintptr(size))
+ if r0 == 0 {
+ err = e1
+ }
+ return
+}
+
+func (wireguard *Adapter) Configuration() (interfaze *Interface, err error) {
+ size := wireguard.lastGetGuessSize
+ if size == 0 {
+ size = 512
+ }
+ for {
+ buf := make([]byte, size)
+ r0, _, e1 := syscall.Syscall(procWireGuardGetConfiguration.Addr(), 3, wireguard.handle, uintptr(unsafe.Pointer(&buf[0])), uintptr(unsafe.Pointer(&size)))
+ if r0 != 0 {
+ wireguard.lastGetGuessSize = size
+ return (*Interface)(unsafe.Pointer(&buf[0])), nil
+ }
+ if e1 != windows.ERROR_MORE_DATA {
+ return nil, e1
+ }
+ }
+}
+
+func (interfaze *Interface) FirstPeer() *Peer {
+ return (*Peer)(unsafe.Pointer(uintptr(unsafe.Pointer(interfaze)) + unsafe.Sizeof(*interfaze)))
+}
+
+func (peer *Peer) NextPeer() *Peer {
+ return (*Peer)(unsafe.Pointer(uintptr(unsafe.Pointer(peer)) + unsafe.Sizeof(*peer) + uintptr(peer.AllowedIPsCount)*unsafe.Sizeof(AllowedIP{})))
+}
+
+func (peer *Peer) FirstAllowedIP() *AllowedIP {
+ return (*AllowedIP)(unsafe.Pointer(uintptr(unsafe.Pointer(peer)) + unsafe.Sizeof(*peer)))
+}
+
+func (allowedIP *AllowedIP) NextAllowedIP() *AllowedIP {
+ return (*AllowedIP)(unsafe.Pointer(uintptr(unsafe.Pointer(allowedIP)) + unsafe.Sizeof(*allowedIP)))
+}
+
+type ConfigBuilder struct {
+ buffer []byte
+}
+
+func (builder *ConfigBuilder) Preallocate(size uint32) {
+ if builder.buffer == nil {
+ builder.buffer = make([]byte, 0, size)
+ }
+}
+
+func (builder *ConfigBuilder) AppendInterface(interfaze *Interface) {
+ var newBytes []byte
+ unsafeSlice(unsafe.Pointer(&newBytes), unsafe.Pointer(interfaze), int(unsafe.Sizeof(*interfaze)))
+ builder.buffer = append(builder.buffer, newBytes...)
+}
+
+func (builder *ConfigBuilder) AppendPeer(peer *Peer) {
+ var newBytes []byte
+ unsafeSlice(unsafe.Pointer(&newBytes), unsafe.Pointer(peer), int(unsafe.Sizeof(*peer)))
+ builder.buffer = append(builder.buffer, newBytes...)
+}
+
+func (builder *ConfigBuilder) AppendAllowedIP(allowedIP *AllowedIP) {
+ var newBytes []byte
+ unsafeSlice(unsafe.Pointer(&newBytes), unsafe.Pointer(allowedIP), int(unsafe.Sizeof(*allowedIP)))
+ builder.buffer = append(builder.buffer, newBytes...)
+}
+
+func (builder *ConfigBuilder) Interface() (*Interface, uint32) {
+ if builder.buffer == nil {
+ return nil, 0
+ }
+ return (*Interface)(unsafe.Pointer(&builder.buffer[0])), uint32(len(builder.buffer))
+}
+
+// unsafeSlice updates the slice slicePtr to be a slice
+// referencing the provided data with its length & capacity set to
+// lenCap.
+//
+// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
+// update callers to use unsafe.Slice instead of this.
+func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
+ type sliceHeader struct {
+ Data unsafe.Pointer
+ Len int
+ Cap int
+ }
+ h := (*sliceHeader)(slicePtr)
+ h.Data = data
+ h.Len = lenCap
+ h.Cap = lenCap
+}
diff --git a/driver/dll_fromfile_windows.go b/driver/dll_fromfile_windows.go
new file mode 100644
index 00000000..c956ad5f
--- /dev/null
+++ b/driver/dll_fromfile_windows.go
@@ -0,0 +1,54 @@
+// +build !load_wgnt_from_rsrc
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package driver
+
+import (
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+type lazyDLL struct {
+ Name string
+ mu sync.Mutex
+ module windows.Handle
+ onLoad func(d *lazyDLL)
+}
+
+func (d *lazyDLL) Load() error {
+ if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
+ return nil
+ }
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ if d.module != 0 {
+ return nil
+ }
+
+ const (
+ LOAD_LIBRARY_SEARCH_APPLICATION_DIR = 0x00000200
+ LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
+ )
+ module, err := windows.LoadLibraryEx(d.Name, 0, LOAD_LIBRARY_SEARCH_APPLICATION_DIR|LOAD_LIBRARY_SEARCH_SYSTEM32)
+ if err != nil {
+ return fmt.Errorf("Unable to load library: %w", err)
+ }
+
+ atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
+ if d.onLoad != nil {
+ d.onLoad(d)
+ }
+ return nil
+}
+
+func (p *lazyProc) nameToAddr() (uintptr, error) {
+ return windows.GetProcAddress(p.dll.module, p.Name)
+}
diff --git a/driver/dll_fromrsrc_windows.go b/driver/dll_fromrsrc_windows.go
new file mode 100644
index 00000000..ff74f4e7
--- /dev/null
+++ b/driver/dll_fromrsrc_windows.go
@@ -0,0 +1,60 @@
+// +build load_wgnt_from_rsrc
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package driver
+
+import (
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+ "golang.zx2c4.com/wireguard/windows/driver/memmod"
+)
+
+type lazyDLL struct {
+ Name string
+ mu sync.Mutex
+ module *memmod.Module
+ onLoad func(d *lazyDLL)
+}
+
+func (d *lazyDLL) Load() error {
+ if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
+ return nil
+ }
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ if d.module != nil {
+ return nil
+ }
+
+ const ourModule windows.Handle = 0
+ resInfo, err := windows.FindResource(ourModule, d.Name, windows.RT_RCDATA)
+ if err != nil {
+ return fmt.Errorf("Unable to find \"%v\" RCDATA resource: %w", d.Name, err)
+ }
+ data, err := windows.LoadResourceData(ourModule, resInfo)
+ if err != nil {
+ return fmt.Errorf("Unable to load resource: %w", err)
+ }
+ module, err := memmod.LoadLibrary(data)
+ if err != nil {
+ return fmt.Errorf("Unable to load library: %w", err)
+ }
+
+ atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
+ if d.onLoad != nil {
+ d.onLoad(d)
+ }
+ return nil
+}
+
+func (p *lazyProc) nameToAddr() (uintptr, error) {
+ return p.dll.module.ProcAddressByName(p.Name)
+}
diff --git a/driver/dll_windows.go b/driver/dll_windows.go
new file mode 100644
index 00000000..968a90cb
--- /dev/null
+++ b/driver/dll_windows.go
@@ -0,0 +1,59 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package driver
+
+import (
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "unsafe"
+)
+
+func newLazyDLL(name string, onLoad func(d *lazyDLL)) *lazyDLL {
+ return &lazyDLL{Name: name, onLoad: onLoad}
+}
+
+func (d *lazyDLL) NewProc(name string) *lazyProc {
+ return &lazyProc{dll: d, Name: name}
+}
+
+type lazyProc struct {
+ Name string
+ mu sync.Mutex
+ dll *lazyDLL
+ addr uintptr
+}
+
+func (p *lazyProc) Find() error {
+ if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr))) != nil {
+ return nil
+ }
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.addr != 0 {
+ return nil
+ }
+
+ err := p.dll.Load()
+ if err != nil {
+ return fmt.Errorf("Error loading %v DLL: %w", p.dll.Name, err)
+ }
+ addr, err := p.nameToAddr()
+ if err != nil {
+ return fmt.Errorf("Error getting %v address: %w", p.Name, err)
+ }
+
+ atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr)), unsafe.Pointer(addr))
+ return nil
+}
+
+func (p *lazyProc) Addr() uintptr {
+ err := p.Find()
+ if err != nil {
+ panic(err)
+ }
+ return p.addr
+}
diff --git a/driver/driver_windows.go b/driver/driver_windows.go
new file mode 100644
index 00000000..68839a71
--- /dev/null
+++ b/driver/driver_windows.go
@@ -0,0 +1,233 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package driver
+
+import (
+ "errors"
+ "log"
+ "runtime"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+ "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
+)
+
+type loggerLevel int
+
+const (
+ logInfo loggerLevel = iota
+ logWarn
+ logErr
+)
+
+const (
+ PoolNameMax = 256
+ AdapterNameMax = 128
+)
+
+type Pool [PoolNameMax]uint16
+type Adapter struct {
+ handle uintptr
+ lastGetGuessSize uint32
+}
+
+var (
+ modwireguard = newLazyDLL("wireguard.dll", setupLogger)
+
+ procWireGuardCreateAdapter = modwireguard.NewProc("WireGuardCreateAdapter")
+ procWireGuardDeleteAdapter = modwireguard.NewProc("WireGuardDeleteAdapter")
+ procWireGuardDeletePoolDriver = modwireguard.NewProc("WireGuardDeletePoolDriver")
+ procWireGuardEnumAdapters = modwireguard.NewProc("WireGuardEnumAdapters")
+ procWireGuardFreeAdapter = modwireguard.NewProc("WireGuardFreeAdapter")
+ procWireGuardOpenAdapter = modwireguard.NewProc("WireGuardOpenAdapter")
+ procWireGuardGetAdapterLUID = modwireguard.NewProc("WireGuardGetAdapterLUID")
+ procWireGuardGetAdapterName = modwireguard.NewProc("WireGuardGetAdapterName")
+ procWireGuardGetRunningDriverVersion = modwireguard.NewProc("WireGuardGetRunningDriverVersion")
+ procWireGuardSetAdapterName = modwireguard.NewProc("WireGuardSetAdapterName")
+ procWireGuardSetAdapterLogging = modwireguard.NewProc("WireGuardSetAdapterLogging")
+)
+
+func setupLogger(dll *lazyDLL) {
+ syscall.Syscall(dll.NewProc("WireGuardSetLogger").Addr(), 1, windows.NewCallback(func(level loggerLevel, msg *uint16) int {
+ log.Println(windows.UTF16PtrToString(msg))
+ return 0
+ }), 0, 0)
+}
+
+var DefaultPool, _ = MakePool("WireGuard")
+
+func MakePool(poolName string) (pool *Pool, err error) {
+ poolName16, err := windows.UTF16FromString(poolName)
+ if err != nil {
+ return
+ }
+ if len(poolName16) > PoolNameMax {
+ err = errors.New("Pool name too long")
+ return
+ }
+ pool = &Pool{}
+ copy(pool[:], poolName16)
+ return
+}
+
+func (pool *Pool) String() string {
+ return windows.UTF16ToString(pool[:])
+}
+
+func freeAdapter(wireguard *Adapter) {
+ syscall.Syscall(procWireGuardFreeAdapter.Addr(), 1, wireguard.handle, 0, 0)
+}
+
+// OpenAdapter finds a WireGuard adapter by its name. This function returns the adapter if found, or
+// windows.ERROR_FILE_NOT_FOUND otherwise. If the adapter is found but not a WireGuard-class or a
+// member of the pool, this function returns windows.ERROR_ALREADY_EXISTS. The adapter must be
+// released after use.
+func (pool *Pool) OpenAdapter(ifname string) (wireguard *Adapter, err error) {
+ ifname16, err := windows.UTF16PtrFromString(ifname)
+ if err != nil {
+ return nil, err
+ }
+ r0, _, e1 := syscall.Syscall(procWireGuardOpenAdapter.Addr(), 2, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(ifname16)), 0)
+ if r0 == 0 {
+ err = e1
+ return
+ }
+ wireguard = &Adapter{handle: r0}
+ runtime.SetFinalizer(wireguard, freeAdapter)
+ return
+}
+
+// CreateAdapter creates a WireGuard adapter. ifname is the requested name of the adapter, while
+// requestedGUID is the GUID of the created network adapter, which then influences NLA generation
+// deterministically. If it is set to nil, the GUID is chosen by the system at random, and hence a
+// new NLA entry is created for each new adapter. It is called "requested" GUID because the API it
+// uses is completely undocumented, and so there could be minor interesting complications with its
+// usage. This function returns the network adapter ID and a flag if reboot is required.
+func (pool *Pool) CreateAdapter(ifname string, requestedGUID *windows.GUID) (wireguard *Adapter, rebootRequired bool, err error) {
+ var ifname16 *uint16
+ ifname16, err = windows.UTF16PtrFromString(ifname)
+ if err != nil {
+ return
+ }
+ var _p0 uint32
+ r0, _, e1 := syscall.Syscall6(procWireGuardCreateAdapter.Addr(), 4, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(ifname16)), uintptr(unsafe.Pointer(requestedGUID)), uintptr(unsafe.Pointer(&_p0)), 0, 0)
+ rebootRequired = _p0 != 0
+ if r0 == 0 {
+ err = e1
+ return
+ }
+ wireguard = &Adapter{handle: r0}
+ runtime.SetFinalizer(wireguard, freeAdapter)
+ return
+}
+
+// Delete deletes a WireGuard adapter. This function succeeds if the adapter was not found. It returns
+// a bool indicating whether a reboot is required.
+func (wireguard *Adapter) Delete() (rebootRequired bool, err error) {
+ var _p0 uint32
+ r1, _, e1 := syscall.Syscall(procWireGuardDeleteAdapter.Addr(), 2, wireguard.handle, uintptr(unsafe.Pointer(&_p0)), 0)
+ rebootRequired = _p0 != 0
+ if r1 == 0 {
+ err = e1
+ }
+ return
+}
+
+// DeleteMatchingAdapters deletes all WireGuard adapters, which match
+// given criteria, and returns which ones it deleted, whether a reboot
+// is required after, and which errors occurred during the process.
+func (pool *Pool) DeleteMatchingAdapters(matches func(adapter *Adapter) bool) (rebootRequired bool, errors []error) {
+ cb := func(handle uintptr, _ uintptr) int {
+ adapter := &Adapter{handle: handle}
+ if !matches(adapter) {
+ return 1
+ }
+ rebootRequired2, err := adapter.Delete()
+ if err != nil {
+ errors = append(errors, err)
+ return 1
+ }
+ rebootRequired = rebootRequired || rebootRequired2
+ return 1
+ }
+ r1, _, e1 := syscall.Syscall(procWireGuardEnumAdapters.Addr(), 3, uintptr(unsafe.Pointer(pool)), uintptr(windows.NewCallback(cb)), 0)
+ if r1 == 0 {
+ errors = append(errors, e1)
+ }
+ return
+}
+
+// Name returns the name of the WireGuard adapter.
+func (wireguard *Adapter) Name() (ifname string, err error) {
+ var ifname16 [AdapterNameMax]uint16
+ r1, _, e1 := syscall.Syscall(procWireGuardGetAdapterName.Addr(), 2, wireguard.handle, uintptr(unsafe.Pointer(&ifname16[0])), 0)
+ if r1 == 0 {
+ err = e1
+ return
+ }
+ ifname = windows.UTF16ToString(ifname16[:])
+ return
+}
+
+// DeleteDriver deletes all WireGuard adapters in a pool and if there are no more adapters in any other
+// pools, also removes WireGuard from the driver store, usually called by uninstallers.
+func (pool *Pool) DeleteDriver() (rebootRequired bool, err error) {
+ var _p0 uint32
+ r1, _, e1 := syscall.Syscall(procWireGuardDeletePoolDriver.Addr(), 2, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(&_p0)), 0)
+ rebootRequired = _p0 != 0
+ if r1 == 0 {
+ err = e1
+ }
+ return
+
+}
+
+// SetName sets name of the WireGuard adapter.
+func (wireguard *Adapter) SetName(ifname string) (err error) {
+ ifname16, err := windows.UTF16FromString(ifname)
+ if err != nil {
+ return err
+ }
+ r1, _, e1 := syscall.Syscall(procWireGuardSetAdapterName.Addr(), 2, wireguard.handle, uintptr(unsafe.Pointer(&ifname16[0])), 0)
+ if r1 == 0 {
+ err = e1
+ }
+ return
+}
+
+type AdapterLogState uint32
+
+const (
+ AdapterLogOff AdapterLogState = 0
+ AdapterLogOn AdapterLogState = 1
+ AdapterLogOnWithPrefix AdapterLogState = 2
+)
+
+// SetLogging enables or disables logging on the WireGuard adapter.
+func (wireguard *Adapter) SetLogging(logState AdapterLogState) (err error) {
+ r1, _, e1 := syscall.Syscall(procWireGuardSetAdapterLogging.Addr(), 2, wireguard.handle, uintptr(logState), 0)
+ if r1 == 0 {
+ err = e1
+ }
+ return
+}
+
+// RunningVersion returns the version of the running WireGuard driver.
+func RunningVersion() (version uint32, err error) {
+ r0, _, e1 := syscall.Syscall(procWireGuardGetRunningDriverVersion.Addr(), 0, 0, 0, 0)
+ version = uint32(r0)
+ if version == 0 {
+ err = e1
+ }
+ return
+}
+
+// LUID returns the LUID of the adapter.
+func (wireguard *Adapter) LUID() (luid winipcfg.LUID) {
+ syscall.Syscall(procWireGuardGetAdapterLUID.Addr(), 2, wireguard.handle, uintptr(unsafe.Pointer(&luid)), 0)
+ return
+}
diff --git a/driver/memmod/memmod_windows.go b/driver/memmod/memmod_windows.go
new file mode 100644
index 00000000..59450e78
--- /dev/null
+++ b/driver/memmod/memmod_windows.go
@@ -0,0 +1,622 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package memmod
+
+import (
+ "errors"
+ "fmt"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+type addressList struct {
+ next *addressList
+ address uintptr
+}
+
+func (head *addressList) free() {
+ for node := head; node != nil; node = node.next {
+ windows.VirtualFree(node.address, 0, windows.MEM_RELEASE)
+ }
+}
+
+type Module struct {
+ headers *IMAGE_NT_HEADERS
+ codeBase uintptr
+ modules []windows.Handle
+ initialized bool
+ isDLL bool
+ isRelocated bool
+ nameExports map[string]uint16
+ entry uintptr
+ blockedMemory *addressList
+}
+
+func (module *Module) headerDirectory(idx int) *IMAGE_DATA_DIRECTORY {
+ return &module.headers.OptionalHeader.DataDirectory[idx]
+}
+
+func (module *Module) copySections(address uintptr, size uintptr, oldHeaders *IMAGE_NT_HEADERS) error {
+ sections := module.headers.Sections()
+ for i := range sections {
+ if sections[i].SizeOfRawData == 0 {
+ // Section doesn't contain data in the dll itself, but may define uninitialized data.
+ sectionSize := oldHeaders.OptionalHeader.SectionAlignment
+ if sectionSize == 0 {
+ continue
+ }
+ dest, err := windows.VirtualAlloc(module.codeBase+uintptr(sections[i].VirtualAddress),
+ uintptr(sectionSize),
+ windows.MEM_COMMIT,
+ windows.PAGE_READWRITE)
+ if err != nil {
+ return fmt.Errorf("Error allocating section: %w", err)
+ }
+
+ // Always use position from file to support alignments smaller than page size (allocation above will align to page size).
+ dest = module.codeBase + uintptr(sections[i].VirtualAddress)
+ // NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used.
+ sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff))
+ var dst []byte
+ unsafeSlice(unsafe.Pointer(&dst), a2p(dest), int(sectionSize))
+ for j := range dst {
+ dst[j] = 0
+ }
+ continue
+ }
+
+ if size < uintptr(sections[i].PointerToRawData+sections[i].SizeOfRawData) {
+ return errors.New("Incomplete section")
+ }
+
+ // Commit memory block and copy data from dll.
+ dest, err := windows.VirtualAlloc(module.codeBase+uintptr(sections[i].VirtualAddress),
+ uintptr(sections[i].SizeOfRawData),
+ windows.MEM_COMMIT,
+ windows.PAGE_READWRITE)
+ if err != nil {
+ return fmt.Errorf("Error allocating memory block: %w", err)
+ }
+
+ // Always use position from file to support alignments smaller than page size (allocation above will align to page size).
+ memcpy(
+ module.codeBase+uintptr(sections[i].VirtualAddress),
+ address+uintptr(sections[i].PointerToRawData),
+ uintptr(sections[i].SizeOfRawData))
+ // NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used.
+ sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff))
+ }
+
+ return nil
+}
+
+func (module *Module) realSectionSize(section *IMAGE_SECTION_HEADER) uintptr {
+ size := section.SizeOfRawData
+ if size != 0 {
+ return uintptr(size)
+ }
+ if (section.Characteristics & IMAGE_SCN_CNT_INITIALIZED_DATA) != 0 {
+ return uintptr(module.headers.OptionalHeader.SizeOfInitializedData)
+ }
+ if (section.Characteristics & IMAGE_SCN_CNT_UNINITIALIZED_DATA) != 0 {
+ return uintptr(module.headers.OptionalHeader.SizeOfUninitializedData)
+ }
+ return 0
+}
+
+type sectionFinalizeData struct {
+ address uintptr
+ alignedAddress uintptr
+ size uintptr
+ characteristics uint32
+ last bool
+}
+
+func (module *Module) finalizeSection(sectionData *sectionFinalizeData) error {
+ if sectionData.size == 0 {
+ return nil
+ }
+
+ if (sectionData.characteristics & IMAGE_SCN_MEM_DISCARDABLE) != 0 {
+ // Section is not needed any more and can safely be freed.
+ if sectionData.address == sectionData.alignedAddress &&
+ (sectionData.last ||
+ (sectionData.size%uintptr(module.headers.OptionalHeader.SectionAlignment)) == 0) {
+ // Only allowed to decommit whole pages.
+ windows.VirtualFree(sectionData.address, sectionData.size, windows.MEM_DECOMMIT)
+ }
+ return nil
+ }
+
+ // determine protection flags based on characteristics
+ var ProtectionFlags = [8]uint32{
+ windows.PAGE_NOACCESS, // not writeable, not readable, not executable
+ windows.PAGE_EXECUTE, // not writeable, not readable, executable
+ windows.PAGE_READONLY, // not writeable, readable, not executable
+ windows.PAGE_EXECUTE_READ, // not writeable, readable, executable
+ windows.PAGE_WRITECOPY, // writeable, not readable, not executable
+ windows.PAGE_EXECUTE_WRITECOPY, // writeable, not readable, executable
+ windows.PAGE_READWRITE, // writeable, readable, not executable
+ windows.PAGE_EXECUTE_READWRITE, // writeable, readable, executable
+ }
+ protect := ProtectionFlags[sectionData.characteristics>>29]
+ if (sectionData.characteristics & IMAGE_SCN_MEM_NOT_CACHED) != 0 {
+ protect |= windows.PAGE_NOCACHE
+ }
+
+ // Change memory access flags.
+ var oldProtect uint32
+ err := windows.VirtualProtect(sectionData.address, sectionData.size, protect, &oldProtect)
+ if err != nil {
+ return fmt.Errorf("Error protecting memory page: %w", err)
+ }
+
+ return nil
+}
+
+func (module *Module) finalizeSections() error {
+ sections := module.headers.Sections()
+ imageOffset := module.headers.OptionalHeader.imageOffset()
+ sectionData := sectionFinalizeData{}
+ sectionData.address = uintptr(sections[0].PhysicalAddress()) | imageOffset
+ sectionData.alignedAddress = alignDown(sectionData.address, uintptr(module.headers.OptionalHeader.SectionAlignment))
+ sectionData.size = module.realSectionSize(&sections[0])
+ sections[0].SetVirtualSize(uint32(sectionData.size))
+ sectionData.characteristics = sections[0].Characteristics
+
+ // Loop through all sections and change access flags.
+ for i := uint16(1); i < module.headers.FileHeader.NumberOfSections; i++ {
+ sectionAddress := uintptr(sections[i].PhysicalAddress()) | imageOffset
+ alignedAddress := alignDown(sectionAddress, uintptr(module.headers.OptionalHeader.SectionAlignment))
+ sectionSize := module.realSectionSize(&sections[i])
+ sections[i].SetVirtualSize(uint32(sectionSize))
+ // Combine access flags of all sections that share a page.
+ // TODO: We currently share flags of a trailing large section with the page of a first small section. This should be optimized.
+ if sectionData.alignedAddress == alignedAddress || sectionData.address+sectionData.size > alignedAddress {
+ // Section shares page with previous.
+ if (sections[i].Characteristics&IMAGE_SCN_MEM_DISCARDABLE) == 0 || (sectionData.characteristics&IMAGE_SCN_MEM_DISCARDABLE) == 0 {
+ sectionData.characteristics = (sectionData.characteristics | sections[i].Characteristics) &^ IMAGE_SCN_MEM_DISCARDABLE
+ } else {
+ sectionData.characteristics |= sections[i].Characteristics
+ }
+ sectionData.size = sectionAddress + sectionSize - sectionData.address
+ continue
+ }
+
+ err := module.finalizeSection(&sectionData)
+ if err != nil {
+ return fmt.Errorf("Error finalizing section: %w", err)
+ }
+ sectionData.address = sectionAddress
+ sectionData.alignedAddress = alignedAddress
+ sectionData.size = sectionSize
+ sectionData.characteristics = sections[i].Characteristics
+ }
+ sectionData.last = true
+ err := module.finalizeSection(&sectionData)
+ if err != nil {
+ return fmt.Errorf("Error finalizing section: %w", err)
+ }
+ return nil
+}
+
+func (module *Module) executeTLS() {
+ directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_TLS)
+ if directory.VirtualAddress == 0 {
+ return
+ }
+
+ tls := (*IMAGE_TLS_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
+ callback := tls.AddressOfCallbacks
+ if callback != 0 {
+ for {
+ f := *(*uintptr)(a2p(callback))
+ if f == 0 {
+ break
+ }
+ syscall.Syscall(f, 3, module.codeBase, uintptr(DLL_PROCESS_ATTACH), uintptr(0))
+ callback += unsafe.Sizeof(f)
+ }
+ }
+}
+
+func (module *Module) performBaseRelocation(delta uintptr) (relocated bool, err error) {
+ directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_BASERELOC)
+ if directory.Size == 0 {
+ return delta == 0, nil
+ }
+
+ relocationHdr := (*IMAGE_BASE_RELOCATION)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
+ for relocationHdr.VirtualAddress > 0 {
+ dest := module.codeBase + uintptr(relocationHdr.VirtualAddress)
+
+ var relInfos []uint16
+ unsafeSlice(
+ unsafe.Pointer(&relInfos),
+ a2p(uintptr(unsafe.Pointer(relocationHdr))+unsafe.Sizeof(*relocationHdr)),
+ int((uintptr(relocationHdr.SizeOfBlock)-unsafe.Sizeof(*relocationHdr))/unsafe.Sizeof(relInfos[0])))
+ for _, relInfo := range relInfos {
+ // The upper 4 bits define the type of relocation.
+ relType := relInfo >> 12
+ // The lower 12 bits define the offset.
+ relOffset := uintptr(relInfo & 0xfff)
+
+ switch relType {
+ case IMAGE_REL_BASED_ABSOLUTE:
+ // Skip relocation.
+
+ case IMAGE_REL_BASED_LOW:
+ *(*uint16)(a2p(dest + relOffset)) += uint16(delta & 0xffff)
+ break
+
+ case IMAGE_REL_BASED_HIGH:
+ *(*uint16)(a2p(dest + relOffset)) += uint16(uint32(delta) >> 16)
+ break
+
+ case IMAGE_REL_BASED_HIGHLOW:
+ *(*uint32)(a2p(dest + relOffset)) += uint32(delta)
+
+ case IMAGE_REL_BASED_DIR64:
+ *(*uint64)(a2p(dest + relOffset)) += uint64(delta)
+
+ case IMAGE_REL_BASED_THUMB_MOV32:
+ inst := *(*uint32)(a2p(dest + relOffset))
+ imm16 := ((inst << 1) & 0x0800) + ((inst << 12) & 0xf000) +
+ ((inst >> 20) & 0x0700) + ((inst >> 16) & 0x00ff)
+ if (inst & 0x8000fbf0) != 0x0000f240 {
+ return false, fmt.Errorf("Wrong Thumb2 instruction %08x, expected MOVW", inst)
+ }
+ imm16 += uint32(delta) & 0xffff
+ hiDelta := (uint32(delta&0xffff0000) >> 16) + ((imm16 & 0xffff0000) >> 16)
+ *(*uint32)(a2p(dest + relOffset)) = (inst & 0x8f00fbf0) + ((imm16 >> 1) & 0x0400) +
+ ((imm16 >> 12) & 0x000f) +
+ ((imm16 << 20) & 0x70000000) +
+ ((imm16 << 16) & 0xff0000)
+ if hiDelta != 0 {
+ inst = *(*uint32)(a2p(dest + relOffset + 4))
+ imm16 = ((inst << 1) & 0x0800) + ((inst << 12) & 0xf000) +
+ ((inst >> 20) & 0x0700) + ((inst >> 16) & 0x00ff)
+ if (inst & 0x8000fbf0) != 0x0000f2c0 {
+ return false, fmt.Errorf("Wrong Thumb2 instruction %08x, expected MOVT", inst)
+ }
+ imm16 += hiDelta
+ if imm16 > 0xffff {
+ return false, fmt.Errorf("Resulting immediate value won't fit: %08x", imm16)
+ }
+ *(*uint32)(a2p(dest + relOffset + 4)) = (inst & 0x8f00fbf0) +
+ ((imm16 >> 1) & 0x0400) +
+ ((imm16 >> 12) & 0x000f) +
+ ((imm16 << 20) & 0x70000000) +
+ ((imm16 << 16) & 0xff0000)
+ }
+
+ default:
+ return false, fmt.Errorf("Unsupported relocation: %v", relType)
+ }
+ }
+
+ // Advance to next relocation block.
+ relocationHdr = (*IMAGE_BASE_RELOCATION)(a2p(uintptr(unsafe.Pointer(relocationHdr)) + uintptr(relocationHdr.SizeOfBlock)))
+ }
+ return true, nil
+}
+
+func (module *Module) buildImportTable() error {
+ directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_IMPORT)
+ if directory.Size == 0 {
+ return nil
+ }
+
+ module.modules = make([]windows.Handle, 0, 16)
+ importDesc := (*IMAGE_IMPORT_DESCRIPTOR)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
+ for importDesc.Name != 0 {
+ handle, err := windows.LoadLibraryEx(windows.BytePtrToString((*byte)(a2p(module.codeBase+uintptr(importDesc.Name)))), 0, windows.LOAD_LIBRARY_SEARCH_SYSTEM32)
+ if err != nil {
+ return fmt.Errorf("Error loading module: %w", err)
+ }
+ var thunkRef, funcRef *uintptr
+ if importDesc.OriginalFirstThunk() != 0 {
+ thunkRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.OriginalFirstThunk())))
+ funcRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
+ } else {
+ // No hint table.
+ thunkRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
+ funcRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
+ }
+ for *thunkRef != 0 {
+ if IMAGE_SNAP_BY_ORDINAL(*thunkRef) {
+ *funcRef, err = windows.GetProcAddressByOrdinal(handle, IMAGE_ORDINAL(*thunkRef))
+ } else {
+ thunkData := (*IMAGE_IMPORT_BY_NAME)(a2p(module.codeBase + *thunkRef))
+ *funcRef, err = windows.GetProcAddress(handle, windows.BytePtrToString(&thunkData.Name[0]))
+ }
+ if err != nil {
+ windows.FreeLibrary(handle)
+ return fmt.Errorf("Error getting function address: %w", err)
+ }
+ thunkRef = (*uintptr)(a2p(uintptr(unsafe.Pointer(thunkRef)) + unsafe.Sizeof(*thunkRef)))
+ funcRef = (*uintptr)(a2p(uintptr(unsafe.Pointer(funcRef)) + unsafe.Sizeof(*funcRef)))
+ }
+ module.modules = append(module.modules, handle)
+ importDesc = (*IMAGE_IMPORT_DESCRIPTOR)(a2p(uintptr(unsafe.Pointer(importDesc)) + unsafe.Sizeof(*importDesc)))
+ }
+ return nil
+}
+
+func (module *Module) buildNameExports() error {
+ directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
+ if directory.Size == 0 {
+ return errors.New("No export table found")
+ }
+ exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
+ if exports.NumberOfNames == 0 || exports.NumberOfFunctions == 0 {
+ return errors.New("No functions exported")
+ }
+ if exports.NumberOfNames == 0 {
+ return errors.New("No functions exported by name")
+ }
+ var nameRefs []uint32
+ unsafeSlice(unsafe.Pointer(&nameRefs), a2p(module.codeBase+uintptr(exports.AddressOfNames)), int(exports.NumberOfNames))
+ var ordinals []uint16
+ unsafeSlice(unsafe.Pointer(&ordinals), a2p(module.codeBase+uintptr(exports.AddressOfNameOrdinals)), int(exports.NumberOfNames))
+ module.nameExports = make(map[string]uint16)
+ for i := range nameRefs {
+ nameArray := windows.BytePtrToString((*byte)(a2p(module.codeBase + uintptr(nameRefs[i]))))
+ module.nameExports[nameArray] = ordinals[i]
+ }
+ return nil
+}
+
+// LoadLibrary loads module image to memory.
+func LoadLibrary(data []byte) (module *Module, err error) {
+ addr := uintptr(unsafe.Pointer(&data[0]))
+ size := uintptr(len(data))
+ if size < unsafe.Sizeof(IMAGE_DOS_HEADER{}) {
+ return nil, errors.New("Incomplete IMAGE_DOS_HEADER")
+ }
+ dosHeader := (*IMAGE_DOS_HEADER)(a2p(addr))
+ if dosHeader.E_magic != IMAGE_DOS_SIGNATURE {
+ return nil, fmt.Errorf("Not an MS-DOS binary (provided: %x, expected: %x)", dosHeader.E_magic, IMAGE_DOS_SIGNATURE)
+ }
+ if (size < uintptr(dosHeader.E_lfanew)+unsafe.Sizeof(IMAGE_NT_HEADERS{})) {
+ return nil, errors.New("Incomplete IMAGE_NT_HEADERS")
+ }
+ oldHeader := (*IMAGE_NT_HEADERS)(a2p(addr + uintptr(dosHeader.E_lfanew)))
+ if oldHeader.Signature != IMAGE_NT_SIGNATURE {
+ return nil, fmt.Errorf("Not an NT binary (provided: %x, expected: %x)", oldHeader.Signature, IMAGE_NT_SIGNATURE)
+ }
+ if oldHeader.FileHeader.Machine != imageFileProcess {
+ return nil, fmt.Errorf("Foreign platform (provided: %x, expected: %x)", oldHeader.FileHeader.Machine, imageFileProcess)
+ }
+ if (oldHeader.OptionalHeader.SectionAlignment & 1) != 0 {
+ return nil, errors.New("Unaligned section")
+ }
+ lastSectionEnd := uintptr(0)
+ sections := oldHeader.Sections()
+ optionalSectionSize := oldHeader.OptionalHeader.SectionAlignment
+ for i := range sections {
+ var endOfSection uintptr
+ if sections[i].SizeOfRawData == 0 {
+ // Section without data in the DLL
+ endOfSection = uintptr(sections[i].VirtualAddress) + uintptr(optionalSectionSize)
+ } else {
+ endOfSection = uintptr(sections[i].VirtualAddress) + uintptr(sections[i].SizeOfRawData)
+ }
+ if endOfSection > lastSectionEnd {
+ lastSectionEnd = endOfSection
+ }
+ }
+ alignedImageSize := alignUp(uintptr(oldHeader.OptionalHeader.SizeOfImage), uintptr(oldHeader.OptionalHeader.SectionAlignment))
+ if alignedImageSize != alignUp(lastSectionEnd, uintptr(oldHeader.OptionalHeader.SectionAlignment)) {
+ return nil, errors.New("Section is not page-aligned")
+ }
+
+ module = &Module{isDLL: (oldHeader.FileHeader.Characteristics & IMAGE_FILE_DLL) != 0}
+ defer func() {
+ if err != nil {
+ module.Free()
+ module = nil
+ }
+ }()
+
+ // Reserve memory for image of library.
+ // TODO: Is it correct to commit the complete memory region at once? Calling DllEntry raises an exception if we don't.
+ module.codeBase, err = windows.VirtualAlloc(oldHeader.OptionalHeader.ImageBase,
+ alignedImageSize,
+ windows.MEM_RESERVE|windows.MEM_COMMIT,
+ windows.PAGE_READWRITE)
+ if err != nil {
+ // Try to allocate memory at arbitrary position.
+ module.codeBase, err = windows.VirtualAlloc(0,
+ alignedImageSize,
+ windows.MEM_RESERVE|windows.MEM_COMMIT,
+ windows.PAGE_READWRITE)
+ if err != nil {
+ err = fmt.Errorf("Error allocating code: %w", err)
+ return
+ }
+ }
+ err = module.check4GBBoundaries(alignedImageSize)
+ if err != nil {
+ err = fmt.Errorf("Error reallocating code: %w", err)
+ return
+ }
+
+ if size < uintptr(oldHeader.OptionalHeader.SizeOfHeaders) {
+ err = errors.New("Incomplete headers")
+ return
+ }
+ // Commit memory for headers.
+ headers, err := windows.VirtualAlloc(module.codeBase,
+ uintptr(oldHeader.OptionalHeader.SizeOfHeaders),
+ windows.MEM_COMMIT,
+ windows.PAGE_READWRITE)
+ if err != nil {
+ err = fmt.Errorf("Error allocating headers: %w", err)
+ return
+ }
+ // Copy PE header to code.
+ memcpy(headers, addr, uintptr(oldHeader.OptionalHeader.SizeOfHeaders))
+ module.headers = (*IMAGE_NT_HEADERS)(a2p(headers + uintptr(dosHeader.E_lfanew)))
+
+ // Update position.
+ module.headers.OptionalHeader.ImageBase = module.codeBase
+
+ // Copy sections from DLL file block to new memory location.
+ err = module.copySections(addr, size, oldHeader)
+ if err != nil {
+ err = fmt.Errorf("Error copying sections: %w", err)
+ return
+ }
+
+ // Adjust base address of imported data.
+ locationDelta := module.headers.OptionalHeader.ImageBase - oldHeader.OptionalHeader.ImageBase
+ if locationDelta != 0 {
+ module.isRelocated, err = module.performBaseRelocation(locationDelta)
+ if err != nil {
+ err = fmt.Errorf("Error relocating module: %w", err)
+ return
+ }
+ } else {
+ module.isRelocated = true
+ }
+
+ // Load required dlls and adjust function table of imports.
+ err = module.buildImportTable()
+ if err != nil {
+ err = fmt.Errorf("Error building import table: %w", err)
+ return
+ }
+
+ // Mark memory pages depending on section headers and release sections that are marked as "discardable".
+ err = module.finalizeSections()
+ if err != nil {
+ err = fmt.Errorf("Error finalizing sections: %w", err)
+ return
+ }
+
+ // TLS callbacks are executed BEFORE the main loading.
+ module.executeTLS()
+
+ // Get entry point of loaded module.
+ if module.headers.OptionalHeader.AddressOfEntryPoint != 0 {
+ module.entry = module.codeBase + uintptr(module.headers.OptionalHeader.AddressOfEntryPoint)
+ if module.isDLL {
+ // Notify library about attaching to process.
+ r0, _, _ := syscall.Syscall(module.entry, 3, module.codeBase, uintptr(DLL_PROCESS_ATTACH), 0)
+ successful := r0 != 0
+ if !successful {
+ err = windows.ERROR_DLL_INIT_FAILED
+ return
+ }
+ module.initialized = true
+ }
+ }
+
+ module.buildNameExports()
+ return
+}
+
+// Free releases module resources and unloads it.
+func (module *Module) Free() {
+ if module.initialized {
+ // Notify library about detaching from process.
+ syscall.Syscall(module.entry, 3, module.codeBase, uintptr(DLL_PROCESS_DETACH), 0)
+ module.initialized = false
+ }
+ if module.modules != nil {
+ // Free previously opened libraries.
+ for _, handle := range module.modules {
+ windows.FreeLibrary(handle)
+ }
+ module.modules = nil
+ }
+ if module.codeBase != 0 {
+ windows.VirtualFree(module.codeBase, 0, windows.MEM_RELEASE)
+ module.codeBase = 0
+ }
+ if module.blockedMemory != nil {
+ module.blockedMemory.free()
+ module.blockedMemory = nil
+ }
+}
+
+// ProcAddressByName returns function address by exported name.
+func (module *Module) ProcAddressByName(name string) (uintptr, error) {
+ directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
+ if directory.Size == 0 {
+ return 0, errors.New("No export table found")
+ }
+ exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
+ if module.nameExports == nil {
+ return 0, errors.New("No functions exported by name")
+ }
+ if idx, ok := module.nameExports[name]; ok {
+ if uint32(idx) > exports.NumberOfFunctions {
+ return 0, errors.New("Ordinal number too high")
+ }
+ // AddressOfFunctions contains the RVAs to the "real" functions.
+ return module.codeBase + uintptr(*(*uint32)(a2p(module.codeBase + uintptr(exports.AddressOfFunctions) + uintptr(idx)*4))), nil
+ }
+ return 0, errors.New("Function not found by name")
+}
+
+// ProcAddressByOrdinal returns function address by exported ordinal.
+func (module *Module) ProcAddressByOrdinal(ordinal uint16) (uintptr, error) {
+ directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
+ if directory.Size == 0 {
+ return 0, errors.New("No export table found")
+ }
+ exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
+ if uint32(ordinal) < exports.Base {
+ return 0, errors.New("Ordinal number too low")
+ }
+ idx := ordinal - uint16(exports.Base)
+ if uint32(idx) > exports.NumberOfFunctions {
+ return 0, errors.New("Ordinal number too high")
+ }
+ // AddressOfFunctions contains the RVAs to the "real" functions.
+ return module.codeBase + uintptr(*(*uint32)(a2p(module.codeBase + uintptr(exports.AddressOfFunctions) + uintptr(idx)*4))), nil
+}
+
+func alignDown(value, alignment uintptr) uintptr {
+ return value & ^(alignment - 1)
+}
+
+func alignUp(value, alignment uintptr) uintptr {
+ return (value + alignment - 1) & ^(alignment - 1)
+}
+
+func a2p(addr uintptr) unsafe.Pointer {
+ return unsafe.Pointer(addr)
+}
+
+func memcpy(dst, src, size uintptr) {
+ var d, s []byte
+ unsafeSlice(unsafe.Pointer(&d), a2p(dst), int(size))
+ unsafeSlice(unsafe.Pointer(&s), a2p(src), int(size))
+ copy(d, s)
+}
+
+// unsafeSlice updates the slice slicePtr to be a slice
+// referencing the provided data with its length & capacity set to
+// lenCap.
+//
+// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
+// update callers to use unsafe.Slice instead of this.
+func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
+ type sliceHeader struct {
+ Data unsafe.Pointer
+ Len int
+ Cap int
+ }
+ h := (*sliceHeader)(slicePtr)
+ h.Data = data
+ h.Len = lenCap
+ h.Cap = lenCap
+}
diff --git a/driver/memmod/memmod_windows_32.go b/driver/memmod/memmod_windows_32.go
new file mode 100644
index 00000000..ac76bdcc
--- /dev/null
+++ b/driver/memmod/memmod_windows_32.go
@@ -0,0 +1,16 @@
+// +build windows,386 windows,arm
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package memmod
+
+func (opthdr *IMAGE_OPTIONAL_HEADER) imageOffset() uintptr {
+ return 0
+}
+
+func (module *Module) check4GBBoundaries(alignedImageSize uintptr) (err error) {
+ return
+}
diff --git a/driver/memmod/memmod_windows_386.go b/driver/memmod/memmod_windows_386.go
new file mode 100644
index 00000000..475c5c52
--- /dev/null
+++ b/driver/memmod/memmod_windows_386.go
@@ -0,0 +1,8 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package memmod
+
+const imageFileProcess = IMAGE_FILE_MACHINE_I386
diff --git a/driver/memmod/memmod_windows_64.go b/driver/memmod/memmod_windows_64.go
new file mode 100644
index 00000000..a6203682
--- /dev/null
+++ b/driver/memmod/memmod_windows_64.go
@@ -0,0 +1,36 @@
+// +build windows,amd64 windows,arm64
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package memmod
+
+import (
+ "fmt"
+
+ "golang.org/x/sys/windows"
+)
+
+func (opthdr *IMAGE_OPTIONAL_HEADER) imageOffset() uintptr {
+ return uintptr(opthdr.ImageBase & 0xffffffff00000000)
+}
+
+func (module *Module) check4GBBoundaries(alignedImageSize uintptr) (err error) {
+ for (module.codeBase >> 32) < ((module.codeBase + alignedImageSize) >> 32) {
+ node := &addressList{
+ next: module.blockedMemory,
+ address: module.codeBase,
+ }
+ module.blockedMemory = node
+ module.codeBase, err = windows.VirtualAlloc(0,
+ alignedImageSize,
+ windows.MEM_RESERVE|windows.MEM_COMMIT,
+ windows.PAGE_READWRITE)
+ if err != nil {
+ return fmt.Errorf("Error allocating memory block: %w", err)
+ }
+ }
+ return
+}
diff --git a/driver/memmod/memmod_windows_amd64.go b/driver/memmod/memmod_windows_amd64.go
new file mode 100644
index 00000000..a021a633
--- /dev/null
+++ b/driver/memmod/memmod_windows_amd64.go
@@ -0,0 +1,8 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package memmod
+
+const imageFileProcess = IMAGE_FILE_MACHINE_AMD64
diff --git a/driver/memmod/memmod_windows_arm.go b/driver/memmod/memmod_windows_arm.go
new file mode 100644
index 00000000..4637a01d
--- /dev/null
+++ b/driver/memmod/memmod_windows_arm.go
@@ -0,0 +1,8 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package memmod
+
+const imageFileProcess = IMAGE_FILE_MACHINE_ARMNT
diff --git a/driver/memmod/memmod_windows_arm64.go b/driver/memmod/memmod_windows_arm64.go
new file mode 100644
index 00000000..b8f12596
--- /dev/null
+++ b/driver/memmod/memmod_windows_arm64.go
@@ -0,0 +1,8 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package memmod
+
+const imageFileProcess = IMAGE_FILE_MACHINE_ARM64
diff --git a/driver/memmod/syscall_windows.go b/driver/memmod/syscall_windows.go
new file mode 100644
index 00000000..b79be69e
--- /dev/null
+++ b/driver/memmod/syscall_windows.go
@@ -0,0 +1,392 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package memmod
+
+import "unsafe"
+
+const (
+ IMAGE_DOS_SIGNATURE = 0x5A4D // MZ
+ IMAGE_OS2_SIGNATURE = 0x454E // NE
+ IMAGE_OS2_SIGNATURE_LE = 0x454C // LE
+ IMAGE_VXD_SIGNATURE = 0x454C // LE
+ IMAGE_NT_SIGNATURE = 0x00004550 // PE00
+)
+
+// DOS .EXE header
+type IMAGE_DOS_HEADER struct {
+ E_magic uint16 // Magic number
+ E_cblp uint16 // Bytes on last page of file
+ E_cp uint16 // Pages in file
+ E_crlc uint16 // Relocations
+ E_cparhdr uint16 // Size of header in paragraphs
+ E_minalloc uint16 // Minimum extra paragraphs needed
+ E_maxalloc uint16 // Maximum extra paragraphs needed
+ E_ss uint16 // Initial (relative) SS value
+ E_sp uint16 // Initial SP value
+ E_csum uint16 // Checksum
+ E_ip uint16 // Initial IP value
+ E_cs uint16 // Initial (relative) CS value
+ E_lfarlc uint16 // File address of relocation table
+ E_ovno uint16 // Overlay number
+ E_res [4]uint16 // Reserved words
+ E_oemid uint16 // OEM identifier (for e_oeminfo)
+ E_oeminfo uint16 // OEM information; e_oemid specific
+ E_res2 [10]uint16 // Reserved words
+ E_lfanew int32 // File address of new exe header
+}
+
+// File header format
+type IMAGE_FILE_HEADER struct {
+ Machine uint16
+ NumberOfSections uint16
+ TimeDateStamp uint32
+ PointerToSymbolTable uint32
+ NumberOfSymbols uint32
+ SizeOfOptionalHeader uint16
+ Characteristics uint16
+}
+
+const (
+ IMAGE_SIZEOF_FILE_HEADER = 20
+
+ IMAGE_FILE_RELOCS_STRIPPED = 0x0001 // Relocation info stripped from file.
+ IMAGE_FILE_EXECUTABLE_IMAGE = 0x0002 // File is executable (i.e. no unresolved external references).
+ IMAGE_FILE_LINE_NUMS_STRIPPED = 0x0004 // Line nunbers stripped from file.
+ IMAGE_FILE_LOCAL_SYMS_STRIPPED = 0x0008 // Local symbols stripped from file.
+ IMAGE_FILE_AGGRESIVE_WS_TRIM = 0x0010 // Aggressively trim working set
+ IMAGE_FILE_LARGE_ADDRESS_AWARE = 0x0020 // App can handle >2gb addresses
+ IMAGE_FILE_BYTES_REVERSED_LO = 0x0080 // Bytes of machine word are reversed.
+ IMAGE_FILE_32BIT_MACHINE = 0x0100 // 32 bit word machine.
+ IMAGE_FILE_DEBUG_STRIPPED = 0x0200 // Debugging info stripped from file in .DBG file
+ IMAGE_FILE_REMOVABLE_RUN_FROM_SWAP = 0x0400 // If Image is on removable media, copy and run from the swap file.
+ IMAGE_FILE_NET_RUN_FROM_SWAP = 0x0800 // If Image is on Net, copy and run from the swap file.
+ IMAGE_FILE_SYSTEM = 0x1000 // System File.
+ IMAGE_FILE_DLL = 0x2000 // File is a DLL.
+ IMAGE_FILE_UP_SYSTEM_ONLY = 0x4000 // File should only be run on a UP machine
+ IMAGE_FILE_BYTES_REVERSED_HI = 0x8000 // Bytes of machine word are reversed.
+
+ IMAGE_FILE_MACHINE_UNKNOWN = 0
+ IMAGE_FILE_MACHINE_TARGET_HOST = 0x0001 // Useful for indicating we want to interact with the host and not a WoW guest.
+ IMAGE_FILE_MACHINE_I386 = 0x014c // Intel 386.
+ IMAGE_FILE_MACHINE_R3000 = 0x0162 // MIPS little-endian, 0x160 big-endian
+ IMAGE_FILE_MACHINE_R4000 = 0x0166 // MIPS little-endian
+ IMAGE_FILE_MACHINE_R10000 = 0x0168 // MIPS little-endian
+ IMAGE_FILE_MACHINE_WCEMIPSV2 = 0x0169 // MIPS little-endian WCE v2
+ IMAGE_FILE_MACHINE_ALPHA = 0x0184 // Alpha_AXP
+ IMAGE_FILE_MACHINE_SH3 = 0x01a2 // SH3 little-endian
+ IMAGE_FILE_MACHINE_SH3DSP = 0x01a3
+ IMAGE_FILE_MACHINE_SH3E = 0x01a4 // SH3E little-endian
+ IMAGE_FILE_MACHINE_SH4 = 0x01a6 // SH4 little-endian
+ IMAGE_FILE_MACHINE_SH5 = 0x01a8 // SH5
+ IMAGE_FILE_MACHINE_ARM = 0x01c0 // ARM Little-Endian
+ IMAGE_FILE_MACHINE_THUMB = 0x01c2 // ARM Thumb/Thumb-2 Little-Endian
+ IMAGE_FILE_MACHINE_ARMNT = 0x01c4 // ARM Thumb-2 Little-Endian
+ IMAGE_FILE_MACHINE_AM33 = 0x01d3
+ IMAGE_FILE_MACHINE_POWERPC = 0x01F0 // IBM PowerPC Little-Endian
+ IMAGE_FILE_MACHINE_POWERPCFP = 0x01f1
+ IMAGE_FILE_MACHINE_IA64 = 0x0200 // Intel 64
+ IMAGE_FILE_MACHINE_MIPS16 = 0x0266 // MIPS
+ IMAGE_FILE_MACHINE_ALPHA64 = 0x0284 // ALPHA64
+ IMAGE_FILE_MACHINE_MIPSFPU = 0x0366 // MIPS
+ IMAGE_FILE_MACHINE_MIPSFPU16 = 0x0466 // MIPS
+ IMAGE_FILE_MACHINE_AXP64 = IMAGE_FILE_MACHINE_ALPHA64
+ IMAGE_FILE_MACHINE_TRICORE = 0x0520 // Infineon
+ IMAGE_FILE_MACHINE_CEF = 0x0CEF
+ IMAGE_FILE_MACHINE_EBC = 0x0EBC // EFI Byte Code
+ IMAGE_FILE_MACHINE_AMD64 = 0x8664 // AMD64 (K8)
+ IMAGE_FILE_MACHINE_M32R = 0x9041 // M32R little-endian
+ IMAGE_FILE_MACHINE_ARM64 = 0xAA64 // ARM64 Little-Endian
+ IMAGE_FILE_MACHINE_CEE = 0xC0EE
+)
+
+// Directory format
+type IMAGE_DATA_DIRECTORY struct {
+ VirtualAddress uint32
+ Size uint32
+}
+
+const IMAGE_NUMBEROF_DIRECTORY_ENTRIES = 16
+
+type IMAGE_NT_HEADERS struct {
+ Signature uint32
+ FileHeader IMAGE_FILE_HEADER
+ OptionalHeader IMAGE_OPTIONAL_HEADER
+}
+
+func (ntheader *IMAGE_NT_HEADERS) Sections() []IMAGE_SECTION_HEADER {
+ return (*[0xffff]IMAGE_SECTION_HEADER)(unsafe.Pointer(
+ (uintptr)(unsafe.Pointer(ntheader)) +
+ unsafe.Offsetof(ntheader.OptionalHeader) +
+ uintptr(ntheader.FileHeader.SizeOfOptionalHeader)))[:ntheader.FileHeader.NumberOfSections]
+}
+
+const (
+ IMAGE_DIRECTORY_ENTRY_EXPORT = 0 // Export Directory
+ IMAGE_DIRECTORY_ENTRY_IMPORT = 1 // Import Directory
+ IMAGE_DIRECTORY_ENTRY_RESOURCE = 2 // Resource Directory
+ IMAGE_DIRECTORY_ENTRY_EXCEPTION = 3 // Exception Directory
+ IMAGE_DIRECTORY_ENTRY_SECURITY = 4 // Security Directory
+ IMAGE_DIRECTORY_ENTRY_BASERELOC = 5 // Base Relocation Table
+ IMAGE_DIRECTORY_ENTRY_DEBUG = 6 // Debug Directory
+ IMAGE_DIRECTORY_ENTRY_COPYRIGHT = 7 // (X86 usage)
+ IMAGE_DIRECTORY_ENTRY_ARCHITECTURE = 7 // Architecture Specific Data
+ IMAGE_DIRECTORY_ENTRY_GLOBALPTR = 8 // RVA of GP
+ IMAGE_DIRECTORY_ENTRY_TLS = 9 // TLS Directory
+ IMAGE_DIRECTORY_ENTRY_LOAD_CONFIG = 10 // Load Configuration Directory
+ IMAGE_DIRECTORY_ENTRY_BOUND_IMPORT = 11 // Bound Import Directory in headers
+ IMAGE_DIRECTORY_ENTRY_IAT = 12 // Import Address Table
+ IMAGE_DIRECTORY_ENTRY_DELAY_IMPORT = 13 // Delay Load Import Descriptors
+ IMAGE_DIRECTORY_ENTRY_COM_DESCRIPTOR = 14 // COM Runtime descriptor
+)
+
+const IMAGE_SIZEOF_SHORT_NAME = 8
+
+// Section header format
+type IMAGE_SECTION_HEADER struct {
+ Name [IMAGE_SIZEOF_SHORT_NAME]byte
+ physicalAddressOrVirtualSize uint32
+ VirtualAddress uint32
+ SizeOfRawData uint32
+ PointerToRawData uint32
+ PointerToRelocations uint32
+ PointerToLinenumbers uint32
+ NumberOfRelocations uint16
+ NumberOfLinenumbers uint16
+ Characteristics uint32
+}
+
+func (ishdr *IMAGE_SECTION_HEADER) PhysicalAddress() uint32 {
+ return ishdr.physicalAddressOrVirtualSize
+}
+
+func (ishdr *IMAGE_SECTION_HEADER) SetPhysicalAddress(addr uint32) {
+ ishdr.physicalAddressOrVirtualSize = addr
+}
+
+func (ishdr *IMAGE_SECTION_HEADER) VirtualSize() uint32 {
+ return ishdr.physicalAddressOrVirtualSize
+}
+
+func (ishdr *IMAGE_SECTION_HEADER) SetVirtualSize(addr uint32) {
+ ishdr.physicalAddressOrVirtualSize = addr
+}
+
+const (
+ // Dll characteristics.
+ IMAGE_DLL_CHARACTERISTICS_HIGH_ENTROPY_VA = 0x0020
+ IMAGE_DLL_CHARACTERISTICS_DYNAMIC_BASE = 0x0040
+ IMAGE_DLL_CHARACTERISTICS_FORCE_INTEGRITY = 0x0080
+ IMAGE_DLL_CHARACTERISTICS_NX_COMPAT = 0x0100
+ IMAGE_DLL_CHARACTERISTICS_NO_ISOLATION = 0x0200
+ IMAGE_DLL_CHARACTERISTICS_NO_SEH = 0x0400
+ IMAGE_DLL_CHARACTERISTICS_NO_BIND = 0x0800
+ IMAGE_DLL_CHARACTERISTICS_APPCONTAINER = 0x1000
+ IMAGE_DLL_CHARACTERISTICS_WDM_DRIVER = 0x2000
+ IMAGE_DLL_CHARACTERISTICS_GUARD_CF = 0x4000
+ IMAGE_DLL_CHARACTERISTICS_TERMINAL_SERVER_AWARE = 0x8000
+)
+
+const (
+ // Section characteristics.
+ IMAGE_SCN_TYPE_REG = 0x00000000 // Reserved.
+ IMAGE_SCN_TYPE_DSECT = 0x00000001 // Reserved.
+ IMAGE_SCN_TYPE_NOLOAD = 0x00000002 // Reserved.
+ IMAGE_SCN_TYPE_GROUP = 0x00000004 // Reserved.
+ IMAGE_SCN_TYPE_NO_PAD = 0x00000008 // Reserved.
+ IMAGE_SCN_TYPE_COPY = 0x00000010 // Reserved.
+
+ IMAGE_SCN_CNT_CODE = 0x00000020 // Section contains code.
+ IMAGE_SCN_CNT_INITIALIZED_DATA = 0x00000040 // Section contains initialized data.
+ IMAGE_SCN_CNT_UNINITIALIZED_DATA = 0x00000080 // Section contains uninitialized data.
+
+ IMAGE_SCN_LNK_OTHER = 0x00000100 // Reserved.
+ IMAGE_SCN_LNK_INFO = 0x00000200 // Section contains comments or some other type of information.
+ IMAGE_SCN_TYPE_OVER = 0x00000400 // Reserved.
+ IMAGE_SCN_LNK_REMOVE = 0x00000800 // Section contents will not become part of image.
+ IMAGE_SCN_LNK_COMDAT = 0x00001000 // Section contents comdat.
+ IMAGE_SCN_MEM_PROTECTED = 0x00004000 // Obsolete.
+ IMAGE_SCN_NO_DEFER_SPEC_EXC = 0x00004000 // Reset speculative exceptions handling bits in the TLB entries for this section.
+ IMAGE_SCN_GPREL = 0x00008000 // Section content can be accessed relative to GP
+ IMAGE_SCN_MEM_FARDATA = 0x00008000
+ IMAGE_SCN_MEM_SYSHEAP = 0x00010000 // Obsolete.
+ IMAGE_SCN_MEM_PURGEABLE = 0x00020000
+ IMAGE_SCN_MEM_16BIT = 0x00020000
+ IMAGE_SCN_MEM_LOCKED = 0x00040000
+ IMAGE_SCN_MEM_PRELOAD = 0x00080000
+
+ IMAGE_SCN_ALIGN_1BYTES = 0x00100000 //
+ IMAGE_SCN_ALIGN_2BYTES = 0x00200000 //
+ IMAGE_SCN_ALIGN_4BYTES = 0x00300000 //
+ IMAGE_SCN_ALIGN_8BYTES = 0x00400000 //
+ IMAGE_SCN_ALIGN_16BYTES = 0x00500000 // Default alignment if no others are specified.
+ IMAGE_SCN_ALIGN_32BYTES = 0x00600000 //
+ IMAGE_SCN_ALIGN_64BYTES = 0x00700000 //
+ IMAGE_SCN_ALIGN_128BYTES = 0x00800000 //
+ IMAGE_SCN_ALIGN_256BYTES = 0x00900000 //
+ IMAGE_SCN_ALIGN_512BYTES = 0x00A00000 //
+ IMAGE_SCN_ALIGN_1024BYTES = 0x00B00000 //
+ IMAGE_SCN_ALIGN_2048BYTES = 0x00C00000 //
+ IMAGE_SCN_ALIGN_4096BYTES = 0x00D00000 //
+ IMAGE_SCN_ALIGN_8192BYTES = 0x00E00000 //
+ IMAGE_SCN_ALIGN_MASK = 0x00F00000
+
+ IMAGE_SCN_LNK_NRELOC_OVFL = 0x01000000 // Section contains extended relocations.
+ IMAGE_SCN_MEM_DISCARDABLE = 0x02000000 // Section can be discarded.
+ IMAGE_SCN_MEM_NOT_CACHED = 0x04000000 // Section is not cachable.
+ IMAGE_SCN_MEM_NOT_PAGED = 0x08000000 // Section is not pageable.
+ IMAGE_SCN_MEM_SHARED = 0x10000000 // Section is shareable.
+ IMAGE_SCN_MEM_EXECUTE = 0x20000000 // Section is executable.
+ IMAGE_SCN_MEM_READ = 0x40000000 // Section is readable.
+ IMAGE_SCN_MEM_WRITE = 0x80000000 // Section is writeable.
+
+ // TLS Characteristic Flags
+ IMAGE_SCN_SCALE_INDEX = 0x00000001 // Tls index is scaled.
+)
+
+// Based relocation format
+type IMAGE_BASE_RELOCATION struct {
+ VirtualAddress uint32
+ SizeOfBlock uint32
+}
+
+const (
+ IMAGE_REL_BASED_ABSOLUTE = 0
+ IMAGE_REL_BASED_HIGH = 1
+ IMAGE_REL_BASED_LOW = 2
+ IMAGE_REL_BASED_HIGHLOW = 3
+ IMAGE_REL_BASED_HIGHADJ = 4
+ IMAGE_REL_BASED_MACHINE_SPECIFIC_5 = 5
+ IMAGE_REL_BASED_RESERVED = 6
+ IMAGE_REL_BASED_MACHINE_SPECIFIC_7 = 7
+ IMAGE_REL_BASED_MACHINE_SPECIFIC_8 = 8
+ IMAGE_REL_BASED_MACHINE_SPECIFIC_9 = 9
+ IMAGE_REL_BASED_DIR64 = 10
+
+ IMAGE_REL_BASED_IA64_IMM64 = 9
+
+ IMAGE_REL_BASED_MIPS_JMPADDR = 5
+ IMAGE_REL_BASED_MIPS_JMPADDR16 = 9
+
+ IMAGE_REL_BASED_ARM_MOV32 = 5
+ IMAGE_REL_BASED_THUMB_MOV32 = 7
+)
+
+// Export Format
+type IMAGE_EXPORT_DIRECTORY struct {
+ Characteristics uint32
+ TimeDateStamp uint32
+ MajorVersion uint16
+ MinorVersion uint16
+ Name uint32
+ Base uint32
+ NumberOfFunctions uint32
+ NumberOfNames uint32
+ AddressOfFunctions uint32 // RVA from base of image
+ AddressOfNames uint32 // RVA from base of image
+ AddressOfNameOrdinals uint32 // RVA from base of image
+}
+
+type IMAGE_IMPORT_BY_NAME struct {
+ Hint uint16
+ Name [1]byte
+}
+
+func IMAGE_ORDINAL(ordinal uintptr) uintptr {
+ return ordinal & 0xffff
+}
+
+func IMAGE_SNAP_BY_ORDINAL(ordinal uintptr) bool {
+ return (ordinal & IMAGE_ORDINAL_FLAG) != 0
+}
+
+// Thread Local Storage
+type IMAGE_TLS_DIRECTORY struct {
+ StartAddressOfRawData uintptr
+ EndAddressOfRawData uintptr
+ AddressOfIndex uintptr // PDWORD
+ AddressOfCallbacks uintptr // PIMAGE_TLS_CALLBACK *;
+ SizeOfZeroFill uint32
+ Characteristics uint32
+}
+
+type IMAGE_IMPORT_DESCRIPTOR struct {
+ characteristicsOrOriginalFirstThunk uint32 // 0 for terminating null import descriptor
+ // RVA to original unbound IAT (PIMAGE_THUNK_DATA)
+ TimeDateStamp uint32 // 0 if not bound,
+ // -1 if bound, and real date\time stamp
+ // in IMAGE_DIRECTORY_ENTRY_BOUND_IMPORT (new BIND)
+ // O.W. date/time stamp of DLL bound to (Old BIND)
+ ForwarderChain uint32 // -1 if no forwarders
+ Name uint32
+ FirstThunk uint32 // RVA to IAT (if bound this IAT has actual addresses)
+}
+
+func (imgimpdesc *IMAGE_IMPORT_DESCRIPTOR) Characteristics() uint32 {
+ return imgimpdesc.characteristicsOrOriginalFirstThunk
+}
+
+func (imgimpdesc *IMAGE_IMPORT_DESCRIPTOR) OriginalFirstThunk() uint32 {
+ return imgimpdesc.characteristicsOrOriginalFirstThunk
+}
+
+type IMAGE_DELAYLOAD_DESCRIPTOR struct {
+ Attributes uint32
+ DllNameRVA uint32
+ ModuleHandleRVA uint32
+ ImportAddressTableRVA uint32
+ ImportNameTableRVA uint32
+ BoundImportAddressTableRVA uint32
+ UnloadInformationTableRVA uint32
+ TimeDateStamp uint32
+}
+
+type IMAGE_LOAD_CONFIG_CODE_INTEGRITY struct {
+ Flags uint16
+ Catalog uint16
+ CatalogOffset uint32
+ Reserved uint32
+}
+
+const (
+ IMAGE_GUARD_CF_INSTRUMENTED = 0x00000100
+ IMAGE_GUARD_CFW_INSTRUMENTED = 0x00000200
+ IMAGE_GUARD_CF_FUNCTION_TABLE_PRESENT = 0x00000400
+ IMAGE_GUARD_SECURITY_COOKIE_UNUSED = 0x00000800
+ IMAGE_GUARD_PROTECT_DELAYLOAD_IAT = 0x00001000
+ IMAGE_GUARD_DELAYLOAD_IAT_IN_ITS_OWN_SECTION = 0x00002000
+ IMAGE_GUARD_CF_EXPORT_SUPPRESSION_INFO_PRESENT = 0x00004000
+ IMAGE_GUARD_CF_ENABLE_EXPORT_SUPPRESSION = 0x00008000
+ IMAGE_GUARD_CF_LONGJUMP_TABLE_PRESENT = 0x00010000
+ IMAGE_GUARD_RF_INSTRUMENTED = 0x00020000
+ IMAGE_GUARD_RF_ENABLE = 0x00040000
+ IMAGE_GUARD_RF_STRICT = 0x00080000
+ IMAGE_GUARD_RETPOLINE_PRESENT = 0x00100000
+ IMAGE_GUARD_EH_CONTINUATION_TABLE_PRESENT = 0x00400000
+ IMAGE_GUARD_XFG_ENABLED = 0x00800000
+ IMAGE_GUARD_CF_FUNCTION_TABLE_SIZE_MASK = 0xF0000000
+ IMAGE_GUARD_CF_FUNCTION_TABLE_SIZE_SHIFT = 28
+)
+
+const (
+ DLL_PROCESS_ATTACH = 1
+ DLL_THREAD_ATTACH = 2
+ DLL_THREAD_DETACH = 3
+ DLL_PROCESS_DETACH = 0
+)
+
+type SYSTEM_INFO struct {
+ ProcessorArchitecture uint16
+ Reserved uint16
+ PageSize uint32
+ MinimumApplicationAddress uintptr
+ MaximumApplicationAddress uintptr
+ ActiveProcessorMask uintptr
+ NumberOfProcessors uint32
+ ProcessorType uint32
+ AllocationGranularity uint32
+ ProcessorLevel uint16
+ ProcessorRevision uint16
+}
diff --git a/driver/memmod/syscall_windows_32.go b/driver/memmod/syscall_windows_32.go
new file mode 100644
index 00000000..7abbac98
--- /dev/null
+++ b/driver/memmod/syscall_windows_32.go
@@ -0,0 +1,96 @@
+// +build windows,386 windows,arm
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package memmod
+
+// Optional header format
+type IMAGE_OPTIONAL_HEADER struct {
+ Magic uint16
+ MajorLinkerVersion uint8
+ MinorLinkerVersion uint8
+ SizeOfCode uint32
+ SizeOfInitializedData uint32
+ SizeOfUninitializedData uint32
+ AddressOfEntryPoint uint32
+ BaseOfCode uint32
+ BaseOfData uint32
+ ImageBase uintptr
+ SectionAlignment uint32
+ FileAlignment uint32
+ MajorOperatingSystemVersion uint16
+ MinorOperatingSystemVersion uint16
+ MajorImageVersion uint16
+ MinorImageVersion uint16
+ MajorSubsystemVersion uint16
+ MinorSubsystemVersion uint16
+ Win32VersionValue uint32
+ SizeOfImage uint32
+ SizeOfHeaders uint32
+ CheckSum uint32
+ Subsystem uint16
+ DllCharacteristics uint16
+ SizeOfStackReserve uintptr
+ SizeOfStackCommit uintptr
+ SizeOfHeapReserve uintptr
+ SizeOfHeapCommit uintptr
+ LoaderFlags uint32
+ NumberOfRvaAndSizes uint32
+ DataDirectory [IMAGE_NUMBEROF_DIRECTORY_ENTRIES]IMAGE_DATA_DIRECTORY
+}
+
+const IMAGE_ORDINAL_FLAG uintptr = 0x80000000
+
+type IMAGE_LOAD_CONFIG_DIRECTORY struct {
+ Size uint32
+ TimeDateStamp uint32
+ MajorVersion uint16
+ MinorVersion uint16
+ GlobalFlagsClear uint32
+ GlobalFlagsSet uint32
+ CriticalSectionDefaultTimeout uint32
+ DeCommitFreeBlockThreshold uint32
+ DeCommitTotalFreeThreshold uint32
+ LockPrefixTable uint32
+ MaximumAllocationSize uint32
+ VirtualMemoryThreshold uint32
+ ProcessHeapFlags uint32
+ ProcessAffinityMask uint32
+ CSDVersion uint16
+ DependentLoadFlags uint16
+ EditList uint32
+ SecurityCookie uint32
+ SEHandlerTable uint32
+ SEHandlerCount uint32
+ GuardCFCheckFunctionPointer uint32
+ GuardCFDispatchFunctionPointer uint32
+ GuardCFFunctionTable uint32
+ GuardCFFunctionCount uint32
+ GuardFlags uint32
+ CodeIntegrity IMAGE_LOAD_CONFIG_CODE_INTEGRITY
+ GuardAddressTakenIatEntryTable uint32
+ GuardAddressTakenIatEntryCount uint32
+ GuardLongJumpTargetTable uint32
+ GuardLongJumpTargetCount uint32
+ DynamicValueRelocTable uint32
+ CHPEMetadataPointer uint32
+ GuardRFFailureRoutine uint32
+ GuardRFFailureRoutineFunctionPointer uint32
+ DynamicValueRelocTableOffset uint32
+ DynamicValueRelocTableSection uint16
+ Reserved2 uint16
+ GuardRFVerifyStackPointerFunctionPointer uint32
+ HotPatchTableOffset uint32
+ Reserved3 uint32
+ EnclaveConfigurationPointer uint32
+ VolatileMetadataPointer uint32
+ GuardEHContinuationTable uint32
+ GuardEHContinuationCount uint32
+ GuardXFGCheckFunctionPointer uint32
+ GuardXFGDispatchFunctionPointer uint32
+ GuardXFGTableDispatchFunctionPointer uint32
+ CastGuardOsDeterminedFailureMode uint32
+}
diff --git a/driver/memmod/syscall_windows_64.go b/driver/memmod/syscall_windows_64.go
new file mode 100644
index 00000000..10c65332
--- /dev/null
+++ b/driver/memmod/syscall_windows_64.go
@@ -0,0 +1,95 @@
+// +build windows,amd64 windows,arm64
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package memmod
+
+// Optional header format
+type IMAGE_OPTIONAL_HEADER struct {
+ Magic uint16
+ MajorLinkerVersion uint8
+ MinorLinkerVersion uint8
+ SizeOfCode uint32
+ SizeOfInitializedData uint32
+ SizeOfUninitializedData uint32
+ AddressOfEntryPoint uint32
+ BaseOfCode uint32
+ ImageBase uintptr
+ SectionAlignment uint32
+ FileAlignment uint32
+ MajorOperatingSystemVersion uint16
+ MinorOperatingSystemVersion uint16
+ MajorImageVersion uint16
+ MinorImageVersion uint16
+ MajorSubsystemVersion uint16
+ MinorSubsystemVersion uint16
+ Win32VersionValue uint32
+ SizeOfImage uint32
+ SizeOfHeaders uint32
+ CheckSum uint32
+ Subsystem uint16
+ DllCharacteristics uint16
+ SizeOfStackReserve uintptr
+ SizeOfStackCommit uintptr
+ SizeOfHeapReserve uintptr
+ SizeOfHeapCommit uintptr
+ LoaderFlags uint32
+ NumberOfRvaAndSizes uint32
+ DataDirectory [IMAGE_NUMBEROF_DIRECTORY_ENTRIES]IMAGE_DATA_DIRECTORY
+}
+
+const IMAGE_ORDINAL_FLAG uintptr = 0x8000000000000000
+
+type IMAGE_LOAD_CONFIG_DIRECTORY struct {
+ Size uint32
+ TimeDateStamp uint32
+ MajorVersion uint16
+ MinorVersion uint16
+ GlobalFlagsClear uint32
+ GlobalFlagsSet uint32
+ CriticalSectionDefaultTimeout uint32
+ DeCommitFreeBlockThreshold uint64
+ DeCommitTotalFreeThreshold uint64
+ LockPrefixTable uint64
+ MaximumAllocationSize uint64
+ VirtualMemoryThreshold uint64
+ ProcessAffinityMask uint64
+ ProcessHeapFlags uint32
+ CSDVersion uint16
+ DependentLoadFlags uint16
+ EditList uint64
+ SecurityCookie uint64
+ SEHandlerTable uint64
+ SEHandlerCount uint64
+ GuardCFCheckFunctionPointer uint64
+ GuardCFDispatchFunctionPointer uint64
+ GuardCFFunctionTable uint64
+ GuardCFFunctionCount uint64
+ GuardFlags uint32
+ CodeIntegrity IMAGE_LOAD_CONFIG_CODE_INTEGRITY
+ GuardAddressTakenIatEntryTable uint64
+ GuardAddressTakenIatEntryCount uint64
+ GuardLongJumpTargetTable uint64
+ GuardLongJumpTargetCount uint64
+ DynamicValueRelocTable uint64
+ CHPEMetadataPointer uint64
+ GuardRFFailureRoutine uint64
+ GuardRFFailureRoutineFunctionPointer uint64
+ DynamicValueRelocTableOffset uint32
+ DynamicValueRelocTableSection uint16
+ Reserved2 uint16
+ GuardRFVerifyStackPointerFunctionPointer uint64
+ HotPatchTableOffset uint32
+ Reserved3 uint32
+ EnclaveConfigurationPointer uint64
+ VolatileMetadataPointer uint64
+ GuardEHContinuationTable uint64
+ GuardEHContinuationCount uint64
+ GuardXFGCheckFunctionPointer uint64
+ GuardXFGDispatchFunctionPointer uint64
+ GuardXFGTableDispatchFunctionPointer uint64
+ CastGuardOsDeterminedFailureMode uint64
+}
diff --git a/embeddable-dll-service/csharp/README.md b/embeddable-dll-service/csharp/README.md
index 071d15f6..493dcb5d 100644
--- a/embeddable-dll-service/csharp/README.md
+++ b/embeddable-dll-service/csharp/README.md
@@ -12,4 +12,4 @@ The code in this repository can be built in Visual Studio 2019 by opening the .s
> .\build.bat
```
-In addition, `tunnel.dll` requires `wintun.dll`, which can be downloaded from [wintun.net](https://www.wintun.net).
+In addition, `tunnel.dll` requires `wintun.dll`, which can be downloaded from [wintun.net](https://www.wintun.net), and `wireguard.dll`, which can be downloaded from [the wireguard-nt download server](https://download.wireguard.com/wireguard-nt/).
diff --git a/main.go b/main.go
index 7aa00643..6abdd11f 100644
--- a/main.go
+++ b/main.go
@@ -19,6 +19,8 @@ import (
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/tun"
+ "golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/driver"
"golang.zx2c4.com/wireguard/windows/elevate"
"golang.zx2c4.com/wireguard/windows/l18n"
"golang.zx2c4.com/wireguard/windows/manager"
@@ -313,10 +315,18 @@ func main() {
if len(os.Args) != 2 {
usage()
}
- rebootRequired, err := tun.WintunPool.DeleteDriver()
+ var rebootRequiredDriver, rebootRequiredWintun bool
+ var err error
+ if conf.AdminBool("ExperimentalKernelDriver") {
+ rebootRequiredDriver, err = driver.DefaultPool.DeleteDriver()
+ if err != nil {
+ fatal(err)
+ }
+ }
+ rebootRequiredWintun, err = tun.WintunPool.DeleteDriver()
if err != nil {
fatal(err)
- } else if rebootRequired {
+ } else if rebootRequiredWintun || rebootRequiredDriver {
log.Println("A reboot may be required")
}
return
diff --git a/manager/interfacecleanup.go b/manager/interfacecleanup.go
index c270f4ab..32976b79 100644
--- a/manager/interfacecleanup.go
+++ b/manager/interfacecleanup.go
@@ -11,33 +11,66 @@ import (
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/mgr"
- "golang.zx2c4.com/wireguard/tun/wintun"
-
"golang.zx2c4.com/wireguard/tun"
+ "golang.zx2c4.com/wireguard/tun/wintun"
+ "golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/driver"
"golang.zx2c4.com/wireguard/windows/services"
)
-func cleanupStaleWintunInterfaces() {
+func cleanupStaleNetworkInterfaces() {
m, err := mgr.Connect()
if err != nil {
return
}
defer m.Disconnect()
- tun.WintunPool.DeleteMatchingAdapters(func(wintun *wintun.Adapter) bool {
- interfaceName, err := wintun.Name()
+ if conf.AdminBool("ExperimentalKernelDriver") {
+ driver.DefaultPool.DeleteMatchingAdapters(func(wintun *driver.Adapter) bool {
+ interfaceName, err := wintun.Name()
+ if err != nil {
+ log.Printf("Removing network adapter because determining interface name failed: %v", err)
+ return true
+ }
+ serviceName, err := services.ServiceNameOfTunnel(interfaceName)
+ if err != nil {
+ log.Printf("Removing network adapter ‘%s’ because determining tunnel service name failed: %v", interfaceName, err)
+ return true
+ }
+ service, err := m.OpenService(serviceName)
+ if err == windows.ERROR_SERVICE_DOES_NOT_EXIST {
+ log.Printf("Removing network adapter ‘%s’ because no service for it exists", interfaceName)
+ return true
+ } else if err != nil {
+ return false
+ }
+ defer service.Close()
+ status, err := service.Query()
+ if err != nil {
+ return false
+ }
+ if status.State == svc.Stopped {
+ log.Printf("Removing network adapter ‘%s’ because its service is stopped", interfaceName)
+ return true
+ }
+ return false
+ })
+ }
+
+ tun.WintunPool.DeleteMatchingAdapters(func(adapter *wintun.Adapter) bool {
+ interfaceName, err := adapter.Name()
if err != nil {
- log.Printf("Removing Wintun interface because determining interface name failed: %v", err)
+ log.Printf("Removing network adapter because determining interface name failed: %v", err)
return true
}
serviceName, err := services.ServiceNameOfTunnel(interfaceName)
if err != nil {
- log.Printf("Removing Wintun interface ‘%s’ because determining tunnel service name failed: %v", interfaceName, err)
+ log.Printf("Removing network adapter ‘%s’ because determining tunnel service name failed: %v", interfaceName, err)
return true
}
service, err := m.OpenService(serviceName)
if err == windows.ERROR_SERVICE_DOES_NOT_EXIST {
- log.Printf("Removing Wintun interface ‘%s’ because no service for it exists", interfaceName)
+ log.Printf("Removing network adapter ‘%s’ because no service for it exists", interfaceName)
return true
} else if err != nil {
return false
@@ -48,7 +81,7 @@ func cleanupStaleWintunInterfaces() {
return false
}
if status.State == svc.Stopped {
- log.Printf("Removing Wintun interface ‘%s’ because its service is stopped", interfaceName)
+ log.Printf("Removing network adapter ‘%s’ because its service is stopped", interfaceName)
return true
}
return false
diff --git a/manager/ipc_driver.go b/manager/ipc_driver.go
new file mode 100644
index 00000000..dae213de
--- /dev/null
+++ b/manager/ipc_driver.go
@@ -0,0 +1,59 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package manager
+
+import (
+ "sync"
+
+ "golang.zx2c4.com/wireguard/windows/driver"
+)
+
+type lockedDriverAdapter struct {
+ *driver.Adapter
+ sync.Mutex
+}
+
+var driverAdapters = make(map[string]*lockedDriverAdapter)
+var driverAdaptersLock sync.RWMutex
+
+func findDriverAdapter(tunnelName string) (*lockedDriverAdapter, error) {
+ driverAdaptersLock.RLock()
+ driverAdapter, ok := driverAdapters[tunnelName]
+ if ok {
+ driverAdapter.Lock()
+ driverAdaptersLock.RUnlock()
+ return driverAdapter, nil
+ }
+ driverAdaptersLock.RUnlock()
+ driverAdaptersLock.Lock()
+ defer driverAdaptersLock.Unlock()
+ driverAdapter, ok = driverAdapters[tunnelName]
+ if ok {
+ driverAdapter.Lock()
+ return driverAdapter, nil
+ }
+ driverAdapter = &lockedDriverAdapter{}
+ var err error
+ driverAdapter.Adapter, err = driver.DefaultPool.OpenAdapter(tunnelName)
+ if err != nil {
+ return nil, err
+ }
+ driverAdapters[tunnelName] = driverAdapter
+ driverAdapter.Lock()
+ return driverAdapter, nil
+}
+
+func releaseDriverAdapter(tunnelName string) {
+ driverAdaptersLock.Lock()
+ defer driverAdaptersLock.Unlock()
+ driverAdapter, ok := driverAdapters[tunnelName]
+ if !ok {
+ return
+ }
+ driverAdapter.Lock()
+ delete(driverAdapters, tunnelName)
+ driverAdapter.Unlock()
+}
diff --git a/manager/ipc_server.go b/manager/ipc_server.go
index 9a36e60f..b9b1eb8c 100644
--- a/manager/ipc_server.go
+++ b/manager/ipc_server.go
@@ -47,41 +47,64 @@ func (s *ManagerService) StoredConfig(tunnelName string) (*conf.Config, error) {
}
func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) {
- storedConfig, err := conf.LoadFromName(tunnelName)
- if err != nil {
- return nil, err
- }
- pipe, err := connectTunnelServicePipe(tunnelName)
- if err != nil {
- return nil, err
- }
- pipe.SetDeadline(time.Now().Add(time.Second * 2))
- _, err = pipe.Write([]byte("get=1\n\n"))
- if err == windows.ERROR_NO_DATA {
- log.Println("IPC pipe closed unexpectedly, so reopening")
- pipe.Unlock()
- disconnectTunnelServicePipe(tunnelName)
- pipe, err = connectTunnelServicePipe(tunnelName)
+ if conf.AdminBool("ExperimentalKernelDriver") {
+ storedConfig, err := conf.LoadFromName(tunnelName)
+ if err != nil {
+ return nil, err
+ }
+ driverAdapter, err := findDriverAdapter(tunnelName)
+ if err != nil {
+ return nil, err
+ }
+ runtimeConfig, err := driverAdapter.Configuration()
+ if err != nil {
+ driverAdapter.Unlock()
+ releaseDriverAdapter(tunnelName)
+ return nil, err
+ }
+ conf := conf.FromDriverConfiguration(runtimeConfig, storedConfig)
+ driverAdapter.Unlock()
+ if s.elevatedToken == 0 {
+ conf.Redact()
+ }
+ return conf, nil
+ } else {
+ storedConfig, err := conf.LoadFromName(tunnelName)
+ if err != nil {
+ return nil, err
+ }
+ pipe, err := connectTunnelServicePipe(tunnelName)
if err != nil {
return nil, err
}
pipe.SetDeadline(time.Now().Add(time.Second * 2))
_, err = pipe.Write([]byte("get=1\n\n"))
- }
- if err != nil {
+ if err == windows.ERROR_NO_DATA {
+ log.Println("IPC pipe closed unexpectedly, so reopening")
+ pipe.Unlock()
+ disconnectTunnelServicePipe(tunnelName)
+ pipe, err = connectTunnelServicePipe(tunnelName)
+ if err != nil {
+ return nil, err
+ }
+ pipe.SetDeadline(time.Now().Add(time.Second * 2))
+ _, err = pipe.Write([]byte("get=1\n\n"))
+ }
+ if err != nil {
+ pipe.Unlock()
+ disconnectTunnelServicePipe(tunnelName)
+ return nil, err
+ }
+ conf, err := conf.FromUAPI(pipe, storedConfig)
pipe.Unlock()
- disconnectTunnelServicePipe(tunnelName)
- return nil, err
- }
- conf, err := conf.FromUAPI(pipe, storedConfig)
- pipe.Unlock()
- if err != nil {
- return nil, err
- }
- if s.elevatedToken == 0 {
- conf.Redact()
+ if err != nil {
+ return nil, err
+ }
+ if s.elevatedToken == 0 {
+ conf.Redact()
+ }
+ return conf, nil
}
- return conf, nil
}
func (s *ManagerService) Start(tunnelName string) error {
@@ -116,7 +139,7 @@ func (s *ManagerService) Start(tunnelName string) error {
}
}()
}
- time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces)
+ time.AfterFunc(time.Second*10, cleanupStaleNetworkInterfaces)
// After that process is started -- it's somewhat asynchronous -- we install the new one.
c, err := conf.LoadFromName(tunnelName)
@@ -131,7 +154,7 @@ func (s *ManagerService) Start(tunnelName string) error {
}
func (s *ManagerService) Stop(tunnelName string) error {
- time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces)
+ time.AfterFunc(time.Second*10, cleanupStaleNetworkInterfaces)
err := UninstallTunnel(tunnelName)
if err == windows.ERROR_SERVICE_DOES_NOT_EXIST {
diff --git a/manager/service.go b/manager/service.go
index da6ff497..9555d386 100644
--- a/manager/service.go
+++ b/manager/service.go
@@ -254,7 +254,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
}()
}
- time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces)
+ time.AfterFunc(time.Second*10, cleanupStaleNetworkInterfaces)
go checkForUpdates()
var sessionsPointer *windows.WTS_SESSION_INFO
diff --git a/resources.rc b/resources.rc
index 464492c0..aa4bb07d 100644
--- a/resources.rc
+++ b/resources.rc
@@ -15,6 +15,7 @@ CREATEPROCESS_MANIFEST_RESOURCE_ID RT_MANIFEST manifest.xml
7 ICON ui/icon/wireguard.ico
8 ICON ui/icon/dot.ico
wintun.dll RCDATA wintun.dll
+wireguard.dll RCDATA wireguard.dll
#define VERSIONINFO_TEMPLATE(block_id, lang_id, codepage_id, file_desc, comments) \
VS_VERSION_INFO VERSIONINFO \
diff --git a/services/errors.go b/services/errors.go
index 674c083b..8986ebcb 100644
--- a/services/errors.go
+++ b/services/errors.go
@@ -17,12 +17,14 @@ const (
ErrorSuccess Error = iota
ErrorRingloggerOpen
ErrorLoadConfiguration
- ErrorCreateWintun
+ ErrorCreateNetworkAdapter
ErrorUAPIListen
ErrorDNSLookup
ErrorFirewall
ErrorDeviceSetConfig
+ ErrorDeviceBringUp
ErrorBindSocketsToDefaultRoutes
+ ErrorMonitorMTUChanges
ErrorSetNetConfig
ErrorDetermineExecutablePath
ErrorTrackTunnels
@@ -42,8 +44,8 @@ func (e Error) Error() string {
return "Unable to determine path of running executable"
case ErrorLoadConfiguration:
return "Unable to load configuration from path"
- case ErrorCreateWintun:
- return "Unable to create Wintun interface"
+ case ErrorCreateNetworkAdapter:
+ return "Unable to create network adapter"
case ErrorUAPIListen:
return "Unable to listen on named pipe"
case ErrorDNSLookup:
@@ -52,8 +54,12 @@ func (e Error) Error() string {
return "Unable to enable firewall rules"
case ErrorDeviceSetConfig:
return "Unable to set device configuration"
+ case ErrorDeviceBringUp:
+ return "Unable to bring up adapter"
case ErrorBindSocketsToDefaultRoutes:
return "Unable to bind sockets to default route"
+ case ErrorMonitorMTUChanges:
+ return "Unable to monitor default route MTU for changes"
case ErrorSetNetConfig:
return "Unable to set interface addresses, routes, dns, and/or interface settings"
case ErrorTrackTunnels:
diff --git a/tunnel/addressconfig.go b/tunnel/addressconfig.go
index 0dec95d0..fba7d770 100644
--- a/tunnel/addressconfig.go
+++ b/tunnel/addressconfig.go
@@ -12,8 +12,6 @@ import (
"sort"
"golang.org/x/sys/windows"
- "golang.zx2c4.com/wireguard/tun"
-
"golang.zx2c4.com/wireguard/windows/conf"
"golang.zx2c4.com/wireguard/windows/tunnel/firewall"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
@@ -57,9 +55,7 @@ func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, add
}
}
-func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, tun *tun.NativeTun) error {
- luid := winipcfg.LUID(tun.LUID())
-
+func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, luid winipcfg.LUID, clamper mtuClamper) error {
estimatedRouteCount := 0
for _, peer := range conf.Peers {
estimatedRouteCount += len(peer.AllowedIPs)
@@ -151,7 +147,9 @@ func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, tun *t
}
if conf.Interface.MTU > 0 {
ipif.NLMTU = uint32(conf.Interface.MTU)
- tun.ForceMTU(int(ipif.NLMTU))
+ if clamper != nil {
+ clamper.ForceMTU(int(ipif.NLMTU))
+ }
}
if family == windows.AF_INET {
if foundDefault4 {
@@ -174,7 +172,7 @@ func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, tun *t
return luid.SetDNS(family, conf.Interface.DNS, conf.Interface.DNSSearch)
}
-func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error {
+func enableFirewall(conf *conf.Config, luid winipcfg.LUID) error {
doNotRestrict := true
if len(conf.Peers) == 1 && !conf.Interface.TableOff {
nextallowedip:
@@ -191,5 +189,5 @@ func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error {
}
}
log.Println("Enabling firewall rules")
- return firewall.EnableFirewall(tun.LUID(), doNotRestrict, conf.Interface.DNS)
+ return firewall.EnableFirewall(uint64(luid), doNotRestrict, conf.Interface.DNS)
}
diff --git a/tunnel/defaultroutemonitor.go b/tunnel/defaultroutemonitor.go
index aa0db675..ac4241c9 100644
--- a/tunnel/defaultroutemonitor.go
+++ b/tunnel/defaultroutemonitor.go
@@ -11,9 +11,7 @@ import (
"time"
"golang.org/x/sys/windows"
-
"golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
@@ -61,14 +59,17 @@ func bindSocketRoute(family winipcfg.AddressFamily, binder conn.BindSocketToInte
return nil
}
-func monitorDefaultRoutes(family winipcfg.AddressFamily, binder conn.BindSocketToInterface, autoMTU bool, blackholeWhenLoop bool, tun *tun.NativeTun) ([]winipcfg.ChangeCallback, error) {
+type mtuClamper interface {
+ ForceMTU(mtu int)
+}
+
+func monitorDefaultRoutes(family winipcfg.AddressFamily, binder conn.BindSocketToInterface, autoMTU bool, blackholeWhenLoop bool, clamper mtuClamper, ourLUID winipcfg.LUID) ([]winipcfg.ChangeCallback, error) {
var minMTU uint32
if family == windows.AF_INET {
minMTU = 576
} else if family == windows.AF_INET6 {
minMTU = 1280
}
- ourLUID := winipcfg.LUID(tun.LUID())
lastLUID := winipcfg.LUID(0)
lastIndex := ^uint32(0)
lastMTU := uint32(0)
@@ -103,7 +104,11 @@ func monitorDefaultRoutes(family winipcfg.AddressFamily, binder conn.BindSocketT
if err != nil {
return err
}
- tun.ForceMTU(int(iface.NLMTU)) // TODO: having one MTU for both v4 and v6 kind of breaks the windows model, so right now this just gets the second one which is... bad.
+
+ // Having one MTU for both v4 and v6 kind of breaks the Windows model, so right now this just gets the
+ // second one which looks bad. However, internally, it doesn't seem like the Windows stack differentiates
+ // anyway, so it's probably fine.
+ clamper.ForceMTU(int(iface.NLMTU))
lastMTU = mtu
}
return nil
diff --git a/tunnel/interfacewatcher.go b/tunnel/interfacewatcher.go
index e12e5929..32132e93 100644
--- a/tunnel/interfacewatcher.go
+++ b/tunnel/interfacewatcher.go
@@ -12,8 +12,6 @@ import (
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/tun"
-
"golang.zx2c4.com/wireguard/windows/conf"
"golang.zx2c4.com/wireguard/windows/services"
"golang.zx2c4.com/wireguard/windows/tunnel/firewall"
@@ -31,9 +29,10 @@ type interfaceWatcherEvent struct {
type interfaceWatcher struct {
errors chan interfaceWatcherError
- binder conn.BindSocketToInterface
- conf *conf.Config
- tun *tun.NativeTun
+ binder conn.BindSocketToInterface
+ clamper mtuClamper
+ conf *conf.Config
+ luid winipcfg.LUID
setupMutex sync.Mutex
interfaceChangeCallback winipcfg.ChangeCallback
@@ -100,15 +99,24 @@ func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) {
}
var err error
- log.Printf("Monitoring default %s routes", ipversion)
- *changeCallbacks, err = monitorDefaultRoutes(family, iw.binder, iw.conf.Interface.MTU == 0, hasDefaultRoute(family, iw.conf.Peers), iw.tun)
- if err != nil {
- iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err}
- return
+ if iw.binder != nil && iw.clamper != nil {
+ log.Printf("Monitoring default %s routes", ipversion)
+ *changeCallbacks, err = monitorDefaultRoutes(family, iw.binder, iw.conf.Interface.MTU == 0, hasDefaultRoute(family, iw.conf.Peers), iw.clamper, iw.luid)
+ if err != nil {
+ iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err}
+ return
+ }
+ } else if iw.conf.Interface.MTU == 0 {
+ log.Printf("Monitoring MTU of default %s routes", ipversion)
+ *changeCallbacks, err = monitorMTU(family, iw.luid)
+ if err != nil {
+ iw.errors <- interfaceWatcherError{services.ErrorMonitorMTUChanges, err}
+ return
+ }
}
log.Printf("Setting device %s addresses", ipversion)
- err = configureInterface(family, iw.conf, iw.tun)
+ err = configureInterface(family, iw.conf, iw.luid, iw.clamper)
if err != nil {
iw.errors <- interfaceWatcherError{services.ErrorSetNetConfig, err}
return
@@ -127,11 +135,11 @@ func watchInterface() (*interfaceWatcher, error) {
if notificationType != winipcfg.MibAddInstance {
return
}
- if iw.tun == nil {
+ if iw.luid == 0 {
iw.storedEvents = append(iw.storedEvents, interfaceWatcherEvent{iface.InterfaceLUID, iface.Family})
return
}
- if iface.InterfaceLUID != winipcfg.LUID(iw.tun.LUID()) {
+ if iface.InterfaceLUID != iw.luid {
return
}
iw.setup(iface.Family)
@@ -142,13 +150,13 @@ func watchInterface() (*interfaceWatcher, error) {
return iw, nil
}
-func (iw *interfaceWatcher) Configure(binder conn.BindSocketToInterface, conf *conf.Config, tun *tun.NativeTun) {
+func (iw *interfaceWatcher) Configure(binder conn.BindSocketToInterface, clamper mtuClamper, conf *conf.Config, luid winipcfg.LUID) {
iw.setupMutex.Lock()
defer iw.setupMutex.Unlock()
- iw.binder, iw.conf, iw.tun = binder, conf, tun
+ iw.binder, iw.clamper, iw.conf, iw.luid = binder, clamper, conf, luid
for _, event := range iw.storedEvents {
- if event.luid == winipcfg.LUID(iw.tun.LUID()) {
+ if event.luid == luid {
iw.setup(event.family)
}
}
@@ -160,7 +168,7 @@ func (iw *interfaceWatcher) Destroy() {
changeCallbacks4 := iw.changeCallbacks4
changeCallbacks6 := iw.changeCallbacks6
interfaceChangeCallback := iw.interfaceChangeCallback
- tun := iw.tun
+ luid := iw.luid
iw.setupMutex.Unlock()
if interfaceChangeCallback != nil {
@@ -186,10 +194,9 @@ func (iw *interfaceWatcher) Destroy() {
changeCallbacks6 = changeCallbacks6[1:]
}
firewall.DisableFirewall()
- if tun != nil && iw.tun == tun {
+ if luid != 0 && iw.luid == luid {
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active
// routes, so to be certain, just remove everything before destroying.
- luid := winipcfg.LUID(tun.LUID())
luid.FlushRoutes(windows.AF_INET)
luid.FlushIPAddresses(windows.AF_INET)
luid.FlushDNS(windows.AF_INET)
diff --git a/tunnel/mtumonitor.go b/tunnel/mtumonitor.go
new file mode 100644
index 00000000..766ca1b8
--- /dev/null
+++ b/tunnel/mtumonitor.go
@@ -0,0 +1,113 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package tunnel
+
+import (
+ "golang.org/x/sys/windows"
+ "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
+)
+
+func findDefaultLUID(family winipcfg.AddressFamily, ourLUID winipcfg.LUID, lastLUID *winipcfg.LUID, lastIndex *uint32) error {
+ r, err := winipcfg.GetIPForwardTable2(family)
+ if err != nil {
+ return err
+ }
+ lowestMetric := ^uint32(0)
+ index := uint32(0)
+ luid := winipcfg.LUID(0)
+ for i := range r {
+ if r[i].DestinationPrefix.PrefixLength != 0 || r[i].InterfaceLUID == ourLUID {
+ continue
+ }
+ ifrow, err := r[i].InterfaceLUID.Interface()
+ if err != nil || ifrow.OperStatus != winipcfg.IfOperStatusUp {
+ continue
+ }
+
+ iface, err := r[i].InterfaceLUID.IPInterface(family)
+ if err != nil {
+ continue
+ }
+
+ if r[i].Metric+iface.Metric < lowestMetric {
+ lowestMetric = r[i].Metric + iface.Metric
+ index = r[i].InterfaceIndex
+ luid = r[i].InterfaceLUID
+ }
+ }
+ if luid == *lastLUID && index == *lastIndex {
+ return nil
+ }
+ *lastLUID = luid
+ *lastIndex = index
+ return nil
+}
+
+func monitorMTU(family winipcfg.AddressFamily, ourLUID winipcfg.LUID) ([]winipcfg.ChangeCallback, error) {
+ var minMTU uint32
+ if family == windows.AF_INET {
+ minMTU = 576
+ } else if family == windows.AF_INET6 {
+ minMTU = 1280
+ }
+ lastLUID := winipcfg.LUID(0)
+ lastIndex := ^uint32(0)
+ lastMTU := uint32(0)
+ doIt := func() error {
+ err := findDefaultLUID(family, ourLUID, &lastLUID, &lastIndex)
+ if err != nil {
+ return err
+ }
+ mtu := uint32(0)
+ if lastLUID != 0 {
+ iface, err := lastLUID.Interface()
+ if err != nil {
+ return err
+ }
+ if iface.MTU > 0 {
+ mtu = iface.MTU
+ }
+ }
+ if mtu > 0 && lastMTU != mtu {
+ iface, err := ourLUID.IPInterface(family)
+ if err != nil {
+ return err
+ }
+ iface.NLMTU = mtu - 80
+ if iface.NLMTU < minMTU {
+ iface.NLMTU = minMTU
+ }
+ err = iface.Set()
+ if err != nil {
+ return err
+ }
+ lastMTU = mtu
+ }
+ return nil
+ }
+ err := doIt()
+ if err != nil {
+ return nil, err
+ }
+ cbr, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) {
+ if route != nil && route.DestinationPrefix.PrefixLength == 0 {
+ doIt()
+ }
+ })
+ if err != nil {
+ return nil, err
+ }
+ cbi, err := winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) {
+ if notificationType == winipcfg.MibParameterNotification {
+ doIt()
+ }
+ })
+ if err != nil {
+ cbr.Unregister()
+ return nil, err
+ }
+ return []winipcfg.ChangeCallback{cbr, cbi}, nil
+}
diff --git a/tunnel/service.go b/tunnel/service.go
index 63cd243f..a595994c 100644
--- a/tunnel/service.go
+++ b/tunnel/service.go
@@ -21,11 +21,12 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
-
"golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/driver"
"golang.zx2c4.com/wireguard/windows/elevate"
"golang.zx2c4.com/wireguard/windows/ringlogger"
"golang.zx2c4.com/wireguard/windows/services"
+ "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"golang.zx2c4.com/wireguard/windows/version"
)
@@ -40,6 +41,9 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
var uapi net.Listener
var watcher *interfaceWatcher
var nativeTun *tun.NativeTun
+ var wintun tun.Device
+ var adapter *driver.Adapter
+ var luid winipcfg.LUID
var config *conf.Config
var err error
serviceError := services.ErrorSuccess
@@ -127,10 +131,10 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
if m, err := mgr.Connect(); err == nil {
if lockStatus, err := m.LockStatus(); err == nil && lockStatus.IsLocked {
- /* If we don't do this, then the Wintun installation will block forever, because
- * installing a Wintun device starts a service too. Apparently at boot time, Windows
- * 8.1 locks the SCM for each service start, creating a deadlock if we don't announce
- * that we're running before starting additional services.
+ /* If we don't do this, then the driver installation will block forever, because
+ * installing a network adapter starts the driver service too. Apparently at boot time,
+ * Windows 8.1 locks the SCM for each service start, creating a deadlock if we don't
+ * announce that we're running before starting additional services.
*/
log.Printf("SCM locked for %v by %s, marking service as started", lockStatus.Age, lockStatus.Owner)
changes <- svc.Status{State: svc.Running}
@@ -146,34 +150,81 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
}
log.Println("Resolving DNS names")
- uapiConf, err := config.ToUAPI()
+ err = config.ResolveEndpoints()
if err != nil {
serviceError = services.ErrorDNSLookup
return
}
- log.Println("Creating Wintun interface")
- var wintun tun.Device
- for i := 0; i < 5; i++ {
- if i > 0 {
- time.Sleep(time.Second)
- log.Printf("Retrying Wintun creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err)
+ log.Println("Creating network adapter")
+ if conf.AdminBool("ExperimentalKernelDriver") {
+ // Does an adapter with this name already exist?
+ adapter, err = driver.DefaultPool.OpenAdapter(config.Name)
+ if err == nil {
+ // If so, we delete it, in case it has weird residual configuration.
+ _, err = adapter.Delete()
+ if err != nil {
+ err = fmt.Errorf("Error deleting already existing adapter: %w", err)
+ serviceError = services.ErrorCreateNetworkAdapter
+ return
+ }
}
- wintun, err = tun.CreateTUNWithRequestedGUID(config.Name, deterministicGUID(config), 0)
- if err == nil || windows.DurationSinceBoot() > time.Minute*4 {
- break
+ for i := 0; i < 5; i++ {
+ if i > 0 {
+ time.Sleep(time.Second)
+ log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err)
+ }
+ var rebootRequired bool
+ adapter, rebootRequired, err = driver.DefaultPool.CreateAdapter(config.Name, deterministicGUID(config))
+ if err == nil || windows.DurationSinceBoot() > time.Minute*4 {
+ if rebootRequired {
+ log.Println("Windows indicated a reboot is required.")
+ }
+ break
+ }
+ }
+ if err != nil {
+ err = fmt.Errorf("Error creating adapter: %w", err)
+ serviceError = services.ErrorCreateNetworkAdapter
+ return
+ }
+ defer adapter.Delete()
+ luid = adapter.LUID()
+ driverVersion, err := driver.RunningVersion()
+ if err != nil {
+ log.Printf("Warning: unable to determine driver version: %v", err)
+ } else {
+ log.Printf("Using WireGuardNT/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff)
+ }
+ err = adapter.SetLogging(driver.AdapterLogOn)
+ if err != nil {
+ err = fmt.Errorf("Error enabling adapter logging: %w", err)
+ serviceError = services.ErrorCreateNetworkAdapter
+ return
}
- }
- if err != nil {
- serviceError = services.ErrorCreateWintun
- return
- }
- nativeTun = wintun.(*tun.NativeTun)
- wintunVersion, err := nativeTun.RunningVersion()
- if err != nil {
- log.Printf("Warning: unable to determine Wintun version: %v", err)
} else {
- log.Printf("Using Wintun/%d.%d", (wintunVersion>>16)&0xffff, wintunVersion&0xffff)
+ for i := 0; i < 5; i++ {
+ if i > 0 {
+ time.Sleep(time.Second)
+ log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err)
+ }
+ wintun, err = tun.CreateTUNWithRequestedGUID(config.Name, deterministicGUID(config), 0)
+ if err == nil || windows.DurationSinceBoot() > time.Minute*4 {
+ break
+ }
+ }
+ if err != nil {
+ serviceError = services.ErrorCreateNetworkAdapter
+ return
+ }
+ nativeTun = wintun.(*tun.NativeTun)
+ luid = winipcfg.LUID(nativeTun.LUID())
+ driverVersion, err := nativeTun.RunningVersion()
+ if err != nil {
+ log.Printf("Warning: unable to determine driver version: %v", err)
+ } else {
+ log.Printf("Using Wintun/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff)
+ }
}
err = runScriptCommand(config.Interface.PreUp, config.Name)
@@ -182,7 +233,7 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
return
}
- err = enableFirewall(config, nativeTun)
+ err = enableFirewall(config, luid)
if err != nil {
serviceError = services.ErrorFirewall
return
@@ -195,37 +246,51 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
return
}
- log.Println("Creating interface instance")
- bind := conn.NewDefaultBind()
- dev = device.NewDevice(wintun, bind, &device.Logger{log.Printf, log.Printf})
+ if nativeTun != nil {
+ log.Println("Creating interface instance")
+ bind := conn.NewDefaultBind()
+ dev = device.NewDevice(wintun, bind, &device.Logger{log.Printf, log.Printf})
- log.Println("Setting interface configuration")
- uapi, err = ipc.UAPIListen(config.Name)
- if err != nil {
- serviceError = services.ErrorUAPIListen
- return
- }
- err = dev.IpcSet(uapiConf)
- if err != nil {
- serviceError = services.ErrorDeviceSetConfig
- return
- }
+ log.Println("Setting interface configuration")
+ uapi, err = ipc.UAPIListen(config.Name)
+ if err != nil {
+ serviceError = services.ErrorUAPIListen
+ return
+ }
+ err = dev.IpcSet(config.ToUAPI())
+ if err != nil {
+ serviceError = services.ErrorDeviceSetConfig
+ return
+ }
- log.Println("Bringing peers up")
- dev.Up()
+ log.Println("Bringing peers up")
+ dev.Up()
- watcher.Configure(bind.(conn.BindSocketToInterface), config, nativeTun)
+ var clamper mtuClamper
+ clamper = nativeTun
+ watcher.Configure(bind.(conn.BindSocketToInterface), clamper, config, luid)
- log.Println("Listening for UAPI requests")
- go func() {
- for {
- conn, err := uapi.Accept()
- if err != nil {
- continue
+ log.Println("Listening for UAPI requests")
+ go func() {
+ for {
+ conn, err := uapi.Accept()
+ if err != nil {
+ continue
+ }
+ go dev.IpcHandle(conn)
}
- go dev.IpcHandle(conn)
+ }()
+ } else {
+ err = adapter.SetConfiguration(config.ToDriverConfiguration())
+ if err != nil {
+ serviceError = services.ErrorDeviceSetConfig
}
- }()
+ err = adapter.SetAdapterState(driver.AdapterStateUp)
+ if err != nil {
+ serviceError = services.ErrorDeviceBringUp
+ }
+ watcher.Configure(nil, nil, config, luid)
+ }
err = runScriptCommand(config.Interface.PostUp, config.Name)
if err != nil {
@@ -236,6 +301,12 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
changes <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptShutdown}
log.Println("Startup complete")
+ var devWaitChan chan struct{}
+ if dev != nil {
+ devWaitChan = dev.Wait()
+ } else {
+ devWaitChan = make(chan struct{})
+ }
for {
select {
case c := <-r:
@@ -247,7 +318,7 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
default:
log.Printf("Unexpected service control request #%d\n", c)
}
- case <-dev.Wait():
+ case <-devWaitChan:
return
case e := <-watcher.errors:
serviceError, err = e.serviceError, e.err
diff --git a/tunnel/winipcfg/types.go b/tunnel/winipcfg/types.go
index 4dc52d8b..b06f05dd 100644
--- a/tunnel/winipcfg/types.go
+++ b/tunnel/winipcfg/types.go
@@ -6,6 +6,7 @@
package winipcfg
import (
+ "encoding/binary"
"net"
"unsafe"
@@ -734,6 +735,16 @@ type RawSockaddrInet struct {
data [26]byte
}
+func ntohs(i uint16) uint16 {
+ return binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&i))[:])
+}
+
+func htons(i uint16) uint16 {
+ b := make([]byte, 2)
+ binary.BigEndian.PutUint16(b, i)
+ return *(*uint16)(unsafe.Pointer(&b[0]))
+}
+
// SetIP method sets family, address, and port to the given IPv4 or IPv6 address and port.
// All other members of the structure are set to zero.
func (addr *RawSockaddrInet) SetIP(ip net.IP, port uint16) error {
@@ -741,7 +752,7 @@ func (addr *RawSockaddrInet) SetIP(ip net.IP, port uint16) error {
addr4 := (*windows.RawSockaddrInet4)(unsafe.Pointer(addr))
addr4.Family = windows.AF_INET
copy(addr4.Addr[:], v4)
- addr4.Port = windows.Ntohs(port)
+ addr4.Port = htons(port)
for i := 0; i < 8; i++ {
addr4.Zero[i] = 0
}
@@ -751,7 +762,7 @@ func (addr *RawSockaddrInet) SetIP(ip net.IP, port uint16) error {
if v6 := ip.To16(); v6 != nil {
addr6 := (*windows.RawSockaddrInet6)(unsafe.Pointer(addr))
addr6.Family = windows.AF_INET6
- addr6.Port = windows.Ntohs(port)
+ addr6.Port = htons(port)
addr6.Flowinfo = 0
copy(addr6.Addr[:], v6)
addr6.Scope_id = 0
@@ -761,8 +772,7 @@ func (addr *RawSockaddrInet) SetIP(ip net.IP, port uint16) error {
return windows.ERROR_INVALID_PARAMETER
}
-// IP method returns IPv4 or IPv6 address.
-// If the address is neither IPv4 not IPv6 nil is returned.
+// IP returns IPv4 or IPv6 address, or nil if the address is neither.
func (addr *RawSockaddrInet) IP() net.IP {
switch addr.Family {
case windows.AF_INET:
@@ -775,6 +785,19 @@ func (addr *RawSockaddrInet) IP() net.IP {
return nil
}
+// Port returns the port if the address if IPv4 or IPv6, or 0 if neither.
+func (addr *RawSockaddrInet) Port() uint16 {
+ switch addr.Family {
+ case windows.AF_INET:
+ return ntohs((*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Port)
+
+ case windows.AF_INET6:
+ return ntohs((*windows.RawSockaddrInet6)(unsafe.Pointer(addr)).Port)
+ }
+
+ return 0
+}
+
// Init method initializes a MibUnicastIPAddressRow structure with default values for a unicast IP address entry on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-initializeunicastipaddressentry
func (row *MibUnicastIPAddressRow) Init() {