diff options
Diffstat (limited to '')
-rw-r--r-- | service/errors.go | 1 | ||||
-rw-r--r-- | service/install.go | 7 | ||||
-rw-r--r-- | service/tunneltracker.go | 62 |
3 files changed, 49 insertions, 21 deletions
diff --git a/service/errors.go b/service/errors.go index ea6147a9..fd6bc6ab 100644 --- a/service/errors.go +++ b/service/errors.go @@ -95,4 +95,5 @@ func combineErrors(err error, serviceError Error) error { const ( serviceDOES_NOT_EXIST uint32 = 0x00000424 serviceMARKED_FOR_DELETE uint32 = 0x00000430 + serviceNEVER_STARTED uint32 = 0x00000435 ) diff --git a/service/install.go b/service/install.go index 87ac002d..7f39b2bc 100644 --- a/service/install.go +++ b/service/install.go @@ -65,7 +65,7 @@ func InstallManager() error { break } service.Close() - time.Sleep(time.Second) + time.Sleep(time.Second / 3) } } @@ -163,8 +163,9 @@ func InstallTunnel(configPath string) error { if err != nil { return err } - go trackTunnelService(name, service) - return service.Start() + err = service.Start() + go trackTunnelService(name, service) // Pass off reference to handle. + return err } func UninstallTunnel(name string) error { diff --git a/service/tunneltracker.go b/service/tunneltracker.go index 402fd070..3cf9fde5 100644 --- a/service/tunneltracker.go +++ b/service/tunneltracker.go @@ -12,6 +12,7 @@ import ( "golang.org/x/sys/windows/svc/mgr" "golang.zx2c4.com/wireguard/windows/conf" "runtime" + "sync" "syscall" "unsafe" ) @@ -84,15 +85,47 @@ var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *serviceNotify return 0 }) +var trackedTunnels = make(map[string]bool) +var trackedTunnelsLock = sync.Mutex{} + +func svcStateToTunState(s svc.State) TunnelState { + switch s { + case svc.StartPending: + return TunnelStarting + case svc.Running: + return TunnelStarted + case svc.StopPending: + return TunnelStopping + case svc.Stopped: + return TunnelStopped + default: + return TunnelUnknown + } +} + func trackTunnelService(tunnelName string, service *mgr.Service) { - runtime.LockOSThread() - const serviceNotifications = serviceNotify_RUNNING | serviceNotify_START_PENDING | serviceNotify_STOP_PENDING | serviceNotify_STOPPED + defer service.Close() + + trackedTunnelsLock.Lock() + _, isTracked := trackedTunnels[tunnelName] + trackedTunnels[tunnelName] = true + trackedTunnelsLock.Unlock() + if isTracked { + return + } + defer func() { + trackedTunnelsLock.Lock() + delete(trackedTunnels, tunnelName) + trackedTunnelsLock.Unlock() + }() + + const serviceNotifications = serviceNotify_RUNNING | serviceNotify_START_PENDING | serviceNotify_STOP_PENDING | serviceNotify_STOPPED | serviceNotify_DELETE_PENDING notifier := &serviceNotify{ version: serviceNotify_STATUS_CHANGE, notifyCallback: serviceTrackerCallbackPtr, } - defer service.Close() + runtime.LockOSThread() lastState := TunnelUnknown for { ret := notifyServiceStatusChange(service.Handle, serviceNotifications, uintptr(unsafe.Pointer(notifier))) @@ -110,32 +143,25 @@ func trackTunnelService(tunnelName string, service *mgr.Service) { return } - state := TunnelUnknown + state := svcStateToTunState(svc.State(notifier.serviceStatus.currentState)) var tunnelError error - switch svc.State(notifier.serviceStatus.currentState) { - case svc.Stopped: - state = TunnelStopped + if state == TunnelStopped { if notifier.serviceStatus.win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) { maybeErr := Error(notifier.serviceStatus.serviceSpecificExitCode) if maybeErr != ErrorSuccess { tunnelError = maybeErr } - } else if notifier.serviceStatus.win32ExitCode != uint32(windows.NO_ERROR) { - tunnelError = syscall.Errno(notifier.serviceStatus.win32ExitCode) + } else { + switch notifier.serviceStatus.win32ExitCode { + case uint32(windows.NO_ERROR), serviceNEVER_STARTED: + default: + tunnelError = syscall.Errno(notifier.serviceStatus.win32ExitCode) + } } - case svc.StopPending: - state = TunnelStopping - case svc.Running: - state = TunnelStarted - case svc.StartPending: - state = TunnelStarting } if state != lastState { IPCServerNotifyTunnelChange(tunnelName, state, tunnelError) lastState = state } - if state == TunnelStopped { - return - } } } |