aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt
diff options
context:
space:
mode:
Diffstat (limited to 'ui/src/main/java/com/wireguard/android/model/TunnelManager.kt')
-rw-r--r--ui/src/main/java/com/wireguard/android/model/TunnelManager.kt267
1 files changed, 143 insertions, 124 deletions
diff --git a/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt b/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt
index 5091ed3b..d7c1391f 100644
--- a/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt
+++ b/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt
@@ -1,20 +1,19 @@
/*
- * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package com.wireguard.android.model
-import android.annotation.SuppressLint
import android.content.BroadcastReceiver
import android.content.Context
import android.content.Intent
import android.os.Build
+import android.util.Log
+import android.widget.Toast
import androidx.databinding.BaseObservable
import androidx.databinding.Bindable
import com.wireguard.android.Application.Companion.get
-import com.wireguard.android.Application.Companion.getAsyncWorker
import com.wireguard.android.Application.Companion.getBackend
-import com.wireguard.android.Application.Companion.getSharedPreferences
import com.wireguard.android.Application.Companion.getTunnelManager
import com.wireguard.android.BR
import com.wireguard.android.R
@@ -22,145 +21,149 @@ import com.wireguard.android.backend.Statistics
import com.wireguard.android.backend.Tunnel
import com.wireguard.android.configStore.ConfigStore
import com.wireguard.android.databinding.ObservableSortedKeyedArrayList
-import com.wireguard.android.util.ExceptionLoggers
+import com.wireguard.android.util.ErrorMessages
+import com.wireguard.android.util.UserKnobs
+import com.wireguard.android.util.applicationScope
import com.wireguard.config.Config
-import java9.util.concurrent.CompletableFuture
-import java9.util.concurrent.CompletionStage
-import java.util.ArrayList
+import kotlinx.coroutines.CompletableDeferred
+import kotlinx.coroutines.Dispatchers
+import kotlinx.coroutines.SupervisorJob
+import kotlinx.coroutines.async
+import kotlinx.coroutines.awaitAll
+import kotlinx.coroutines.flow.first
+import kotlinx.coroutines.launch
+import kotlinx.coroutines.withContext
/**
* Maintains and mediates changes to the set of available WireGuard tunnels,
*/
class TunnelManager(private val configStore: ConfigStore) : BaseObservable() {
- val tunnels = CompletableFuture<ObservableSortedKeyedArrayList<String, ObservableTunnel>>()
+ private val tunnels = CompletableDeferred<ObservableSortedKeyedArrayList<String, ObservableTunnel>>()
private val context: Context = get()
- private val delayedLoadRestoreTunnels = ArrayList<CompletableFuture<Void>>()
private val tunnelMap: ObservableSortedKeyedArrayList<String, ObservableTunnel> = ObservableSortedKeyedArrayList(TunnelComparator)
private var haveLoaded = false
- private fun addToList(name: String, config: Config?, state: Tunnel.State): ObservableTunnel? {
+ private fun addToList(name: String, config: Config?, state: Tunnel.State): ObservableTunnel {
val tunnel = ObservableTunnel(this, name, config, state)
tunnelMap.add(tunnel)
return tunnel
}
- fun create(name: String, config: Config?): CompletionStage<ObservableTunnel> {
+ suspend fun getTunnels(): ObservableSortedKeyedArrayList<String, ObservableTunnel> = tunnels.await()
+
+ suspend fun create(name: String, config: Config?): ObservableTunnel = withContext(Dispatchers.Main.immediate) {
if (Tunnel.isNameInvalid(name))
- return CompletableFuture.failedFuture(IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name)))
+ throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name))
if (tunnelMap.containsKey(name))
- return CompletableFuture.failedFuture(IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name)))
- return getAsyncWorker().supplyAsync { configStore.create(name, config!!) }.thenApply { addToList(name, it, Tunnel.State.DOWN) }
+ throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name))
+ addToList(name, withContext(Dispatchers.IO) { configStore.create(name, config!!) }, Tunnel.State.DOWN)
}
- fun delete(tunnel: ObservableTunnel): CompletionStage<Void> {
+ suspend fun delete(tunnel: ObservableTunnel) = withContext(Dispatchers.Main.immediate) {
val originalState = tunnel.state
val wasLastUsed = tunnel == lastUsedTunnel
// Make sure nothing touches the tunnel.
if (wasLastUsed)
lastUsedTunnel = null
tunnelMap.remove(tunnel)
- return getAsyncWorker().runAsync {
+ try {
if (originalState == Tunnel.State.UP)
- getBackend().setState(tunnel, Tunnel.State.DOWN, null)
+ withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) }
try {
- configStore.delete(tunnel.name)
- } catch (e: Exception) {
+ withContext(Dispatchers.IO) { configStore.delete(tunnel.name) }
+ } catch (e: Throwable) {
if (originalState == Tunnel.State.UP)
- getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config)
+ withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) }
throw e
}
- }.whenComplete { _, e ->
- if (e == null)
- return@whenComplete
+ } catch (e: Throwable) {
// Failure, put the tunnel back.
tunnelMap.add(tunnel)
if (wasLastUsed)
lastUsedTunnel = tunnel
+ throw e
}
}
@get:Bindable
- @SuppressLint("ApplySharedPref")
var lastUsedTunnel: ObservableTunnel? = null
private set(value) {
if (value == field) return
field = value
notifyPropertyChanged(BR.lastUsedTunnel)
- if (value != null)
- getSharedPreferences().edit().putString(KEY_LAST_USED_TUNNEL, value.name).commit()
- else
- getSharedPreferences().edit().remove(KEY_LAST_USED_TUNNEL).commit()
+ applicationScope.launch { UserKnobs.setLastUsedTunnel(value?.name) }
}
- fun getTunnelConfig(tunnel: ObservableTunnel): CompletionStage<Config> = getAsyncWorker()
- .supplyAsync { configStore.load(tunnel.name) }.thenApply(tunnel::onConfigChanged)
-
+ suspend fun getTunnelConfig(tunnel: ObservableTunnel): Config = withContext(Dispatchers.Main.immediate) {
+ tunnel.onConfigChanged(withContext(Dispatchers.IO) { configStore.load(tunnel.name) })!!
+ }
fun onCreate() {
- getAsyncWorker().supplyAsync { configStore.enumerate() }
- .thenAcceptBoth(getAsyncWorker().supplyAsync { getBackend().runningTunnelNames }, this::onTunnelsLoaded)
- .whenComplete(ExceptionLoggers.E)
+ applicationScope.launch {
+ try {
+ onTunnelsLoaded(withContext(Dispatchers.IO) { configStore.enumerate() }, withContext(Dispatchers.IO) { getBackend().runningTunnelNames })
+ } catch (e: Throwable) {
+ Log.e(TAG, Log.getStackTraceString(e))
+ }
+ }
}
private fun onTunnelsLoaded(present: Iterable<String>, running: Collection<String>) {
for (name in present)
addToList(name, null, if (running.contains(name)) Tunnel.State.UP else Tunnel.State.DOWN)
- val lastUsedName = getSharedPreferences().getString(KEY_LAST_USED_TUNNEL, null)
- if (lastUsedName != null)
- lastUsedTunnel = tunnelMap[lastUsedName]
- var toComplete: Array<CompletableFuture<Void>>
- synchronized(delayedLoadRestoreTunnels) {
+ applicationScope.launch {
+ val lastUsedName = UserKnobs.lastUsedTunnel.first()
+ if (lastUsedName != null)
+ lastUsedTunnel = tunnelMap[lastUsedName]
haveLoaded = true
- toComplete = delayedLoadRestoreTunnels.toTypedArray()
- delayedLoadRestoreTunnels.clear()
+ restoreState(true)
+ tunnels.complete(tunnelMap)
}
- restoreState(true).whenComplete { v: Void?, t: Throwable? ->
- for (f in toComplete) {
- if (t == null)
- f.complete(v)
- else
- f.completeExceptionally(t)
- }
- }
- tunnels.complete(tunnelMap)
}
- fun refreshTunnelStates() {
- getAsyncWorker().supplyAsync { getBackend().runningTunnelNames }
- .thenAccept { running: Set<String> -> for (tunnel in tunnelMap) tunnel.onStateChanged(if (running.contains(tunnel.name)) Tunnel.State.UP else Tunnel.State.DOWN) }
- .whenComplete(ExceptionLoggers.E)
+ private fun refreshTunnelStates() {
+ applicationScope.launch {
+ try {
+ val running = withContext(Dispatchers.IO) { getBackend().runningTunnelNames }
+ for (tunnel in tunnelMap)
+ tunnel.onStateChanged(if (running.contains(tunnel.name)) Tunnel.State.UP else Tunnel.State.DOWN)
+ } catch (e: Throwable) {
+ Log.e(TAG, Log.getStackTraceString(e))
+ }
+ }
}
- fun restoreState(force: Boolean): CompletionStage<Void> {
- if (!force && !getSharedPreferences().getBoolean(KEY_RESTORE_ON_BOOT, false))
- return CompletableFuture.completedFuture(null)
- synchronized(delayedLoadRestoreTunnels) {
- if (!haveLoaded) {
- val f = CompletableFuture<Void>()
- delayedLoadRestoreTunnels.add(f)
- return f
+ suspend fun restoreState(force: Boolean) {
+ if (!haveLoaded || (!force && !UserKnobs.restoreOnBoot.first()))
+ return
+ val previouslyRunning = UserKnobs.runningTunnels.first()
+ if (previouslyRunning.isEmpty()) return
+ withContext(Dispatchers.IO) {
+ try {
+ tunnelMap.filter { previouslyRunning.contains(it.name) }.map { async(Dispatchers.IO + SupervisorJob()) { setTunnelState(it, Tunnel.State.UP) } }
+ .awaitAll()
+ } catch (e: Throwable) {
+ Log.e(TAG, Log.getStackTraceString(e))
}
}
- val previouslyRunning = getSharedPreferences().getStringSet(KEY_RUNNING_TUNNELS, null)
- ?: return CompletableFuture.completedFuture(null)
- return CompletableFuture.allOf(*tunnelMap.filter { previouslyRunning.contains(it.name) }.map { setTunnelState(it, Tunnel.State.UP).toCompletableFuture() }.toTypedArray())
}
- @SuppressLint("ApplySharedPref")
- fun saveState() {
- getSharedPreferences().edit().putStringSet(KEY_RUNNING_TUNNELS, tunnelMap.filter { it.state == Tunnel.State.UP }.map { it.name }.toSet()).commit()
+ suspend fun saveState() {
+ UserKnobs.setRunningTunnels(tunnelMap.filter { it.state == Tunnel.State.UP }.map { it.name }.toSet())
}
- fun setTunnelConfig(tunnel: ObservableTunnel, config: Config): CompletionStage<Config> = getAsyncWorker().supplyAsync {
- getBackend().setState(tunnel, tunnel.state, config)
- configStore.save(tunnel.name, config)
- }.thenApply { tunnel.onConfigChanged(it) }
+ suspend fun setTunnelConfig(tunnel: ObservableTunnel, config: Config): Config = withContext(Dispatchers.Main.immediate) {
+ tunnel.onConfigChanged(withContext(Dispatchers.IO) {
+ getBackend().setState(tunnel, tunnel.state, config)
+ configStore.save(tunnel.name, config)
+ })!!
+ }
- fun setTunnelName(tunnel: ObservableTunnel, name: String): CompletionStage<String> {
+ suspend fun setTunnelName(tunnel: ObservableTunnel, name: String): String = withContext(Dispatchers.Main.immediate) {
if (Tunnel.isNameInvalid(name))
- return CompletableFuture.failedFuture(IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name)))
+ throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name))
if (tunnelMap.containsKey(name)) {
- return CompletableFuture.failedFuture(IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name)))
+ throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name))
}
val originalState = tunnel.state
val wasLastUsed = tunnel == lastUsedTunnel
@@ -168,69 +171,85 @@ class TunnelManager(private val configStore: ConfigStore) : BaseObservable() {
if (wasLastUsed)
lastUsedTunnel = null
tunnelMap.remove(tunnel)
- return getAsyncWorker().supplyAsync {
+ var throwable: Throwable? = null
+ var newName: String? = null
+ try {
if (originalState == Tunnel.State.UP)
- getBackend().setState(tunnel, Tunnel.State.DOWN, null)
- configStore.rename(tunnel.name, name)
- val newName = tunnel.onNameChanged(name)
+ withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) }
+ withContext(Dispatchers.IO) { configStore.rename(tunnel.name, name) }
+ newName = tunnel.onNameChanged(name)
if (originalState == Tunnel.State.UP)
- getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config)
- newName
- }.whenComplete { _, e ->
+ withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) }
+ } catch (e: Throwable) {
+ throwable = e
// On failure, we don't know what state the tunnel might be in. Fix that.
- if (e != null)
- getTunnelState(tunnel)
- // Add the tunnel back to the manager, under whatever name it thinks it has.
- tunnelMap.add(tunnel)
- if (wasLastUsed)
- lastUsedTunnel = tunnel
+ getTunnelState(tunnel)
}
+ // Add the tunnel back to the manager, under whatever name it thinks it has.
+ tunnelMap.add(tunnel)
+ if (wasLastUsed)
+ lastUsedTunnel = tunnel
+ if (throwable != null)
+ throw throwable
+ newName!!
}
- fun setTunnelState(tunnel: ObservableTunnel, state: Tunnel.State): CompletionStage<Tunnel.State> = tunnel.configAsync
- .thenCompose { getAsyncWorker().supplyAsync { getBackend().setState(tunnel, state, it) } }
- .whenComplete { newState, e ->
- // Ensure onStateChanged is always called (failure or not), and with the correct state.
- tunnel.onStateChanged(if (e == null) newState else tunnel.state)
- if (e == null && newState == Tunnel.State.UP)
- lastUsedTunnel = tunnel
- saveState()
- }
+ suspend fun setTunnelState(tunnel: ObservableTunnel, state: Tunnel.State): Tunnel.State = withContext(Dispatchers.Main.immediate) {
+ var newState = tunnel.state
+ var throwable: Throwable? = null
+ try {
+ newState = withContext(Dispatchers.IO) { getBackend().setState(tunnel, state, tunnel.getConfigAsync()) }
+ if (newState == Tunnel.State.UP)
+ lastUsedTunnel = tunnel
+ } catch (e: Throwable) {
+ throwable = e
+ }
+ tunnel.onStateChanged(newState)
+ saveState()
+ if (throwable != null)
+ throw throwable
+ newState
+ }
class IntentReceiver : BroadcastReceiver() {
override fun onReceive(context: Context, intent: Intent?) {
- val manager = getTunnelManager()
- if (intent == null) return
- val action = intent.action ?: return
- if ("com.wireguard.android.action.REFRESH_TUNNEL_STATES" == action) {
- manager.refreshTunnelStates()
- return
- }
- if (Build.VERSION.SDK_INT < Build.VERSION_CODES.M || !getSharedPreferences().getBoolean("allow_remote_control_intents", false))
- return
- val state: Tunnel.State
- state = when (action) {
- "com.wireguard.android.action.SET_TUNNEL_UP" -> Tunnel.State.UP
- "com.wireguard.android.action.SET_TUNNEL_DOWN" -> Tunnel.State.DOWN
- else -> return
- }
- val tunnelName = intent.getStringExtra("tunnel") ?: return
- manager.tunnels.thenAccept {
- val tunnel = it[tunnelName] ?: return@thenAccept
- manager.setTunnelState(tunnel, state)
+ applicationScope.launch {
+ val manager = getTunnelManager()
+ if (intent == null) return@launch
+ val action = intent.action ?: return@launch
+ if ("com.wireguard.android.action.REFRESH_TUNNEL_STATES" == action) {
+ manager.refreshTunnelStates()
+ return@launch
+ }
+ if (Build.VERSION.SDK_INT < Build.VERSION_CODES.M || !UserKnobs.allowRemoteControlIntents.first())
+ return@launch
+ val state: Tunnel.State
+ state = when (action) {
+ "com.wireguard.android.action.SET_TUNNEL_UP" -> Tunnel.State.UP
+ "com.wireguard.android.action.SET_TUNNEL_DOWN" -> Tunnel.State.DOWN
+ else -> return@launch
+ }
+ val tunnelName = intent.getStringExtra("tunnel") ?: return@launch
+ val tunnels = manager.getTunnels()
+ val tunnel = tunnels[tunnelName] ?: return@launch
+ try {
+ manager.setTunnelState(tunnel, state)
+ } catch (e: Throwable) {
+ Toast.makeText(context, ErrorMessages[e], Toast.LENGTH_LONG).show()
+ }
}
}
}
- fun getTunnelState(tunnel: ObservableTunnel): CompletionStage<Tunnel.State> = getAsyncWorker()
- .supplyAsync { getBackend().getState(tunnel) }.thenApply(tunnel::onStateChanged)
+ suspend fun getTunnelState(tunnel: ObservableTunnel): Tunnel.State = withContext(Dispatchers.Main.immediate) {
+ tunnel.onStateChanged(withContext(Dispatchers.IO) { getBackend().getState(tunnel) })
+ }
- fun getTunnelStatistics(tunnel: ObservableTunnel): CompletionStage<Statistics> = getAsyncWorker()
- .supplyAsync { getBackend().getStatistics(tunnel) }.thenApply(tunnel::onStatisticsChanged)
+ suspend fun getTunnelStatistics(tunnel: ObservableTunnel): Statistics = withContext(Dispatchers.Main.immediate) {
+ tunnel.onStatisticsChanged(withContext(Dispatchers.IO) { getBackend().getStatistics(tunnel) })!!
+ }
companion object {
- private const val KEY_LAST_USED_TUNNEL = "last_used_tunnel"
- private const val KEY_RESTORE_ON_BOOT = "restore_on_boot"
- private const val KEY_RUNNING_TUNNELS = "enabled_configs"
+ private const val TAG = "WireGuard/TunnelManager"
}
}