aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2020-12-03 01:08:55 +0100
committerJason A. Donenfeld <Jason@zx2c4.com>2020-12-09 16:01:47 +0100
commit1b4a2a1e9704b747e399ccd430b918db38a2bfb6 (patch)
tree62c777713e3a85444af6213d42dd5f004971d8a1
parentmod: bump (diff)
downloadwireguard-windows-1b4a2a1e9704b747e399ccd430b918db38a2bfb6.tar.xz
wireguard-windows-1b4a2a1e9704b747e399ccd430b918db38a2bfb6.zip
manager: use service subscriptions on win 8+
Work in progress, but this should be more reliable than the older Win 7 code. It's still unclear what the role of checkForDisabled is to be for the Win 8+ path. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r--manager/tunneltracker.go240
1 files changed, 164 insertions, 76 deletions
diff --git a/manager/tunneltracker.go b/manager/tunneltracker.go
index 3532932e..6d376e4a 100644
--- a/manager/tunneltracker.go
+++ b/manager/tunneltracker.go
@@ -6,12 +6,15 @@
package manager
import (
+ "errors"
"fmt"
"log"
"runtime"
"sync"
+ "sync/atomic"
"syscall"
"time"
+ "unsafe"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
@@ -44,28 +47,9 @@ func trackExistingTunnels() error {
return nil
}
-var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *windows.SERVICE_NOTIFY) uintptr {
- return 0
-})
-
var trackedTunnels = make(map[string]TunnelState)
var trackedTunnelsLock = sync.Mutex{}
-func svcStateToTunState(s svc.State) TunnelState {
- switch s {
- case svc.StartPending:
- return TunnelStarting
- case svc.Running:
- return TunnelStarted
- case svc.StopPending:
- return TunnelStopping
- case svc.Stopped:
- return TunnelStopped
- default:
- return TunnelUnknown
- }
-}
-
func trackedTunnelsGlobalState() (state TunnelState) {
state = TunnelStopped
trackedTunnelsLock.Lock()
@@ -82,50 +66,95 @@ func trackedTunnelsGlobalState() (state TunnelState) {
return
}
-func trackTunnelService(tunnelName string, service *mgr.Service) {
- defer func() {
- service.Close()
- log.Printf("[%s] Tunnel service tracker finished", tunnelName)
- }()
+var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *windows.SERVICE_NOTIFY) uintptr {
+ return 0
+})
- trackedTunnelsLock.Lock()
- if _, found := trackedTunnels[tunnelName]; found {
- trackedTunnelsLock.Unlock()
- return
+type serviceSubscriptionState struct {
+ service *mgr.Service
+ cb func(status uint32) bool
+ done sync.WaitGroup
+ once uint32
+}
+
+var serviceSubscriptionCallbackPtr = windows.NewCallback(func(notification uint32, context uintptr) uintptr {
+ state := (*serviceSubscriptionState)(unsafe.Pointer(context))
+ if atomic.LoadUint32(&state.once) != 0 {
+ return 0
}
- trackedTunnels[tunnelName] = TunnelUnknown
- trackedTunnelsLock.Unlock()
- defer func() {
- trackedTunnelsLock.Lock()
- delete(trackedTunnels, tunnelName)
- trackedTunnelsLock.Unlock()
- }()
+ if notification == 0 {
+ status, err := state.service.Query()
+ if err == nil {
+ notification = svcStateToNotifyState(uint32(status.State))
+ }
+ }
+ if state.cb(notification) && atomic.CompareAndSwapUint32(&state.once, 0, 1) {
+ state.done.Done()
+ }
+ return 0
+})
- const serviceNotifications = windows.SERVICE_NOTIFY_RUNNING | windows.SERVICE_NOTIFY_START_PENDING | windows.SERVICE_NOTIFY_STOP_PENDING | windows.SERVICE_NOTIFY_STOPPED | windows.SERVICE_NOTIFY_DELETE_PENDING
- notifier := &windows.SERVICE_NOTIFY{
- Version: windows.SERVICE_NOTIFY_STATUS_CHANGE,
- NotifyCallback: serviceTrackerCallbackPtr,
+func svcStateToNotifyState(s uint32) uint32 {
+ switch s {
+ case windows.SERVICE_STOPPED:
+ return windows.SERVICE_NOTIFY_STOPPED
+ case windows.SERVICE_START_PENDING:
+ return windows.SERVICE_NOTIFY_START_PENDING
+ case windows.SERVICE_STOP_PENDING:
+ return windows.SERVICE_NOTIFY_STOP_PENDING
+ case windows.SERVICE_RUNNING:
+ return windows.SERVICE_NOTIFY_RUNNING
+ case windows.SERVICE_CONTINUE_PENDING:
+ return windows.SERVICE_NOTIFY_CONTINUE_PENDING
+ case windows.SERVICE_PAUSE_PENDING:
+ return windows.SERVICE_NOTIFY_PAUSE_PENDING
+ case windows.SERVICE_PAUSED:
+ return windows.SERVICE_NOTIFY_PAUSED
+ case windows.SERVICE_NO_CHANGE:
+ return 0
+ default:
+ return 0
}
+}
- checkForDisabled := func() (shouldReturn bool) {
- config, err := service.Config()
- if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || config.StartType == windows.SERVICE_DISABLED {
- log.Printf("[%s] Found disabled service via timeout, so deleting", tunnelName)
- service.Delete()
- trackedTunnelsLock.Lock()
- trackedTunnels[tunnelName] = TunnelStopped
- trackedTunnelsLock.Unlock()
- IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil)
- return true
+func notifyStateToTunState(s uint32) TunnelState {
+ if s&(windows.SERVICE_NOTIFY_STOPPED|windows.SERVICE_NOTIFY_DELETED) != 0 {
+ return TunnelStopped
+ } else if s&(windows.SERVICE_NOTIFY_DELETE_PENDING|windows.SERVICE_NOTIFY_STOP_PENDING) != 0 {
+ return TunnelStopping
+ } else if s&windows.SERVICE_NOTIFY_RUNNING != 0 {
+ return TunnelStarted
+ } else if s&windows.SERVICE_NOTIFY_START_PENDING != 0 {
+ return TunnelStarting
+ } else {
+ return TunnelUnknown
+ }
+}
+
+func trackService(service *mgr.Service, callback func(status uint32) bool) error {
+ var subscription uintptr
+ state := &serviceSubscriptionState{service: service, cb: callback}
+ state.done.Add(1)
+ err := windows.SubscribeServiceChangeNotifications(service.Handle, windows.SC_EVENT_STATUS_CHANGE, serviceSubscriptionCallbackPtr, uintptr(unsafe.Pointer(state)), &subscription)
+ if err == nil {
+ defer windows.UnsubscribeServiceChangeNotifications(subscription)
+ status, err := service.Query()
+ if err == nil {
+ if callback(svcStateToNotifyState(uint32(status.State))) {
+ return nil
+ }
}
- return false
+ state.done.Wait()
+ runtime.KeepAlive(state.cb)
+ return nil
}
- if checkForDisabled() {
- return
+ if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) {
+ return err
}
- runtime.LockOSThread()
+ // TODO: Below this line is Windows 7 compatibility code, which hopefully we can delete at some point.
+ runtime.LockOSThread()
// This line would be fitting but is intentionally commented out:
//
// defer runtime.UnlockOSThread()
@@ -134,7 +163,11 @@ func trackTunnelService(tunnelName string, service *mgr.Service) {
// with the thread local context, which in turn appears to corrupt Go's own usage of TLS,
// leading to crashes sometime later (usually in runtime_unlock()) when the thread is recycled.
- lastState := TunnelUnknown
+ const serviceNotifications = windows.SERVICE_NOTIFY_RUNNING | windows.SERVICE_NOTIFY_START_PENDING | windows.SERVICE_NOTIFY_STOP_PENDING | windows.SERVICE_NOTIFY_STOPPED | windows.SERVICE_NOTIFY_DELETE_PENDING
+ notifier := &windows.SERVICE_NOTIFY{
+ Version: windows.SERVICE_NOTIFY_STATUS_CHANGE,
+ NotifyCallback: serviceTrackerCallbackPtr,
+ }
for {
err := windows.NotifyServiceStatusChange(service.Handle, serviceNotifications, notifier)
switch err {
@@ -142,42 +175,86 @@ func trackTunnelService(tunnelName string, service *mgr.Service) {
for {
if windows.SleepEx(uint32(time.Second*3/time.Millisecond), true) == windows.WAIT_IO_COMPLETION {
break
- } else if checkForDisabled() {
- return
+ } else if callback(0) {
+ return nil
}
}
case windows.ERROR_SERVICE_MARKED_FOR_DELETE:
- trackedTunnelsLock.Lock()
- trackedTunnels[tunnelName] = TunnelStopped
- trackedTunnelsLock.Unlock()
- IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil)
- return
+ // Should be SERVICE_NOTIFY_DELETE_PENDING, but actually, we must release the handle and return here; otherwise it never deletes.
+ if callback(windows.SERVICE_NOTIFY_DELETED) {
+ return nil
+ }
case windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING:
continue
default:
+ return err
+ }
+ if callback(svcStateToNotifyState(notifier.ServiceStatus.CurrentState)) {
+ return nil
+ }
+ }
+}
+
+func trackTunnelService(tunnelName string, service *mgr.Service) {
+ defer printPanic()
+
+ defer func() {
+ service.Close()
+ log.Printf("[%s] Tunnel service tracker finished", tunnelName)
+ }()
+
+ trackedTunnelsLock.Lock()
+ if _, found := trackedTunnels[tunnelName]; found {
+ trackedTunnelsLock.Unlock()
+ return
+ }
+ trackedTunnels[tunnelName] = TunnelUnknown
+ trackedTunnelsLock.Unlock()
+ defer func() {
+ trackedTunnelsLock.Lock()
+ delete(trackedTunnels, tunnelName)
+ trackedTunnelsLock.Unlock()
+ }()
+
+ checkForDisabled := func() (shouldReturn bool) {
+ config, err := service.Config()
+ if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || (err != nil && config.StartType == windows.SERVICE_DISABLED) {
+ log.Printf("[%s] Found disabled service via timeout, so deleting", tunnelName)
+ service.Delete()
trackedTunnelsLock.Lock()
trackedTunnels[tunnelName] = TunnelStopped
trackedTunnelsLock.Unlock()
- IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring service, so stopping: %w", err))
- service.Control(svc.Stop)
- return
+ IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil)
+ return true
}
-
- state := svcStateToTunState(svc.State(notifier.ServiceStatus.CurrentState))
+ return false
+ }
+ if checkForDisabled() {
+ return
+ }
+ lastState := TunnelUnknown
+ err := trackService(service, func(status uint32) bool {
+ state := notifyStateToTunState(status)
var tunnelError error
if state == TunnelStopped {
- if notifier.ServiceStatus.Win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) {
- maybeErr := services.Error(notifier.ServiceStatus.ServiceSpecificExitCode)
- if maybeErr != services.ErrorSuccess {
- tunnelError = maybeErr
- }
- } else {
- switch notifier.ServiceStatus.Win32ExitCode {
- case uint32(windows.NO_ERROR), uint32(windows.ERROR_SERVICE_NEVER_STARTED):
- default:
- tunnelError = syscall.Errno(notifier.ServiceStatus.Win32ExitCode)
+ serviceStatus, err := service.Query()
+ if err == nil {
+ if serviceStatus.Win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) {
+ maybeErr := services.Error(serviceStatus.ServiceSpecificExitCode)
+ if maybeErr != services.ErrorSuccess {
+ tunnelError = maybeErr
+ }
+ } else {
+ switch serviceStatus.Win32ExitCode {
+ case uint32(windows.NO_ERROR), uint32(windows.ERROR_SERVICE_NEVER_STARTED):
+ default:
+ tunnelError = syscall.Errno(serviceStatus.Win32ExitCode)
+ }
}
}
+ if tunnelError != nil {
+ service.Delete()
+ }
}
if state != lastState {
trackedTunnelsLock.Lock()
@@ -186,5 +263,16 @@ func trackTunnelService(tunnelName string, service *mgr.Service) {
IPCServerNotifyTunnelChange(tunnelName, state, tunnelError)
lastState = state
}
+ if state == TunnelUnknown && checkForDisabled() {
+ return true
+ }
+ return state == TunnelStopped
+ })
+ if err != nil && !checkForDisabled() {
+ trackedTunnelsLock.Lock()
+ trackedTunnels[tunnelName] = TunnelStopped
+ trackedTunnelsLock.Unlock()
+ IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring service, so stopping: %w", err))
+ service.Control(svc.Stop)
}
}