aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/manager
diff options
context:
space:
mode:
Diffstat (limited to 'manager')
-rw-r--r--manager/install.go51
-rw-r--r--manager/interfacecleanup.go56
-rw-r--r--manager/ipc_client.go17
-rw-r--r--manager/ipc_driver.go61
-rw-r--r--manager/ipc_server.go128
-rw-r--r--manager/ipc_uapi.go71
-rw-r--r--manager/legacystore.go129
-rw-r--r--manager/service.go65
-rw-r--r--manager/tunneltracker.go130
-rw-r--r--manager/uiprocess.go103
-rw-r--r--manager/updatestate.go34
11 files changed, 424 insertions, 421 deletions
diff --git a/manager/install.go b/manager/install.go
index bd7d23c7..44a744cf 100644
--- a/manager/install.go
+++ b/manager/install.go
@@ -1,13 +1,15 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package manager
import (
"errors"
+ "log"
"os"
+ "strings"
"time"
"golang.org/x/sys/windows"
@@ -15,7 +17,6 @@ import (
"golang.org/x/sys/windows/svc/mgr"
"golang.zx2c4.com/wireguard/windows/conf"
- "golang.zx2c4.com/wireguard/windows/services"
)
var cachedServiceManager *mgr.Mgr
@@ -128,7 +129,7 @@ func InstallTunnel(configPath string) error {
return err
}
- serviceName, err := services.ServiceNameOfTunnel(name)
+ serviceName, err := conf.ServiceNameOfTunnel(name)
if err != nil {
return err
}
@@ -181,7 +182,7 @@ func UninstallTunnel(name string) error {
if err != nil {
return err
}
- serviceName, err := services.ServiceNameOfTunnel(name)
+ serviceName, err := conf.ServiceNameOfTunnel(name)
if err != nil {
return err
}
@@ -197,3 +198,45 @@ func UninstallTunnel(name string) error {
}
return err2
}
+
+func changeTunnelServiceConfigFilePath(name, oldPath, newPath string) {
+ var err error
+ defer func() {
+ if err != nil {
+ log.Printf("Unable to change tunnel service command line argument from %#q to %#q: %v", oldPath, newPath, err)
+ }
+ }()
+ m, err := serviceManager()
+ if err != nil {
+ return
+ }
+ serviceName, err := conf.ServiceNameOfTunnel(name)
+ if err != nil {
+ return
+ }
+ service, err := m.OpenService(serviceName)
+ if err == windows.ERROR_SERVICE_DOES_NOT_EXIST {
+ err = nil
+ return
+ } else if err != nil {
+ return
+ }
+ defer service.Close()
+ config, err := service.Config()
+ if err != nil {
+ return
+ }
+ exePath, err := os.Executable()
+ if err != nil {
+ return
+ }
+ args, err := windows.DecomposeCommandLine(config.BinaryPathName)
+ if err != nil || len(args) != 3 ||
+ !strings.EqualFold(args[0], exePath) || args[1] != "/tunnelservice" || !strings.EqualFold(args[2], oldPath) {
+ err = nil
+ return
+ }
+ args[2] = newPath
+ config.BinaryPathName = windows.ComposeCommandLine(args)
+ err = service.UpdateConfig(config)
+}
diff --git a/manager/interfacecleanup.go b/manager/interfacecleanup.go
deleted file mode 100644
index c270f4ab..00000000
--- a/manager/interfacecleanup.go
+++ /dev/null
@@ -1,56 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package manager
-
-import (
- "log"
-
- "golang.org/x/sys/windows"
- "golang.org/x/sys/windows/svc"
- "golang.org/x/sys/windows/svc/mgr"
- "golang.zx2c4.com/wireguard/tun/wintun"
-
- "golang.zx2c4.com/wireguard/tun"
- "golang.zx2c4.com/wireguard/windows/services"
-)
-
-func cleanupStaleWintunInterfaces() {
- m, err := mgr.Connect()
- if err != nil {
- return
- }
- defer m.Disconnect()
-
- tun.WintunPool.DeleteMatchingAdapters(func(wintun *wintun.Adapter) bool {
- interfaceName, err := wintun.Name()
- if err != nil {
- log.Printf("Removing Wintun interface because determining interface name failed: %v", err)
- return true
- }
- serviceName, err := services.ServiceNameOfTunnel(interfaceName)
- if err != nil {
- log.Printf("Removing Wintun interface ‘%s’ because determining tunnel service name failed: %v", interfaceName, err)
- return true
- }
- service, err := m.OpenService(serviceName)
- if err == windows.ERROR_SERVICE_DOES_NOT_EXIST {
- log.Printf("Removing Wintun interface ‘%s’ because no service for it exists", interfaceName)
- return true
- } else if err != nil {
- return false
- }
- defer service.Close()
- status, err := service.Query()
- if err != nil {
- return false
- }
- if status.State == svc.Stopped {
- log.Printf("Removing Wintun interface ‘%s’ because its service is stopped", interfaceName)
- return true
- }
- return false
- }, false)
-}
diff --git a/manager/ipc_client.go b/manager/ipc_client.go
index 2f78a47e..8c9c4c04 100644
--- a/manager/ipc_client.go
+++ b/manager/ipc_client.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package manager
@@ -64,7 +64,7 @@ var (
)
type TunnelChangeCallback struct {
- cb func(tunnel *Tunnel, state TunnelState, globalState TunnelState, err error)
+ cb func(tunnel *Tunnel, state, globalState TunnelState, err error)
}
var tunnelChangeCallbacks = make(map[*TunnelChangeCallback]bool)
@@ -93,7 +93,7 @@ type UpdateProgressCallback struct {
var updateProgressCallbacks = make(map[*UpdateProgressCallback]bool)
-func InitializeIPCClient(reader *os.File, writer *os.File, events *os.File) {
+func InitializeIPCClient(reader, writer, events *os.File) {
rpcDecoder = gob.NewDecoder(reader)
rpcEncoder = gob.NewEncoder(writer)
go func() {
@@ -431,43 +431,52 @@ func IPCClientUpdate() error {
return rpcEncoder.Encode(UpdateMethodType)
}
-func IPCClientRegisterTunnelChange(cb func(tunnel *Tunnel, state TunnelState, globalState TunnelState, err error)) *TunnelChangeCallback {
+func IPCClientRegisterTunnelChange(cb func(tunnel *Tunnel, state, globalState TunnelState, err error)) *TunnelChangeCallback {
s := &TunnelChangeCallback{cb}
tunnelChangeCallbacks[s] = true
return s
}
+
func (cb *TunnelChangeCallback) Unregister() {
delete(tunnelChangeCallbacks, cb)
}
+
func IPCClientRegisterTunnelsChange(cb func()) *TunnelsChangeCallback {
s := &TunnelsChangeCallback{cb}
tunnelsChangeCallbacks[s] = true
return s
}
+
func (cb *TunnelsChangeCallback) Unregister() {
delete(tunnelsChangeCallbacks, cb)
}
+
func IPCClientRegisterManagerStopping(cb func()) *ManagerStoppingCallback {
s := &ManagerStoppingCallback{cb}
managerStoppingCallbacks[s] = true
return s
}
+
func (cb *ManagerStoppingCallback) Unregister() {
delete(managerStoppingCallbacks, cb)
}
+
func IPCClientRegisterUpdateFound(cb func(updateState UpdateState)) *UpdateFoundCallback {
s := &UpdateFoundCallback{cb}
updateFoundCallbacks[s] = true
return s
}
+
func (cb *UpdateFoundCallback) Unregister() {
delete(updateFoundCallbacks, cb)
}
+
func IPCClientRegisterUpdateProgress(cb func(dp updater.DownloadProgress)) *UpdateProgressCallback {
s := &UpdateProgressCallback{cb}
updateProgressCallbacks[s] = true
return s
}
+
func (cb *UpdateProgressCallback) Unregister() {
delete(updateProgressCallbacks, cb)
}
diff --git a/manager/ipc_driver.go b/manager/ipc_driver.go
new file mode 100644
index 00000000..6cb43c38
--- /dev/null
+++ b/manager/ipc_driver.go
@@ -0,0 +1,61 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
+ */
+
+package manager
+
+import (
+ "sync"
+
+ "golang.zx2c4.com/wireguard/windows/driver"
+)
+
+type lockedDriverAdapter struct {
+ *driver.Adapter
+ sync.Mutex
+}
+
+var (
+ driverAdapters = make(map[string]*lockedDriverAdapter)
+ driverAdaptersLock sync.RWMutex
+)
+
+func findDriverAdapter(tunnelName string) (*lockedDriverAdapter, error) {
+ driverAdaptersLock.RLock()
+ driverAdapter, ok := driverAdapters[tunnelName]
+ if ok {
+ driverAdapter.Lock()
+ driverAdaptersLock.RUnlock()
+ return driverAdapter, nil
+ }
+ driverAdaptersLock.RUnlock()
+ driverAdaptersLock.Lock()
+ defer driverAdaptersLock.Unlock()
+ driverAdapter, ok = driverAdapters[tunnelName]
+ if ok {
+ driverAdapter.Lock()
+ return driverAdapter, nil
+ }
+ driverAdapter = &lockedDriverAdapter{}
+ var err error
+ driverAdapter.Adapter, err = driver.OpenAdapter(tunnelName)
+ if err != nil {
+ return nil, err
+ }
+ driverAdapters[tunnelName] = driverAdapter
+ driverAdapter.Lock()
+ return driverAdapter, nil
+}
+
+func releaseDriverAdapter(tunnelName string) {
+ driverAdaptersLock.Lock()
+ defer driverAdaptersLock.Unlock()
+ driverAdapter, ok := driverAdapters[tunnelName]
+ if !ok {
+ return
+ }
+ driverAdapter.Lock()
+ delete(driverAdapters, tunnelName)
+ driverAdapter.Unlock()
+}
diff --git a/manager/ipc_server.go b/manager/ipc_server.go
index 8ba050f9..e21ffaf0 100644
--- a/manager/ipc_server.go
+++ b/manager/ipc_server.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package manager
@@ -20,14 +20,15 @@ import (
"golang.org/x/sys/windows/svc"
"golang.zx2c4.com/wireguard/windows/conf"
- "golang.zx2c4.com/wireguard/windows/services"
"golang.zx2c4.com/wireguard/windows/updater"
)
-var managerServices = make(map[*ManagerService]bool)
-var managerServicesLock sync.RWMutex
-var haveQuit uint32
-var quitManagersChan = make(chan struct{}, 1)
+var (
+ managerServices = make(map[*ManagerService]bool)
+ managerServicesLock sync.RWMutex
+ haveQuit uint32
+ quitManagersChan = make(chan struct{}, 1)
+)
type ManagerService struct {
events *os.File
@@ -51,33 +52,18 @@ func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error)
if err != nil {
return nil, err
}
- pipe, err := connectTunnelServicePipe(tunnelName)
- if err != nil {
- return nil, err
- }
- pipe.SetDeadline(time.Now().Add(time.Second * 2))
- _, err = pipe.Write([]byte("get=1\n\n"))
- if err == windows.ERROR_NO_DATA {
- log.Println("IPC pipe closed unexpectedly, so reopening")
- pipe.Unlock()
- disconnectTunnelServicePipe(tunnelName)
- pipe, err = connectTunnelServicePipe(tunnelName)
- if err != nil {
- return nil, err
- }
- pipe.SetDeadline(time.Now().Add(time.Second * 2))
- _, err = pipe.Write([]byte("get=1\n\n"))
- }
+ driverAdapter, err := findDriverAdapter(tunnelName)
if err != nil {
- pipe.Unlock()
- disconnectTunnelServicePipe(tunnelName)
return nil, err
}
- conf, err := conf.FromUAPI(pipe, storedConfig)
- pipe.Unlock()
+ runtimeConfig, err := driverAdapter.Configuration()
if err != nil {
+ driverAdapter.Unlock()
+ releaseDriverAdapter(tunnelName)
return nil, err
}
+ conf := conf.FromDriverConfiguration(runtimeConfig, storedConfig)
+ driverAdapter.Unlock()
if s.elevatedToken == 0 {
conf.Redact()
}
@@ -85,44 +71,47 @@ func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error)
}
func (s *ManagerService) Start(tunnelName string) error {
- // TODO: Rather than being lazy and gating this behind a knob (yuck!), we should instead keep track of the routes
- // of each tunnel, and only deactivate in the case of a tunnel with identical routes being added.
- if !conf.AdminBool("MultipleSimultaneousTunnels") {
- trackedTunnelsLock.Lock()
- tt := make([]string, 0, len(trackedTunnels))
- var inTransition string
- for t, state := range trackedTunnels {
- tt = append(tt, t)
- if len(t) > 0 && (state == TunnelStarting || state == TunnelUnknown) {
- inTransition = t
- break
- }
- }
- trackedTunnelsLock.Unlock()
- if len(inTransition) != 0 {
- return fmt.Errorf("Please allow the tunnel ‘%s’ to finish activating", inTransition)
- }
- go func() {
- for _, t := range tt {
- s.Stop(t)
- }
- for _, t := range tt {
- state, err := s.State(t)
- if err == nil && (state == TunnelStarted || state == TunnelStarting) {
- log.Printf("[%s] Trying again to stop zombie tunnel", t)
- s.Stop(t)
- time.Sleep(time.Millisecond * 100)
- }
- }
- }()
- }
- time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces)
-
- // After that process is started -- it's somewhat asynchronous -- we install the new one.
c, err := conf.LoadFromName(tunnelName)
if err != nil {
return err
}
+
+ // Figure out which tunnels have intersecting addresses/routes and stop those.
+ trackedTunnelsLock.Lock()
+ tt := make([]string, 0, len(trackedTunnels))
+ var inTransition string
+ for t, state := range trackedTunnels {
+ c2, err := conf.LoadFromName(t)
+ if err != nil || !c.IntersectsWith(c2) {
+ // If we can't get the config, assume it doesn't intersect.
+ continue
+ }
+ tt = append(tt, t)
+ if len(t) > 0 && (state == TunnelStarting || state == TunnelUnknown) {
+ inTransition = t
+ break
+ }
+ }
+ trackedTunnelsLock.Unlock()
+ if len(inTransition) != 0 {
+ return fmt.Errorf("Please allow the tunnel ‘%s’ to finish activating", inTransition)
+ }
+
+ // Stop those intersecting tunnels asynchronously.
+ go func() {
+ for _, t := range tt {
+ s.Stop(t)
+ }
+ for _, t := range tt {
+ state, err := s.State(t)
+ if err == nil && (state == TunnelStarted || state == TunnelStarting) {
+ log.Printf("[%s] Trying again to stop zombie tunnel", t)
+ s.Stop(t)
+ time.Sleep(time.Millisecond * 100)
+ }
+ }
+ }()
+ // After the stop process has begun, but before it's finished, we install the new one.
path, err := c.Path()
if err != nil {
return err
@@ -131,8 +120,6 @@ func (s *ManagerService) Start(tunnelName string) error {
}
func (s *ManagerService) Stop(tunnelName string) error {
- time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces)
-
err := UninstallTunnel(tunnelName)
if err == windows.ERROR_SERVICE_DOES_NOT_EXIST {
_, notExistsError := conf.LoadFromName(tunnelName)
@@ -144,7 +131,7 @@ func (s *ManagerService) Stop(tunnelName string) error {
}
func (s *ManagerService) WaitForStop(tunnelName string) error {
- serviceName, err := services.ServiceNameOfTunnel(tunnelName)
+ serviceName, err := conf.ServiceNameOfTunnel(tunnelName)
if err != nil {
return err
}
@@ -175,7 +162,7 @@ func (s *ManagerService) Delete(tunnelName string) error {
}
func (s *ManagerService) State(tunnelName string) (TunnelState, error) {
- serviceName, err := services.ServiceNameOfTunnel(tunnelName)
+ serviceName, err := conf.ServiceNameOfTunnel(tunnelName)
if err != nil {
return 0, err
}
@@ -230,7 +217,7 @@ func (s *ManagerService) Tunnels() ([]Tunnel, error) {
}
tunnels := make([]Tunnel, len(names))
for i := 0; i < len(tunnels); i++ {
- (tunnels)[i].Name = names[i]
+ tunnels[i].Name = names[i]
}
return tunnels, nil
// TODO: account for running ones that aren't in the configuration store somehow
@@ -267,9 +254,6 @@ func (s *ManagerService) Quit(stopTunnelsOnQuit bool) (alreadyQuit bool, err err
}
func (s *ManagerService) UpdateState() UpdateState {
- if s.elevatedToken == 0 {
- return UpdateStateUnknown
- }
return updateState
}
@@ -457,7 +441,7 @@ func (s *ManagerService) ServeConn(reader io.Reader, writer io.Writer) {
}
}
-func IPCServerListen(reader *os.File, writer *os.File, events *os.File, elevatedToken windows.Token) {
+func IPCServerListen(reader, writer, events *os.File, elevatedToken windows.Token) {
service := &ManagerService{
events: events,
elevatedToken: elevatedToken,
@@ -477,7 +461,7 @@ func IPCServerListen(reader *os.File, writer *os.File, events *os.File, elevated
}()
}
-func notifyAll(notificationType NotificationType, adminOnly bool, ifaces ...interface{}) {
+func notifyAll(notificationType NotificationType, adminOnly bool, ifaces ...any) {
if len(managerServices) == 0 {
return
}
@@ -528,7 +512,7 @@ func IPCServerNotifyTunnelsChange() {
}
func IPCServerNotifyUpdateFound(state UpdateState) {
- notifyAll(UpdateFoundNotificationType, true, state)
+ notifyAll(UpdateFoundNotificationType, false, state)
}
func IPCServerNotifyUpdateProgress(dp updater.DownloadProgress) {
diff --git a/manager/ipc_uapi.go b/manager/ipc_uapi.go
deleted file mode 100644
index 85477125..00000000
--- a/manager/ipc_uapi.go
+++ /dev/null
@@ -1,71 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package manager
-
-import (
- "net"
- "sync"
-
- "golang.org/x/sys/windows"
- "golang.zx2c4.com/wireguard/ipc/winpipe"
-
- "golang.zx2c4.com/wireguard/windows/services"
-)
-
-type connectedTunnel struct {
- net.Conn
- sync.Mutex
-}
-
-var connectedTunnelServicePipes = make(map[string]*connectedTunnel)
-var connectedTunnelServicePipesLock sync.RWMutex
-
-func connectTunnelServicePipe(tunnelName string) (*connectedTunnel, error) {
- connectedTunnelServicePipesLock.RLock()
- pipe, ok := connectedTunnelServicePipes[tunnelName]
- if ok {
- pipe.Lock()
- connectedTunnelServicePipesLock.RUnlock()
- return pipe, nil
- }
- connectedTunnelServicePipesLock.RUnlock()
- connectedTunnelServicePipesLock.Lock()
- defer connectedTunnelServicePipesLock.Unlock()
- pipe, ok = connectedTunnelServicePipes[tunnelName]
- if ok {
- pipe.Lock()
- return pipe, nil
- }
- pipePath, err := services.PipePathOfTunnel(tunnelName)
- if err != nil {
- return nil, err
- }
- localSystem, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid)
- if err != nil {
- return nil, err
- }
- pipe = &connectedTunnel{}
- pipe.Conn, err = winpipe.Dial(pipePath, nil, &winpipe.DialConfig{ExpectedOwner: localSystem})
- if err != nil {
- return nil, err
- }
- connectedTunnelServicePipes[tunnelName] = pipe
- pipe.Lock()
- return pipe, nil
-}
-
-func disconnectTunnelServicePipe(tunnelName string) {
- connectedTunnelServicePipesLock.Lock()
- defer connectedTunnelServicePipesLock.Unlock()
- pipe, ok := connectedTunnelServicePipes[tunnelName]
- if !ok {
- return
- }
- pipe.Lock()
- pipe.Close()
- delete(connectedTunnelServicePipes, tunnelName)
- pipe.Unlock()
-}
diff --git a/manager/legacystore.go b/manager/legacystore.go
deleted file mode 100644
index 4125cb86..00000000
--- a/manager/legacystore.go
+++ /dev/null
@@ -1,129 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package manager
-
-import (
- "fmt"
- "log"
- "os"
- "path/filepath"
- "regexp"
- "strings"
-
- "golang.org/x/sys/windows"
- "golang.org/x/sys/windows/registry"
- "golang.org/x/sys/windows/svc/mgr"
-
- "golang.zx2c4.com/wireguard/windows/conf"
-)
-
-func moveConfigsFromLegacyStore() {
- oldRoot, err := windows.KnownFolderPath(windows.FOLDERID_LocalAppData, windows.KF_FLAG_DEFAULT)
- if err != nil {
- return
- }
- oldC := filepath.Join(oldRoot, "WireGuard", "Configurations")
- files, err := os.ReadDir(oldC)
- if err != nil {
- return
- }
- pendingDeletion := make(map[string]bool)
- if key, err := registry.OpenKey(registry.LOCAL_MACHINE, `SYSTEM\CurrentControlSet\Control\Session Manager`, registry.READ); err == nil {
- if ntPaths, _, err := key.GetStringsValue("PendingFileRenameOperations"); err == nil {
- for _, ntPath := range ntPaths {
- pendingDeletion[strings.ToLower(strings.TrimPrefix(ntPath, `\??\`))] = true
- }
- }
- key.Close()
- }
- migratedConfigs := make(map[string]string)
- for i := range files {
- if files[i].IsDir() {
- continue
- }
- fileName := files[i].Name()
- oldPath := filepath.Join(oldC, fileName)
- if pendingDeletion[strings.ToLower(oldPath)] {
- continue
- }
- config, err := conf.LoadFromPath(oldPath)
- if err != nil {
- continue
- }
- newPath, err := config.Path()
- if err != nil {
- continue
- }
- err = config.Save(false)
- if err != nil {
- continue
- }
- oldPath16, err := windows.UTF16PtrFromString(oldPath)
- if err == nil {
- windows.MoveFileEx(oldPath16, nil, windows.MOVEFILE_DELAY_UNTIL_REBOOT)
- }
- migratedConfigs[strings.ToLower(oldPath)] = newPath
- log.Printf("Migrated configuration from ‘%s’ to ‘%s’", oldPath, newPath)
- }
- oldC16, err := windows.UTF16PtrFromString(oldC)
- if err == nil {
- windows.MoveFileEx(oldC16, nil, windows.MOVEFILE_DELAY_UNTIL_REBOOT)
- }
- oldLog16, err := windows.UTF16PtrFromString(filepath.Join(oldRoot, "WireGuard", "log.bin"))
- if err == nil {
- windows.MoveFileEx(oldLog16, nil, windows.MOVEFILE_DELAY_UNTIL_REBOOT)
- }
- oldRoot16, err := windows.UTF16PtrFromString(filepath.Join(oldRoot, "WireGuard"))
- if err == nil {
- windows.MoveFileEx(oldRoot16, nil, windows.MOVEFILE_DELAY_UNTIL_REBOOT)
- }
- if len(migratedConfigs) == 0 {
- return
- }
- m, err := mgr.Connect()
- if err != nil {
- return
- }
- defer m.Disconnect()
- services, err := m.ListServices()
- if err != nil {
- return
- }
- matcher, err := regexp.Compile(" /tunnelservice \"?([^\"]+)\"?$")
- if err != nil {
- return
- }
- for _, svcName := range services {
- if !strings.HasPrefix(svcName, "WireGuardTunnel$") {
- continue
- }
- svc, err := m.OpenService(svcName)
- if err != nil {
- continue
- }
- config, err := svc.Config()
- if err != nil {
- continue
- }
- matches := matcher.FindStringSubmatchIndex(config.BinaryPathName)
- if len(matches) != 4 {
- svc.Close()
- continue
- }
- newName, found := migratedConfigs[strings.ToLower(config.BinaryPathName[matches[2]:])]
- if !found {
- svc.Close()
- continue
- }
- config.BinaryPathName = config.BinaryPathName[:matches[0]] + fmt.Sprintf(" /tunnelservice \"%s\"", newName)
- err = svc.UpdateConfig(config)
- svc.Close()
- if err != nil {
- continue
- }
- log.Printf("Migrated service command line arguments for ‘%s’", svcName)
- }
-}
diff --git a/manager/service.go b/manager/service.go
index 99187eab..47e20d45 100644
--- a/manager/service.go
+++ b/manager/service.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package manager
@@ -12,18 +12,17 @@ import (
"runtime"
"strconv"
"sync"
- "syscall"
"time"
"unsafe"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
+ "golang.zx2c4.com/wireguard/windows/driver"
"golang.zx2c4.com/wireguard/windows/conf"
"golang.zx2c4.com/wireguard/windows/elevate"
"golang.zx2c4.com/wireguard/windows/ringlogger"
"golang.zx2c4.com/wireguard/windows/services"
- "golang.zx2c4.com/wireguard/windows/version"
)
type managerService struct{}
@@ -43,38 +42,36 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
changes <- svc.Status{State: svc.StopPending}
}()
- err = ringlogger.InitGlobalLogger("MGR")
+ var logFile string
+ logFile, err = conf.LogFile(true)
if err != nil {
serviceError = services.ErrorRingloggerOpen
return
}
-
- log.Println("Starting", version.UserAgent())
-
- path, err := os.Executable()
+ err = ringlogger.InitGlobalLogger(logFile, "MGR")
if err != nil {
- serviceError = services.ErrorDetermineExecutablePath
+ serviceError = services.ErrorRingloggerOpen
return
}
- devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0)
+ services.PrintStarting()
+
+ path, err := os.Executable()
if err != nil {
- serviceError = services.ErrorOpenNULFile
+ serviceError = services.ErrorDetermineExecutablePath
return
}
- moveConfigsFromLegacyStore()
-
- err = trackExistingTunnels()
+ err = watchNewTunnelServices()
if err != nil {
serviceError = services.ErrorTrackTunnels
return
}
- conf.RegisterStoreChangeCallback(conf.MigrateUnencryptedConfigs)
+ conf.RegisterStoreChangeCallback(func() { conf.MigrateUnencryptedConfigs(changeTunnelServiceConfigFilePath) })
conf.RegisterStoreChangeCallback(IPCServerNotifyTunnelsChange)
- procs := make(map[uint32]*os.Process)
+ procs := make(map[uint32]*uiProcess)
aliveSessions := make(map[uint32]bool)
procsLock := sync.Mutex{}
stoppingManager := false
@@ -196,29 +193,22 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
}
log.Printf("Starting UI process for user ‘%s@%s’ for session %d", username, domain, session)
- attr := &os.ProcAttr{
- Sys: &syscall.SysProcAttr{
- Token: syscall.Token(runToken),
- AdditionalInheritedHandles: []syscall.Handle{
- syscall.Handle(theirReader.Fd()),
- syscall.Handle(theirWriter.Fd()),
- syscall.Handle(theirEvents.Fd()),
- syscall.Handle(theirLogMapping)},
- },
- Files: []*os.File{devNull, devNull, devNull},
- Dir: userProfileDirectory,
- }
procsLock.Lock()
- var proc *os.Process
+ var proc *uiProcess
if alive := aliveSessions[session]; alive {
- proc, err = os.StartProcess(path, []string{
+ proc, err = launchUIProcess(path, []string{
path,
"/ui",
strconv.FormatUint(uint64(theirReader.Fd()), 10),
strconv.FormatUint(uint64(theirWriter.Fd()), 10),
strconv.FormatUint(uint64(theirEvents.Fd()), 10),
strconv.FormatUint(uint64(theirLogMapping), 10),
- }, attr)
+ }, userProfileDirectory, []windows.Handle{
+ windows.Handle(theirReader.Fd()),
+ windows.Handle(theirWriter.Fd()),
+ windows.Handle(theirEvents.Fd()),
+ theirLogMapping,
+ }, runToken)
} else {
err = errors.New("Session has logged out")
}
@@ -240,9 +230,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
procsLock.Unlock()
sessionIsDead := false
- processStatus, err := proc.Wait()
- if err == nil {
- exitCode := processStatus.Sys().(syscall.WaitStatus).ExitCode
+ if exitCode, err := proc.Wait(); err == nil {
log.Printf("Exited UI process for user '%s@%s' for session %d with status %x", username, domain, session, exitCode)
const STATUS_DLL_INIT_FAILED_LOGOFF = 0xC000026B
sessionIsDead = exitCode == STATUS_DLL_INIT_FAILED_LOGOFF
@@ -271,8 +259,8 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
}()
}
- time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces)
go checkForUpdates()
+ go driver.UninstallLegacyWintun() // We uninstall opportunistically here, so that we don't have to carry around the uninstaller code forever.
var sessionsPointer *windows.WTS_SESSION_INFO
var count uint32
@@ -281,12 +269,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
serviceError = services.ErrorEnumerateSessions
return
}
- sessions := *(*[]windows.WTS_SESSION_INFO)(unsafe.Pointer(&struct {
- addr *windows.WTS_SESSION_INFO
- len int
- cap int
- }{sessionsPointer, int(count), int(count)}))
- for _, session := range sessions {
+ for _, session := range unsafe.Slice(sessionsPointer, count) {
if session.State != windows.WTSActive && session.State != windows.WTSDisconnected {
continue
}
diff --git a/manager/tunneltracker.go b/manager/tunneltracker.go
index b32450a3..9003d445 100644
--- a/manager/tunneltracker.go
+++ b/manager/tunneltracker.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package manager
@@ -24,31 +24,10 @@ import (
"golang.zx2c4.com/wireguard/windows/services"
)
-func trackExistingTunnels() error {
- m, err := serviceManager()
- if err != nil {
- return err
- }
- names, err := conf.ListConfigNames()
- if err != nil {
- return err
- }
- for _, name := range names {
- serviceName, err := services.ServiceNameOfTunnel(name)
- if err != nil {
- continue
- }
- service, err := m.OpenService(serviceName)
- if err != nil {
- continue
- }
- go trackTunnelService(name, service)
- }
- return nil
-}
-
-var trackedTunnels = make(map[string]TunnelState)
-var trackedTunnelsLock = sync.Mutex{}
+var (
+ trackedTunnels = make(map[string]TunnelState)
+ trackedTunnelsLock = sync.Mutex{}
+)
func trackedTunnelsGlobalState() (state TunnelState) {
state = TunnelStopped
@@ -196,16 +175,17 @@ func trackService(service *mgr.Service, callback func(status uint32) bool) error
}
func trackTunnelService(tunnelName string, service *mgr.Service) {
- defer func() {
- service.Close()
- log.Printf("[%s] Tunnel service tracker finished", tunnelName)
- }()
-
trackedTunnelsLock.Lock()
if _, found := trackedTunnels[tunnelName]; found {
trackedTunnelsLock.Unlock()
+ service.Close()
return
}
+
+ defer func() {
+ service.Close()
+ log.Printf("[%s] Tunnel service tracker finished", tunnelName)
+ }()
trackedTunnels[tunnelName] = TunnelUnknown
trackedTunnelsLock.Unlock()
defer func() {
@@ -214,6 +194,15 @@ func trackTunnelService(tunnelName string, service *mgr.Service) {
trackedTunnelsLock.Unlock()
}()
+ for i := 0; i < 20; i++ {
+ if i > 0 {
+ time.Sleep(time.Second / 5)
+ }
+ if status, err := service.Query(); err != nil || status.State != svc.Stopped {
+ break
+ }
+ }
+
checkForDisabled := func() (shouldReturn bool) {
config, err := service.Config()
if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || (err != nil && config.StartType == windows.SERVICE_DISABLED) {
@@ -273,5 +262,82 @@ func trackTunnelService(tunnelName string, service *mgr.Service) {
IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring service, so stopping: %w", err))
service.Control(svc.Stop)
}
- disconnectTunnelServicePipe(tunnelName)
+}
+
+func trackExistingTunnels() error {
+ m, err := serviceManager()
+ if err != nil {
+ return err
+ }
+ names, err := conf.ListConfigNames()
+ if err != nil {
+ return err
+ }
+ for _, name := range names {
+ trackedTunnelsLock.Lock()
+ if _, found := trackedTunnels[name]; found {
+ trackedTunnelsLock.Unlock()
+ continue
+ }
+ trackedTunnelsLock.Unlock()
+ serviceName, err := conf.ServiceNameOfTunnel(name)
+ if err != nil {
+ continue
+ }
+ service, err := m.OpenService(serviceName)
+ if err != nil {
+ continue
+ }
+ go trackTunnelService(name, service)
+ }
+ return nil
+}
+
+var servicesSubscriptionWatcherCallbackPtr = windows.NewCallback(func(notification uint32, context uintptr) uintptr {
+ trackExistingTunnels()
+ return 0
+})
+
+func watchNewTunnelServices() error {
+ m, err := serviceManager()
+ if err != nil {
+ return err
+ }
+ var subscription uintptr
+ err = windows.SubscribeServiceChangeNotifications(m.Handle, windows.SC_EVENT_DATABASE_CHANGE, servicesSubscriptionWatcherCallbackPtr, 0, &subscription)
+ if err == nil {
+ // We probably could do:
+ // defer windows.UnsubscribeServiceChangeNotifications(subscription)
+ // and then terminate after some point, but instead we just let this go forever; it's process-lived.
+ return trackExistingTunnels()
+ }
+ if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) {
+ return err
+ }
+
+ // TODO: Below this line is Windows 7 compatibility code, which hopefully we can delete at some point.
+ go func() {
+ runtime.LockOSThread()
+ notifier := &windows.SERVICE_NOTIFY{
+ Version: windows.SERVICE_NOTIFY_STATUS_CHANGE,
+ NotifyCallback: serviceTrackerCallbackPtr,
+ }
+ for {
+ err := windows.NotifyServiceStatusChange(m.Handle, windows.SERVICE_NOTIFY_CREATED, notifier)
+ if err == nil {
+ windows.SleepEx(windows.INFINITE, true)
+ if notifier.ServiceNames != nil {
+ windows.LocalFree(windows.Handle(unsafe.Pointer(notifier.ServiceNames)))
+ notifier.ServiceNames = nil
+ }
+ trackExistingTunnels()
+ } else if err == windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING {
+ continue
+ } else {
+ time.Sleep(time.Second * 3)
+ trackExistingTunnels()
+ }
+ }
+ }()
+ return trackExistingTunnels()
}
diff --git a/manager/uiprocess.go b/manager/uiprocess.go
new file mode 100644
index 00000000..b33b1ad3
--- /dev/null
+++ b/manager/uiprocess.go
@@ -0,0 +1,103 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
+ */
+
+package manager
+
+import (
+ "errors"
+ "runtime"
+ "sync/atomic"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+type uiProcess struct {
+ handle uintptr
+}
+
+func launchUIProcess(executable string, args []string, workingDirectory string, handles []windows.Handle, token windows.Token) (*uiProcess, error) {
+ executable16, err := windows.UTF16PtrFromString(executable)
+ if err != nil {
+ return nil, err
+ }
+ args16, err := windows.UTF16PtrFromString(windows.ComposeCommandLine(args))
+ if err != nil {
+ return nil, err
+ }
+ workingDirectory16, err := windows.UTF16PtrFromString(workingDirectory)
+ if err != nil {
+ return nil, err
+ }
+ var environmentBlock *uint16
+ err = windows.CreateEnvironmentBlock(&environmentBlock, token, false)
+ if err != nil {
+ return nil, err
+ }
+ defer windows.DestroyEnvironmentBlock(environmentBlock)
+ attributeList, err := windows.NewProcThreadAttributeList(1)
+ if err != nil {
+ return nil, err
+ }
+ defer attributeList.Delete()
+ si := &windows.StartupInfoEx{
+ StartupInfo: windows.StartupInfo{Cb: uint32(unsafe.Sizeof(windows.StartupInfoEx{}))},
+ ProcThreadAttributeList: attributeList.List(),
+ }
+ if len(handles) == 0 {
+ handles = []windows.Handle{0}
+ }
+ attributeList.Update(windows.PROC_THREAD_ATTRIBUTE_HANDLE_LIST, unsafe.Pointer(&handles[0]), uintptr(len(handles))*unsafe.Sizeof(handles[0]))
+ pi := new(windows.ProcessInformation)
+ err = windows.CreateProcessAsUser(token, executable16, args16, nil, nil, true, windows.CREATE_DEFAULT_ERROR_MODE|windows.CREATE_UNICODE_ENVIRONMENT|windows.EXTENDED_STARTUPINFO_PRESENT, environmentBlock, workingDirectory16, &si.StartupInfo, pi)
+ if err != nil {
+ return nil, err
+ }
+ windows.CloseHandle(pi.Thread)
+ uiProc := &uiProcess{handle: uintptr(pi.Process)}
+ runtime.SetFinalizer(uiProc, (*uiProcess).release)
+ return uiProc, nil
+}
+
+func (p *uiProcess) release() error {
+ handle := windows.Handle(atomic.SwapUintptr(&p.handle, uintptr(windows.InvalidHandle)))
+ if handle == windows.InvalidHandle {
+ return nil
+ }
+ err := windows.CloseHandle(handle)
+ if err != nil {
+ return err
+ }
+ runtime.SetFinalizer(p, nil)
+ return nil
+}
+
+func (p *uiProcess) Wait() (uint32, error) {
+ handle := windows.Handle(atomic.LoadUintptr(&p.handle))
+ s, err := windows.WaitForSingleObject(handle, syscall.INFINITE)
+ switch s {
+ case windows.WAIT_OBJECT_0:
+ case windows.WAIT_FAILED:
+ return 0, err
+ default:
+ return 0, errors.New("unexpected result from WaitForSingleObject")
+ }
+ var exitCode uint32
+ err = windows.GetExitCodeProcess(handle, &exitCode)
+ if err != nil {
+ return 0, err
+ }
+ p.release()
+ return exitCode, nil
+}
+
+func (p *uiProcess) Kill() error {
+ handle := windows.Handle(atomic.LoadUintptr(&p.handle))
+ if handle == windows.InvalidHandle {
+ return nil
+ }
+ return windows.TerminateProcess(handle, 1)
+}
diff --git a/manager/updatestate.go b/manager/updatestate.go
index 069e9b8a..d5a19c8d 100644
--- a/manager/updatestate.go
+++ b/manager/updatestate.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package manager
@@ -8,11 +8,16 @@ package manager
import (
"log"
"time"
+ _ "unsafe"
+ "golang.zx2c4.com/wireguard/windows/services"
"golang.zx2c4.com/wireguard/windows/updater"
"golang.zx2c4.com/wireguard/windows/version"
)
+//go:linkname fastrandn runtime.fastrandn
+func fastrandn(n uint32) uint32
+
type UpdateState uint32
const (
@@ -23,6 +28,10 @@ const (
var updateState = UpdateStateUnknown
+func jitterSleep(min, max time.Duration) {
+ time.Sleep(min + time.Millisecond*time.Duration(fastrandn(uint32((max-min+1)/time.Millisecond))))
+}
+
func checkForUpdates() {
if !version.IsRunningOfficialVersion() {
log.Println("Build is not official, so updates are disabled")
@@ -30,26 +39,27 @@ func checkForUpdates() {
IPCServerNotifyUpdateFound(updateState)
return
}
-
- first := true
+ if services.StartedAtBoot() {
+ jitterSleep(time.Minute*2, time.Minute*5)
+ }
+ noError, didNotify := true, false
for {
update, err := updater.CheckForUpdate()
- if err == nil && update != nil {
+ if err == nil && update != nil && !didNotify {
log.Println("An update is available")
updateState = UpdateStateFoundUpdate
IPCServerNotifyUpdateFound(updateState)
- return
- }
- if err != nil {
+ didNotify = true
+ } else if err != nil && !didNotify {
log.Printf("Update checker: %v", err)
- if first {
- time.Sleep(time.Minute * 4)
- first = false
+ if noError {
+ jitterSleep(time.Minute*4, time.Minute*6)
+ noError = false
} else {
- time.Sleep(time.Minute * 25)
+ jitterSleep(time.Minute*25, time.Minute*30)
}
} else {
- time.Sleep(time.Hour)
+ jitterSleep(time.Hour-time.Minute*3, time.Hour+time.Minute*3)
}
}
}