diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2019-04-26 20:05:24 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2019-04-27 08:24:22 +0200 |
commit | d099b3eda13bda85a3763cb7a73543999c35e11f (patch) | |
tree | d62408b62529d0e18a19866d540c1c87cfa957c1 /service | |
parent | installer: new checksums for working msm (diff) | |
download | wireguard-windows-d099b3eda13bda85a3763cb7a73543999c35e11f.tar.xz wireguard-windows-d099b3eda13bda85a3763cb7a73543999c35e11f.zip |
ui: simplify everything
Diffstat (limited to '')
-rw-r--r-- | service/ipc_client.go | 19 | ||||
-rw-r--r-- | service/ipc_server.go | 17 | ||||
-rw-r--r-- | service/tunneltracker.go | 35 |
3 files changed, 66 insertions, 5 deletions
diff --git a/service/ipc_client.go b/service/ipc_client.go index e6295b91..55fc043d 100644 --- a/service/ipc_client.go +++ b/service/ipc_client.go @@ -113,6 +113,20 @@ func (t *Tunnel) Stop() error { return rpcClient.Call("ManagerService.Stop", t.Name, nil) } +func (t *Tunnel) Toggle() (oldState TunnelState, err error) { + oldState, err = t.State() + if err != nil { + oldState = TunnelUnknown + return + } + if oldState == TunnelStarted { + err = t.Stop() + } else if oldState == TunnelStopped { + err = t.Start() + } + return +} + func (t *Tunnel) WaitForStop() error { return rpcClient.Call("ManagerService.WaitForStop", t.Name, nil) } @@ -136,6 +150,11 @@ func IPCClientTunnels() ([]Tunnel, error) { return tunnels, rpcClient.Call("ManagerService.Tunnels", uintptr(0), &tunnels) } +func IPCClientGlobalState() (TunnelState, error) { + var state TunnelState + return state, rpcClient.Call("ManagerService.GlobalState", uintptr(0), &state) +} + func IPCClientQuit(stopTunnelsOnQuit bool) (bool, error) { var alreadyQuit bool return alreadyQuit, rpcClient.Call("ManagerService.Quit", stopTunnelsOnQuit, &alreadyQuit) diff --git a/service/ipc_server.go b/service/ipc_server.go index 17ea67c2..6d576846 100644 --- a/service/ipc_server.go +++ b/service/ipc_server.go @@ -71,6 +71,18 @@ func (s *ManagerService) RuntimeConfig(tunnelName string, config *conf.Config) e } func (s *ManagerService) Start(tunnelName string, unused *uintptr) error { + // For now, enforce only one tunnel at a time. Later we'll remove this silly restriction. + trackedTunnelsLock.Lock() + tt := make([]string, 0, len(trackedTunnels)) + for t := range trackedTunnels { + tt = append(tt, t) + } + trackedTunnelsLock.Unlock() + for _, t := range tt { + s.Stop(t, unused) + } + + // After that process is started -- it's somewhat asynchronous -- we install the new one. c, err := conf.LoadFromName(tunnelName) if err != nil { return err @@ -156,6 +168,11 @@ func (s *ManagerService) State(tunnelName string, state *TunnelState) error { return nil } +func (s *ManagerService) GlobalState(unused uintptr, state *TunnelState) error { + *state = trackedTunnelsGlobalState() + return nil +} + func (s *ManagerService) Create(tunnelConfig conf.Config, tunnel *Tunnel) error { err := tunnelConfig.Save() if err != nil { diff --git a/service/tunneltracker.go b/service/tunneltracker.go index 3cf9fde5..cfc894f4 100644 --- a/service/tunneltracker.go +++ b/service/tunneltracker.go @@ -85,7 +85,7 @@ var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *serviceNotify return 0 }) -var trackedTunnels = make(map[string]bool) +var trackedTunnels = make(map[string]TunnelState) var trackedTunnelsLock = sync.Mutex{} func svcStateToTunState(s svc.State) TunnelState { @@ -103,16 +103,32 @@ func svcStateToTunState(s svc.State) TunnelState { } } +func trackedTunnelsGlobalState() (state TunnelState) { + state = TunnelStopped + trackedTunnelsLock.Lock() + defer trackedTunnelsLock.Unlock() + for _, s := range trackedTunnels { + if s == TunnelStarting { + return TunnelStarting + } else if s == TunnelStopping { + return TunnelStopping + } else if s == TunnelStarted || s == TunnelUnknown { + state = TunnelStarted + } + } + return +} + func trackTunnelService(tunnelName string, service *mgr.Service) { defer service.Close() trackedTunnelsLock.Lock() - _, isTracked := trackedTunnels[tunnelName] - trackedTunnels[tunnelName] = true - trackedTunnelsLock.Unlock() - if isTracked { + if _, found := trackedTunnels[tunnelName]; found { + trackedTunnelsLock.Unlock() return } + trackedTunnels[tunnelName] = TunnelUnknown + trackedTunnelsLock.Unlock() defer func() { trackedTunnelsLock.Lock() delete(trackedTunnels, tunnelName) @@ -133,11 +149,17 @@ func trackTunnelService(tunnelName string, service *mgr.Service) { case 0: sleepEx(windows.INFINITE, true) case errorServiceMARKED_FOR_DELETE: + trackedTunnelsLock.Lock() + trackedTunnels[tunnelName] = TunnelStopped + trackedTunnelsLock.Unlock() IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil) return case errorServiceNOTIFY_CLIENT_LAGGING: continue default: + trackedTunnelsLock.Lock() + trackedTunnels[tunnelName] = TunnelStopped + trackedTunnelsLock.Unlock() IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring service, so stopping: %v", syscall.Errno(ret))) service.Control(svc.Stop) return @@ -160,6 +182,9 @@ func trackTunnelService(tunnelName string, service *mgr.Service) { } } if state != lastState { + trackedTunnelsLock.Lock() + trackedTunnels[tunnelName] = state + trackedTunnelsLock.Unlock() IPCServerNotifyTunnelChange(tunnelName, state, tunnelError) lastState = state } |