aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/manager/tunneltracker.go
diff options
context:
space:
mode:
Diffstat (limited to 'manager/tunneltracker.go')
-rw-r--r--manager/tunneltracker.go182
1 files changed, 182 insertions, 0 deletions
diff --git a/manager/tunneltracker.go b/manager/tunneltracker.go
new file mode 100644
index 00000000..1cde98e2
--- /dev/null
+++ b/manager/tunneltracker.go
@@ -0,0 +1,182 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package manager
+
+import (
+ "fmt"
+ "log"
+ "runtime"
+ "sync"
+ "syscall"
+ "time"
+
+ "golang.org/x/sys/windows"
+ "golang.org/x/sys/windows/svc"
+ "golang.org/x/sys/windows/svc/mgr"
+
+ "golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/services"
+)
+
+func trackExistingTunnels() error {
+ m, err := serviceManager()
+ if err != nil {
+ return err
+ }
+ names, err := conf.ListConfigNames()
+ if err != nil {
+ return err
+ }
+ for _, name := range names {
+ serviceName, err := ServiceNameOfTunnel(name)
+ if err != nil {
+ continue
+ }
+ service, err := m.OpenService(serviceName)
+ if err != nil {
+ continue
+ }
+ go trackTunnelService(name, service)
+ }
+ return nil
+}
+
+var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *windows.SERVICE_NOTIFY) uintptr {
+ return 0
+})
+
+var trackedTunnels = make(map[string]TunnelState)
+var trackedTunnelsLock = sync.Mutex{}
+
+func svcStateToTunState(s svc.State) TunnelState {
+ switch s {
+ case svc.StartPending:
+ return TunnelStarting
+ case svc.Running:
+ return TunnelStarted
+ case svc.StopPending:
+ return TunnelStopping
+ case svc.Stopped:
+ return TunnelStopped
+ default:
+ return TunnelUnknown
+ }
+}
+
+func trackedTunnelsGlobalState() (state TunnelState) {
+ state = TunnelStopped
+ trackedTunnelsLock.Lock()
+ defer trackedTunnelsLock.Unlock()
+ for _, s := range trackedTunnels {
+ if s == TunnelStarting {
+ return TunnelStarting
+ } else if s == TunnelStopping {
+ return TunnelStopping
+ } else if s == TunnelStarted || s == TunnelUnknown {
+ state = TunnelStarted
+ }
+ }
+ return
+}
+
+func trackTunnelService(tunnelName string, service *mgr.Service) {
+ defer func() {
+ service.Close()
+ log.Printf("[%s] Tunnel managerService tracker finished", tunnelName)
+ }()
+
+ trackedTunnelsLock.Lock()
+ if _, found := trackedTunnels[tunnelName]; found {
+ trackedTunnelsLock.Unlock()
+ return
+ }
+ trackedTunnels[tunnelName] = TunnelUnknown
+ trackedTunnelsLock.Unlock()
+ defer func() {
+ trackedTunnelsLock.Lock()
+ delete(trackedTunnels, tunnelName)
+ trackedTunnelsLock.Unlock()
+ }()
+
+ const serviceNotifications = windows.SERVICE_NOTIFY_RUNNING | windows.SERVICE_NOTIFY_START_PENDING | windows.SERVICE_NOTIFY_STOP_PENDING | windows.SERVICE_NOTIFY_STOPPED | windows.SERVICE_NOTIFY_DELETE_PENDING
+ notifier := &windows.SERVICE_NOTIFY{
+ Version: windows.SERVICE_NOTIFY_STATUS_CHANGE,
+ NotifyCallback: serviceTrackerCallbackPtr,
+ }
+
+ checkForDisabled := func() (shouldReturn bool) {
+ config, err := service.Config()
+ if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || config.StartType == windows.SERVICE_DISABLED {
+ log.Printf("[%s] Found disabled service via timeout, so deleting", tunnelName)
+ service.Delete()
+ trackedTunnelsLock.Lock()
+ trackedTunnels[tunnelName] = TunnelStopped
+ trackedTunnelsLock.Unlock()
+ IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil)
+ return true
+ }
+ return false
+ }
+ if checkForDisabled() {
+ return
+ }
+
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+ lastState := TunnelUnknown
+ for {
+ err := windows.NotifyServiceStatusChange(service.Handle, serviceNotifications, notifier)
+ switch err {
+ case nil:
+ for {
+ if windows.SleepEx(uint32(time.Second*3/time.Millisecond), true) == windows.WAIT_IO_COMPLETION {
+ break
+ } else if checkForDisabled() {
+ return
+ }
+ }
+ case windows.ERROR_SERVICE_MARKED_FOR_DELETE:
+ trackedTunnelsLock.Lock()
+ trackedTunnels[tunnelName] = TunnelStopped
+ trackedTunnelsLock.Unlock()
+ IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil)
+ return
+ case windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING:
+ continue
+ default:
+ trackedTunnelsLock.Lock()
+ trackedTunnels[tunnelName] = TunnelStopped
+ trackedTunnelsLock.Unlock()
+ IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring managerService, so stopping: %v", err))
+ service.Control(svc.Stop)
+ return
+ }
+
+ state := svcStateToTunState(svc.State(notifier.ServiceStatus.CurrentState))
+ var tunnelError error
+ if state == TunnelStopped {
+ if notifier.ServiceStatus.Win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) {
+ maybeErr := services.Error(notifier.ServiceStatus.ServiceSpecificExitCode)
+ if maybeErr != services.ErrorSuccess {
+ tunnelError = maybeErr
+ }
+ } else {
+ switch notifier.ServiceStatus.Win32ExitCode {
+ case uint32(windows.NO_ERROR), uint32(windows.ERROR_SERVICE_NEVER_STARTED):
+ default:
+ tunnelError = syscall.Errno(notifier.ServiceStatus.Win32ExitCode)
+ }
+ }
+ }
+ if state != lastState {
+ trackedTunnelsLock.Lock()
+ trackedTunnels[tunnelName] = state
+ trackedTunnelsLock.Unlock()
+ IPCServerNotifyTunnelChange(tunnelName, state, tunnelError)
+ lastState = state
+ }
+ }
+}