diff options
Diffstat (limited to '')
-rw-r--r-- | manager/tunneltracker.go | 357 |
1 files changed, 255 insertions, 102 deletions
diff --git a/manager/tunneltracker.go b/manager/tunneltracker.go index 0f222aac..9003d445 100644 --- a/manager/tunneltracker.go +++ b/manager/tunneltracker.go @@ -1,17 +1,20 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package manager import ( + "errors" "fmt" "log" "runtime" "sync" + "sync/atomic" "syscall" "time" + "unsafe" "golang.org/x/sys/windows" "golang.org/x/sys/windows/svc" @@ -21,50 +24,10 @@ import ( "golang.zx2c4.com/wireguard/windows/services" ) -func trackExistingTunnels() error { - m, err := serviceManager() - if err != nil { - return err - } - names, err := conf.ListConfigNames() - if err != nil { - return err - } - for _, name := range names { - serviceName, err := services.ServiceNameOfTunnel(name) - if err != nil { - continue - } - service, err := m.OpenService(serviceName) - if err != nil { - continue - } - go trackTunnelService(name, service) - } - return nil -} - -var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *windows.SERVICE_NOTIFY) uintptr { - return 0 -}) - -var trackedTunnels = make(map[string]TunnelState) -var trackedTunnelsLock = sync.Mutex{} - -func svcStateToTunState(s svc.State) TunnelState { - switch s { - case svc.StartPending: - return TunnelStarting - case svc.Running: - return TunnelStarted - case svc.StopPending: - return TunnelStopping - case svc.Stopped: - return TunnelStopped - default: - return TunnelUnknown - } -} +var ( + trackedTunnels = make(map[string]TunnelState) + trackedTunnelsLock = sync.Mutex{} +) func trackedTunnelsGlobalState() (state TunnelState) { state = TunnelStopped @@ -82,50 +45,95 @@ func trackedTunnelsGlobalState() (state TunnelState) { return } -func trackTunnelService(tunnelName string, service *mgr.Service) { - defer func() { - service.Close() - log.Printf("[%s] Tunnel service tracker finished", tunnelName) - }() +var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *windows.SERVICE_NOTIFY) uintptr { + return 0 +}) - trackedTunnelsLock.Lock() - if _, found := trackedTunnels[tunnelName]; found { - trackedTunnelsLock.Unlock() - return +type serviceSubscriptionState struct { + service *mgr.Service + cb func(status uint32) bool + done sync.WaitGroup + once uint32 +} + +var serviceSubscriptionCallbackPtr = windows.NewCallback(func(notification uint32, context uintptr) uintptr { + state := (*serviceSubscriptionState)(unsafe.Pointer(context)) + if atomic.LoadUint32(&state.once) != 0 { + return 0 } - trackedTunnels[tunnelName] = TunnelUnknown - trackedTunnelsLock.Unlock() - defer func() { - trackedTunnelsLock.Lock() - delete(trackedTunnels, tunnelName) - trackedTunnelsLock.Unlock() - }() + if notification == 0 { + status, err := state.service.Query() + if err == nil { + notification = svcStateToNotifyState(uint32(status.State)) + } + } + if state.cb(notification) && atomic.CompareAndSwapUint32(&state.once, 0, 1) { + state.done.Done() + } + return 0 +}) - const serviceNotifications = windows.SERVICE_NOTIFY_RUNNING | windows.SERVICE_NOTIFY_START_PENDING | windows.SERVICE_NOTIFY_STOP_PENDING | windows.SERVICE_NOTIFY_STOPPED | windows.SERVICE_NOTIFY_DELETE_PENDING - notifier := &windows.SERVICE_NOTIFY{ - Version: windows.SERVICE_NOTIFY_STATUS_CHANGE, - NotifyCallback: serviceTrackerCallbackPtr, +func svcStateToNotifyState(s uint32) uint32 { + switch s { + case windows.SERVICE_STOPPED: + return windows.SERVICE_NOTIFY_STOPPED + case windows.SERVICE_START_PENDING: + return windows.SERVICE_NOTIFY_START_PENDING + case windows.SERVICE_STOP_PENDING: + return windows.SERVICE_NOTIFY_STOP_PENDING + case windows.SERVICE_RUNNING: + return windows.SERVICE_NOTIFY_RUNNING + case windows.SERVICE_CONTINUE_PENDING: + return windows.SERVICE_NOTIFY_CONTINUE_PENDING + case windows.SERVICE_PAUSE_PENDING: + return windows.SERVICE_NOTIFY_PAUSE_PENDING + case windows.SERVICE_PAUSED: + return windows.SERVICE_NOTIFY_PAUSED + case windows.SERVICE_NO_CHANGE: + return 0 + default: + return 0 } +} - checkForDisabled := func() (shouldReturn bool) { - config, err := service.Config() - if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || config.StartType == windows.SERVICE_DISABLED { - log.Printf("[%s] Found disabled service via timeout, so deleting", tunnelName) - service.Delete() - trackedTunnelsLock.Lock() - trackedTunnels[tunnelName] = TunnelStopped - trackedTunnelsLock.Unlock() - IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil) - return true +func notifyStateToTunState(s uint32) TunnelState { + if s&(windows.SERVICE_NOTIFY_STOPPED|windows.SERVICE_NOTIFY_DELETED) != 0 { + return TunnelStopped + } else if s&(windows.SERVICE_NOTIFY_DELETE_PENDING|windows.SERVICE_NOTIFY_STOP_PENDING) != 0 { + return TunnelStopping + } else if s&windows.SERVICE_NOTIFY_RUNNING != 0 { + return TunnelStarted + } else if s&windows.SERVICE_NOTIFY_START_PENDING != 0 { + return TunnelStarting + } else { + return TunnelUnknown + } +} + +func trackService(service *mgr.Service, callback func(status uint32) bool) error { + var subscription uintptr + state := &serviceSubscriptionState{service: service, cb: callback} + state.done.Add(1) + err := windows.SubscribeServiceChangeNotifications(service.Handle, windows.SC_EVENT_STATUS_CHANGE, serviceSubscriptionCallbackPtr, uintptr(unsafe.Pointer(state)), &subscription) + if err == nil { + defer windows.UnsubscribeServiceChangeNotifications(subscription) + status, err := service.Query() + if err == nil { + if callback(svcStateToNotifyState(uint32(status.State))) { + return nil + } } - return false + state.done.Wait() + runtime.KeepAlive(state.cb) + return nil } - if checkForDisabled() { - return + if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) { + return err } - runtime.LockOSThread() + // TODO: Below this line is Windows 7 compatibility code, which hopefully we can delete at some point. + runtime.LockOSThread() // This line would be fitting but is intentionally commented out: // // defer runtime.UnlockOSThread() @@ -134,7 +142,11 @@ func trackTunnelService(tunnelName string, service *mgr.Service) { // with the thread local context, which in turn appears to corrupt Go's own usage of TLS, // leading to crashes sometime later (usually in runtime_unlock()) when the thread is recycled. - lastState := TunnelUnknown + const serviceNotifications = windows.SERVICE_NOTIFY_RUNNING | windows.SERVICE_NOTIFY_START_PENDING | windows.SERVICE_NOTIFY_STOP_PENDING | windows.SERVICE_NOTIFY_STOPPED | windows.SERVICE_NOTIFY_DELETE_PENDING + notifier := &windows.SERVICE_NOTIFY{ + Version: windows.SERVICE_NOTIFY_STATUS_CHANGE, + NotifyCallback: serviceTrackerCallbackPtr, + } for { err := windows.NotifyServiceStatusChange(service.Handle, serviceNotifications, notifier) switch err { @@ -142,42 +154,94 @@ func trackTunnelService(tunnelName string, service *mgr.Service) { for { if windows.SleepEx(uint32(time.Second*3/time.Millisecond), true) == windows.WAIT_IO_COMPLETION { break - } else if checkForDisabled() { - return + } else if callback(0) { + return nil } } case windows.ERROR_SERVICE_MARKED_FOR_DELETE: - trackedTunnelsLock.Lock() - trackedTunnels[tunnelName] = TunnelStopped - trackedTunnelsLock.Unlock() - IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil) - return + // Should be SERVICE_NOTIFY_DELETE_PENDING, but actually, we must release the handle and return here; otherwise it never deletes. + if callback(windows.SERVICE_NOTIFY_DELETED) { + return nil + } case windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING: continue default: + return err + } + if callback(svcStateToNotifyState(notifier.ServiceStatus.CurrentState)) { + return nil + } + } +} + +func trackTunnelService(tunnelName string, service *mgr.Service) { + trackedTunnelsLock.Lock() + if _, found := trackedTunnels[tunnelName]; found { + trackedTunnelsLock.Unlock() + service.Close() + return + } + + defer func() { + service.Close() + log.Printf("[%s] Tunnel service tracker finished", tunnelName) + }() + trackedTunnels[tunnelName] = TunnelUnknown + trackedTunnelsLock.Unlock() + defer func() { + trackedTunnelsLock.Lock() + delete(trackedTunnels, tunnelName) + trackedTunnelsLock.Unlock() + }() + + for i := 0; i < 20; i++ { + if i > 0 { + time.Sleep(time.Second / 5) + } + if status, err := service.Query(); err != nil || status.State != svc.Stopped { + break + } + } + + checkForDisabled := func() (shouldReturn bool) { + config, err := service.Config() + if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || (err != nil && config.StartType == windows.SERVICE_DISABLED) { + log.Printf("[%s] Found disabled service via timeout, so deleting", tunnelName) + service.Delete() trackedTunnelsLock.Lock() trackedTunnels[tunnelName] = TunnelStopped trackedTunnelsLock.Unlock() - IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring service, so stopping: %v", err)) - service.Control(svc.Stop) - return + IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil) + return true } - - state := svcStateToTunState(svc.State(notifier.ServiceStatus.CurrentState)) + return false + } + if checkForDisabled() { + return + } + lastState := TunnelUnknown + err := trackService(service, func(status uint32) bool { + state := notifyStateToTunState(status) var tunnelError error if state == TunnelStopped { - if notifier.ServiceStatus.Win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) { - maybeErr := services.Error(notifier.ServiceStatus.ServiceSpecificExitCode) - if maybeErr != services.ErrorSuccess { - tunnelError = maybeErr - } - } else { - switch notifier.ServiceStatus.Win32ExitCode { - case uint32(windows.NO_ERROR), uint32(windows.ERROR_SERVICE_NEVER_STARTED): - default: - tunnelError = syscall.Errno(notifier.ServiceStatus.Win32ExitCode) + serviceStatus, err := service.Query() + if err == nil { + if serviceStatus.Win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) { + maybeErr := services.Error(serviceStatus.ServiceSpecificExitCode) + if maybeErr != services.ErrorSuccess { + tunnelError = maybeErr + } + } else { + switch serviceStatus.Win32ExitCode { + case uint32(windows.NO_ERROR), uint32(windows.ERROR_SERVICE_NEVER_STARTED): + default: + tunnelError = syscall.Errno(serviceStatus.Win32ExitCode) + } } } + if tunnelError != nil { + service.Delete() + } } if state != lastState { trackedTunnelsLock.Lock() @@ -186,5 +250,94 @@ func trackTunnelService(tunnelName string, service *mgr.Service) { IPCServerNotifyTunnelChange(tunnelName, state, tunnelError) lastState = state } + if state == TunnelUnknown && checkForDisabled() { + return true + } + return state == TunnelStopped + }) + if err != nil && !checkForDisabled() { + trackedTunnelsLock.Lock() + trackedTunnels[tunnelName] = TunnelStopped + trackedTunnelsLock.Unlock() + IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring service, so stopping: %w", err)) + service.Control(svc.Stop) + } +} + +func trackExistingTunnels() error { + m, err := serviceManager() + if err != nil { + return err + } + names, err := conf.ListConfigNames() + if err != nil { + return err + } + for _, name := range names { + trackedTunnelsLock.Lock() + if _, found := trackedTunnels[name]; found { + trackedTunnelsLock.Unlock() + continue + } + trackedTunnelsLock.Unlock() + serviceName, err := conf.ServiceNameOfTunnel(name) + if err != nil { + continue + } + service, err := m.OpenService(serviceName) + if err != nil { + continue + } + go trackTunnelService(name, service) } + return nil +} + +var servicesSubscriptionWatcherCallbackPtr = windows.NewCallback(func(notification uint32, context uintptr) uintptr { + trackExistingTunnels() + return 0 +}) + +func watchNewTunnelServices() error { + m, err := serviceManager() + if err != nil { + return err + } + var subscription uintptr + err = windows.SubscribeServiceChangeNotifications(m.Handle, windows.SC_EVENT_DATABASE_CHANGE, servicesSubscriptionWatcherCallbackPtr, 0, &subscription) + if err == nil { + // We probably could do: + // defer windows.UnsubscribeServiceChangeNotifications(subscription) + // and then terminate after some point, but instead we just let this go forever; it's process-lived. + return trackExistingTunnels() + } + if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) { + return err + } + + // TODO: Below this line is Windows 7 compatibility code, which hopefully we can delete at some point. + go func() { + runtime.LockOSThread() + notifier := &windows.SERVICE_NOTIFY{ + Version: windows.SERVICE_NOTIFY_STATUS_CHANGE, + NotifyCallback: serviceTrackerCallbackPtr, + } + for { + err := windows.NotifyServiceStatusChange(m.Handle, windows.SERVICE_NOTIFY_CREATED, notifier) + if err == nil { + windows.SleepEx(windows.INFINITE, true) + if notifier.ServiceNames != nil { + windows.LocalFree(windows.Handle(unsafe.Pointer(notifier.ServiceNames))) + notifier.ServiceNames = nil + } + trackExistingTunnels() + } else if err == windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING { + continue + } else { + time.Sleep(time.Second * 3) + trackExistingTunnels() + } + } + }() + return trackExistingTunnels() } |