aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--ui/confview.go49
-rw-r--r--ui/manage_tunnels.go11
-rw-r--r--ui/tunneltracker.go99
-rw-r--r--ui/ui.go28
4 files changed, 145 insertions, 42 deletions
diff --git a/ui/confview.go b/ui/confview.go
index b424720a..230d4dd7 100644
--- a/ui/confview.go
+++ b/ui/confview.go
@@ -42,8 +42,9 @@ type labelTextLine struct {
}
type toggleActiveLine struct {
- composite *walk.Composite
- button *walk.PushButton
+ composite *walk.Composite
+ button *walk.PushButton
+ tunnelTracker *TunnelTracker
}
type interfaceView struct {
@@ -195,6 +196,10 @@ func (tal *toggleActiveLine) update(state service.TunnelState) {
enabled, text = false, ""
}
+ if tt := tal.tunnelTracker; tt != nil && tt.InTransition() {
+ enabled = false
+ }
+
tal.button.SetEnabled(enabled)
tal.button.SetText(text)
tal.button.SetVisible(state != service.TunnelUnknown)
@@ -430,28 +435,30 @@ var crossThreadMessageHijack = windows.NewCallback(func(hwnd win.HWND, msg uint3
return win.CallWindowProc(cv.originalWndProc, hwnd, msg, wParam, lParam)
})
-func (cv *ConfView) onToggleActiveClicked() {
- state, err := cv.tunnel.State()
- if err != nil {
- walk.MsgBox(cv.Form(), "Failed to retrieve tunnel state", fmt.Sprintf("Error: %s", err.Error()), walk.MsgBoxIconError)
- return
- }
-
- cv.interfaze.toggleActive.button.SetEnabled(false)
+func (cv *ConfView) TunnelTracker() *TunnelTracker {
+ return cv.interfaze.toggleActive.tunnelTracker
+}
- switch state {
- case service.TunnelStarted:
- if err := cv.tunnel.Stop(); err != nil {
- walk.MsgBox(cv.Form(), "Failed to stop tunnel", fmt.Sprintf("Error: %s", err.Error()), walk.MsgBoxIconError)
- }
+func (cv *ConfView) SetTunnelTracker(tunnelTracker *TunnelTracker) {
+ cv.interfaze.toggleActive.tunnelTracker = tunnelTracker
+}
- case service.TunnelStopped:
- if err := cv.tunnel.Start(); err != nil {
- walk.MsgBox(cv.Form(), "Failed to start tunnel", fmt.Sprintf("Error: %s", err.Error()), walk.MsgBoxIconError)
- }
+func (cv *ConfView) onToggleActiveClicked() {
+ cv.interfaze.toggleActive.button.SetEnabled(false)
- default:
- panic("unexpected state")
+ var title string
+ var err error
+ tt := cv.TunnelTracker()
+ if activeTunnel := tt.ActiveTunnel(); activeTunnel != nil && activeTunnel.Name == cv.tunnel.Name {
+ title = "Failed to deactivate tunnel"
+ err = tt.DeactivateTunnel()
+ } else {
+ title = "Failed to activate tunnel"
+ err = tt.ActivateTunnel(cv.tunnel)
+ }
+ if err != nil {
+ walk.MsgBox(cv.Form(), title, err.Error(), walk.MsgBoxIconError)
+ return
}
cv.setTunnel(cv.tunnel)
diff --git a/ui/manage_tunnels.go b/ui/manage_tunnels.go
index 4a3985d0..815dee00 100644
--- a/ui/manage_tunnels.go
+++ b/ui/manage_tunnels.go
@@ -25,6 +25,7 @@ type ManageTunnelsWindow struct {
icon *walk.Icon
+ tunnelTracker *TunnelTracker
tunnelsView *TunnelsView
confView *ConfView
tunnelAddedPublisher walk.StringEventPublisher
@@ -172,6 +173,16 @@ func (mtw *ManageTunnelsWindow) Show() {
win.BringWindowToTop(mtw.Handle())
}
+func (mtw *ManageTunnelsWindow) TunnelTracker() *TunnelTracker {
+ return mtw.tunnelTracker
+}
+
+func (mtw *ManageTunnelsWindow) SetTunnelTracker(tunnelTracker *TunnelTracker) {
+ mtw.tunnelTracker = tunnelTracker
+
+ mtw.confView.SetTunnelTracker(tunnelTracker)
+}
+
func (mtw *ManageTunnelsWindow) SetTunnelState(tunnel *service.Tunnel, state service.TunnelState) {
mtw.tunnelsView.SetTunnelState(tunnel, state)
// mtw.confView.SetTunnelState(tunnel, state)
diff --git a/ui/tunneltracker.go b/ui/tunneltracker.go
new file mode 100644
index 00000000..73d48538
--- /dev/null
+++ b/ui/tunneltracker.go
@@ -0,0 +1,99 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package ui
+
+import (
+ "fmt"
+
+ "github.com/lxn/walk"
+ "golang.zx2c4.com/wireguard/windows/service"
+)
+
+type TunnelTracker struct {
+ activeTunnel *service.Tunnel
+ activeTunnelChanged walk.EventPublisher
+ tunnelChangeCB *service.TunnelChangeCallback
+ inTransition bool
+}
+
+func (tt *TunnelTracker) ActiveTunnel() *service.Tunnel {
+ return tt.activeTunnel
+}
+
+func (tt *TunnelTracker) ActivateTunnel(tunnel *service.Tunnel) error {
+ if tunnel == tt.activeTunnel {
+ return nil
+ }
+
+ if err := tt.DeactivateTunnel(); err != nil {
+ return fmt.Errorf("ActivateTunnel: Failed to deactivate tunnel '%s': %v", tunnel.Name, err)
+ }
+
+ if err := tunnel.Start(); err != nil {
+ return fmt.Errorf("ActivateTunnel: Failed to start tunnel '%s': %v", tunnel.Name, err)
+ }
+
+ return nil
+}
+
+func (tt *TunnelTracker) DeactivateTunnel() error {
+ if tt.activeTunnel == nil {
+ return nil
+ }
+
+ state, err := tt.activeTunnel.State()
+ if err != nil {
+ return fmt.Errorf("DeactivateTunnel: Failed to retrieve state for tunnel %s: %v", tt.activeTunnel.Name, err)
+ }
+
+ if state == service.TunnelStarted {
+ if err := tt.activeTunnel.Stop(); err != nil {
+ return fmt.Errorf("DeactivateTunnel: Failed to stop tunnel '%s': %v", tt.activeTunnel.Name, err)
+ }
+ }
+
+ if state == service.TunnelStarted || state == service.TunnelStopping {
+ if err := tt.activeTunnel.WaitForStop(); err != nil {
+ return fmt.Errorf("DeactivateTunnel: Failed to wait for tunnel '%s' to stop: %v", tt.activeTunnel.Name, err)
+ }
+ }
+
+ return nil
+}
+
+func (tt *TunnelTracker) ActiveTunnelChanged() *walk.Event {
+ return tt.activeTunnelChanged.Event()
+}
+
+func (tt *TunnelTracker) InTransition() bool {
+ return tt.inTransition
+}
+
+func (tt *TunnelTracker) SetTunnelState(tunnel *service.Tunnel, state service.TunnelState, err error) {
+ if err != nil {
+ tt.inTransition = false
+ }
+
+ switch state {
+ case service.TunnelStarted:
+ tt.inTransition = false
+ tt.activeTunnel = tunnel
+
+ case service.TunnelStarting, service.TunnelStopping:
+ tt.inTransition = true
+
+ case service.TunnelStopped:
+ if tt.activeTunnel != nil && tt.activeTunnel.Name == tunnel.Name {
+ tt.inTransition = false
+ }
+ tt.activeTunnel = nil
+
+ default:
+ return
+ }
+
+ tt.activeTunnelChanged.Publish()
+}
diff --git a/ui/ui.go b/ui/ui.go
index 7b9a8ce0..7cc96ee2 100644
--- a/ui/ui.go
+++ b/ui/ui.go
@@ -33,6 +33,8 @@ func nag() {
func RunUI() {
runtime.LockOSThread()
+ tunnelTracker := new(TunnelTracker)
+
icon, err := walk.NewIconFromResourceId(1)
if err != nil {
panic(err)
@@ -45,6 +47,8 @@ func RunUI() {
}
defer mtw.Dispose()
+ mtw.SetTunnelTracker(tunnelTracker)
+
tray, err := NewTray(mtw, icon)
if err != nil {
panic(err)
@@ -52,14 +56,13 @@ func RunUI() {
defer tray.Dispose()
// Bind to updates
- setTunnelState := func(tunnel *service.Tunnel, state service.TunnelState, showNotifications bool) {
+ service.IPCClientRegisterTunnelChange(func(tunnel *service.Tunnel, state service.TunnelState, err error) {
mtw.Synchronize(func() {
+ tunnelTracker.SetTunnelState(tunnel, state, err)
mtw.SetTunnelState(tunnel, state)
- tray.SetTunnelStateWithNotification(tunnel, state, showNotifications)
+ tray.SetTunnelStateWithNotification(tunnel, state, err == nil)
})
- }
- service.IPCClientRegisterTunnelChange(func(tunnel *service.Tunnel, state service.TunnelState, err error) {
if err == nil {
return
}
@@ -73,25 +76,8 @@ func RunUI() {
} else {
tray.ShowError("WireGuard Tunnel Error", err.Error())
}
-
- setTunnelState(tunnel, state, err == nil)
})
- // Fetch current state
- go func() {
- tunnels, err := service.IPCClientTunnels()
- if err != nil {
- return
- }
- for _, tunnel := range tunnels {
- state, err := tunnel.State()
- if err != nil {
- continue
- }
- setTunnelState(&tunnel, state, false)
- }
- }()
-
time.AfterFunc(time.Minute*15, nag)
mtw.Run()
}