aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/manager
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--manager/install.go59
-rw-r--r--manager/interfacecleanup.go58
-rw-r--r--manager/ipc_client.go17
-rw-r--r--manager/ipc_driver.go61
-rw-r--r--manager/ipc_pipe.go55
-rw-r--r--manager/ipc_server.go134
-rw-r--r--manager/service.go157
-rw-r--r--manager/tunneltracker.go357
-rw-r--r--manager/uiprocess.go103
-rw-r--r--manager/updatestate.go36
10 files changed, 671 insertions, 366 deletions
diff --git a/manager/install.go b/manager/install.go
index f84a96ae..44a744cf 100644
--- a/manager/install.go
+++ b/manager/install.go
@@ -1,13 +1,15 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package manager
import (
"errors"
+ "log"
"os"
+ "strings"
"time"
"golang.org/x/sys/windows"
@@ -15,7 +17,6 @@ import (
"golang.org/x/sys/windows/svc/mgr"
"golang.zx2c4.com/wireguard/windows/conf"
- "golang.zx2c4.com/wireguard/windows/services"
)
var cachedServiceManager *mgr.Mgr
@@ -56,6 +57,12 @@ func InstallManager() error {
}
if status.State != svc.Stopped {
service.Close()
+ if status.State == svc.StartPending {
+ // We were *just* started by something else, so return success here, assuming the other program
+ // starting this does the right thing. This can happen when, e.g., the updater relaunches the
+ // manager service and then invokes wireguard.exe to raise the UI.
+ return nil
+ }
return ErrManagerAlreadyRunning
}
err = service.Delete()
@@ -122,7 +129,7 @@ func InstallTunnel(configPath string) error {
return err
}
- serviceName, err := services.ServiceNameOfTunnel(name)
+ serviceName, err := conf.ServiceNameOfTunnel(name)
if err != nil {
return err
}
@@ -156,7 +163,7 @@ func InstallTunnel(configPath string) error {
ServiceType: windows.SERVICE_WIN32_OWN_PROCESS,
StartType: mgr.StartAutomatic,
ErrorControl: mgr.ErrorNormal,
- Dependencies: []string{"Nsi"},
+ Dependencies: []string{"Nsi", "TcpIp"},
DisplayName: "WireGuard Tunnel: " + name,
SidType: windows.SERVICE_SID_TYPE_UNRESTRICTED,
}
@@ -175,7 +182,7 @@ func UninstallTunnel(name string) error {
if err != nil {
return err
}
- serviceName, err := services.ServiceNameOfTunnel(name)
+ serviceName, err := conf.ServiceNameOfTunnel(name)
if err != nil {
return err
}
@@ -191,3 +198,45 @@ func UninstallTunnel(name string) error {
}
return err2
}
+
+func changeTunnelServiceConfigFilePath(name, oldPath, newPath string) {
+ var err error
+ defer func() {
+ if err != nil {
+ log.Printf("Unable to change tunnel service command line argument from %#q to %#q: %v", oldPath, newPath, err)
+ }
+ }()
+ m, err := serviceManager()
+ if err != nil {
+ return
+ }
+ serviceName, err := conf.ServiceNameOfTunnel(name)
+ if err != nil {
+ return
+ }
+ service, err := m.OpenService(serviceName)
+ if err == windows.ERROR_SERVICE_DOES_NOT_EXIST {
+ err = nil
+ return
+ } else if err != nil {
+ return
+ }
+ defer service.Close()
+ config, err := service.Config()
+ if err != nil {
+ return
+ }
+ exePath, err := os.Executable()
+ if err != nil {
+ return
+ }
+ args, err := windows.DecomposeCommandLine(config.BinaryPathName)
+ if err != nil || len(args) != 3 ||
+ !strings.EqualFold(args[0], exePath) || args[1] != "/tunnelservice" || !strings.EqualFold(args[2], oldPath) {
+ err = nil
+ return
+ }
+ args[2] = newPath
+ config.BinaryPathName = windows.ComposeCommandLine(args)
+ err = service.UpdateConfig(config)
+}
diff --git a/manager/interfacecleanup.go b/manager/interfacecleanup.go
deleted file mode 100644
index f5d9ef48..00000000
--- a/manager/interfacecleanup.go
+++ /dev/null
@@ -1,58 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package manager
-
-import (
- "log"
-
- "golang.org/x/sys/windows"
- "golang.org/x/sys/windows/svc"
- "golang.org/x/sys/windows/svc/mgr"
- "golang.zx2c4.com/wireguard/tun/wintun"
-
- "golang.zx2c4.com/wireguard/tun"
- "golang.zx2c4.com/wireguard/windows/services"
-)
-
-func cleanupStaleWintunInterfaces() {
- defer printPanic()
-
- m, err := mgr.Connect()
- if err != nil {
- return
- }
- defer m.Disconnect()
-
- tun.WintunPool.DeleteMatchingInterfaces(func(wintun *wintun.Interface) bool {
- interfaceName, err := wintun.Name()
- if err != nil {
- log.Printf("Removing Wintun interface %s because determining interface name failed: %v", wintun.GUID().String(), err)
- return true
- }
- serviceName, err := services.ServiceNameOfTunnel(interfaceName)
- if err != nil {
- log.Printf("Removing Wintun interface ‘%s’ because determining tunnel service name failed: %v", interfaceName, err)
- return true
- }
- service, err := m.OpenService(serviceName)
- if err == windows.ERROR_SERVICE_DOES_NOT_EXIST {
- log.Printf("Removing Wintun interface ‘%s’ because no service for it exists", interfaceName)
- return true
- } else if err != nil {
- return false
- }
- defer service.Close()
- status, err := service.Query()
- if err != nil {
- return false
- }
- if status.State == svc.Stopped {
- log.Printf("Removing Wintun interface ‘%s’ because its service is stopped", interfaceName)
- return true
- }
- return false
- })
-}
diff --git a/manager/ipc_client.go b/manager/ipc_client.go
index c8b2f852..8c9c4c04 100644
--- a/manager/ipc_client.go
+++ b/manager/ipc_client.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package manager
@@ -64,7 +64,7 @@ var (
)
type TunnelChangeCallback struct {
- cb func(tunnel *Tunnel, state TunnelState, globalState TunnelState, err error)
+ cb func(tunnel *Tunnel, state, globalState TunnelState, err error)
}
var tunnelChangeCallbacks = make(map[*TunnelChangeCallback]bool)
@@ -93,7 +93,7 @@ type UpdateProgressCallback struct {
var updateProgressCallbacks = make(map[*UpdateProgressCallback]bool)
-func InitializeIPCClient(reader *os.File, writer *os.File, events *os.File) {
+func InitializeIPCClient(reader, writer, events *os.File) {
rpcDecoder = gob.NewDecoder(reader)
rpcEncoder = gob.NewEncoder(writer)
go func() {
@@ -431,43 +431,52 @@ func IPCClientUpdate() error {
return rpcEncoder.Encode(UpdateMethodType)
}
-func IPCClientRegisterTunnelChange(cb func(tunnel *Tunnel, state TunnelState, globalState TunnelState, err error)) *TunnelChangeCallback {
+func IPCClientRegisterTunnelChange(cb func(tunnel *Tunnel, state, globalState TunnelState, err error)) *TunnelChangeCallback {
s := &TunnelChangeCallback{cb}
tunnelChangeCallbacks[s] = true
return s
}
+
func (cb *TunnelChangeCallback) Unregister() {
delete(tunnelChangeCallbacks, cb)
}
+
func IPCClientRegisterTunnelsChange(cb func()) *TunnelsChangeCallback {
s := &TunnelsChangeCallback{cb}
tunnelsChangeCallbacks[s] = true
return s
}
+
func (cb *TunnelsChangeCallback) Unregister() {
delete(tunnelsChangeCallbacks, cb)
}
+
func IPCClientRegisterManagerStopping(cb func()) *ManagerStoppingCallback {
s := &ManagerStoppingCallback{cb}
managerStoppingCallbacks[s] = true
return s
}
+
func (cb *ManagerStoppingCallback) Unregister() {
delete(managerStoppingCallbacks, cb)
}
+
func IPCClientRegisterUpdateFound(cb func(updateState UpdateState)) *UpdateFoundCallback {
s := &UpdateFoundCallback{cb}
updateFoundCallbacks[s] = true
return s
}
+
func (cb *UpdateFoundCallback) Unregister() {
delete(updateFoundCallbacks, cb)
}
+
func IPCClientRegisterUpdateProgress(cb func(dp updater.DownloadProgress)) *UpdateProgressCallback {
s := &UpdateProgressCallback{cb}
updateProgressCallbacks[s] = true
return s
}
+
func (cb *UpdateProgressCallback) Unregister() {
delete(updateProgressCallbacks, cb)
}
diff --git a/manager/ipc_driver.go b/manager/ipc_driver.go
new file mode 100644
index 00000000..6cb43c38
--- /dev/null
+++ b/manager/ipc_driver.go
@@ -0,0 +1,61 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
+ */
+
+package manager
+
+import (
+ "sync"
+
+ "golang.zx2c4.com/wireguard/windows/driver"
+)
+
+type lockedDriverAdapter struct {
+ *driver.Adapter
+ sync.Mutex
+}
+
+var (
+ driverAdapters = make(map[string]*lockedDriverAdapter)
+ driverAdaptersLock sync.RWMutex
+)
+
+func findDriverAdapter(tunnelName string) (*lockedDriverAdapter, error) {
+ driverAdaptersLock.RLock()
+ driverAdapter, ok := driverAdapters[tunnelName]
+ if ok {
+ driverAdapter.Lock()
+ driverAdaptersLock.RUnlock()
+ return driverAdapter, nil
+ }
+ driverAdaptersLock.RUnlock()
+ driverAdaptersLock.Lock()
+ defer driverAdaptersLock.Unlock()
+ driverAdapter, ok = driverAdapters[tunnelName]
+ if ok {
+ driverAdapter.Lock()
+ return driverAdapter, nil
+ }
+ driverAdapter = &lockedDriverAdapter{}
+ var err error
+ driverAdapter.Adapter, err = driver.OpenAdapter(tunnelName)
+ if err != nil {
+ return nil, err
+ }
+ driverAdapters[tunnelName] = driverAdapter
+ driverAdapter.Lock()
+ return driverAdapter, nil
+}
+
+func releaseDriverAdapter(tunnelName string) {
+ driverAdaptersLock.Lock()
+ defer driverAdaptersLock.Unlock()
+ driverAdapter, ok := driverAdapters[tunnelName]
+ if !ok {
+ return
+ }
+ driverAdapter.Lock()
+ delete(driverAdapters, tunnelName)
+ driverAdapter.Unlock()
+}
diff --git a/manager/ipc_pipe.go b/manager/ipc_pipe.go
deleted file mode 100644
index d4214ac0..00000000
--- a/manager/ipc_pipe.go
+++ /dev/null
@@ -1,55 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package manager
-
-import (
- "os"
- "strconv"
-
- "golang.org/x/sys/windows"
-)
-
-func makeInheritableAndGetStr(f *os.File) (str string, err error) {
- sc, err := f.SyscallConn()
- if err != nil {
- return
- }
- err2 := sc.Control(func(fd uintptr) {
- err = windows.SetHandleInformation(windows.Handle(fd), windows.HANDLE_FLAG_INHERIT, windows.HANDLE_FLAG_INHERIT)
- str = strconv.FormatUint(uint64(fd), 10)
- })
- if err2 != nil {
- err = err2
- }
- return
-}
-
-func inheritableEvents() (ourEvents *os.File, theirEvents *os.File, theirEventStr string, err error) {
- theirEvents, ourEvents, err = os.Pipe()
- if err != nil {
- return
- }
- theirEventStr, err = makeInheritableAndGetStr(theirEvents)
- return
-}
-
-func inheritableSocketpairEmulation() (ourReader *os.File, theirReader *os.File, theirReaderStr string, ourWriter *os.File, theirWriter *os.File, theirWriterStr string, err error) {
- ourReader, theirWriter, err = os.Pipe()
- if err != nil {
- return
- }
- theirWriterStr, err = makeInheritableAndGetStr(theirWriter)
- if err != nil {
- return
- }
-
- theirReader, ourWriter, err = os.Pipe()
- if err != nil {
- return
- }
- theirReaderStr, err = makeInheritableAndGetStr(theirReader)
- return
-}
diff --git a/manager/ipc_server.go b/manager/ipc_server.go
index 1367c2e9..e21ffaf0 100644
--- a/manager/ipc_server.go
+++ b/manager/ipc_server.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package manager
@@ -10,7 +10,6 @@ import (
"encoding/gob"
"fmt"
"io"
- "io/ioutil"
"log"
"os"
"sync"
@@ -20,64 +19,73 @@ import (
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
- "golang.zx2c4.com/wireguard/ipc/winpipe"
-
"golang.zx2c4.com/wireguard/windows/conf"
- "golang.zx2c4.com/wireguard/windows/services"
"golang.zx2c4.com/wireguard/windows/updater"
)
-var managerServices = make(map[*ManagerService]bool)
-var managerServicesLock sync.RWMutex
-var haveQuit uint32
-var quitManagersChan = make(chan struct{}, 1)
+var (
+ managerServices = make(map[*ManagerService]bool)
+ managerServicesLock sync.RWMutex
+ haveQuit uint32
+ quitManagersChan = make(chan struct{}, 1)
+)
type ManagerService struct {
events *os.File
+ eventLock sync.Mutex
elevatedToken windows.Token
}
func (s *ManagerService) StoredConfig(tunnelName string) (*conf.Config, error) {
- return conf.LoadFromName(tunnelName)
-}
-
-func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) {
- storedConfig, err := conf.LoadFromName(tunnelName)
+ conf, err := conf.LoadFromName(tunnelName)
if err != nil {
return nil, err
}
- pipePath, err := services.PipePathOfTunnel(storedConfig.Name)
- if err != nil {
- return nil, err
+ if s.elevatedToken == 0 {
+ conf.Redact()
}
- localSystem, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid)
+ return conf, nil
+}
+
+func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) {
+ storedConfig, err := conf.LoadFromName(tunnelName)
if err != nil {
return nil, err
}
- pipe, err := winpipe.DialPipe(pipePath, nil, localSystem)
+ driverAdapter, err := findDriverAdapter(tunnelName)
if err != nil {
return nil, err
}
- defer pipe.Close()
- pipe.SetWriteDeadline(time.Now().Add(time.Second * 2))
- _, err = pipe.Write([]byte("get=1\n\n"))
+ runtimeConfig, err := driverAdapter.Configuration()
if err != nil {
+ driverAdapter.Unlock()
+ releaseDriverAdapter(tunnelName)
return nil, err
}
- pipe.SetReadDeadline(time.Now().Add(time.Second * 2))
- resp, err := ioutil.ReadAll(pipe)
- if err != nil {
- return nil, err
+ conf := conf.FromDriverConfiguration(runtimeConfig, storedConfig)
+ driverAdapter.Unlock()
+ if s.elevatedToken == 0 {
+ conf.Redact()
}
- return conf.FromUAPI(string(resp), storedConfig)
+ return conf, nil
}
func (s *ManagerService) Start(tunnelName string) error {
- // For now, enforce only one tunnel at a time. Later we'll remove this silly restriction.
+ c, err := conf.LoadFromName(tunnelName)
+ if err != nil {
+ return err
+ }
+
+ // Figure out which tunnels have intersecting addresses/routes and stop those.
trackedTunnelsLock.Lock()
tt := make([]string, 0, len(trackedTunnels))
var inTransition string
for t, state := range trackedTunnels {
+ c2, err := conf.LoadFromName(t)
+ if err != nil || !c.IntersectsWith(c2) {
+ // If we can't get the config, assume it doesn't intersect.
+ continue
+ }
tt = append(tt, t)
if len(t) > 0 && (state == TunnelStarting || state == TunnelUnknown) {
inTransition = t
@@ -88,6 +96,8 @@ func (s *ManagerService) Start(tunnelName string) error {
if len(inTransition) != 0 {
return fmt.Errorf("Please allow the tunnel ‘%s’ to finish activating", inTransition)
}
+
+ // Stop those intersecting tunnels asynchronously.
go func() {
for _, t := range tt {
s.Stop(t)
@@ -101,13 +111,7 @@ func (s *ManagerService) Start(tunnelName string) error {
}
}
}()
- time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces)
-
- // After that process is started -- it's somewhat asynchronous -- we install the new one.
- c, err := conf.LoadFromName(tunnelName)
- if err != nil {
- return err
- }
+ // After the stop process has begun, but before it's finished, we install the new one.
path, err := c.Path()
if err != nil {
return err
@@ -116,8 +120,6 @@ func (s *ManagerService) Start(tunnelName string) error {
}
func (s *ManagerService) Stop(tunnelName string) error {
- time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces)
-
err := UninstallTunnel(tunnelName)
if err == windows.ERROR_SERVICE_DOES_NOT_EXIST {
_, notExistsError := conf.LoadFromName(tunnelName)
@@ -129,7 +131,7 @@ func (s *ManagerService) Stop(tunnelName string) error {
}
func (s *ManagerService) WaitForStop(tunnelName string) error {
- serviceName, err := services.ServiceNameOfTunnel(tunnelName)
+ serviceName, err := conf.ServiceNameOfTunnel(tunnelName)
if err != nil {
return err
}
@@ -149,6 +151,9 @@ func (s *ManagerService) WaitForStop(tunnelName string) error {
}
func (s *ManagerService) Delete(tunnelName string) error {
+ if s.elevatedToken == 0 {
+ return windows.ERROR_ACCESS_DENIED
+ }
err := s.Stop(tunnelName)
if err != nil {
return err
@@ -157,7 +162,7 @@ func (s *ManagerService) Delete(tunnelName string) error {
}
func (s *ManagerService) State(tunnelName string) (TunnelState, error) {
- serviceName, err := services.ServiceNameOfTunnel(tunnelName)
+ serviceName, err := conf.ServiceNameOfTunnel(tunnelName)
if err != nil {
return 0, err
}
@@ -193,7 +198,10 @@ func (s *ManagerService) GlobalState() TunnelState {
}
func (s *ManagerService) Create(tunnelConfig *conf.Config) (*Tunnel, error) {
- err := tunnelConfig.Save()
+ if s.elevatedToken == 0 {
+ return nil, windows.ERROR_ACCESS_DENIED
+ }
+ err := tunnelConfig.Save(true)
if err != nil {
return nil, err
}
@@ -209,19 +217,25 @@ func (s *ManagerService) Tunnels() ([]Tunnel, error) {
}
tunnels := make([]Tunnel, len(names))
for i := 0; i < len(tunnels); i++ {
- (tunnels)[i].Name = names[i]
+ tunnels[i].Name = names[i]
}
return tunnels, nil
// TODO: account for running ones that aren't in the configuration store somehow
}
func (s *ManagerService) Quit(stopTunnelsOnQuit bool) (alreadyQuit bool, err error) {
+ if s.elevatedToken == 0 {
+ return false, windows.ERROR_ACCESS_DENIED
+ }
if !atomic.CompareAndSwapUint32(&haveQuit, 0, 1) {
return true, nil
}
// Work around potential race condition of delivering messages to the wrong process by removing from notifications.
managerServicesLock.Lock()
+ s.eventLock.Lock()
+ s.events = nil
+ s.eventLock.Unlock()
delete(managerServices, s)
managerServicesLock.Unlock()
@@ -244,6 +258,9 @@ func (s *ManagerService) UpdateState() UpdateState {
}
func (s *ManagerService) Update() {
+ if s.elevatedToken == 0 {
+ return
+ }
progress := updater.DownloadVerifyAndExecute(uintptr(s.elevatedToken))
go func() {
for {
@@ -374,6 +391,9 @@ func (s *ManagerService) ServeConn(reader io.Reader, writer io.Writer) {
return
}
tunnel, retErr := s.Create(&config)
+ if tunnel == nil {
+ tunnel = &Tunnel{}
+ }
err = encoder.Encode(tunnel)
if err != nil {
return
@@ -421,26 +441,27 @@ func (s *ManagerService) ServeConn(reader io.Reader, writer io.Writer) {
}
}
-func IPCServerListen(reader *os.File, writer *os.File, events *os.File, elevatedToken windows.Token) {
+func IPCServerListen(reader, writer, events *os.File, elevatedToken windows.Token) {
service := &ManagerService{
events: events,
elevatedToken: elevatedToken,
}
go func() {
- defer printPanic()
managerServicesLock.Lock()
managerServices[service] = true
managerServicesLock.Unlock()
service.ServeConn(reader, writer)
managerServicesLock.Lock()
+ service.eventLock.Lock()
+ service.events = nil
+ service.eventLock.Unlock()
delete(managerServices, service)
managerServicesLock.Unlock()
-
}()
}
-func notifyAll(notificationType NotificationType, ifaces ...interface{}) {
+func notifyAll(notificationType NotificationType, adminOnly bool, ifaces ...any) {
if len(managerServices) == 0 {
return
}
@@ -460,8 +481,17 @@ func notifyAll(notificationType NotificationType, ifaces ...interface{}) {
managerServicesLock.RLock()
for m := range managerServices {
- m.events.SetWriteDeadline(time.Now().Add(time.Second))
- m.events.Write(buf.Bytes())
+ if m.elevatedToken == 0 && adminOnly {
+ continue
+ }
+ go func(m *ManagerService) {
+ m.eventLock.Lock()
+ defer m.eventLock.Unlock()
+ if m.events != nil {
+ m.events.SetWriteDeadline(time.Now().Add(time.Second))
+ m.events.Write(buf.Bytes())
+ }
+ }(m)
}
managerServicesLock.RUnlock()
}
@@ -474,22 +504,22 @@ func errToString(err error) string {
}
func IPCServerNotifyTunnelChange(name string, state TunnelState, err error) {
- notifyAll(TunnelChangeNotificationType, name, state, trackedTunnelsGlobalState(), errToString(err))
+ notifyAll(TunnelChangeNotificationType, false, name, state, trackedTunnelsGlobalState(), errToString(err))
}
func IPCServerNotifyTunnelsChange() {
- notifyAll(TunnelsChangeNotificationType)
+ notifyAll(TunnelsChangeNotificationType, false)
}
func IPCServerNotifyUpdateFound(state UpdateState) {
- notifyAll(UpdateFoundNotificationType, state)
+ notifyAll(UpdateFoundNotificationType, false, state)
}
func IPCServerNotifyUpdateProgress(dp updater.DownloadProgress) {
- notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, errToString(dp.Error), dp.Complete)
+ notifyAll(UpdateProgressNotificationType, true, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, errToString(dp.Error), dp.Complete)
}
func IPCServerNotifyManagerStopping() {
- notifyAll(ManagerStoppingNotificationType)
+ notifyAll(ManagerStoppingNotificationType, false)
time.Sleep(time.Millisecond * 200)
}
diff --git a/manager/service.go b/manager/service.go
index 6c3b039b..47e20d45 100644
--- a/manager/service.go
+++ b/manager/service.go
@@ -1,46 +1,32 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package manager
import (
"errors"
- "fmt"
"log"
"os"
"runtime"
- "runtime/debug"
- "strings"
+ "strconv"
"sync"
- "syscall"
"time"
"unsafe"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
+ "golang.zx2c4.com/wireguard/windows/driver"
"golang.zx2c4.com/wireguard/windows/conf"
"golang.zx2c4.com/wireguard/windows/elevate"
"golang.zx2c4.com/wireguard/windows/ringlogger"
"golang.zx2c4.com/wireguard/windows/services"
- "golang.zx2c4.com/wireguard/windows/version"
)
type managerService struct{}
-func printPanic() {
- if x := recover(); x != nil {
- for _, line := range append([]string{fmt.Sprint(x)}, strings.Split(string(debug.Stack()), "\n")...) {
- if len(strings.TrimSpace(line)) > 0 {
- log.Println(line)
- }
- }
- panic(x)
- }
-}
-
func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) {
changes <- svc.Status{State: svc.StartPending}
@@ -56,40 +42,40 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
changes <- svc.Status{State: svc.StopPending}
}()
- err = ringlogger.InitGlobalLogger("MGR")
+ var logFile string
+ logFile, err = conf.LogFile(true)
if err != nil {
serviceError = services.ErrorRingloggerOpen
return
}
- defer printPanic()
-
- log.Println("Starting", version.UserAgent())
-
- path, err := os.Executable()
+ err = ringlogger.InitGlobalLogger(logFile, "MGR")
if err != nil {
- serviceError = services.ErrorDetermineExecutablePath
+ serviceError = services.ErrorRingloggerOpen
return
}
- devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0)
+ services.PrintStarting()
+
+ path, err := os.Executable()
if err != nil {
- serviceError = services.ErrorOpenNULFile
+ serviceError = services.ErrorDetermineExecutablePath
return
}
- err = trackExistingTunnels()
+ err = watchNewTunnelServices()
if err != nil {
serviceError = services.ErrorTrackTunnels
return
}
- conf.RegisterStoreChangeCallback(func() { conf.MigrateUnencryptedConfigs() }) // Ignore return value for now, but could be useful later.
+ conf.RegisterStoreChangeCallback(func() { conf.MigrateUnencryptedConfigs(changeTunnelServiceConfigFilePath) })
conf.RegisterStoreChangeCallback(IPCServerNotifyTunnelsChange)
- procs := make(map[uint32]*os.Process)
+ procs := make(map[uint32]*uiProcess)
aliveSessions := make(map[uint32]bool)
procsLock := sync.Mutex{}
stoppingManager := false
+ operatorGroupSid, _ := windows.CreateWellKnownSid(windows.WinBuiltinNetworkConfigurationOperatorsSid)
startProcess := func(session uint32) {
defer func() {
@@ -104,7 +90,24 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
if err != nil {
return
}
- if !elevate.TokenIsElevatedOrElevatable(userToken) {
+ isAdmin := elevate.TokenIsElevatedOrElevatable(userToken)
+ isOperator := false
+ if !isAdmin && conf.AdminBool("LimitedOperatorUI") && operatorGroupSid != nil {
+ linkedToken, err := userToken.GetLinkedToken()
+ var impersonationToken windows.Token
+ if err == nil {
+ err = windows.DuplicateTokenEx(linkedToken, windows.TOKEN_QUERY, nil, windows.SecurityImpersonation, windows.TokenImpersonation, &impersonationToken)
+ linkedToken.Close()
+ } else {
+ err = windows.DuplicateTokenEx(userToken, windows.TOKEN_QUERY, nil, windows.SecurityImpersonation, windows.TokenImpersonation, &impersonationToken)
+ }
+ if err == nil {
+ isOperator, err = impersonationToken.IsMember(operatorGroupSid)
+ isOperator = isOperator && err == nil
+ impersonationToken.Close()
+ }
+ }
+ if !isAdmin && !isOperator {
userToken.Close()
return
}
@@ -125,23 +128,28 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
return
}
userProfileDirectory, _ := userToken.GetUserProfileDirectory()
- var elevatedToken windows.Token
- if userToken.IsElevated() {
- elevatedToken = userToken
- } else {
- elevatedToken, err = userToken.GetLinkedToken()
- userToken.Close()
- if err != nil {
- log.Printf("Unable to elevate token: %v", err)
- return
- }
- if !elevatedToken.IsElevated() {
- elevatedToken.Close()
- log.Println("Linked token is not elevated")
- return
+ var elevatedToken, runToken windows.Token
+ if isAdmin {
+ if userToken.IsElevated() {
+ elevatedToken = userToken
+ } else {
+ elevatedToken, err = userToken.GetLinkedToken()
+ userToken.Close()
+ if err != nil {
+ log.Printf("Unable to elevate token: %v", err)
+ return
+ }
+ if !elevatedToken.IsElevated() {
+ elevatedToken.Close()
+ log.Println("Linked token is not elevated")
+ return
+ }
}
+ runToken = elevatedToken
+ } else {
+ runToken = userToken
}
- defer elevatedToken.Close()
+ defer runToken.Close()
userToken = 0
first := true
for {
@@ -162,39 +170,45 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
first = false
}
- // TODO: we lock the OS thread so that these inheritable handles don't escape into other processes that
- // might be running in parallel Go routines. But the Go runtime is strange and who knows what's really
- // happening with these or what is inherited. We need to do some analysis to be certain of what's going on.
- runtime.LockOSThread()
- ourReader, theirReader, theirReaderStr, ourWriter, theirWriter, theirWriterStr, err := inheritableSocketpairEmulation()
+ ourReader, theirWriter, err := os.Pipe()
+ if err != nil {
+ log.Printf("Unable to create pipe: %v", err)
+ return
+ }
+ theirReader, ourWriter, err := os.Pipe()
if err != nil {
- log.Printf("Unable to create two inheritable RPC pipes: %v", err)
+ log.Printf("Unable to create pipe: %v", err)
return
}
- ourEvents, theirEvents, theirEventStr, err := inheritableEvents()
+ theirEvents, ourEvents, err := os.Pipe()
if err != nil {
- log.Printf("Unable to create one inheritable events pipe: %v", err)
+ log.Printf("Unable to create pipe: %v", err)
return
}
IPCServerListen(ourReader, ourWriter, ourEvents, elevatedToken)
- theirLogMapping, theirLogMappingHandle, err := ringlogger.Global.ExportInheritableMappingHandleStr()
+ theirLogMapping, err := ringlogger.Global.ExportInheritableMappingHandle()
if err != nil {
log.Printf("Unable to export inheritable mapping handle for logging: %v", err)
return
}
log.Printf("Starting UI process for user ‘%s@%s’ for session %d", username, domain, session)
- attr := &os.ProcAttr{
- Sys: &syscall.SysProcAttr{
- Token: syscall.Token(elevatedToken),
- },
- Files: []*os.File{devNull, devNull, devNull},
- Dir: userProfileDirectory,
- }
procsLock.Lock()
- var proc *os.Process
+ var proc *uiProcess
if alive := aliveSessions[session]; alive {
- proc, err = os.StartProcess(path, []string{path, "/ui", theirReaderStr, theirWriterStr, theirEventStr, theirLogMapping}, attr)
+ proc, err = launchUIProcess(path, []string{
+ path,
+ "/ui",
+ strconv.FormatUint(uint64(theirReader.Fd()), 10),
+ strconv.FormatUint(uint64(theirWriter.Fd()), 10),
+ strconv.FormatUint(uint64(theirEvents.Fd()), 10),
+ strconv.FormatUint(uint64(theirLogMapping), 10),
+ }, userProfileDirectory, []windows.Handle{
+ windows.Handle(theirReader.Fd()),
+ windows.Handle(theirWriter.Fd()),
+ windows.Handle(theirEvents.Fd()),
+ theirLogMapping,
+ }, runToken)
} else {
err = errors.New("Session has logged out")
}
@@ -202,8 +216,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
theirReader.Close()
theirWriter.Close()
theirEvents.Close()
- windows.Close(theirLogMappingHandle)
- runtime.UnlockOSThread()
+ windows.CloseHandle(theirLogMapping)
if err != nil {
ourReader.Close()
ourWriter.Close()
@@ -217,9 +230,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
procsLock.Unlock()
sessionIsDead := false
- processStatus, err := proc.Wait()
- if err == nil {
- exitCode := processStatus.Sys().(syscall.WaitStatus).ExitCode
+ if exitCode, err := proc.Wait(); err == nil {
log.Printf("Exited UI process for user '%s@%s' for session %d with status %x", username, domain, session, exitCode)
const STATUS_DLL_INIT_FAILED_LOGOFF = 0xC000026B
sessionIsDead = exitCode == STATUS_DLL_INIT_FAILED_LOGOFF
@@ -243,14 +254,13 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
goStartProcess := func(session uint32) {
procsGroup.Add(1)
go func() {
- defer printPanic()
startProcess(session)
procsGroup.Done()
}()
}
- time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces)
go checkForUpdates()
+ go driver.UninstallLegacyWintun() // We uninstall opportunistically here, so that we don't have to carry around the uninstaller code forever.
var sessionsPointer *windows.WTS_SESSION_INFO
var count uint32
@@ -259,12 +269,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
serviceError = services.ErrorEnumerateSessions
return
}
- sessions := *(*[]windows.WTS_SESSION_INFO)(unsafe.Pointer(&struct {
- addr *windows.WTS_SESSION_INFO
- len int
- cap int
- }{sessionsPointer, int(count), int(count)}))
- for _, session := range sessions {
+ for _, session := range unsafe.Slice(sessionsPointer, count) {
if session.State != windows.WTSActive && session.State != windows.WTSDisconnected {
continue
}
diff --git a/manager/tunneltracker.go b/manager/tunneltracker.go
index 0f222aac..9003d445 100644
--- a/manager/tunneltracker.go
+++ b/manager/tunneltracker.go
@@ -1,17 +1,20 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package manager
import (
+ "errors"
"fmt"
"log"
"runtime"
"sync"
+ "sync/atomic"
"syscall"
"time"
+ "unsafe"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
@@ -21,50 +24,10 @@ import (
"golang.zx2c4.com/wireguard/windows/services"
)
-func trackExistingTunnels() error {
- m, err := serviceManager()
- if err != nil {
- return err
- }
- names, err := conf.ListConfigNames()
- if err != nil {
- return err
- }
- for _, name := range names {
- serviceName, err := services.ServiceNameOfTunnel(name)
- if err != nil {
- continue
- }
- service, err := m.OpenService(serviceName)
- if err != nil {
- continue
- }
- go trackTunnelService(name, service)
- }
- return nil
-}
-
-var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *windows.SERVICE_NOTIFY) uintptr {
- return 0
-})
-
-var trackedTunnels = make(map[string]TunnelState)
-var trackedTunnelsLock = sync.Mutex{}
-
-func svcStateToTunState(s svc.State) TunnelState {
- switch s {
- case svc.StartPending:
- return TunnelStarting
- case svc.Running:
- return TunnelStarted
- case svc.StopPending:
- return TunnelStopping
- case svc.Stopped:
- return TunnelStopped
- default:
- return TunnelUnknown
- }
-}
+var (
+ trackedTunnels = make(map[string]TunnelState)
+ trackedTunnelsLock = sync.Mutex{}
+)
func trackedTunnelsGlobalState() (state TunnelState) {
state = TunnelStopped
@@ -82,50 +45,95 @@ func trackedTunnelsGlobalState() (state TunnelState) {
return
}
-func trackTunnelService(tunnelName string, service *mgr.Service) {
- defer func() {
- service.Close()
- log.Printf("[%s] Tunnel service tracker finished", tunnelName)
- }()
+var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *windows.SERVICE_NOTIFY) uintptr {
+ return 0
+})
- trackedTunnelsLock.Lock()
- if _, found := trackedTunnels[tunnelName]; found {
- trackedTunnelsLock.Unlock()
- return
+type serviceSubscriptionState struct {
+ service *mgr.Service
+ cb func(status uint32) bool
+ done sync.WaitGroup
+ once uint32
+}
+
+var serviceSubscriptionCallbackPtr = windows.NewCallback(func(notification uint32, context uintptr) uintptr {
+ state := (*serviceSubscriptionState)(unsafe.Pointer(context))
+ if atomic.LoadUint32(&state.once) != 0 {
+ return 0
}
- trackedTunnels[tunnelName] = TunnelUnknown
- trackedTunnelsLock.Unlock()
- defer func() {
- trackedTunnelsLock.Lock()
- delete(trackedTunnels, tunnelName)
- trackedTunnelsLock.Unlock()
- }()
+ if notification == 0 {
+ status, err := state.service.Query()
+ if err == nil {
+ notification = svcStateToNotifyState(uint32(status.State))
+ }
+ }
+ if state.cb(notification) && atomic.CompareAndSwapUint32(&state.once, 0, 1) {
+ state.done.Done()
+ }
+ return 0
+})
- const serviceNotifications = windows.SERVICE_NOTIFY_RUNNING | windows.SERVICE_NOTIFY_START_PENDING | windows.SERVICE_NOTIFY_STOP_PENDING | windows.SERVICE_NOTIFY_STOPPED | windows.SERVICE_NOTIFY_DELETE_PENDING
- notifier := &windows.SERVICE_NOTIFY{
- Version: windows.SERVICE_NOTIFY_STATUS_CHANGE,
- NotifyCallback: serviceTrackerCallbackPtr,
+func svcStateToNotifyState(s uint32) uint32 {
+ switch s {
+ case windows.SERVICE_STOPPED:
+ return windows.SERVICE_NOTIFY_STOPPED
+ case windows.SERVICE_START_PENDING:
+ return windows.SERVICE_NOTIFY_START_PENDING
+ case windows.SERVICE_STOP_PENDING:
+ return windows.SERVICE_NOTIFY_STOP_PENDING
+ case windows.SERVICE_RUNNING:
+ return windows.SERVICE_NOTIFY_RUNNING
+ case windows.SERVICE_CONTINUE_PENDING:
+ return windows.SERVICE_NOTIFY_CONTINUE_PENDING
+ case windows.SERVICE_PAUSE_PENDING:
+ return windows.SERVICE_NOTIFY_PAUSE_PENDING
+ case windows.SERVICE_PAUSED:
+ return windows.SERVICE_NOTIFY_PAUSED
+ case windows.SERVICE_NO_CHANGE:
+ return 0
+ default:
+ return 0
}
+}
- checkForDisabled := func() (shouldReturn bool) {
- config, err := service.Config()
- if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || config.StartType == windows.SERVICE_DISABLED {
- log.Printf("[%s] Found disabled service via timeout, so deleting", tunnelName)
- service.Delete()
- trackedTunnelsLock.Lock()
- trackedTunnels[tunnelName] = TunnelStopped
- trackedTunnelsLock.Unlock()
- IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil)
- return true
+func notifyStateToTunState(s uint32) TunnelState {
+ if s&(windows.SERVICE_NOTIFY_STOPPED|windows.SERVICE_NOTIFY_DELETED) != 0 {
+ return TunnelStopped
+ } else if s&(windows.SERVICE_NOTIFY_DELETE_PENDING|windows.SERVICE_NOTIFY_STOP_PENDING) != 0 {
+ return TunnelStopping
+ } else if s&windows.SERVICE_NOTIFY_RUNNING != 0 {
+ return TunnelStarted
+ } else if s&windows.SERVICE_NOTIFY_START_PENDING != 0 {
+ return TunnelStarting
+ } else {
+ return TunnelUnknown
+ }
+}
+
+func trackService(service *mgr.Service, callback func(status uint32) bool) error {
+ var subscription uintptr
+ state := &serviceSubscriptionState{service: service, cb: callback}
+ state.done.Add(1)
+ err := windows.SubscribeServiceChangeNotifications(service.Handle, windows.SC_EVENT_STATUS_CHANGE, serviceSubscriptionCallbackPtr, uintptr(unsafe.Pointer(state)), &subscription)
+ if err == nil {
+ defer windows.UnsubscribeServiceChangeNotifications(subscription)
+ status, err := service.Query()
+ if err == nil {
+ if callback(svcStateToNotifyState(uint32(status.State))) {
+ return nil
+ }
}
- return false
+ state.done.Wait()
+ runtime.KeepAlive(state.cb)
+ return nil
}
- if checkForDisabled() {
- return
+ if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) {
+ return err
}
- runtime.LockOSThread()
+ // TODO: Below this line is Windows 7 compatibility code, which hopefully we can delete at some point.
+ runtime.LockOSThread()
// This line would be fitting but is intentionally commented out:
//
// defer runtime.UnlockOSThread()
@@ -134,7 +142,11 @@ func trackTunnelService(tunnelName string, service *mgr.Service) {
// with the thread local context, which in turn appears to corrupt Go's own usage of TLS,
// leading to crashes sometime later (usually in runtime_unlock()) when the thread is recycled.
- lastState := TunnelUnknown
+ const serviceNotifications = windows.SERVICE_NOTIFY_RUNNING | windows.SERVICE_NOTIFY_START_PENDING | windows.SERVICE_NOTIFY_STOP_PENDING | windows.SERVICE_NOTIFY_STOPPED | windows.SERVICE_NOTIFY_DELETE_PENDING
+ notifier := &windows.SERVICE_NOTIFY{
+ Version: windows.SERVICE_NOTIFY_STATUS_CHANGE,
+ NotifyCallback: serviceTrackerCallbackPtr,
+ }
for {
err := windows.NotifyServiceStatusChange(service.Handle, serviceNotifications, notifier)
switch err {
@@ -142,42 +154,94 @@ func trackTunnelService(tunnelName string, service *mgr.Service) {
for {
if windows.SleepEx(uint32(time.Second*3/time.Millisecond), true) == windows.WAIT_IO_COMPLETION {
break
- } else if checkForDisabled() {
- return
+ } else if callback(0) {
+ return nil
}
}
case windows.ERROR_SERVICE_MARKED_FOR_DELETE:
- trackedTunnelsLock.Lock()
- trackedTunnels[tunnelName] = TunnelStopped
- trackedTunnelsLock.Unlock()
- IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil)
- return
+ // Should be SERVICE_NOTIFY_DELETE_PENDING, but actually, we must release the handle and return here; otherwise it never deletes.
+ if callback(windows.SERVICE_NOTIFY_DELETED) {
+ return nil
+ }
case windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING:
continue
default:
+ return err
+ }
+ if callback(svcStateToNotifyState(notifier.ServiceStatus.CurrentState)) {
+ return nil
+ }
+ }
+}
+
+func trackTunnelService(tunnelName string, service *mgr.Service) {
+ trackedTunnelsLock.Lock()
+ if _, found := trackedTunnels[tunnelName]; found {
+ trackedTunnelsLock.Unlock()
+ service.Close()
+ return
+ }
+
+ defer func() {
+ service.Close()
+ log.Printf("[%s] Tunnel service tracker finished", tunnelName)
+ }()
+ trackedTunnels[tunnelName] = TunnelUnknown
+ trackedTunnelsLock.Unlock()
+ defer func() {
+ trackedTunnelsLock.Lock()
+ delete(trackedTunnels, tunnelName)
+ trackedTunnelsLock.Unlock()
+ }()
+
+ for i := 0; i < 20; i++ {
+ if i > 0 {
+ time.Sleep(time.Second / 5)
+ }
+ if status, err := service.Query(); err != nil || status.State != svc.Stopped {
+ break
+ }
+ }
+
+ checkForDisabled := func() (shouldReturn bool) {
+ config, err := service.Config()
+ if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || (err != nil && config.StartType == windows.SERVICE_DISABLED) {
+ log.Printf("[%s] Found disabled service via timeout, so deleting", tunnelName)
+ service.Delete()
trackedTunnelsLock.Lock()
trackedTunnels[tunnelName] = TunnelStopped
trackedTunnelsLock.Unlock()
- IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring service, so stopping: %v", err))
- service.Control(svc.Stop)
- return
+ IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil)
+ return true
}
-
- state := svcStateToTunState(svc.State(notifier.ServiceStatus.CurrentState))
+ return false
+ }
+ if checkForDisabled() {
+ return
+ }
+ lastState := TunnelUnknown
+ err := trackService(service, func(status uint32) bool {
+ state := notifyStateToTunState(status)
var tunnelError error
if state == TunnelStopped {
- if notifier.ServiceStatus.Win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) {
- maybeErr := services.Error(notifier.ServiceStatus.ServiceSpecificExitCode)
- if maybeErr != services.ErrorSuccess {
- tunnelError = maybeErr
- }
- } else {
- switch notifier.ServiceStatus.Win32ExitCode {
- case uint32(windows.NO_ERROR), uint32(windows.ERROR_SERVICE_NEVER_STARTED):
- default:
- tunnelError = syscall.Errno(notifier.ServiceStatus.Win32ExitCode)
+ serviceStatus, err := service.Query()
+ if err == nil {
+ if serviceStatus.Win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) {
+ maybeErr := services.Error(serviceStatus.ServiceSpecificExitCode)
+ if maybeErr != services.ErrorSuccess {
+ tunnelError = maybeErr
+ }
+ } else {
+ switch serviceStatus.Win32ExitCode {
+ case uint32(windows.NO_ERROR), uint32(windows.ERROR_SERVICE_NEVER_STARTED):
+ default:
+ tunnelError = syscall.Errno(serviceStatus.Win32ExitCode)
+ }
}
}
+ if tunnelError != nil {
+ service.Delete()
+ }
}
if state != lastState {
trackedTunnelsLock.Lock()
@@ -186,5 +250,94 @@ func trackTunnelService(tunnelName string, service *mgr.Service) {
IPCServerNotifyTunnelChange(tunnelName, state, tunnelError)
lastState = state
}
+ if state == TunnelUnknown && checkForDisabled() {
+ return true
+ }
+ return state == TunnelStopped
+ })
+ if err != nil && !checkForDisabled() {
+ trackedTunnelsLock.Lock()
+ trackedTunnels[tunnelName] = TunnelStopped
+ trackedTunnelsLock.Unlock()
+ IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring service, so stopping: %w", err))
+ service.Control(svc.Stop)
+ }
+}
+
+func trackExistingTunnels() error {
+ m, err := serviceManager()
+ if err != nil {
+ return err
+ }
+ names, err := conf.ListConfigNames()
+ if err != nil {
+ return err
+ }
+ for _, name := range names {
+ trackedTunnelsLock.Lock()
+ if _, found := trackedTunnels[name]; found {
+ trackedTunnelsLock.Unlock()
+ continue
+ }
+ trackedTunnelsLock.Unlock()
+ serviceName, err := conf.ServiceNameOfTunnel(name)
+ if err != nil {
+ continue
+ }
+ service, err := m.OpenService(serviceName)
+ if err != nil {
+ continue
+ }
+ go trackTunnelService(name, service)
}
+ return nil
+}
+
+var servicesSubscriptionWatcherCallbackPtr = windows.NewCallback(func(notification uint32, context uintptr) uintptr {
+ trackExistingTunnels()
+ return 0
+})
+
+func watchNewTunnelServices() error {
+ m, err := serviceManager()
+ if err != nil {
+ return err
+ }
+ var subscription uintptr
+ err = windows.SubscribeServiceChangeNotifications(m.Handle, windows.SC_EVENT_DATABASE_CHANGE, servicesSubscriptionWatcherCallbackPtr, 0, &subscription)
+ if err == nil {
+ // We probably could do:
+ // defer windows.UnsubscribeServiceChangeNotifications(subscription)
+ // and then terminate after some point, but instead we just let this go forever; it's process-lived.
+ return trackExistingTunnels()
+ }
+ if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) {
+ return err
+ }
+
+ // TODO: Below this line is Windows 7 compatibility code, which hopefully we can delete at some point.
+ go func() {
+ runtime.LockOSThread()
+ notifier := &windows.SERVICE_NOTIFY{
+ Version: windows.SERVICE_NOTIFY_STATUS_CHANGE,
+ NotifyCallback: serviceTrackerCallbackPtr,
+ }
+ for {
+ err := windows.NotifyServiceStatusChange(m.Handle, windows.SERVICE_NOTIFY_CREATED, notifier)
+ if err == nil {
+ windows.SleepEx(windows.INFINITE, true)
+ if notifier.ServiceNames != nil {
+ windows.LocalFree(windows.Handle(unsafe.Pointer(notifier.ServiceNames)))
+ notifier.ServiceNames = nil
+ }
+ trackExistingTunnels()
+ } else if err == windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING {
+ continue
+ } else {
+ time.Sleep(time.Second * 3)
+ trackExistingTunnels()
+ }
+ }
+ }()
+ return trackExistingTunnels()
}
diff --git a/manager/uiprocess.go b/manager/uiprocess.go
new file mode 100644
index 00000000..b33b1ad3
--- /dev/null
+++ b/manager/uiprocess.go
@@ -0,0 +1,103 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
+ */
+
+package manager
+
+import (
+ "errors"
+ "runtime"
+ "sync/atomic"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+type uiProcess struct {
+ handle uintptr
+}
+
+func launchUIProcess(executable string, args []string, workingDirectory string, handles []windows.Handle, token windows.Token) (*uiProcess, error) {
+ executable16, err := windows.UTF16PtrFromString(executable)
+ if err != nil {
+ return nil, err
+ }
+ args16, err := windows.UTF16PtrFromString(windows.ComposeCommandLine(args))
+ if err != nil {
+ return nil, err
+ }
+ workingDirectory16, err := windows.UTF16PtrFromString(workingDirectory)
+ if err != nil {
+ return nil, err
+ }
+ var environmentBlock *uint16
+ err = windows.CreateEnvironmentBlock(&environmentBlock, token, false)
+ if err != nil {
+ return nil, err
+ }
+ defer windows.DestroyEnvironmentBlock(environmentBlock)
+ attributeList, err := windows.NewProcThreadAttributeList(1)
+ if err != nil {
+ return nil, err
+ }
+ defer attributeList.Delete()
+ si := &windows.StartupInfoEx{
+ StartupInfo: windows.StartupInfo{Cb: uint32(unsafe.Sizeof(windows.StartupInfoEx{}))},
+ ProcThreadAttributeList: attributeList.List(),
+ }
+ if len(handles) == 0 {
+ handles = []windows.Handle{0}
+ }
+ attributeList.Update(windows.PROC_THREAD_ATTRIBUTE_HANDLE_LIST, unsafe.Pointer(&handles[0]), uintptr(len(handles))*unsafe.Sizeof(handles[0]))
+ pi := new(windows.ProcessInformation)
+ err = windows.CreateProcessAsUser(token, executable16, args16, nil, nil, true, windows.CREATE_DEFAULT_ERROR_MODE|windows.CREATE_UNICODE_ENVIRONMENT|windows.EXTENDED_STARTUPINFO_PRESENT, environmentBlock, workingDirectory16, &si.StartupInfo, pi)
+ if err != nil {
+ return nil, err
+ }
+ windows.CloseHandle(pi.Thread)
+ uiProc := &uiProcess{handle: uintptr(pi.Process)}
+ runtime.SetFinalizer(uiProc, (*uiProcess).release)
+ return uiProc, nil
+}
+
+func (p *uiProcess) release() error {
+ handle := windows.Handle(atomic.SwapUintptr(&p.handle, uintptr(windows.InvalidHandle)))
+ if handle == windows.InvalidHandle {
+ return nil
+ }
+ err := windows.CloseHandle(handle)
+ if err != nil {
+ return err
+ }
+ runtime.SetFinalizer(p, nil)
+ return nil
+}
+
+func (p *uiProcess) Wait() (uint32, error) {
+ handle := windows.Handle(atomic.LoadUintptr(&p.handle))
+ s, err := windows.WaitForSingleObject(handle, syscall.INFINITE)
+ switch s {
+ case windows.WAIT_OBJECT_0:
+ case windows.WAIT_FAILED:
+ return 0, err
+ default:
+ return 0, errors.New("unexpected result from WaitForSingleObject")
+ }
+ var exitCode uint32
+ err = windows.GetExitCodeProcess(handle, &exitCode)
+ if err != nil {
+ return 0, err
+ }
+ p.release()
+ return exitCode, nil
+}
+
+func (p *uiProcess) Kill() error {
+ handle := windows.Handle(atomic.LoadUintptr(&p.handle))
+ if handle == windows.InvalidHandle {
+ return nil
+ }
+ return windows.TerminateProcess(handle, 1)
+}
diff --git a/manager/updatestate.go b/manager/updatestate.go
index b54cc367..d5a19c8d 100644
--- a/manager/updatestate.go
+++ b/manager/updatestate.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package manager
@@ -8,11 +8,16 @@ package manager
import (
"log"
"time"
+ _ "unsafe"
+ "golang.zx2c4.com/wireguard/windows/services"
"golang.zx2c4.com/wireguard/windows/updater"
"golang.zx2c4.com/wireguard/windows/version"
)
+//go:linkname fastrandn runtime.fastrandn
+func fastrandn(n uint32) uint32
+
type UpdateState uint32
const (
@@ -23,35 +28,38 @@ const (
var updateState = UpdateStateUnknown
-func checkForUpdates() {
- defer printPanic()
+func jitterSleep(min, max time.Duration) {
+ time.Sleep(min + time.Millisecond*time.Duration(fastrandn(uint32((max-min+1)/time.Millisecond))))
+}
+func checkForUpdates() {
if !version.IsRunningOfficialVersion() {
log.Println("Build is not official, so updates are disabled")
updateState = UpdateStateUpdatesDisabledUnofficialBuild
IPCServerNotifyUpdateFound(updateState)
return
}
-
- first := true
+ if services.StartedAtBoot() {
+ jitterSleep(time.Minute*2, time.Minute*5)
+ }
+ noError, didNotify := true, false
for {
update, err := updater.CheckForUpdate()
- if err == nil && update != nil {
+ if err == nil && update != nil && !didNotify {
log.Println("An update is available")
updateState = UpdateStateFoundUpdate
IPCServerNotifyUpdateFound(updateState)
- return
- }
- if err != nil {
+ didNotify = true
+ } else if err != nil && !didNotify {
log.Printf("Update checker: %v", err)
- if first {
- time.Sleep(time.Minute * 4)
- first = false
+ if noError {
+ jitterSleep(time.Minute*4, time.Minute*6)
+ noError = false
} else {
- time.Sleep(time.Minute * 25)
+ jitterSleep(time.Minute*25, time.Minute*30)
}
} else {
- time.Sleep(time.Hour)
+ jitterSleep(time.Hour-time.Minute*3, time.Hour+time.Minute*3)
}
}
}