aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2020-09-14 19:46:49 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2020-09-15 12:30:15 +0200
commitbab70ab51ecc02c2e8afd1843cdd4d90ae9cc257 (patch)
treebd7117473f42dc6211d9aad4c78cbdddeb851b3e /ui/src/main/java/com/wireguard/android/model/TunnelManager.kt
parentcoroutines: convert low-hanging fruits (diff)
downloadwireguard-android-bab70ab51ecc02c2e8afd1843cdd4d90ae9cc257.tar.xz
wireguard-android-bab70ab51ecc02c2e8afd1843cdd4d90ae9cc257.zip
coroutines: convert the rest
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'ui/src/main/java/com/wireguard/android/model/TunnelManager.kt')
-rw-r--r--ui/src/main/java/com/wireguard/android/model/TunnelManager.kt204
1 files changed, 113 insertions, 91 deletions
diff --git a/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt b/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt
index 5091ed3b..b06585e4 100644
--- a/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt
+++ b/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt
@@ -9,10 +9,10 @@ import android.content.BroadcastReceiver
import android.content.Context
import android.content.Intent
import android.os.Build
+import android.util.Log
import androidx.databinding.BaseObservable
import androidx.databinding.Bindable
import com.wireguard.android.Application.Companion.get
-import com.wireguard.android.Application.Companion.getAsyncWorker
import com.wireguard.android.Application.Companion.getBackend
import com.wireguard.android.Application.Companion.getSharedPreferences
import com.wireguard.android.Application.Companion.getTunnelManager
@@ -22,60 +22,64 @@ import com.wireguard.android.backend.Statistics
import com.wireguard.android.backend.Tunnel
import com.wireguard.android.configStore.ConfigStore
import com.wireguard.android.databinding.ObservableSortedKeyedArrayList
-import com.wireguard.android.util.ExceptionLoggers
import com.wireguard.config.Config
-import java9.util.concurrent.CompletableFuture
-import java9.util.concurrent.CompletionStage
-import java.util.ArrayList
+import kotlinx.coroutines.CompletableDeferred
+import kotlinx.coroutines.Dispatchers
+import kotlinx.coroutines.GlobalScope
+import kotlinx.coroutines.SupervisorJob
+import kotlinx.coroutines.async
+import kotlinx.coroutines.awaitAll
+import kotlinx.coroutines.launch
+import kotlinx.coroutines.withContext
/**
* Maintains and mediates changes to the set of available WireGuard tunnels,
*/
class TunnelManager(private val configStore: ConfigStore) : BaseObservable() {
- val tunnels = CompletableFuture<ObservableSortedKeyedArrayList<String, ObservableTunnel>>()
+ private val tunnels = CompletableDeferred<ObservableSortedKeyedArrayList<String, ObservableTunnel>>()
private val context: Context = get()
- private val delayedLoadRestoreTunnels = ArrayList<CompletableFuture<Void>>()
private val tunnelMap: ObservableSortedKeyedArrayList<String, ObservableTunnel> = ObservableSortedKeyedArrayList(TunnelComparator)
private var haveLoaded = false
- private fun addToList(name: String, config: Config?, state: Tunnel.State): ObservableTunnel? {
+ private fun addToList(name: String, config: Config?, state: Tunnel.State): ObservableTunnel {
val tunnel = ObservableTunnel(this, name, config, state)
tunnelMap.add(tunnel)
return tunnel
}
- fun create(name: String, config: Config?): CompletionStage<ObservableTunnel> {
+ suspend fun getTunnels(): ObservableSortedKeyedArrayList<String, ObservableTunnel> = tunnels.await()
+
+ suspend fun create(name: String, config: Config?): ObservableTunnel = withContext(Dispatchers.Main.immediate) {
if (Tunnel.isNameInvalid(name))
- return CompletableFuture.failedFuture(IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name)))
+ throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name))
if (tunnelMap.containsKey(name))
- return CompletableFuture.failedFuture(IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name)))
- return getAsyncWorker().supplyAsync { configStore.create(name, config!!) }.thenApply { addToList(name, it, Tunnel.State.DOWN) }
+ throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name))
+ addToList(name, withContext(Dispatchers.IO) { configStore.create(name, config!!) }, Tunnel.State.DOWN)
}
- fun delete(tunnel: ObservableTunnel): CompletionStage<Void> {
+ suspend fun delete(tunnel: ObservableTunnel) = withContext(Dispatchers.Main.immediate) {
val originalState = tunnel.state
val wasLastUsed = tunnel == lastUsedTunnel
// Make sure nothing touches the tunnel.
if (wasLastUsed)
lastUsedTunnel = null
tunnelMap.remove(tunnel)
- return getAsyncWorker().runAsync {
+ try {
if (originalState == Tunnel.State.UP)
- getBackend().setState(tunnel, Tunnel.State.DOWN, null)
+ withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) }
try {
- configStore.delete(tunnel.name)
- } catch (e: Exception) {
+ withContext(Dispatchers.IO) { configStore.delete(tunnel.name) }
+ } catch (e: Throwable) {
if (originalState == Tunnel.State.UP)
- getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config)
+ withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) }
throw e
}
- }.whenComplete { _, e ->
- if (e == null)
- return@whenComplete
+ } catch (e: Throwable) {
// Failure, put the tunnel back.
tunnelMap.add(tunnel)
if (wasLastUsed)
lastUsedTunnel = tunnel
+ throw e
}
}
@@ -92,14 +96,18 @@ class TunnelManager(private val configStore: ConfigStore) : BaseObservable() {
getSharedPreferences().edit().remove(KEY_LAST_USED_TUNNEL).commit()
}
- fun getTunnelConfig(tunnel: ObservableTunnel): CompletionStage<Config> = getAsyncWorker()
- .supplyAsync { configStore.load(tunnel.name) }.thenApply(tunnel::onConfigChanged)
-
+ suspend fun getTunnelConfig(tunnel: ObservableTunnel): Config = withContext(Dispatchers.Main.immediate) {
+ tunnel.onConfigChanged(withContext(Dispatchers.IO) { configStore.load(tunnel.name) })!!
+ }
fun onCreate() {
- getAsyncWorker().supplyAsync { configStore.enumerate() }
- .thenAcceptBoth(getAsyncWorker().supplyAsync { getBackend().runningTunnelNames }, this::onTunnelsLoaded)
- .whenComplete(ExceptionLoggers.E)
+ GlobalScope.launch(Dispatchers.Main.immediate) {
+ try {
+ onTunnelsLoaded(withContext(Dispatchers.IO) { configStore.enumerate() }, withContext(Dispatchers.IO) { getBackend().runningTunnelNames })
+ } catch (e: Throwable) {
+ Log.println(Log.ERROR, TAG, Log.getStackTraceString(e))
+ }
+ }
}
private fun onTunnelsLoaded(present: Iterable<String>, running: Collection<String>) {
@@ -108,42 +116,38 @@ class TunnelManager(private val configStore: ConfigStore) : BaseObservable() {
val lastUsedName = getSharedPreferences().getString(KEY_LAST_USED_TUNNEL, null)
if (lastUsedName != null)
lastUsedTunnel = tunnelMap[lastUsedName]
- var toComplete: Array<CompletableFuture<Void>>
- synchronized(delayedLoadRestoreTunnels) {
- haveLoaded = true
- toComplete = delayedLoadRestoreTunnels.toTypedArray()
- delayedLoadRestoreTunnels.clear()
- }
- restoreState(true).whenComplete { v: Void?, t: Throwable? ->
- for (f in toComplete) {
- if (t == null)
- f.complete(v)
- else
- f.completeExceptionally(t)
- }
- }
+ haveLoaded = true
+ restoreState(true)
tunnels.complete(tunnelMap)
}
- fun refreshTunnelStates() {
- getAsyncWorker().supplyAsync { getBackend().runningTunnelNames }
- .thenAccept { running: Set<String> -> for (tunnel in tunnelMap) tunnel.onStateChanged(if (running.contains(tunnel.name)) Tunnel.State.UP else Tunnel.State.DOWN) }
- .whenComplete(ExceptionLoggers.E)
+ private fun refreshTunnelStates() {
+ GlobalScope.launch(Dispatchers.Main.immediate) {
+ try {
+ val running = withContext(Dispatchers.IO) { getBackend().runningTunnelNames }
+ for (tunnel in tunnelMap)
+ tunnel.onStateChanged(if (running.contains(tunnel.name)) Tunnel.State.UP else Tunnel.State.DOWN)
+ } catch (e: Throwable) {
+ Log.println(Log.ERROR, TAG, Log.getStackTraceString(e))
+ }
+ }
}
- fun restoreState(force: Boolean): CompletionStage<Void> {
- if (!force && !getSharedPreferences().getBoolean(KEY_RESTORE_ON_BOOT, false))
- return CompletableFuture.completedFuture(null)
- synchronized(delayedLoadRestoreTunnels) {
- if (!haveLoaded) {
- val f = CompletableFuture<Void>()
- delayedLoadRestoreTunnels.add(f)
- return f
+ fun restoreState(force: Boolean) {
+ if (!haveLoaded || (!force && !getSharedPreferences().getBoolean(KEY_RESTORE_ON_BOOT, false)))
+ return
+ val previouslyRunning = getSharedPreferences().getStringSet(KEY_RUNNING_TUNNELS, null)
+ ?: return
+ if (previouslyRunning.isEmpty()) return
+ GlobalScope.launch(Dispatchers.Main.immediate) {
+ withContext(Dispatchers.IO) {
+ try {
+ tunnelMap.filter { previouslyRunning.contains(it.name) }.map { async(SupervisorJob()) { setTunnelState(it, Tunnel.State.UP) } }.awaitAll()
+ } catch (e: Throwable) {
+ Log.println(Log.ERROR, TAG, Log.getStackTraceString(e))
+ }
}
}
- val previouslyRunning = getSharedPreferences().getStringSet(KEY_RUNNING_TUNNELS, null)
- ?: return CompletableFuture.completedFuture(null)
- return CompletableFuture.allOf(*tunnelMap.filter { previouslyRunning.contains(it.name) }.map { setTunnelState(it, Tunnel.State.UP).toCompletableFuture() }.toTypedArray())
}
@SuppressLint("ApplySharedPref")
@@ -151,16 +155,18 @@ class TunnelManager(private val configStore: ConfigStore) : BaseObservable() {
getSharedPreferences().edit().putStringSet(KEY_RUNNING_TUNNELS, tunnelMap.filter { it.state == Tunnel.State.UP }.map { it.name }.toSet()).commit()
}
- fun setTunnelConfig(tunnel: ObservableTunnel, config: Config): CompletionStage<Config> = getAsyncWorker().supplyAsync {
- getBackend().setState(tunnel, tunnel.state, config)
- configStore.save(tunnel.name, config)
- }.thenApply { tunnel.onConfigChanged(it) }
+ suspend fun setTunnelConfig(tunnel: ObservableTunnel, config: Config): Config = withContext(Dispatchers.Main.immediate) {
+ tunnel.onConfigChanged(withContext(Dispatchers.IO) {
+ getBackend().setState(tunnel, tunnel.state, config)
+ configStore.save(tunnel.name, config)
+ })!!
+ }
- fun setTunnelName(tunnel: ObservableTunnel, name: String): CompletionStage<String> {
+ suspend fun setTunnelName(tunnel: ObservableTunnel, name: String): String = withContext(Dispatchers.Main.immediate) {
if (Tunnel.isNameInvalid(name))
- return CompletableFuture.failedFuture(IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name)))
+ throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name))
if (tunnelMap.containsKey(name)) {
- return CompletableFuture.failedFuture(IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name)))
+ throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name))
}
val originalState = tunnel.state
val wasLastUsed = tunnel == lastUsedTunnel
@@ -168,34 +174,45 @@ class TunnelManager(private val configStore: ConfigStore) : BaseObservable() {
if (wasLastUsed)
lastUsedTunnel = null
tunnelMap.remove(tunnel)
- return getAsyncWorker().supplyAsync {
+ var throwable: Throwable? = null
+ var newName: String? = null
+ try {
if (originalState == Tunnel.State.UP)
- getBackend().setState(tunnel, Tunnel.State.DOWN, null)
- configStore.rename(tunnel.name, name)
- val newName = tunnel.onNameChanged(name)
+ withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) }
+ withContext(Dispatchers.IO) { configStore.rename(tunnel.name, name) }
+ newName = tunnel.onNameChanged(name)
if (originalState == Tunnel.State.UP)
- getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config)
- newName
- }.whenComplete { _, e ->
+ withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) }
+ } catch (e: Throwable) {
+ throwable = e
// On failure, we don't know what state the tunnel might be in. Fix that.
- if (e != null)
- getTunnelState(tunnel)
- // Add the tunnel back to the manager, under whatever name it thinks it has.
- tunnelMap.add(tunnel)
- if (wasLastUsed)
- lastUsedTunnel = tunnel
+ getTunnelState(tunnel)
}
+ // Add the tunnel back to the manager, under whatever name it thinks it has.
+ tunnelMap.add(tunnel)
+ if (wasLastUsed)
+ lastUsedTunnel = tunnel
+ if (throwable != null)
+ throw throwable
+ newName!!
}
- fun setTunnelState(tunnel: ObservableTunnel, state: Tunnel.State): CompletionStage<Tunnel.State> = tunnel.configAsync
- .thenCompose { getAsyncWorker().supplyAsync { getBackend().setState(tunnel, state, it) } }
- .whenComplete { newState, e ->
- // Ensure onStateChanged is always called (failure or not), and with the correct state.
- tunnel.onStateChanged(if (e == null) newState else tunnel.state)
- if (e == null && newState == Tunnel.State.UP)
- lastUsedTunnel = tunnel
- saveState()
- }
+ suspend fun setTunnelState(tunnel: ObservableTunnel, state: Tunnel.State): Tunnel.State = withContext(Dispatchers.Main.immediate) {
+ var newState = tunnel.state
+ var throwable: Throwable? = null
+ try {
+ newState = withContext(Dispatchers.IO) { getBackend().setState(tunnel, state, tunnel.getConfigAsync()) }
+ if (newState == Tunnel.State.UP)
+ lastUsedTunnel = tunnel
+ } catch (e: Throwable) {
+ throwable = e
+ }
+ tunnel.onStateChanged(newState)
+ saveState()
+ if (throwable != null)
+ throw throwable
+ newState
+ }
class IntentReceiver : BroadcastReceiver() {
override fun onReceive(context: Context, intent: Intent?) {
@@ -215,20 +232,25 @@ class TunnelManager(private val configStore: ConfigStore) : BaseObservable() {
else -> return
}
val tunnelName = intent.getStringExtra("tunnel") ?: return
- manager.tunnels.thenAccept {
- val tunnel = it[tunnelName] ?: return@thenAccept
+ GlobalScope.launch(Dispatchers.Main.immediate) {
+ val tunnels = manager.getTunnels()
+ val tunnel = tunnels[tunnelName] ?: return@launch
manager.setTunnelState(tunnel, state)
}
}
}
- fun getTunnelState(tunnel: ObservableTunnel): CompletionStage<Tunnel.State> = getAsyncWorker()
- .supplyAsync { getBackend().getState(tunnel) }.thenApply(tunnel::onStateChanged)
+ suspend fun getTunnelState(tunnel: ObservableTunnel): Tunnel.State = withContext(Dispatchers.Main.immediate) {
+ tunnel.onStateChanged(withContext(Dispatchers.IO) { getBackend().getState(tunnel) })
+ }
- fun getTunnelStatistics(tunnel: ObservableTunnel): CompletionStage<Statistics> = getAsyncWorker()
- .supplyAsync { getBackend().getStatistics(tunnel) }.thenApply(tunnel::onStatisticsChanged)
+ suspend fun getTunnelStatistics(tunnel: ObservableTunnel): Statistics = withContext(Dispatchers.Main.immediate) {
+ tunnel.onStatisticsChanged(withContext(Dispatchers.IO) { getBackend().getStatistics(tunnel) })!!
+ }
companion object {
+ private const val TAG = "WireGuard/TunnelManager"
+
private const val KEY_LAST_USED_TUNNEL = "last_used_tunnel"
private const val KEY_RESTORE_ON_BOOT = "restore_on_boot"
private const val KEY_RUNNING_TUNNELS = "enabled_configs"