aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--service/errors.go1
-rw-r--r--service/install.go7
-rw-r--r--service/tunneltracker.go62
-rw-r--r--ui/ui.go18
4 files changed, 58 insertions, 30 deletions
diff --git a/service/errors.go b/service/errors.go
index ea6147a9..fd6bc6ab 100644
--- a/service/errors.go
+++ b/service/errors.go
@@ -95,4 +95,5 @@ func combineErrors(err error, serviceError Error) error {
const (
serviceDOES_NOT_EXIST uint32 = 0x00000424
serviceMARKED_FOR_DELETE uint32 = 0x00000430
+ serviceNEVER_STARTED uint32 = 0x00000435
)
diff --git a/service/install.go b/service/install.go
index 87ac002d..7f39b2bc 100644
--- a/service/install.go
+++ b/service/install.go
@@ -65,7 +65,7 @@ func InstallManager() error {
break
}
service.Close()
- time.Sleep(time.Second)
+ time.Sleep(time.Second / 3)
}
}
@@ -163,8 +163,9 @@ func InstallTunnel(configPath string) error {
if err != nil {
return err
}
- go trackTunnelService(name, service)
- return service.Start()
+ err = service.Start()
+ go trackTunnelService(name, service) // Pass off reference to handle.
+ return err
}
func UninstallTunnel(name string) error {
diff --git a/service/tunneltracker.go b/service/tunneltracker.go
index 402fd070..3cf9fde5 100644
--- a/service/tunneltracker.go
+++ b/service/tunneltracker.go
@@ -12,6 +12,7 @@ import (
"golang.org/x/sys/windows/svc/mgr"
"golang.zx2c4.com/wireguard/windows/conf"
"runtime"
+ "sync"
"syscall"
"unsafe"
)
@@ -84,15 +85,47 @@ var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *serviceNotify
return 0
})
+var trackedTunnels = make(map[string]bool)
+var trackedTunnelsLock = sync.Mutex{}
+
+func svcStateToTunState(s svc.State) TunnelState {
+ switch s {
+ case svc.StartPending:
+ return TunnelStarting
+ case svc.Running:
+ return TunnelStarted
+ case svc.StopPending:
+ return TunnelStopping
+ case svc.Stopped:
+ return TunnelStopped
+ default:
+ return TunnelUnknown
+ }
+}
+
func trackTunnelService(tunnelName string, service *mgr.Service) {
- runtime.LockOSThread()
- const serviceNotifications = serviceNotify_RUNNING | serviceNotify_START_PENDING | serviceNotify_STOP_PENDING | serviceNotify_STOPPED
+ defer service.Close()
+
+ trackedTunnelsLock.Lock()
+ _, isTracked := trackedTunnels[tunnelName]
+ trackedTunnels[tunnelName] = true
+ trackedTunnelsLock.Unlock()
+ if isTracked {
+ return
+ }
+ defer func() {
+ trackedTunnelsLock.Lock()
+ delete(trackedTunnels, tunnelName)
+ trackedTunnelsLock.Unlock()
+ }()
+
+ const serviceNotifications = serviceNotify_RUNNING | serviceNotify_START_PENDING | serviceNotify_STOP_PENDING | serviceNotify_STOPPED | serviceNotify_DELETE_PENDING
notifier := &serviceNotify{
version: serviceNotify_STATUS_CHANGE,
notifyCallback: serviceTrackerCallbackPtr,
}
- defer service.Close()
+ runtime.LockOSThread()
lastState := TunnelUnknown
for {
ret := notifyServiceStatusChange(service.Handle, serviceNotifications, uintptr(unsafe.Pointer(notifier)))
@@ -110,32 +143,25 @@ func trackTunnelService(tunnelName string, service *mgr.Service) {
return
}
- state := TunnelUnknown
+ state := svcStateToTunState(svc.State(notifier.serviceStatus.currentState))
var tunnelError error
- switch svc.State(notifier.serviceStatus.currentState) {
- case svc.Stopped:
- state = TunnelStopped
+ if state == TunnelStopped {
if notifier.serviceStatus.win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) {
maybeErr := Error(notifier.serviceStatus.serviceSpecificExitCode)
if maybeErr != ErrorSuccess {
tunnelError = maybeErr
}
- } else if notifier.serviceStatus.win32ExitCode != uint32(windows.NO_ERROR) {
- tunnelError = syscall.Errno(notifier.serviceStatus.win32ExitCode)
+ } else {
+ switch notifier.serviceStatus.win32ExitCode {
+ case uint32(windows.NO_ERROR), serviceNEVER_STARTED:
+ default:
+ tunnelError = syscall.Errno(notifier.serviceStatus.win32ExitCode)
+ }
}
- case svc.StopPending:
- state = TunnelStopping
- case svc.Running:
- state = TunnelStarted
- case svc.StartPending:
- state = TunnelStarting
}
if state != lastState {
IPCServerNotifyTunnelChange(tunnelName, state, tunnelError)
lastState = state
}
- if state == TunnelStopped {
- return
- }
}
}
diff --git a/ui/ui.go b/ui/ui.go
index cdfc8472..7bc49489 100644
--- a/ui/ui.go
+++ b/ui/ui.go
@@ -157,7 +157,6 @@ func RunUI() {
return
}
restoreState = false
- runningTunnel = nil
return
}
c, err := conf.FromWgQuick(se.Text(), "test")
@@ -177,7 +176,6 @@ func RunUI() {
return
}
restoreState = false
- runningTunnel = &tunnel
})
quitAction := walk.NewAction()
@@ -209,12 +207,14 @@ func RunUI() {
//TODO: also set tray icon to reflect state
switch state {
case service.TunnelStarting:
+ runningTunnel = tunnel
showRunningView(false)
se.SetEnabled(false)
pb.SetText("Starting...")
pb.SetEnabled(false)
tray.SetToolTip("WireGuard: Activating...")
case service.TunnelStarted:
+ runningTunnel = tunnel
showRunningView(true)
se.SetEnabled(false)
pb.SetText("Stop")
@@ -225,6 +225,7 @@ func RunUI() {
tray.ShowInfo("WireGuard Activated", fmt.Sprintf("The %s tunnel has been activated.", tunnel.Name))
}
case service.TunnelStopping:
+ runningTunnel = tunnel
showRunningView(false)
se.SetEnabled(false)
pb.SetText("Stopping...")
@@ -232,18 +233,15 @@ func RunUI() {
tray.SetToolTip("WireGuard: Deactivating...")
case service.TunnelStopped:
showRunningView(false)
- if runningTunnel != nil {
- runningTunnel.Stop()
- runningTunnel = nil
- }
se.SetEnabled(true)
pb.SetText("Start")
pb.SetEnabled(true)
tray.SetToolTip("WireGuard: Deactivated")
- if showNotifications {
+ if showNotifications && runningTunnel != nil {
//TODO: ShowCustom with right icon
tray.ShowInfo("WireGuard Deactivated", fmt.Sprintf("The %s tunnel has been deactivated.", tunnel.Name))
}
+ runningTunnel = nil
}
mw.SetSuspended(false)
}
@@ -267,8 +265,10 @@ func RunUI() {
if err != nil {
continue
}
- runningTunnel = &tunnel
- setServiceState(&tunnel, state, false)
+ if tunnel.Name == "test" && state != service.TunnelStopped {
+ runningTunnel = &tunnel
+ setServiceState(&tunnel, state, false)
+ }
}
}()