diff options
Diffstat (limited to 'ui/src/main/java/com/wireguard/android/model/TunnelManager.kt')
-rw-r--r-- | ui/src/main/java/com/wireguard/android/model/TunnelManager.kt | 267 |
1 files changed, 143 insertions, 124 deletions
diff --git a/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt b/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt index 5091ed3b..d7c1391f 100644 --- a/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt +++ b/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt @@ -1,20 +1,19 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.model -import android.annotation.SuppressLint import android.content.BroadcastReceiver import android.content.Context import android.content.Intent import android.os.Build +import android.util.Log +import android.widget.Toast import androidx.databinding.BaseObservable import androidx.databinding.Bindable import com.wireguard.android.Application.Companion.get -import com.wireguard.android.Application.Companion.getAsyncWorker import com.wireguard.android.Application.Companion.getBackend -import com.wireguard.android.Application.Companion.getSharedPreferences import com.wireguard.android.Application.Companion.getTunnelManager import com.wireguard.android.BR import com.wireguard.android.R @@ -22,145 +21,149 @@ import com.wireguard.android.backend.Statistics import com.wireguard.android.backend.Tunnel import com.wireguard.android.configStore.ConfigStore import com.wireguard.android.databinding.ObservableSortedKeyedArrayList -import com.wireguard.android.util.ExceptionLoggers +import com.wireguard.android.util.ErrorMessages +import com.wireguard.android.util.UserKnobs +import com.wireguard.android.util.applicationScope import com.wireguard.config.Config -import java9.util.concurrent.CompletableFuture -import java9.util.concurrent.CompletionStage -import java.util.ArrayList +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext /** * Maintains and mediates changes to the set of available WireGuard tunnels, */ class TunnelManager(private val configStore: ConfigStore) : BaseObservable() { - val tunnels = CompletableFuture<ObservableSortedKeyedArrayList<String, ObservableTunnel>>() + private val tunnels = CompletableDeferred<ObservableSortedKeyedArrayList<String, ObservableTunnel>>() private val context: Context = get() - private val delayedLoadRestoreTunnels = ArrayList<CompletableFuture<Void>>() private val tunnelMap: ObservableSortedKeyedArrayList<String, ObservableTunnel> = ObservableSortedKeyedArrayList(TunnelComparator) private var haveLoaded = false - private fun addToList(name: String, config: Config?, state: Tunnel.State): ObservableTunnel? { + private fun addToList(name: String, config: Config?, state: Tunnel.State): ObservableTunnel { val tunnel = ObservableTunnel(this, name, config, state) tunnelMap.add(tunnel) return tunnel } - fun create(name: String, config: Config?): CompletionStage<ObservableTunnel> { + suspend fun getTunnels(): ObservableSortedKeyedArrayList<String, ObservableTunnel> = tunnels.await() + + suspend fun create(name: String, config: Config?): ObservableTunnel = withContext(Dispatchers.Main.immediate) { if (Tunnel.isNameInvalid(name)) - return CompletableFuture.failedFuture(IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name))) + throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name)) if (tunnelMap.containsKey(name)) - return CompletableFuture.failedFuture(IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name))) - return getAsyncWorker().supplyAsync { configStore.create(name, config!!) }.thenApply { addToList(name, it, Tunnel.State.DOWN) } + throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name)) + addToList(name, withContext(Dispatchers.IO) { configStore.create(name, config!!) }, Tunnel.State.DOWN) } - fun delete(tunnel: ObservableTunnel): CompletionStage<Void> { + suspend fun delete(tunnel: ObservableTunnel) = withContext(Dispatchers.Main.immediate) { val originalState = tunnel.state val wasLastUsed = tunnel == lastUsedTunnel // Make sure nothing touches the tunnel. if (wasLastUsed) lastUsedTunnel = null tunnelMap.remove(tunnel) - return getAsyncWorker().runAsync { + try { if (originalState == Tunnel.State.UP) - getBackend().setState(tunnel, Tunnel.State.DOWN, null) + withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) } try { - configStore.delete(tunnel.name) - } catch (e: Exception) { + withContext(Dispatchers.IO) { configStore.delete(tunnel.name) } + } catch (e: Throwable) { if (originalState == Tunnel.State.UP) - getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) + withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) } throw e } - }.whenComplete { _, e -> - if (e == null) - return@whenComplete + } catch (e: Throwable) { // Failure, put the tunnel back. tunnelMap.add(tunnel) if (wasLastUsed) lastUsedTunnel = tunnel + throw e } } @get:Bindable - @SuppressLint("ApplySharedPref") var lastUsedTunnel: ObservableTunnel? = null private set(value) { if (value == field) return field = value notifyPropertyChanged(BR.lastUsedTunnel) - if (value != null) - getSharedPreferences().edit().putString(KEY_LAST_USED_TUNNEL, value.name).commit() - else - getSharedPreferences().edit().remove(KEY_LAST_USED_TUNNEL).commit() + applicationScope.launch { UserKnobs.setLastUsedTunnel(value?.name) } } - fun getTunnelConfig(tunnel: ObservableTunnel): CompletionStage<Config> = getAsyncWorker() - .supplyAsync { configStore.load(tunnel.name) }.thenApply(tunnel::onConfigChanged) - + suspend fun getTunnelConfig(tunnel: ObservableTunnel): Config = withContext(Dispatchers.Main.immediate) { + tunnel.onConfigChanged(withContext(Dispatchers.IO) { configStore.load(tunnel.name) })!! + } fun onCreate() { - getAsyncWorker().supplyAsync { configStore.enumerate() } - .thenAcceptBoth(getAsyncWorker().supplyAsync { getBackend().runningTunnelNames }, this::onTunnelsLoaded) - .whenComplete(ExceptionLoggers.E) + applicationScope.launch { + try { + onTunnelsLoaded(withContext(Dispatchers.IO) { configStore.enumerate() }, withContext(Dispatchers.IO) { getBackend().runningTunnelNames }) + } catch (e: Throwable) { + Log.e(TAG, Log.getStackTraceString(e)) + } + } } private fun onTunnelsLoaded(present: Iterable<String>, running: Collection<String>) { for (name in present) addToList(name, null, if (running.contains(name)) Tunnel.State.UP else Tunnel.State.DOWN) - val lastUsedName = getSharedPreferences().getString(KEY_LAST_USED_TUNNEL, null) - if (lastUsedName != null) - lastUsedTunnel = tunnelMap[lastUsedName] - var toComplete: Array<CompletableFuture<Void>> - synchronized(delayedLoadRestoreTunnels) { + applicationScope.launch { + val lastUsedName = UserKnobs.lastUsedTunnel.first() + if (lastUsedName != null) + lastUsedTunnel = tunnelMap[lastUsedName] haveLoaded = true - toComplete = delayedLoadRestoreTunnels.toTypedArray() - delayedLoadRestoreTunnels.clear() + restoreState(true) + tunnels.complete(tunnelMap) } - restoreState(true).whenComplete { v: Void?, t: Throwable? -> - for (f in toComplete) { - if (t == null) - f.complete(v) - else - f.completeExceptionally(t) - } - } - tunnels.complete(tunnelMap) } - fun refreshTunnelStates() { - getAsyncWorker().supplyAsync { getBackend().runningTunnelNames } - .thenAccept { running: Set<String> -> for (tunnel in tunnelMap) tunnel.onStateChanged(if (running.contains(tunnel.name)) Tunnel.State.UP else Tunnel.State.DOWN) } - .whenComplete(ExceptionLoggers.E) + private fun refreshTunnelStates() { + applicationScope.launch { + try { + val running = withContext(Dispatchers.IO) { getBackend().runningTunnelNames } + for (tunnel in tunnelMap) + tunnel.onStateChanged(if (running.contains(tunnel.name)) Tunnel.State.UP else Tunnel.State.DOWN) + } catch (e: Throwable) { + Log.e(TAG, Log.getStackTraceString(e)) + } + } } - fun restoreState(force: Boolean): CompletionStage<Void> { - if (!force && !getSharedPreferences().getBoolean(KEY_RESTORE_ON_BOOT, false)) - return CompletableFuture.completedFuture(null) - synchronized(delayedLoadRestoreTunnels) { - if (!haveLoaded) { - val f = CompletableFuture<Void>() - delayedLoadRestoreTunnels.add(f) - return f + suspend fun restoreState(force: Boolean) { + if (!haveLoaded || (!force && !UserKnobs.restoreOnBoot.first())) + return + val previouslyRunning = UserKnobs.runningTunnels.first() + if (previouslyRunning.isEmpty()) return + withContext(Dispatchers.IO) { + try { + tunnelMap.filter { previouslyRunning.contains(it.name) }.map { async(Dispatchers.IO + SupervisorJob()) { setTunnelState(it, Tunnel.State.UP) } } + .awaitAll() + } catch (e: Throwable) { + Log.e(TAG, Log.getStackTraceString(e)) } } - val previouslyRunning = getSharedPreferences().getStringSet(KEY_RUNNING_TUNNELS, null) - ?: return CompletableFuture.completedFuture(null) - return CompletableFuture.allOf(*tunnelMap.filter { previouslyRunning.contains(it.name) }.map { setTunnelState(it, Tunnel.State.UP).toCompletableFuture() }.toTypedArray()) } - @SuppressLint("ApplySharedPref") - fun saveState() { - getSharedPreferences().edit().putStringSet(KEY_RUNNING_TUNNELS, tunnelMap.filter { it.state == Tunnel.State.UP }.map { it.name }.toSet()).commit() + suspend fun saveState() { + UserKnobs.setRunningTunnels(tunnelMap.filter { it.state == Tunnel.State.UP }.map { it.name }.toSet()) } - fun setTunnelConfig(tunnel: ObservableTunnel, config: Config): CompletionStage<Config> = getAsyncWorker().supplyAsync { - getBackend().setState(tunnel, tunnel.state, config) - configStore.save(tunnel.name, config) - }.thenApply { tunnel.onConfigChanged(it) } + suspend fun setTunnelConfig(tunnel: ObservableTunnel, config: Config): Config = withContext(Dispatchers.Main.immediate) { + tunnel.onConfigChanged(withContext(Dispatchers.IO) { + getBackend().setState(tunnel, tunnel.state, config) + configStore.save(tunnel.name, config) + })!! + } - fun setTunnelName(tunnel: ObservableTunnel, name: String): CompletionStage<String> { + suspend fun setTunnelName(tunnel: ObservableTunnel, name: String): String = withContext(Dispatchers.Main.immediate) { if (Tunnel.isNameInvalid(name)) - return CompletableFuture.failedFuture(IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name))) + throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name)) if (tunnelMap.containsKey(name)) { - return CompletableFuture.failedFuture(IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name))) + throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name)) } val originalState = tunnel.state val wasLastUsed = tunnel == lastUsedTunnel @@ -168,69 +171,85 @@ class TunnelManager(private val configStore: ConfigStore) : BaseObservable() { if (wasLastUsed) lastUsedTunnel = null tunnelMap.remove(tunnel) - return getAsyncWorker().supplyAsync { + var throwable: Throwable? = null + var newName: String? = null + try { if (originalState == Tunnel.State.UP) - getBackend().setState(tunnel, Tunnel.State.DOWN, null) - configStore.rename(tunnel.name, name) - val newName = tunnel.onNameChanged(name) + withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) } + withContext(Dispatchers.IO) { configStore.rename(tunnel.name, name) } + newName = tunnel.onNameChanged(name) if (originalState == Tunnel.State.UP) - getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) - newName - }.whenComplete { _, e -> + withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) } + } catch (e: Throwable) { + throwable = e // On failure, we don't know what state the tunnel might be in. Fix that. - if (e != null) - getTunnelState(tunnel) - // Add the tunnel back to the manager, under whatever name it thinks it has. - tunnelMap.add(tunnel) - if (wasLastUsed) - lastUsedTunnel = tunnel + getTunnelState(tunnel) } + // Add the tunnel back to the manager, under whatever name it thinks it has. + tunnelMap.add(tunnel) + if (wasLastUsed) + lastUsedTunnel = tunnel + if (throwable != null) + throw throwable + newName!! } - fun setTunnelState(tunnel: ObservableTunnel, state: Tunnel.State): CompletionStage<Tunnel.State> = tunnel.configAsync - .thenCompose { getAsyncWorker().supplyAsync { getBackend().setState(tunnel, state, it) } } - .whenComplete { newState, e -> - // Ensure onStateChanged is always called (failure or not), and with the correct state. - tunnel.onStateChanged(if (e == null) newState else tunnel.state) - if (e == null && newState == Tunnel.State.UP) - lastUsedTunnel = tunnel - saveState() - } + suspend fun setTunnelState(tunnel: ObservableTunnel, state: Tunnel.State): Tunnel.State = withContext(Dispatchers.Main.immediate) { + var newState = tunnel.state + var throwable: Throwable? = null + try { + newState = withContext(Dispatchers.IO) { getBackend().setState(tunnel, state, tunnel.getConfigAsync()) } + if (newState == Tunnel.State.UP) + lastUsedTunnel = tunnel + } catch (e: Throwable) { + throwable = e + } + tunnel.onStateChanged(newState) + saveState() + if (throwable != null) + throw throwable + newState + } class IntentReceiver : BroadcastReceiver() { override fun onReceive(context: Context, intent: Intent?) { - val manager = getTunnelManager() - if (intent == null) return - val action = intent.action ?: return - if ("com.wireguard.android.action.REFRESH_TUNNEL_STATES" == action) { - manager.refreshTunnelStates() - return - } - if (Build.VERSION.SDK_INT < Build.VERSION_CODES.M || !getSharedPreferences().getBoolean("allow_remote_control_intents", false)) - return - val state: Tunnel.State - state = when (action) { - "com.wireguard.android.action.SET_TUNNEL_UP" -> Tunnel.State.UP - "com.wireguard.android.action.SET_TUNNEL_DOWN" -> Tunnel.State.DOWN - else -> return - } - val tunnelName = intent.getStringExtra("tunnel") ?: return - manager.tunnels.thenAccept { - val tunnel = it[tunnelName] ?: return@thenAccept - manager.setTunnelState(tunnel, state) + applicationScope.launch { + val manager = getTunnelManager() + if (intent == null) return@launch + val action = intent.action ?: return@launch + if ("com.wireguard.android.action.REFRESH_TUNNEL_STATES" == action) { + manager.refreshTunnelStates() + return@launch + } + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.M || !UserKnobs.allowRemoteControlIntents.first()) + return@launch + val state: Tunnel.State + state = when (action) { + "com.wireguard.android.action.SET_TUNNEL_UP" -> Tunnel.State.UP + "com.wireguard.android.action.SET_TUNNEL_DOWN" -> Tunnel.State.DOWN + else -> return@launch + } + val tunnelName = intent.getStringExtra("tunnel") ?: return@launch + val tunnels = manager.getTunnels() + val tunnel = tunnels[tunnelName] ?: return@launch + try { + manager.setTunnelState(tunnel, state) + } catch (e: Throwable) { + Toast.makeText(context, ErrorMessages[e], Toast.LENGTH_LONG).show() + } } } } - fun getTunnelState(tunnel: ObservableTunnel): CompletionStage<Tunnel.State> = getAsyncWorker() - .supplyAsync { getBackend().getState(tunnel) }.thenApply(tunnel::onStateChanged) + suspend fun getTunnelState(tunnel: ObservableTunnel): Tunnel.State = withContext(Dispatchers.Main.immediate) { + tunnel.onStateChanged(withContext(Dispatchers.IO) { getBackend().getState(tunnel) }) + } - fun getTunnelStatistics(tunnel: ObservableTunnel): CompletionStage<Statistics> = getAsyncWorker() - .supplyAsync { getBackend().getStatistics(tunnel) }.thenApply(tunnel::onStatisticsChanged) + suspend fun getTunnelStatistics(tunnel: ObservableTunnel): Statistics = withContext(Dispatchers.Main.immediate) { + tunnel.onStatisticsChanged(withContext(Dispatchers.IO) { getBackend().getStatistics(tunnel) })!! + } companion object { - private const val KEY_LAST_USED_TUNNEL = "last_used_tunnel" - private const val KEY_RESTORE_ON_BOOT = "restore_on_boot" - private const val KEY_RUNNING_TUNNELS = "enabled_configs" + private const val TAG = "WireGuard/TunnelManager" } } |