aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/service/tunneltracker.go
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--service/tunneltracker.go62
1 files changed, 44 insertions, 18 deletions
diff --git a/service/tunneltracker.go b/service/tunneltracker.go
index 402fd070..3cf9fde5 100644
--- a/service/tunneltracker.go
+++ b/service/tunneltracker.go
@@ -12,6 +12,7 @@ import (
"golang.org/x/sys/windows/svc/mgr"
"golang.zx2c4.com/wireguard/windows/conf"
"runtime"
+ "sync"
"syscall"
"unsafe"
)
@@ -84,15 +85,47 @@ var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *serviceNotify
return 0
})
+var trackedTunnels = make(map[string]bool)
+var trackedTunnelsLock = sync.Mutex{}
+
+func svcStateToTunState(s svc.State) TunnelState {
+ switch s {
+ case svc.StartPending:
+ return TunnelStarting
+ case svc.Running:
+ return TunnelStarted
+ case svc.StopPending:
+ return TunnelStopping
+ case svc.Stopped:
+ return TunnelStopped
+ default:
+ return TunnelUnknown
+ }
+}
+
func trackTunnelService(tunnelName string, service *mgr.Service) {
- runtime.LockOSThread()
- const serviceNotifications = serviceNotify_RUNNING | serviceNotify_START_PENDING | serviceNotify_STOP_PENDING | serviceNotify_STOPPED
+ defer service.Close()
+
+ trackedTunnelsLock.Lock()
+ _, isTracked := trackedTunnels[tunnelName]
+ trackedTunnels[tunnelName] = true
+ trackedTunnelsLock.Unlock()
+ if isTracked {
+ return
+ }
+ defer func() {
+ trackedTunnelsLock.Lock()
+ delete(trackedTunnels, tunnelName)
+ trackedTunnelsLock.Unlock()
+ }()
+
+ const serviceNotifications = serviceNotify_RUNNING | serviceNotify_START_PENDING | serviceNotify_STOP_PENDING | serviceNotify_STOPPED | serviceNotify_DELETE_PENDING
notifier := &serviceNotify{
version: serviceNotify_STATUS_CHANGE,
notifyCallback: serviceTrackerCallbackPtr,
}
- defer service.Close()
+ runtime.LockOSThread()
lastState := TunnelUnknown
for {
ret := notifyServiceStatusChange(service.Handle, serviceNotifications, uintptr(unsafe.Pointer(notifier)))
@@ -110,32 +143,25 @@ func trackTunnelService(tunnelName string, service *mgr.Service) {
return
}
- state := TunnelUnknown
+ state := svcStateToTunState(svc.State(notifier.serviceStatus.currentState))
var tunnelError error
- switch svc.State(notifier.serviceStatus.currentState) {
- case svc.Stopped:
- state = TunnelStopped
+ if state == TunnelStopped {
if notifier.serviceStatus.win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) {
maybeErr := Error(notifier.serviceStatus.serviceSpecificExitCode)
if maybeErr != ErrorSuccess {
tunnelError = maybeErr
}
- } else if notifier.serviceStatus.win32ExitCode != uint32(windows.NO_ERROR) {
- tunnelError = syscall.Errno(notifier.serviceStatus.win32ExitCode)
+ } else {
+ switch notifier.serviceStatus.win32ExitCode {
+ case uint32(windows.NO_ERROR), serviceNEVER_STARTED:
+ default:
+ tunnelError = syscall.Errno(notifier.serviceStatus.win32ExitCode)
+ }
}
- case svc.StopPending:
- state = TunnelStopping
- case svc.Running:
- state = TunnelStarted
- case svc.StartPending:
- state = TunnelStarting
}
if state != lastState {
IPCServerNotifyTunnelChange(tunnelName, state, tunnelError)
lastState = state
}
- if state == TunnelStopped {
- return
- }
}
}