From 5b8d4ff0a3d0dabe620374c65e36c08b7f501cf2 Mon Sep 17 00:00:00 2001 From: Alexander Neumann Date: Wed, 10 Apr 2019 12:47:03 +0200 Subject: ui: only allow a single tunnel to be active at any time Signed-off-by: Alexander Neumann Signed-off-by: Jason A. Donenfeld --- ui/confview.go | 49 +++++++++++++++----------- ui/manage_tunnels.go | 11 ++++++ ui/tunneltracker.go | 99 ++++++++++++++++++++++++++++++++++++++++++++++++++++ ui/ui.go | 28 ++++----------- 4 files changed, 145 insertions(+), 42 deletions(-) create mode 100644 ui/tunneltracker.go diff --git a/ui/confview.go b/ui/confview.go index b424720a..230d4dd7 100644 --- a/ui/confview.go +++ b/ui/confview.go @@ -42,8 +42,9 @@ type labelTextLine struct { } type toggleActiveLine struct { - composite *walk.Composite - button *walk.PushButton + composite *walk.Composite + button *walk.PushButton + tunnelTracker *TunnelTracker } type interfaceView struct { @@ -195,6 +196,10 @@ func (tal *toggleActiveLine) update(state service.TunnelState) { enabled, text = false, "" } + if tt := tal.tunnelTracker; tt != nil && tt.InTransition() { + enabled = false + } + tal.button.SetEnabled(enabled) tal.button.SetText(text) tal.button.SetVisible(state != service.TunnelUnknown) @@ -430,28 +435,30 @@ var crossThreadMessageHijack = windows.NewCallback(func(hwnd win.HWND, msg uint3 return win.CallWindowProc(cv.originalWndProc, hwnd, msg, wParam, lParam) }) -func (cv *ConfView) onToggleActiveClicked() { - state, err := cv.tunnel.State() - if err != nil { - walk.MsgBox(cv.Form(), "Failed to retrieve tunnel state", fmt.Sprintf("Error: %s", err.Error()), walk.MsgBoxIconError) - return - } - - cv.interfaze.toggleActive.button.SetEnabled(false) +func (cv *ConfView) TunnelTracker() *TunnelTracker { + return cv.interfaze.toggleActive.tunnelTracker +} - switch state { - case service.TunnelStarted: - if err := cv.tunnel.Stop(); err != nil { - walk.MsgBox(cv.Form(), "Failed to stop tunnel", fmt.Sprintf("Error: %s", err.Error()), walk.MsgBoxIconError) - } +func (cv *ConfView) SetTunnelTracker(tunnelTracker *TunnelTracker) { + cv.interfaze.toggleActive.tunnelTracker = tunnelTracker +} - case service.TunnelStopped: - if err := cv.tunnel.Start(); err != nil { - walk.MsgBox(cv.Form(), "Failed to start tunnel", fmt.Sprintf("Error: %s", err.Error()), walk.MsgBoxIconError) - } +func (cv *ConfView) onToggleActiveClicked() { + cv.interfaze.toggleActive.button.SetEnabled(false) - default: - panic("unexpected state") + var title string + var err error + tt := cv.TunnelTracker() + if activeTunnel := tt.ActiveTunnel(); activeTunnel != nil && activeTunnel.Name == cv.tunnel.Name { + title = "Failed to deactivate tunnel" + err = tt.DeactivateTunnel() + } else { + title = "Failed to activate tunnel" + err = tt.ActivateTunnel(cv.tunnel) + } + if err != nil { + walk.MsgBox(cv.Form(), title, err.Error(), walk.MsgBoxIconError) + return } cv.setTunnel(cv.tunnel) diff --git a/ui/manage_tunnels.go b/ui/manage_tunnels.go index 4a3985d0..815dee00 100644 --- a/ui/manage_tunnels.go +++ b/ui/manage_tunnels.go @@ -25,6 +25,7 @@ type ManageTunnelsWindow struct { icon *walk.Icon + tunnelTracker *TunnelTracker tunnelsView *TunnelsView confView *ConfView tunnelAddedPublisher walk.StringEventPublisher @@ -172,6 +173,16 @@ func (mtw *ManageTunnelsWindow) Show() { win.BringWindowToTop(mtw.Handle()) } +func (mtw *ManageTunnelsWindow) TunnelTracker() *TunnelTracker { + return mtw.tunnelTracker +} + +func (mtw *ManageTunnelsWindow) SetTunnelTracker(tunnelTracker *TunnelTracker) { + mtw.tunnelTracker = tunnelTracker + + mtw.confView.SetTunnelTracker(tunnelTracker) +} + func (mtw *ManageTunnelsWindow) SetTunnelState(tunnel *service.Tunnel, state service.TunnelState) { mtw.tunnelsView.SetTunnelState(tunnel, state) // mtw.confView.SetTunnelState(tunnel, state) diff --git a/ui/tunneltracker.go b/ui/tunneltracker.go new file mode 100644 index 00000000..73d48538 --- /dev/null +++ b/ui/tunneltracker.go @@ -0,0 +1,99 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package ui + +import ( + "fmt" + + "github.com/lxn/walk" + "golang.zx2c4.com/wireguard/windows/service" +) + +type TunnelTracker struct { + activeTunnel *service.Tunnel + activeTunnelChanged walk.EventPublisher + tunnelChangeCB *service.TunnelChangeCallback + inTransition bool +} + +func (tt *TunnelTracker) ActiveTunnel() *service.Tunnel { + return tt.activeTunnel +} + +func (tt *TunnelTracker) ActivateTunnel(tunnel *service.Tunnel) error { + if tunnel == tt.activeTunnel { + return nil + } + + if err := tt.DeactivateTunnel(); err != nil { + return fmt.Errorf("ActivateTunnel: Failed to deactivate tunnel '%s': %v", tunnel.Name, err) + } + + if err := tunnel.Start(); err != nil { + return fmt.Errorf("ActivateTunnel: Failed to start tunnel '%s': %v", tunnel.Name, err) + } + + return nil +} + +func (tt *TunnelTracker) DeactivateTunnel() error { + if tt.activeTunnel == nil { + return nil + } + + state, err := tt.activeTunnel.State() + if err != nil { + return fmt.Errorf("DeactivateTunnel: Failed to retrieve state for tunnel %s: %v", tt.activeTunnel.Name, err) + } + + if state == service.TunnelStarted { + if err := tt.activeTunnel.Stop(); err != nil { + return fmt.Errorf("DeactivateTunnel: Failed to stop tunnel '%s': %v", tt.activeTunnel.Name, err) + } + } + + if state == service.TunnelStarted || state == service.TunnelStopping { + if err := tt.activeTunnel.WaitForStop(); err != nil { + return fmt.Errorf("DeactivateTunnel: Failed to wait for tunnel '%s' to stop: %v", tt.activeTunnel.Name, err) + } + } + + return nil +} + +func (tt *TunnelTracker) ActiveTunnelChanged() *walk.Event { + return tt.activeTunnelChanged.Event() +} + +func (tt *TunnelTracker) InTransition() bool { + return tt.inTransition +} + +func (tt *TunnelTracker) SetTunnelState(tunnel *service.Tunnel, state service.TunnelState, err error) { + if err != nil { + tt.inTransition = false + } + + switch state { + case service.TunnelStarted: + tt.inTransition = false + tt.activeTunnel = tunnel + + case service.TunnelStarting, service.TunnelStopping: + tt.inTransition = true + + case service.TunnelStopped: + if tt.activeTunnel != nil && tt.activeTunnel.Name == tunnel.Name { + tt.inTransition = false + } + tt.activeTunnel = nil + + default: + return + } + + tt.activeTunnelChanged.Publish() +} diff --git a/ui/ui.go b/ui/ui.go index 7b9a8ce0..7cc96ee2 100644 --- a/ui/ui.go +++ b/ui/ui.go @@ -33,6 +33,8 @@ func nag() { func RunUI() { runtime.LockOSThread() + tunnelTracker := new(TunnelTracker) + icon, err := walk.NewIconFromResourceId(1) if err != nil { panic(err) @@ -45,6 +47,8 @@ func RunUI() { } defer mtw.Dispose() + mtw.SetTunnelTracker(tunnelTracker) + tray, err := NewTray(mtw, icon) if err != nil { panic(err) @@ -52,14 +56,13 @@ func RunUI() { defer tray.Dispose() // Bind to updates - setTunnelState := func(tunnel *service.Tunnel, state service.TunnelState, showNotifications bool) { + service.IPCClientRegisterTunnelChange(func(tunnel *service.Tunnel, state service.TunnelState, err error) { mtw.Synchronize(func() { + tunnelTracker.SetTunnelState(tunnel, state, err) mtw.SetTunnelState(tunnel, state) - tray.SetTunnelStateWithNotification(tunnel, state, showNotifications) + tray.SetTunnelStateWithNotification(tunnel, state, err == nil) }) - } - service.IPCClientRegisterTunnelChange(func(tunnel *service.Tunnel, state service.TunnelState, err error) { if err == nil { return } @@ -73,25 +76,8 @@ func RunUI() { } else { tray.ShowError("WireGuard Tunnel Error", err.Error()) } - - setTunnelState(tunnel, state, err == nil) }) - // Fetch current state - go func() { - tunnels, err := service.IPCClientTunnels() - if err != nil { - return - } - for _, tunnel := range tunnels { - state, err := tunnel.State() - if err != nil { - continue - } - setTunnelState(&tunnel, state, false) - } - }() - time.AfterFunc(time.Minute*15, nag) mtw.Run() } -- cgit v1.2.3-59-g8ed1b