diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2019-03-12 00:04:40 -0600 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2019-03-12 03:07:59 -0600 |
commit | c273ce15e796bcbf441c233adbe8b113fd2e5991 (patch) | |
tree | 9e620accca016b1520f9437ceb6c948ea53f46a1 /service/tunneltracker.go | |
parent | tunneltracker: redo deletion state machine (diff) | |
download | wireguard-windows-c273ce15e796bcbf441c233adbe8b113fd2e5991.tar.xz wireguard-windows-c273ce15e796bcbf441c233adbe8b113fd2e5991.zip |
tunneltracker: don't track tunnels that haven't been started
Otherwise we get the hasn't-been-started-yet error, and the tracker
quits. Meanwhile this is reported back to the ui as an error. While
we're at it, don't let multiple trackers be run, in the event that the
at-start tracker races with the installation tracker. And, make sure we
actually get the deletion notification.
Diffstat (limited to '')
-rw-r--r-- | service/tunneltracker.go | 62 |
1 files changed, 44 insertions, 18 deletions
diff --git a/service/tunneltracker.go b/service/tunneltracker.go index 402fd070..3cf9fde5 100644 --- a/service/tunneltracker.go +++ b/service/tunneltracker.go @@ -12,6 +12,7 @@ import ( "golang.org/x/sys/windows/svc/mgr" "golang.zx2c4.com/wireguard/windows/conf" "runtime" + "sync" "syscall" "unsafe" ) @@ -84,15 +85,47 @@ var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *serviceNotify return 0 }) +var trackedTunnels = make(map[string]bool) +var trackedTunnelsLock = sync.Mutex{} + +func svcStateToTunState(s svc.State) TunnelState { + switch s { + case svc.StartPending: + return TunnelStarting + case svc.Running: + return TunnelStarted + case svc.StopPending: + return TunnelStopping + case svc.Stopped: + return TunnelStopped + default: + return TunnelUnknown + } +} + func trackTunnelService(tunnelName string, service *mgr.Service) { - runtime.LockOSThread() - const serviceNotifications = serviceNotify_RUNNING | serviceNotify_START_PENDING | serviceNotify_STOP_PENDING | serviceNotify_STOPPED + defer service.Close() + + trackedTunnelsLock.Lock() + _, isTracked := trackedTunnels[tunnelName] + trackedTunnels[tunnelName] = true + trackedTunnelsLock.Unlock() + if isTracked { + return + } + defer func() { + trackedTunnelsLock.Lock() + delete(trackedTunnels, tunnelName) + trackedTunnelsLock.Unlock() + }() + + const serviceNotifications = serviceNotify_RUNNING | serviceNotify_START_PENDING | serviceNotify_STOP_PENDING | serviceNotify_STOPPED | serviceNotify_DELETE_PENDING notifier := &serviceNotify{ version: serviceNotify_STATUS_CHANGE, notifyCallback: serviceTrackerCallbackPtr, } - defer service.Close() + runtime.LockOSThread() lastState := TunnelUnknown for { ret := notifyServiceStatusChange(service.Handle, serviceNotifications, uintptr(unsafe.Pointer(notifier))) @@ -110,32 +143,25 @@ func trackTunnelService(tunnelName string, service *mgr.Service) { return } - state := TunnelUnknown + state := svcStateToTunState(svc.State(notifier.serviceStatus.currentState)) var tunnelError error - switch svc.State(notifier.serviceStatus.currentState) { - case svc.Stopped: - state = TunnelStopped + if state == TunnelStopped { if notifier.serviceStatus.win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) { maybeErr := Error(notifier.serviceStatus.serviceSpecificExitCode) if maybeErr != ErrorSuccess { tunnelError = maybeErr } - } else if notifier.serviceStatus.win32ExitCode != uint32(windows.NO_ERROR) { - tunnelError = syscall.Errno(notifier.serviceStatus.win32ExitCode) + } else { + switch notifier.serviceStatus.win32ExitCode { + case uint32(windows.NO_ERROR), serviceNEVER_STARTED: + default: + tunnelError = syscall.Errno(notifier.serviceStatus.win32ExitCode) + } } - case svc.StopPending: - state = TunnelStopping - case svc.Running: - state = TunnelStarted - case svc.StartPending: - state = TunnelStarting } if state != lastState { IPCServerNotifyTunnelChange(tunnelName, state, tunnelError) lastState = state } - if state == TunnelStopped { - return - } } } |