From bab70ab51ecc02c2e8afd1843cdd4d90ae9cc257 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Mon, 14 Sep 2020 19:46:49 +0200 Subject: coroutines: convert the rest Signed-off-by: Jason A. Donenfeld --- .../com/wireguard/android/model/TunnelManager.kt | 204 ++++++++++++--------- 1 file changed, 113 insertions(+), 91 deletions(-) (limited to 'ui/src/main/java/com/wireguard/android/model/TunnelManager.kt') diff --git a/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt b/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt index 5091ed3b..b06585e4 100644 --- a/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt +++ b/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt @@ -9,10 +9,10 @@ import android.content.BroadcastReceiver import android.content.Context import android.content.Intent import android.os.Build +import android.util.Log import androidx.databinding.BaseObservable import androidx.databinding.Bindable import com.wireguard.android.Application.Companion.get -import com.wireguard.android.Application.Companion.getAsyncWorker import com.wireguard.android.Application.Companion.getBackend import com.wireguard.android.Application.Companion.getSharedPreferences import com.wireguard.android.Application.Companion.getTunnelManager @@ -22,60 +22,64 @@ import com.wireguard.android.backend.Statistics import com.wireguard.android.backend.Tunnel import com.wireguard.android.configStore.ConfigStore import com.wireguard.android.databinding.ObservableSortedKeyedArrayList -import com.wireguard.android.util.ExceptionLoggers import com.wireguard.config.Config -import java9.util.concurrent.CompletableFuture -import java9.util.concurrent.CompletionStage -import java.util.ArrayList +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.GlobalScope +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext /** * Maintains and mediates changes to the set of available WireGuard tunnels, */ class TunnelManager(private val configStore: ConfigStore) : BaseObservable() { - val tunnels = CompletableFuture>() + private val tunnels = CompletableDeferred>() private val context: Context = get() - private val delayedLoadRestoreTunnels = ArrayList>() private val tunnelMap: ObservableSortedKeyedArrayList = ObservableSortedKeyedArrayList(TunnelComparator) private var haveLoaded = false - private fun addToList(name: String, config: Config?, state: Tunnel.State): ObservableTunnel? { + private fun addToList(name: String, config: Config?, state: Tunnel.State): ObservableTunnel { val tunnel = ObservableTunnel(this, name, config, state) tunnelMap.add(tunnel) return tunnel } - fun create(name: String, config: Config?): CompletionStage { + suspend fun getTunnels(): ObservableSortedKeyedArrayList = tunnels.await() + + suspend fun create(name: String, config: Config?): ObservableTunnel = withContext(Dispatchers.Main.immediate) { if (Tunnel.isNameInvalid(name)) - return CompletableFuture.failedFuture(IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name))) + throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name)) if (tunnelMap.containsKey(name)) - return CompletableFuture.failedFuture(IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name))) - return getAsyncWorker().supplyAsync { configStore.create(name, config!!) }.thenApply { addToList(name, it, Tunnel.State.DOWN) } + throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name)) + addToList(name, withContext(Dispatchers.IO) { configStore.create(name, config!!) }, Tunnel.State.DOWN) } - fun delete(tunnel: ObservableTunnel): CompletionStage { + suspend fun delete(tunnel: ObservableTunnel) = withContext(Dispatchers.Main.immediate) { val originalState = tunnel.state val wasLastUsed = tunnel == lastUsedTunnel // Make sure nothing touches the tunnel. if (wasLastUsed) lastUsedTunnel = null tunnelMap.remove(tunnel) - return getAsyncWorker().runAsync { + try { if (originalState == Tunnel.State.UP) - getBackend().setState(tunnel, Tunnel.State.DOWN, null) + withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) } try { - configStore.delete(tunnel.name) - } catch (e: Exception) { + withContext(Dispatchers.IO) { configStore.delete(tunnel.name) } + } catch (e: Throwable) { if (originalState == Tunnel.State.UP) - getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) + withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) } throw e } - }.whenComplete { _, e -> - if (e == null) - return@whenComplete + } catch (e: Throwable) { // Failure, put the tunnel back. tunnelMap.add(tunnel) if (wasLastUsed) lastUsedTunnel = tunnel + throw e } } @@ -92,14 +96,18 @@ class TunnelManager(private val configStore: ConfigStore) : BaseObservable() { getSharedPreferences().edit().remove(KEY_LAST_USED_TUNNEL).commit() } - fun getTunnelConfig(tunnel: ObservableTunnel): CompletionStage = getAsyncWorker() - .supplyAsync { configStore.load(tunnel.name) }.thenApply(tunnel::onConfigChanged) - + suspend fun getTunnelConfig(tunnel: ObservableTunnel): Config = withContext(Dispatchers.Main.immediate) { + tunnel.onConfigChanged(withContext(Dispatchers.IO) { configStore.load(tunnel.name) })!! + } fun onCreate() { - getAsyncWorker().supplyAsync { configStore.enumerate() } - .thenAcceptBoth(getAsyncWorker().supplyAsync { getBackend().runningTunnelNames }, this::onTunnelsLoaded) - .whenComplete(ExceptionLoggers.E) + GlobalScope.launch(Dispatchers.Main.immediate) { + try { + onTunnelsLoaded(withContext(Dispatchers.IO) { configStore.enumerate() }, withContext(Dispatchers.IO) { getBackend().runningTunnelNames }) + } catch (e: Throwable) { + Log.println(Log.ERROR, TAG, Log.getStackTraceString(e)) + } + } } private fun onTunnelsLoaded(present: Iterable, running: Collection) { @@ -108,42 +116,38 @@ class TunnelManager(private val configStore: ConfigStore) : BaseObservable() { val lastUsedName = getSharedPreferences().getString(KEY_LAST_USED_TUNNEL, null) if (lastUsedName != null) lastUsedTunnel = tunnelMap[lastUsedName] - var toComplete: Array> - synchronized(delayedLoadRestoreTunnels) { - haveLoaded = true - toComplete = delayedLoadRestoreTunnels.toTypedArray() - delayedLoadRestoreTunnels.clear() - } - restoreState(true).whenComplete { v: Void?, t: Throwable? -> - for (f in toComplete) { - if (t == null) - f.complete(v) - else - f.completeExceptionally(t) - } - } + haveLoaded = true + restoreState(true) tunnels.complete(tunnelMap) } - fun refreshTunnelStates() { - getAsyncWorker().supplyAsync { getBackend().runningTunnelNames } - .thenAccept { running: Set -> for (tunnel in tunnelMap) tunnel.onStateChanged(if (running.contains(tunnel.name)) Tunnel.State.UP else Tunnel.State.DOWN) } - .whenComplete(ExceptionLoggers.E) + private fun refreshTunnelStates() { + GlobalScope.launch(Dispatchers.Main.immediate) { + try { + val running = withContext(Dispatchers.IO) { getBackend().runningTunnelNames } + for (tunnel in tunnelMap) + tunnel.onStateChanged(if (running.contains(tunnel.name)) Tunnel.State.UP else Tunnel.State.DOWN) + } catch (e: Throwable) { + Log.println(Log.ERROR, TAG, Log.getStackTraceString(e)) + } + } } - fun restoreState(force: Boolean): CompletionStage { - if (!force && !getSharedPreferences().getBoolean(KEY_RESTORE_ON_BOOT, false)) - return CompletableFuture.completedFuture(null) - synchronized(delayedLoadRestoreTunnels) { - if (!haveLoaded) { - val f = CompletableFuture() - delayedLoadRestoreTunnels.add(f) - return f + fun restoreState(force: Boolean) { + if (!haveLoaded || (!force && !getSharedPreferences().getBoolean(KEY_RESTORE_ON_BOOT, false))) + return + val previouslyRunning = getSharedPreferences().getStringSet(KEY_RUNNING_TUNNELS, null) + ?: return + if (previouslyRunning.isEmpty()) return + GlobalScope.launch(Dispatchers.Main.immediate) { + withContext(Dispatchers.IO) { + try { + tunnelMap.filter { previouslyRunning.contains(it.name) }.map { async(SupervisorJob()) { setTunnelState(it, Tunnel.State.UP) } }.awaitAll() + } catch (e: Throwable) { + Log.println(Log.ERROR, TAG, Log.getStackTraceString(e)) + } } } - val previouslyRunning = getSharedPreferences().getStringSet(KEY_RUNNING_TUNNELS, null) - ?: return CompletableFuture.completedFuture(null) - return CompletableFuture.allOf(*tunnelMap.filter { previouslyRunning.contains(it.name) }.map { setTunnelState(it, Tunnel.State.UP).toCompletableFuture() }.toTypedArray()) } @SuppressLint("ApplySharedPref") @@ -151,16 +155,18 @@ class TunnelManager(private val configStore: ConfigStore) : BaseObservable() { getSharedPreferences().edit().putStringSet(KEY_RUNNING_TUNNELS, tunnelMap.filter { it.state == Tunnel.State.UP }.map { it.name }.toSet()).commit() } - fun setTunnelConfig(tunnel: ObservableTunnel, config: Config): CompletionStage = getAsyncWorker().supplyAsync { - getBackend().setState(tunnel, tunnel.state, config) - configStore.save(tunnel.name, config) - }.thenApply { tunnel.onConfigChanged(it) } + suspend fun setTunnelConfig(tunnel: ObservableTunnel, config: Config): Config = withContext(Dispatchers.Main.immediate) { + tunnel.onConfigChanged(withContext(Dispatchers.IO) { + getBackend().setState(tunnel, tunnel.state, config) + configStore.save(tunnel.name, config) + })!! + } - fun setTunnelName(tunnel: ObservableTunnel, name: String): CompletionStage { + suspend fun setTunnelName(tunnel: ObservableTunnel, name: String): String = withContext(Dispatchers.Main.immediate) { if (Tunnel.isNameInvalid(name)) - return CompletableFuture.failedFuture(IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name))) + throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name)) if (tunnelMap.containsKey(name)) { - return CompletableFuture.failedFuture(IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name))) + throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name)) } val originalState = tunnel.state val wasLastUsed = tunnel == lastUsedTunnel @@ -168,34 +174,45 @@ class TunnelManager(private val configStore: ConfigStore) : BaseObservable() { if (wasLastUsed) lastUsedTunnel = null tunnelMap.remove(tunnel) - return getAsyncWorker().supplyAsync { + var throwable: Throwable? = null + var newName: String? = null + try { if (originalState == Tunnel.State.UP) - getBackend().setState(tunnel, Tunnel.State.DOWN, null) - configStore.rename(tunnel.name, name) - val newName = tunnel.onNameChanged(name) + withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) } + withContext(Dispatchers.IO) { configStore.rename(tunnel.name, name) } + newName = tunnel.onNameChanged(name) if (originalState == Tunnel.State.UP) - getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) - newName - }.whenComplete { _, e -> + withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) } + } catch (e: Throwable) { + throwable = e // On failure, we don't know what state the tunnel might be in. Fix that. - if (e != null) - getTunnelState(tunnel) - // Add the tunnel back to the manager, under whatever name it thinks it has. - tunnelMap.add(tunnel) - if (wasLastUsed) - lastUsedTunnel = tunnel + getTunnelState(tunnel) } + // Add the tunnel back to the manager, under whatever name it thinks it has. + tunnelMap.add(tunnel) + if (wasLastUsed) + lastUsedTunnel = tunnel + if (throwable != null) + throw throwable + newName!! } - fun setTunnelState(tunnel: ObservableTunnel, state: Tunnel.State): CompletionStage = tunnel.configAsync - .thenCompose { getAsyncWorker().supplyAsync { getBackend().setState(tunnel, state, it) } } - .whenComplete { newState, e -> - // Ensure onStateChanged is always called (failure or not), and with the correct state. - tunnel.onStateChanged(if (e == null) newState else tunnel.state) - if (e == null && newState == Tunnel.State.UP) - lastUsedTunnel = tunnel - saveState() - } + suspend fun setTunnelState(tunnel: ObservableTunnel, state: Tunnel.State): Tunnel.State = withContext(Dispatchers.Main.immediate) { + var newState = tunnel.state + var throwable: Throwable? = null + try { + newState = withContext(Dispatchers.IO) { getBackend().setState(tunnel, state, tunnel.getConfigAsync()) } + if (newState == Tunnel.State.UP) + lastUsedTunnel = tunnel + } catch (e: Throwable) { + throwable = e + } + tunnel.onStateChanged(newState) + saveState() + if (throwable != null) + throw throwable + newState + } class IntentReceiver : BroadcastReceiver() { override fun onReceive(context: Context, intent: Intent?) { @@ -215,20 +232,25 @@ class TunnelManager(private val configStore: ConfigStore) : BaseObservable() { else -> return } val tunnelName = intent.getStringExtra("tunnel") ?: return - manager.tunnels.thenAccept { - val tunnel = it[tunnelName] ?: return@thenAccept + GlobalScope.launch(Dispatchers.Main.immediate) { + val tunnels = manager.getTunnels() + val tunnel = tunnels[tunnelName] ?: return@launch manager.setTunnelState(tunnel, state) } } } - fun getTunnelState(tunnel: ObservableTunnel): CompletionStage = getAsyncWorker() - .supplyAsync { getBackend().getState(tunnel) }.thenApply(tunnel::onStateChanged) + suspend fun getTunnelState(tunnel: ObservableTunnel): Tunnel.State = withContext(Dispatchers.Main.immediate) { + tunnel.onStateChanged(withContext(Dispatchers.IO) { getBackend().getState(tunnel) }) + } - fun getTunnelStatistics(tunnel: ObservableTunnel): CompletionStage = getAsyncWorker() - .supplyAsync { getBackend().getStatistics(tunnel) }.thenApply(tunnel::onStatisticsChanged) + suspend fun getTunnelStatistics(tunnel: ObservableTunnel): Statistics = withContext(Dispatchers.Main.immediate) { + tunnel.onStatisticsChanged(withContext(Dispatchers.IO) { getBackend().getStatistics(tunnel) })!! + } companion object { + private const val TAG = "WireGuard/TunnelManager" + private const val KEY_LAST_USED_TUNNEL = "last_used_tunnel" private const val KEY_RESTORE_ON_BOOT = "restore_on_boot" private const val KEY_RUNNING_TUNNELS = "enabled_configs" -- cgit v1.2.3-59-g8ed1b