aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/service
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--service/errors.go1
-rw-r--r--service/install.go7
-rw-r--r--service/tunneltracker.go62
3 files changed, 49 insertions, 21 deletions
diff --git a/service/errors.go b/service/errors.go
index ea6147a9..fd6bc6ab 100644
--- a/service/errors.go
+++ b/service/errors.go
@@ -95,4 +95,5 @@ func combineErrors(err error, serviceError Error) error {
const (
serviceDOES_NOT_EXIST uint32 = 0x00000424
serviceMARKED_FOR_DELETE uint32 = 0x00000430
+ serviceNEVER_STARTED uint32 = 0x00000435
)
diff --git a/service/install.go b/service/install.go
index 87ac002d..7f39b2bc 100644
--- a/service/install.go
+++ b/service/install.go
@@ -65,7 +65,7 @@ func InstallManager() error {
break
}
service.Close()
- time.Sleep(time.Second)
+ time.Sleep(time.Second / 3)
}
}
@@ -163,8 +163,9 @@ func InstallTunnel(configPath string) error {
if err != nil {
return err
}
- go trackTunnelService(name, service)
- return service.Start()
+ err = service.Start()
+ go trackTunnelService(name, service) // Pass off reference to handle.
+ return err
}
func UninstallTunnel(name string) error {
diff --git a/service/tunneltracker.go b/service/tunneltracker.go
index 402fd070..3cf9fde5 100644
--- a/service/tunneltracker.go
+++ b/service/tunneltracker.go
@@ -12,6 +12,7 @@ import (
"golang.org/x/sys/windows/svc/mgr"
"golang.zx2c4.com/wireguard/windows/conf"
"runtime"
+ "sync"
"syscall"
"unsafe"
)
@@ -84,15 +85,47 @@ var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *serviceNotify
return 0
})
+var trackedTunnels = make(map[string]bool)
+var trackedTunnelsLock = sync.Mutex{}
+
+func svcStateToTunState(s svc.State) TunnelState {
+ switch s {
+ case svc.StartPending:
+ return TunnelStarting
+ case svc.Running:
+ return TunnelStarted
+ case svc.StopPending:
+ return TunnelStopping
+ case svc.Stopped:
+ return TunnelStopped
+ default:
+ return TunnelUnknown
+ }
+}
+
func trackTunnelService(tunnelName string, service *mgr.Service) {
- runtime.LockOSThread()
- const serviceNotifications = serviceNotify_RUNNING | serviceNotify_START_PENDING | serviceNotify_STOP_PENDING | serviceNotify_STOPPED
+ defer service.Close()
+
+ trackedTunnelsLock.Lock()
+ _, isTracked := trackedTunnels[tunnelName]
+ trackedTunnels[tunnelName] = true
+ trackedTunnelsLock.Unlock()
+ if isTracked {
+ return
+ }
+ defer func() {
+ trackedTunnelsLock.Lock()
+ delete(trackedTunnels, tunnelName)
+ trackedTunnelsLock.Unlock()
+ }()
+
+ const serviceNotifications = serviceNotify_RUNNING | serviceNotify_START_PENDING | serviceNotify_STOP_PENDING | serviceNotify_STOPPED | serviceNotify_DELETE_PENDING
notifier := &serviceNotify{
version: serviceNotify_STATUS_CHANGE,
notifyCallback: serviceTrackerCallbackPtr,
}
- defer service.Close()
+ runtime.LockOSThread()
lastState := TunnelUnknown
for {
ret := notifyServiceStatusChange(service.Handle, serviceNotifications, uintptr(unsafe.Pointer(notifier)))
@@ -110,32 +143,25 @@ func trackTunnelService(tunnelName string, service *mgr.Service) {
return
}
- state := TunnelUnknown
+ state := svcStateToTunState(svc.State(notifier.serviceStatus.currentState))
var tunnelError error
- switch svc.State(notifier.serviceStatus.currentState) {
- case svc.Stopped:
- state = TunnelStopped
+ if state == TunnelStopped {
if notifier.serviceStatus.win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) {
maybeErr := Error(notifier.serviceStatus.serviceSpecificExitCode)
if maybeErr != ErrorSuccess {
tunnelError = maybeErr
}
- } else if notifier.serviceStatus.win32ExitCode != uint32(windows.NO_ERROR) {
- tunnelError = syscall.Errno(notifier.serviceStatus.win32ExitCode)
+ } else {
+ switch notifier.serviceStatus.win32ExitCode {
+ case uint32(windows.NO_ERROR), serviceNEVER_STARTED:
+ default:
+ tunnelError = syscall.Errno(notifier.serviceStatus.win32ExitCode)
+ }
}
- case svc.StopPending:
- state = TunnelStopping
- case svc.Running:
- state = TunnelStarted
- case svc.StartPending:
- state = TunnelStarting
}
if state != lastState {
IPCServerNotifyTunnelChange(tunnelName, state, tunnelError)
lastState = state
}
- if state == TunnelStopped {
- return
- }
}
}