From 1b4a2a1e9704b747e399ccd430b918db38a2bfb6 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Thu, 3 Dec 2020 01:08:55 +0100 Subject: manager: use service subscriptions on win 8+ Work in progress, but this should be more reliable than the older Win 7 code. It's still unclear what the role of checkForDisabled is to be for the Win 8+ path. Signed-off-by: Jason A. Donenfeld --- manager/tunneltracker.go | 240 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 164 insertions(+), 76 deletions(-) (limited to 'manager/tunneltracker.go') diff --git a/manager/tunneltracker.go b/manager/tunneltracker.go index 3532932e..6d376e4a 100644 --- a/manager/tunneltracker.go +++ b/manager/tunneltracker.go @@ -6,12 +6,15 @@ package manager import ( + "errors" "fmt" "log" "runtime" "sync" + "sync/atomic" "syscall" "time" + "unsafe" "golang.org/x/sys/windows" "golang.org/x/sys/windows/svc" @@ -44,28 +47,9 @@ func trackExistingTunnels() error { return nil } -var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *windows.SERVICE_NOTIFY) uintptr { - return 0 -}) - var trackedTunnels = make(map[string]TunnelState) var trackedTunnelsLock = sync.Mutex{} -func svcStateToTunState(s svc.State) TunnelState { - switch s { - case svc.StartPending: - return TunnelStarting - case svc.Running: - return TunnelStarted - case svc.StopPending: - return TunnelStopping - case svc.Stopped: - return TunnelStopped - default: - return TunnelUnknown - } -} - func trackedTunnelsGlobalState() (state TunnelState) { state = TunnelStopped trackedTunnelsLock.Lock() @@ -82,50 +66,95 @@ func trackedTunnelsGlobalState() (state TunnelState) { return } -func trackTunnelService(tunnelName string, service *mgr.Service) { - defer func() { - service.Close() - log.Printf("[%s] Tunnel service tracker finished", tunnelName) - }() +var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *windows.SERVICE_NOTIFY) uintptr { + return 0 +}) - trackedTunnelsLock.Lock() - if _, found := trackedTunnels[tunnelName]; found { - trackedTunnelsLock.Unlock() - return +type serviceSubscriptionState struct { + service *mgr.Service + cb func(status uint32) bool + done sync.WaitGroup + once uint32 +} + +var serviceSubscriptionCallbackPtr = windows.NewCallback(func(notification uint32, context uintptr) uintptr { + state := (*serviceSubscriptionState)(unsafe.Pointer(context)) + if atomic.LoadUint32(&state.once) != 0 { + return 0 } - trackedTunnels[tunnelName] = TunnelUnknown - trackedTunnelsLock.Unlock() - defer func() { - trackedTunnelsLock.Lock() - delete(trackedTunnels, tunnelName) - trackedTunnelsLock.Unlock() - }() + if notification == 0 { + status, err := state.service.Query() + if err == nil { + notification = svcStateToNotifyState(uint32(status.State)) + } + } + if state.cb(notification) && atomic.CompareAndSwapUint32(&state.once, 0, 1) { + state.done.Done() + } + return 0 +}) - const serviceNotifications = windows.SERVICE_NOTIFY_RUNNING | windows.SERVICE_NOTIFY_START_PENDING | windows.SERVICE_NOTIFY_STOP_PENDING | windows.SERVICE_NOTIFY_STOPPED | windows.SERVICE_NOTIFY_DELETE_PENDING - notifier := &windows.SERVICE_NOTIFY{ - Version: windows.SERVICE_NOTIFY_STATUS_CHANGE, - NotifyCallback: serviceTrackerCallbackPtr, +func svcStateToNotifyState(s uint32) uint32 { + switch s { + case windows.SERVICE_STOPPED: + return windows.SERVICE_NOTIFY_STOPPED + case windows.SERVICE_START_PENDING: + return windows.SERVICE_NOTIFY_START_PENDING + case windows.SERVICE_STOP_PENDING: + return windows.SERVICE_NOTIFY_STOP_PENDING + case windows.SERVICE_RUNNING: + return windows.SERVICE_NOTIFY_RUNNING + case windows.SERVICE_CONTINUE_PENDING: + return windows.SERVICE_NOTIFY_CONTINUE_PENDING + case windows.SERVICE_PAUSE_PENDING: + return windows.SERVICE_NOTIFY_PAUSE_PENDING + case windows.SERVICE_PAUSED: + return windows.SERVICE_NOTIFY_PAUSED + case windows.SERVICE_NO_CHANGE: + return 0 + default: + return 0 } +} - checkForDisabled := func() (shouldReturn bool) { - config, err := service.Config() - if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || config.StartType == windows.SERVICE_DISABLED { - log.Printf("[%s] Found disabled service via timeout, so deleting", tunnelName) - service.Delete() - trackedTunnelsLock.Lock() - trackedTunnels[tunnelName] = TunnelStopped - trackedTunnelsLock.Unlock() - IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil) - return true +func notifyStateToTunState(s uint32) TunnelState { + if s&(windows.SERVICE_NOTIFY_STOPPED|windows.SERVICE_NOTIFY_DELETED) != 0 { + return TunnelStopped + } else if s&(windows.SERVICE_NOTIFY_DELETE_PENDING|windows.SERVICE_NOTIFY_STOP_PENDING) != 0 { + return TunnelStopping + } else if s&windows.SERVICE_NOTIFY_RUNNING != 0 { + return TunnelStarted + } else if s&windows.SERVICE_NOTIFY_START_PENDING != 0 { + return TunnelStarting + } else { + return TunnelUnknown + } +} + +func trackService(service *mgr.Service, callback func(status uint32) bool) error { + var subscription uintptr + state := &serviceSubscriptionState{service: service, cb: callback} + state.done.Add(1) + err := windows.SubscribeServiceChangeNotifications(service.Handle, windows.SC_EVENT_STATUS_CHANGE, serviceSubscriptionCallbackPtr, uintptr(unsafe.Pointer(state)), &subscription) + if err == nil { + defer windows.UnsubscribeServiceChangeNotifications(subscription) + status, err := service.Query() + if err == nil { + if callback(svcStateToNotifyState(uint32(status.State))) { + return nil + } } - return false + state.done.Wait() + runtime.KeepAlive(state.cb) + return nil } - if checkForDisabled() { - return + if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) { + return err } - runtime.LockOSThread() + // TODO: Below this line is Windows 7 compatibility code, which hopefully we can delete at some point. + runtime.LockOSThread() // This line would be fitting but is intentionally commented out: // // defer runtime.UnlockOSThread() @@ -134,7 +163,11 @@ func trackTunnelService(tunnelName string, service *mgr.Service) { // with the thread local context, which in turn appears to corrupt Go's own usage of TLS, // leading to crashes sometime later (usually in runtime_unlock()) when the thread is recycled. - lastState := TunnelUnknown + const serviceNotifications = windows.SERVICE_NOTIFY_RUNNING | windows.SERVICE_NOTIFY_START_PENDING | windows.SERVICE_NOTIFY_STOP_PENDING | windows.SERVICE_NOTIFY_STOPPED | windows.SERVICE_NOTIFY_DELETE_PENDING + notifier := &windows.SERVICE_NOTIFY{ + Version: windows.SERVICE_NOTIFY_STATUS_CHANGE, + NotifyCallback: serviceTrackerCallbackPtr, + } for { err := windows.NotifyServiceStatusChange(service.Handle, serviceNotifications, notifier) switch err { @@ -142,42 +175,86 @@ func trackTunnelService(tunnelName string, service *mgr.Service) { for { if windows.SleepEx(uint32(time.Second*3/time.Millisecond), true) == windows.WAIT_IO_COMPLETION { break - } else if checkForDisabled() { - return + } else if callback(0) { + return nil } } case windows.ERROR_SERVICE_MARKED_FOR_DELETE: - trackedTunnelsLock.Lock() - trackedTunnels[tunnelName] = TunnelStopped - trackedTunnelsLock.Unlock() - IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil) - return + // Should be SERVICE_NOTIFY_DELETE_PENDING, but actually, we must release the handle and return here; otherwise it never deletes. + if callback(windows.SERVICE_NOTIFY_DELETED) { + return nil + } case windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING: continue default: + return err + } + if callback(svcStateToNotifyState(notifier.ServiceStatus.CurrentState)) { + return nil + } + } +} + +func trackTunnelService(tunnelName string, service *mgr.Service) { + defer printPanic() + + defer func() { + service.Close() + log.Printf("[%s] Tunnel service tracker finished", tunnelName) + }() + + trackedTunnelsLock.Lock() + if _, found := trackedTunnels[tunnelName]; found { + trackedTunnelsLock.Unlock() + return + } + trackedTunnels[tunnelName] = TunnelUnknown + trackedTunnelsLock.Unlock() + defer func() { + trackedTunnelsLock.Lock() + delete(trackedTunnels, tunnelName) + trackedTunnelsLock.Unlock() + }() + + checkForDisabled := func() (shouldReturn bool) { + config, err := service.Config() + if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || (err != nil && config.StartType == windows.SERVICE_DISABLED) { + log.Printf("[%s] Found disabled service via timeout, so deleting", tunnelName) + service.Delete() trackedTunnelsLock.Lock() trackedTunnels[tunnelName] = TunnelStopped trackedTunnelsLock.Unlock() - IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring service, so stopping: %w", err)) - service.Control(svc.Stop) - return + IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil) + return true } - - state := svcStateToTunState(svc.State(notifier.ServiceStatus.CurrentState)) + return false + } + if checkForDisabled() { + return + } + lastState := TunnelUnknown + err := trackService(service, func(status uint32) bool { + state := notifyStateToTunState(status) var tunnelError error if state == TunnelStopped { - if notifier.ServiceStatus.Win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) { - maybeErr := services.Error(notifier.ServiceStatus.ServiceSpecificExitCode) - if maybeErr != services.ErrorSuccess { - tunnelError = maybeErr - } - } else { - switch notifier.ServiceStatus.Win32ExitCode { - case uint32(windows.NO_ERROR), uint32(windows.ERROR_SERVICE_NEVER_STARTED): - default: - tunnelError = syscall.Errno(notifier.ServiceStatus.Win32ExitCode) + serviceStatus, err := service.Query() + if err == nil { + if serviceStatus.Win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) { + maybeErr := services.Error(serviceStatus.ServiceSpecificExitCode) + if maybeErr != services.ErrorSuccess { + tunnelError = maybeErr + } + } else { + switch serviceStatus.Win32ExitCode { + case uint32(windows.NO_ERROR), uint32(windows.ERROR_SERVICE_NEVER_STARTED): + default: + tunnelError = syscall.Errno(serviceStatus.Win32ExitCode) + } } } + if tunnelError != nil { + service.Delete() + } } if state != lastState { trackedTunnelsLock.Lock() @@ -186,5 +263,16 @@ func trackTunnelService(tunnelName string, service *mgr.Service) { IPCServerNotifyTunnelChange(tunnelName, state, tunnelError) lastState = state } + if state == TunnelUnknown && checkForDisabled() { + return true + } + return state == TunnelStopped + }) + if err != nil && !checkForDisabled() { + trackedTunnelsLock.Lock() + trackedTunnels[tunnelName] = TunnelStopped + trackedTunnelsLock.Unlock() + IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring service, so stopping: %w", err)) + service.Control(svc.Stop) } } -- cgit v1.2.3-59-g8ed1b