diff options
Diffstat (limited to 'service/tunneltracker.go')
-rw-r--r-- | service/tunneltracker.go | 62 |
1 files changed, 44 insertions, 18 deletions
diff --git a/service/tunneltracker.go b/service/tunneltracker.go index 402fd070..3cf9fde5 100644 --- a/service/tunneltracker.go +++ b/service/tunneltracker.go @@ -12,6 +12,7 @@ import ( "golang.org/x/sys/windows/svc/mgr" "golang.zx2c4.com/wireguard/windows/conf" "runtime" + "sync" "syscall" "unsafe" ) @@ -84,15 +85,47 @@ var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *serviceNotify return 0 }) +var trackedTunnels = make(map[string]bool) +var trackedTunnelsLock = sync.Mutex{} + +func svcStateToTunState(s svc.State) TunnelState { + switch s { + case svc.StartPending: + return TunnelStarting + case svc.Running: + return TunnelStarted + case svc.StopPending: + return TunnelStopping + case svc.Stopped: + return TunnelStopped + default: + return TunnelUnknown + } +} + func trackTunnelService(tunnelName string, service *mgr.Service) { - runtime.LockOSThread() - const serviceNotifications = serviceNotify_RUNNING | serviceNotify_START_PENDING | serviceNotify_STOP_PENDING | serviceNotify_STOPPED + defer service.Close() + + trackedTunnelsLock.Lock() + _, isTracked := trackedTunnels[tunnelName] + trackedTunnels[tunnelName] = true + trackedTunnelsLock.Unlock() + if isTracked { + return + } + defer func() { + trackedTunnelsLock.Lock() + delete(trackedTunnels, tunnelName) + trackedTunnelsLock.Unlock() + }() + + const serviceNotifications = serviceNotify_RUNNING | serviceNotify_START_PENDING | serviceNotify_STOP_PENDING | serviceNotify_STOPPED | serviceNotify_DELETE_PENDING notifier := &serviceNotify{ version: serviceNotify_STATUS_CHANGE, notifyCallback: serviceTrackerCallbackPtr, } - defer service.Close() + runtime.LockOSThread() lastState := TunnelUnknown for { ret := notifyServiceStatusChange(service.Handle, serviceNotifications, uintptr(unsafe.Pointer(notifier))) @@ -110,32 +143,25 @@ func trackTunnelService(tunnelName string, service *mgr.Service) { return } - state := TunnelUnknown + state := svcStateToTunState(svc.State(notifier.serviceStatus.currentState)) var tunnelError error - switch svc.State(notifier.serviceStatus.currentState) { - case svc.Stopped: - state = TunnelStopped + if state == TunnelStopped { if notifier.serviceStatus.win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) { maybeErr := Error(notifier.serviceStatus.serviceSpecificExitCode) if maybeErr != ErrorSuccess { tunnelError = maybeErr } - } else if notifier.serviceStatus.win32ExitCode != uint32(windows.NO_ERROR) { - tunnelError = syscall.Errno(notifier.serviceStatus.win32ExitCode) + } else { + switch notifier.serviceStatus.win32ExitCode { + case uint32(windows.NO_ERROR), serviceNEVER_STARTED: + default: + tunnelError = syscall.Errno(notifier.serviceStatus.win32ExitCode) + } } - case svc.StopPending: - state = TunnelStopping - case svc.Running: - state = TunnelStarted - case svc.StartPending: - state = TunnelStarting } if state != lastState { IPCServerNotifyTunnelChange(tunnelName, state, tunnelError) lastState = state } - if state == TunnelStopped { - return - } } } |