diff options
Diffstat (limited to 'manager/tunneltracker.go')
-rw-r--r-- | manager/tunneltracker.go | 182 |
1 files changed, 182 insertions, 0 deletions
diff --git a/manager/tunneltracker.go b/manager/tunneltracker.go new file mode 100644 index 00000000..1cde98e2 --- /dev/null +++ b/manager/tunneltracker.go @@ -0,0 +1,182 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package manager + +import ( + "fmt" + "log" + "runtime" + "sync" + "syscall" + "time" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/mgr" + + "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/services" +) + +func trackExistingTunnels() error { + m, err := serviceManager() + if err != nil { + return err + } + names, err := conf.ListConfigNames() + if err != nil { + return err + } + for _, name := range names { + serviceName, err := ServiceNameOfTunnel(name) + if err != nil { + continue + } + service, err := m.OpenService(serviceName) + if err != nil { + continue + } + go trackTunnelService(name, service) + } + return nil +} + +var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *windows.SERVICE_NOTIFY) uintptr { + return 0 +}) + +var trackedTunnels = make(map[string]TunnelState) +var trackedTunnelsLock = sync.Mutex{} + +func svcStateToTunState(s svc.State) TunnelState { + switch s { + case svc.StartPending: + return TunnelStarting + case svc.Running: + return TunnelStarted + case svc.StopPending: + return TunnelStopping + case svc.Stopped: + return TunnelStopped + default: + return TunnelUnknown + } +} + +func trackedTunnelsGlobalState() (state TunnelState) { + state = TunnelStopped + trackedTunnelsLock.Lock() + defer trackedTunnelsLock.Unlock() + for _, s := range trackedTunnels { + if s == TunnelStarting { + return TunnelStarting + } else if s == TunnelStopping { + return TunnelStopping + } else if s == TunnelStarted || s == TunnelUnknown { + state = TunnelStarted + } + } + return +} + +func trackTunnelService(tunnelName string, service *mgr.Service) { + defer func() { + service.Close() + log.Printf("[%s] Tunnel managerService tracker finished", tunnelName) + }() + + trackedTunnelsLock.Lock() + if _, found := trackedTunnels[tunnelName]; found { + trackedTunnelsLock.Unlock() + return + } + trackedTunnels[tunnelName] = TunnelUnknown + trackedTunnelsLock.Unlock() + defer func() { + trackedTunnelsLock.Lock() + delete(trackedTunnels, tunnelName) + trackedTunnelsLock.Unlock() + }() + + const serviceNotifications = windows.SERVICE_NOTIFY_RUNNING | windows.SERVICE_NOTIFY_START_PENDING | windows.SERVICE_NOTIFY_STOP_PENDING | windows.SERVICE_NOTIFY_STOPPED | windows.SERVICE_NOTIFY_DELETE_PENDING + notifier := &windows.SERVICE_NOTIFY{ + Version: windows.SERVICE_NOTIFY_STATUS_CHANGE, + NotifyCallback: serviceTrackerCallbackPtr, + } + + checkForDisabled := func() (shouldReturn bool) { + config, err := service.Config() + if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || config.StartType == windows.SERVICE_DISABLED { + log.Printf("[%s] Found disabled service via timeout, so deleting", tunnelName) + service.Delete() + trackedTunnelsLock.Lock() + trackedTunnels[tunnelName] = TunnelStopped + trackedTunnelsLock.Unlock() + IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil) + return true + } + return false + } + if checkForDisabled() { + return + } + + runtime.LockOSThread() + defer runtime.UnlockOSThread() + lastState := TunnelUnknown + for { + err := windows.NotifyServiceStatusChange(service.Handle, serviceNotifications, notifier) + switch err { + case nil: + for { + if windows.SleepEx(uint32(time.Second*3/time.Millisecond), true) == windows.WAIT_IO_COMPLETION { + break + } else if checkForDisabled() { + return + } + } + case windows.ERROR_SERVICE_MARKED_FOR_DELETE: + trackedTunnelsLock.Lock() + trackedTunnels[tunnelName] = TunnelStopped + trackedTunnelsLock.Unlock() + IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil) + return + case windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING: + continue + default: + trackedTunnelsLock.Lock() + trackedTunnels[tunnelName] = TunnelStopped + trackedTunnelsLock.Unlock() + IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring managerService, so stopping: %v", err)) + service.Control(svc.Stop) + return + } + + state := svcStateToTunState(svc.State(notifier.ServiceStatus.CurrentState)) + var tunnelError error + if state == TunnelStopped { + if notifier.ServiceStatus.Win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) { + maybeErr := services.Error(notifier.ServiceStatus.ServiceSpecificExitCode) + if maybeErr != services.ErrorSuccess { + tunnelError = maybeErr + } + } else { + switch notifier.ServiceStatus.Win32ExitCode { + case uint32(windows.NO_ERROR), uint32(windows.ERROR_SERVICE_NEVER_STARTED): + default: + tunnelError = syscall.Errno(notifier.ServiceStatus.Win32ExitCode) + } + } + } + if state != lastState { + trackedTunnelsLock.Lock() + trackedTunnels[tunnelName] = state + trackedTunnelsLock.Unlock() + IPCServerNotifyTunnelChange(tunnelName, state, tunnelError) + lastState = state + } + } +} |