aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/service
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2019-04-26 20:05:24 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2019-04-27 08:24:22 +0200
commitd099b3eda13bda85a3763cb7a73543999c35e11f (patch)
treed62408b62529d0e18a19866d540c1c87cfa957c1 /service
parentinstaller: new checksums for working msm (diff)
downloadwireguard-windows-d099b3eda13bda85a3763cb7a73543999c35e11f.tar.xz
wireguard-windows-d099b3eda13bda85a3763cb7a73543999c35e11f.zip
ui: simplify everything
Diffstat (limited to 'service')
-rw-r--r--service/ipc_client.go19
-rw-r--r--service/ipc_server.go17
-rw-r--r--service/tunneltracker.go35
3 files changed, 66 insertions, 5 deletions
diff --git a/service/ipc_client.go b/service/ipc_client.go
index e6295b91..55fc043d 100644
--- a/service/ipc_client.go
+++ b/service/ipc_client.go
@@ -113,6 +113,20 @@ func (t *Tunnel) Stop() error {
return rpcClient.Call("ManagerService.Stop", t.Name, nil)
}
+func (t *Tunnel) Toggle() (oldState TunnelState, err error) {
+ oldState, err = t.State()
+ if err != nil {
+ oldState = TunnelUnknown
+ return
+ }
+ if oldState == TunnelStarted {
+ err = t.Stop()
+ } else if oldState == TunnelStopped {
+ err = t.Start()
+ }
+ return
+}
+
func (t *Tunnel) WaitForStop() error {
return rpcClient.Call("ManagerService.WaitForStop", t.Name, nil)
}
@@ -136,6 +150,11 @@ func IPCClientTunnels() ([]Tunnel, error) {
return tunnels, rpcClient.Call("ManagerService.Tunnels", uintptr(0), &tunnels)
}
+func IPCClientGlobalState() (TunnelState, error) {
+ var state TunnelState
+ return state, rpcClient.Call("ManagerService.GlobalState", uintptr(0), &state)
+}
+
func IPCClientQuit(stopTunnelsOnQuit bool) (bool, error) {
var alreadyQuit bool
return alreadyQuit, rpcClient.Call("ManagerService.Quit", stopTunnelsOnQuit, &alreadyQuit)
diff --git a/service/ipc_server.go b/service/ipc_server.go
index 17ea67c2..6d576846 100644
--- a/service/ipc_server.go
+++ b/service/ipc_server.go
@@ -71,6 +71,18 @@ func (s *ManagerService) RuntimeConfig(tunnelName string, config *conf.Config) e
}
func (s *ManagerService) Start(tunnelName string, unused *uintptr) error {
+ // For now, enforce only one tunnel at a time. Later we'll remove this silly restriction.
+ trackedTunnelsLock.Lock()
+ tt := make([]string, 0, len(trackedTunnels))
+ for t := range trackedTunnels {
+ tt = append(tt, t)
+ }
+ trackedTunnelsLock.Unlock()
+ for _, t := range tt {
+ s.Stop(t, unused)
+ }
+
+ // After that process is started -- it's somewhat asynchronous -- we install the new one.
c, err := conf.LoadFromName(tunnelName)
if err != nil {
return err
@@ -156,6 +168,11 @@ func (s *ManagerService) State(tunnelName string, state *TunnelState) error {
return nil
}
+func (s *ManagerService) GlobalState(unused uintptr, state *TunnelState) error {
+ *state = trackedTunnelsGlobalState()
+ return nil
+}
+
func (s *ManagerService) Create(tunnelConfig conf.Config, tunnel *Tunnel) error {
err := tunnelConfig.Save()
if err != nil {
diff --git a/service/tunneltracker.go b/service/tunneltracker.go
index 3cf9fde5..cfc894f4 100644
--- a/service/tunneltracker.go
+++ b/service/tunneltracker.go
@@ -85,7 +85,7 @@ var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *serviceNotify
return 0
})
-var trackedTunnels = make(map[string]bool)
+var trackedTunnels = make(map[string]TunnelState)
var trackedTunnelsLock = sync.Mutex{}
func svcStateToTunState(s svc.State) TunnelState {
@@ -103,16 +103,32 @@ func svcStateToTunState(s svc.State) TunnelState {
}
}
+func trackedTunnelsGlobalState() (state TunnelState) {
+ state = TunnelStopped
+ trackedTunnelsLock.Lock()
+ defer trackedTunnelsLock.Unlock()
+ for _, s := range trackedTunnels {
+ if s == TunnelStarting {
+ return TunnelStarting
+ } else if s == TunnelStopping {
+ return TunnelStopping
+ } else if s == TunnelStarted || s == TunnelUnknown {
+ state = TunnelStarted
+ }
+ }
+ return
+}
+
func trackTunnelService(tunnelName string, service *mgr.Service) {
defer service.Close()
trackedTunnelsLock.Lock()
- _, isTracked := trackedTunnels[tunnelName]
- trackedTunnels[tunnelName] = true
- trackedTunnelsLock.Unlock()
- if isTracked {
+ if _, found := trackedTunnels[tunnelName]; found {
+ trackedTunnelsLock.Unlock()
return
}
+ trackedTunnels[tunnelName] = TunnelUnknown
+ trackedTunnelsLock.Unlock()
defer func() {
trackedTunnelsLock.Lock()
delete(trackedTunnels, tunnelName)
@@ -133,11 +149,17 @@ func trackTunnelService(tunnelName string, service *mgr.Service) {
case 0:
sleepEx(windows.INFINITE, true)
case errorServiceMARKED_FOR_DELETE:
+ trackedTunnelsLock.Lock()
+ trackedTunnels[tunnelName] = TunnelStopped
+ trackedTunnelsLock.Unlock()
IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil)
return
case errorServiceNOTIFY_CLIENT_LAGGING:
continue
default:
+ trackedTunnelsLock.Lock()
+ trackedTunnels[tunnelName] = TunnelStopped
+ trackedTunnelsLock.Unlock()
IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring service, so stopping: %v", syscall.Errno(ret)))
service.Control(svc.Stop)
return
@@ -160,6 +182,9 @@ func trackTunnelService(tunnelName string, service *mgr.Service) {
}
}
if state != lastState {
+ trackedTunnelsLock.Lock()
+ trackedTunnels[tunnelName] = state
+ trackedTunnelsLock.Unlock()
IPCServerNotifyTunnelChange(tunnelName, state, tunnelError)
lastState = state
}