aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt
diff options
context:
space:
mode:
Diffstat (limited to 'ui/src/main/java/com/wireguard/android/model/TunnelManager.kt')
-rw-r--r--ui/src/main/java/com/wireguard/android/model/TunnelManager.kt255
1 files changed, 255 insertions, 0 deletions
diff --git a/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt b/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt
new file mode 100644
index 00000000..d7c1391f
--- /dev/null
+++ b/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt
@@ -0,0 +1,255 @@
+/*
+ * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ */
+package com.wireguard.android.model
+
+import android.content.BroadcastReceiver
+import android.content.Context
+import android.content.Intent
+import android.os.Build
+import android.util.Log
+import android.widget.Toast
+import androidx.databinding.BaseObservable
+import androidx.databinding.Bindable
+import com.wireguard.android.Application.Companion.get
+import com.wireguard.android.Application.Companion.getBackend
+import com.wireguard.android.Application.Companion.getTunnelManager
+import com.wireguard.android.BR
+import com.wireguard.android.R
+import com.wireguard.android.backend.Statistics
+import com.wireguard.android.backend.Tunnel
+import com.wireguard.android.configStore.ConfigStore
+import com.wireguard.android.databinding.ObservableSortedKeyedArrayList
+import com.wireguard.android.util.ErrorMessages
+import com.wireguard.android.util.UserKnobs
+import com.wireguard.android.util.applicationScope
+import com.wireguard.config.Config
+import kotlinx.coroutines.CompletableDeferred
+import kotlinx.coroutines.Dispatchers
+import kotlinx.coroutines.SupervisorJob
+import kotlinx.coroutines.async
+import kotlinx.coroutines.awaitAll
+import kotlinx.coroutines.flow.first
+import kotlinx.coroutines.launch
+import kotlinx.coroutines.withContext
+
+/**
+ * Maintains and mediates changes to the set of available WireGuard tunnels,
+ */
+class TunnelManager(private val configStore: ConfigStore) : BaseObservable() {
+ private val tunnels = CompletableDeferred<ObservableSortedKeyedArrayList<String, ObservableTunnel>>()
+ private val context: Context = get()
+ private val tunnelMap: ObservableSortedKeyedArrayList<String, ObservableTunnel> = ObservableSortedKeyedArrayList(TunnelComparator)
+ private var haveLoaded = false
+
+ private fun addToList(name: String, config: Config?, state: Tunnel.State): ObservableTunnel {
+ val tunnel = ObservableTunnel(this, name, config, state)
+ tunnelMap.add(tunnel)
+ return tunnel
+ }
+
+ suspend fun getTunnels(): ObservableSortedKeyedArrayList<String, ObservableTunnel> = tunnels.await()
+
+ suspend fun create(name: String, config: Config?): ObservableTunnel = withContext(Dispatchers.Main.immediate) {
+ if (Tunnel.isNameInvalid(name))
+ throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name))
+ if (tunnelMap.containsKey(name))
+ throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name))
+ addToList(name, withContext(Dispatchers.IO) { configStore.create(name, config!!) }, Tunnel.State.DOWN)
+ }
+
+ suspend fun delete(tunnel: ObservableTunnel) = withContext(Dispatchers.Main.immediate) {
+ val originalState = tunnel.state
+ val wasLastUsed = tunnel == lastUsedTunnel
+ // Make sure nothing touches the tunnel.
+ if (wasLastUsed)
+ lastUsedTunnel = null
+ tunnelMap.remove(tunnel)
+ try {
+ if (originalState == Tunnel.State.UP)
+ withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) }
+ try {
+ withContext(Dispatchers.IO) { configStore.delete(tunnel.name) }
+ } catch (e: Throwable) {
+ if (originalState == Tunnel.State.UP)
+ withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) }
+ throw e
+ }
+ } catch (e: Throwable) {
+ // Failure, put the tunnel back.
+ tunnelMap.add(tunnel)
+ if (wasLastUsed)
+ lastUsedTunnel = tunnel
+ throw e
+ }
+ }
+
+ @get:Bindable
+ var lastUsedTunnel: ObservableTunnel? = null
+ private set(value) {
+ if (value == field) return
+ field = value
+ notifyPropertyChanged(BR.lastUsedTunnel)
+ applicationScope.launch { UserKnobs.setLastUsedTunnel(value?.name) }
+ }
+
+ suspend fun getTunnelConfig(tunnel: ObservableTunnel): Config = withContext(Dispatchers.Main.immediate) {
+ tunnel.onConfigChanged(withContext(Dispatchers.IO) { configStore.load(tunnel.name) })!!
+ }
+
+ fun onCreate() {
+ applicationScope.launch {
+ try {
+ onTunnelsLoaded(withContext(Dispatchers.IO) { configStore.enumerate() }, withContext(Dispatchers.IO) { getBackend().runningTunnelNames })
+ } catch (e: Throwable) {
+ Log.e(TAG, Log.getStackTraceString(e))
+ }
+ }
+ }
+
+ private fun onTunnelsLoaded(present: Iterable<String>, running: Collection<String>) {
+ for (name in present)
+ addToList(name, null, if (running.contains(name)) Tunnel.State.UP else Tunnel.State.DOWN)
+ applicationScope.launch {
+ val lastUsedName = UserKnobs.lastUsedTunnel.first()
+ if (lastUsedName != null)
+ lastUsedTunnel = tunnelMap[lastUsedName]
+ haveLoaded = true
+ restoreState(true)
+ tunnels.complete(tunnelMap)
+ }
+ }
+
+ private fun refreshTunnelStates() {
+ applicationScope.launch {
+ try {
+ val running = withContext(Dispatchers.IO) { getBackend().runningTunnelNames }
+ for (tunnel in tunnelMap)
+ tunnel.onStateChanged(if (running.contains(tunnel.name)) Tunnel.State.UP else Tunnel.State.DOWN)
+ } catch (e: Throwable) {
+ Log.e(TAG, Log.getStackTraceString(e))
+ }
+ }
+ }
+
+ suspend fun restoreState(force: Boolean) {
+ if (!haveLoaded || (!force && !UserKnobs.restoreOnBoot.first()))
+ return
+ val previouslyRunning = UserKnobs.runningTunnels.first()
+ if (previouslyRunning.isEmpty()) return
+ withContext(Dispatchers.IO) {
+ try {
+ tunnelMap.filter { previouslyRunning.contains(it.name) }.map { async(Dispatchers.IO + SupervisorJob()) { setTunnelState(it, Tunnel.State.UP) } }
+ .awaitAll()
+ } catch (e: Throwable) {
+ Log.e(TAG, Log.getStackTraceString(e))
+ }
+ }
+ }
+
+ suspend fun saveState() {
+ UserKnobs.setRunningTunnels(tunnelMap.filter { it.state == Tunnel.State.UP }.map { it.name }.toSet())
+ }
+
+ suspend fun setTunnelConfig(tunnel: ObservableTunnel, config: Config): Config = withContext(Dispatchers.Main.immediate) {
+ tunnel.onConfigChanged(withContext(Dispatchers.IO) {
+ getBackend().setState(tunnel, tunnel.state, config)
+ configStore.save(tunnel.name, config)
+ })!!
+ }
+
+ suspend fun setTunnelName(tunnel: ObservableTunnel, name: String): String = withContext(Dispatchers.Main.immediate) {
+ if (Tunnel.isNameInvalid(name))
+ throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name))
+ if (tunnelMap.containsKey(name)) {
+ throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name))
+ }
+ val originalState = tunnel.state
+ val wasLastUsed = tunnel == lastUsedTunnel
+ // Make sure nothing touches the tunnel.
+ if (wasLastUsed)
+ lastUsedTunnel = null
+ tunnelMap.remove(tunnel)
+ var throwable: Throwable? = null
+ var newName: String? = null
+ try {
+ if (originalState == Tunnel.State.UP)
+ withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) }
+ withContext(Dispatchers.IO) { configStore.rename(tunnel.name, name) }
+ newName = tunnel.onNameChanged(name)
+ if (originalState == Tunnel.State.UP)
+ withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) }
+ } catch (e: Throwable) {
+ throwable = e
+ // On failure, we don't know what state the tunnel might be in. Fix that.
+ getTunnelState(tunnel)
+ }
+ // Add the tunnel back to the manager, under whatever name it thinks it has.
+ tunnelMap.add(tunnel)
+ if (wasLastUsed)
+ lastUsedTunnel = tunnel
+ if (throwable != null)
+ throw throwable
+ newName!!
+ }
+
+ suspend fun setTunnelState(tunnel: ObservableTunnel, state: Tunnel.State): Tunnel.State = withContext(Dispatchers.Main.immediate) {
+ var newState = tunnel.state
+ var throwable: Throwable? = null
+ try {
+ newState = withContext(Dispatchers.IO) { getBackend().setState(tunnel, state, tunnel.getConfigAsync()) }
+ if (newState == Tunnel.State.UP)
+ lastUsedTunnel = tunnel
+ } catch (e: Throwable) {
+ throwable = e
+ }
+ tunnel.onStateChanged(newState)
+ saveState()
+ if (throwable != null)
+ throw throwable
+ newState
+ }
+
+ class IntentReceiver : BroadcastReceiver() {
+ override fun onReceive(context: Context, intent: Intent?) {
+ applicationScope.launch {
+ val manager = getTunnelManager()
+ if (intent == null) return@launch
+ val action = intent.action ?: return@launch
+ if ("com.wireguard.android.action.REFRESH_TUNNEL_STATES" == action) {
+ manager.refreshTunnelStates()
+ return@launch
+ }
+ if (Build.VERSION.SDK_INT < Build.VERSION_CODES.M || !UserKnobs.allowRemoteControlIntents.first())
+ return@launch
+ val state: Tunnel.State
+ state = when (action) {
+ "com.wireguard.android.action.SET_TUNNEL_UP" -> Tunnel.State.UP
+ "com.wireguard.android.action.SET_TUNNEL_DOWN" -> Tunnel.State.DOWN
+ else -> return@launch
+ }
+ val tunnelName = intent.getStringExtra("tunnel") ?: return@launch
+ val tunnels = manager.getTunnels()
+ val tunnel = tunnels[tunnelName] ?: return@launch
+ try {
+ manager.setTunnelState(tunnel, state)
+ } catch (e: Throwable) {
+ Toast.makeText(context, ErrorMessages[e], Toast.LENGTH_LONG).show()
+ }
+ }
+ }
+ }
+
+ suspend fun getTunnelState(tunnel: ObservableTunnel): Tunnel.State = withContext(Dispatchers.Main.immediate) {
+ tunnel.onStateChanged(withContext(Dispatchers.IO) { getBackend().getState(tunnel) })
+ }
+
+ suspend fun getTunnelStatistics(tunnel: ObservableTunnel): Statistics = withContext(Dispatchers.Main.immediate) {
+ tunnel.onStatisticsChanged(withContext(Dispatchers.IO) { getBackend().getStatistics(tunnel) })!!
+ }
+
+ companion object {
+ private const val TAG = "WireGuard/TunnelManager"
+ }
+}