aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorHarsh Shandilya <me@msfjarvis.dev>2020-04-08 18:03:30 +0530
committerHarsh Shandilya <me@msfjarvis.dev>2020-04-10 00:09:26 +0530
commitb04dc501691bf07dde0b0a833cde45f39a4cf721 (patch)
tree429c65d489e39d91b160866a92cc43bd08a438c8
parentui: switch to lifecycle scope where available (diff)
downloadwireguard-android-b04dc501691bf07dde0b0a833cde45f39a4cf721.tar.xz
wireguard-android-b04dc501691bf07dde0b0a833cde45f39a4cf721.zip
WIP: revisit async tunnel import
I am so good at this that I broke the code in ways I did not know were possible. Signed-off-by: Harsh Shandilya <me@msfjarvis.dev>
-rw-r--r--ui/src/main/java/com/wireguard/android/fragment/TunnelListFragment.kt158
-rw-r--r--ui/src/main/java/com/wireguard/android/model/TunnelManager.kt11
2 files changed, 86 insertions, 83 deletions
diff --git a/ui/src/main/java/com/wireguard/android/fragment/TunnelListFragment.kt b/ui/src/main/java/com/wireguard/android/fragment/TunnelListFragment.kt
index 4cc9e997..ccf0d584 100644
--- a/ui/src/main/java/com/wireguard/android/fragment/TunnelListFragment.kt
+++ b/ui/src/main/java/com/wireguard/android/fragment/TunnelListFragment.kt
@@ -36,16 +36,17 @@ import com.wireguard.android.widget.EdgeToEdge.setUpRoot
import com.wireguard.android.widget.EdgeToEdge.setUpScrollingContent
import com.wireguard.android.widget.MultiselectableRelativeLayout
import com.wireguard.config.Config
-import java9.util.concurrent.CompletableFuture
+import kotlinx.coroutines.Deferred
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.launch
+import kotlinx.coroutines.supervisorScope
+import kotlinx.coroutines.withContext
import java.io.BufferedReader
import java.io.ByteArrayInputStream
import java.io.InputStreamReader
import java.nio.charset.StandardCharsets
-import java.util.ArrayList
import java.util.HashSet
import java.util.Locale
import java.util.zip.ZipEntry
@@ -72,105 +73,98 @@ class TunnelListFragment : BaseFragment() {
}
}
- private fun importTunnel(uri: Uri?) {
+ private suspend fun importTunnel(uri: Uri?) {
val activity = activity
if (activity == null || uri == null) {
return
}
val contentResolver = activity.contentResolver
- val futureTunnels = ArrayList<CompletableFuture<ObservableTunnel>>()
+ val deferredTunnels = ArrayList<Deferred<ObservableTunnel>>()
val throwables = ArrayList<Throwable>()
- Application.getAsyncWorker().supplyAsync {
- val columns = arrayOf(OpenableColumns.DISPLAY_NAME)
- var name = ""
- contentResolver.query(uri, columns, null, null, null)?.use { cursor ->
- if (cursor.moveToFirst() && !cursor.isNull(0)) {
- name = cursor.getString(0)
- }
- }
- if (name.isEmpty()) {
- name = Uri.decode(uri.lastPathSegment)
- }
- var idx = name.lastIndexOf('/')
- if (idx >= 0) {
- require(idx < name.length - 1) { resources.getString(R.string.illegal_filename_error, name) }
- name = name.substring(idx + 1)
- }
- val isZip = name.toLowerCase(Locale.ROOT).endsWith(".zip")
- if (name.toLowerCase(Locale.ROOT).endsWith(".conf")) {
- name = name.substring(0, name.length - ".conf".length)
- } else {
- require(isZip) { resources.getString(R.string.bad_extension_error) }
+ val columns = arrayOf(OpenableColumns.DISPLAY_NAME)
+ var name = ""
+ contentResolver.query(uri, columns, null, null, null)?.use { cursor ->
+ if (cursor.moveToFirst() && !cursor.isNull(0)) {
+ name = cursor.getString(0)
}
+ }
+ if (name.isEmpty()) {
+ name = Uri.decode(uri.lastPathSegment)
+ }
+ var idx = name.lastIndexOf('/')
+ if (idx >= 0) {
+ require(idx < name.length - 1) { resources.getString(R.string.illegal_filename_error, name) }
+ name = name.substring(idx + 1)
+ }
+ val isZip = name.toLowerCase(Locale.ROOT).endsWith(".zip")
+ if (name.toLowerCase(Locale.ROOT).endsWith(".conf")) {
+ name = name.substring(0, name.length - ".conf".length)
+ } else {
+ require(isZip) { resources.getString(R.string.bad_extension_error) }
+ }
- if (isZip) {
- ZipInputStream(contentResolver.openInputStream(uri)).use { zip ->
- val reader = BufferedReader(InputStreamReader(zip, StandardCharsets.UTF_8))
- var entry: ZipEntry?
- while (true) {
- entry = zip.nextEntry ?: break
- name = entry.name
- idx = name.lastIndexOf('/')
- if (idx >= 0) {
- if (idx >= name.length - 1) {
+ withContext(Dispatchers.IO) {
+ supervisorScope {
+ if (isZip) {
+ ZipInputStream(contentResolver.openInputStream(uri)).use { zip ->
+ val reader = BufferedReader(InputStreamReader(zip, StandardCharsets.UTF_8))
+ var entry: ZipEntry?
+ while (true) {
+ entry = zip.nextEntry ?: break
+ name = entry.name
+ idx = name.lastIndexOf('/')
+ if (idx >= 0) {
+ if (idx >= name.length - 1) {
+ continue
+ }
+ name = name.substring(name.lastIndexOf('/') + 1)
+ }
+ if (name.toLowerCase(Locale.ROOT).endsWith(".conf")) {
+ name = name.substring(0, name.length - ".conf".length)
+ } else {
continue
}
- name = name.substring(name.lastIndexOf('/') + 1)
- }
- if (name.toLowerCase(Locale.ROOT).endsWith(".conf")) {
- name = name.substring(0, name.length - ".conf".length)
- } else {
- continue
- }
- try {
- Config.parse(reader)
- } catch (e: Exception) {
- throwables.add(e)
- null
- }?.let {
- futureTunnels.add(Application.getTunnelManager().create(name, it).toCompletableFuture())
+ try {
+ Config.parse(reader)
+ } catch (e: Exception) {
+ throwables.add(e)
+ null
+ }?.let {
+ deferredTunnels.add(async {
+ Application.getTunnelManager().createAsync(name, it)
+ })
+ }
}
}
- }
- } else {
- futureTunnels.add(
- Application.getTunnelManager().create(
- name,
- Config.parse(contentResolver.openInputStream(uri)!!)
- ).toCompletableFuture()
- )
- }
-
- if (futureTunnels.isEmpty()) {
- if (throwables.size == 1) {
- throw throwables[0]
} else {
- require(throwables.isNotEmpty()) { resources.getString(R.string.no_configs_error) }
+ deferredTunnels.add(async {
+ Application.getTunnelManager().createAsync(name, Config.parse(contentResolver.openInputStream(uri)!!))
+ })
}
}
- CompletableFuture.allOf(*futureTunnels.toTypedArray())
- }.whenComplete { future, exception ->
- if (exception != null) {
- onTunnelImportFinished(emptyList(), listOf(exception))
+ }
+
+ if (deferredTunnels.isEmpty()) {
+ if (throwables.size == 1) {
+ throw throwables[0]
} else {
- future.whenComplete { _, _ ->
- val tunnels = mutableListOf<ObservableTunnel>()
- for (futureTunnel in futureTunnels) {
- val tunnel: ObservableTunnel? = try {
- futureTunnel.getNow(null)
- } catch (e: Exception) {
- throwables.add(e)
- null
- }
+ require(throwables.isNotEmpty()) { resources.getString(R.string.no_configs_error) }
+ }
+ }
- if (tunnel != null) {
- tunnels.add(tunnel)
- }
- }
- onTunnelImportFinished(tunnels, throwables)
+ try {
+ val tunnels = mutableListOf<ObservableTunnel>()
+ deferredTunnels.forEach {
+ try {
+ tunnels.add(it.await())
+ } catch (e: Exception) {
+ throwables.add(e)
}
}
+ onTunnelImportFinished(tunnels, throwables)
+ } catch (e: Exception) {
+ onTunnelImportFinished(emptyList(), listOf(e))
}
}
@@ -187,7 +181,7 @@ class TunnelListFragment : BaseFragment() {
override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) {
when (requestCode) {
REQUEST_IMPORT -> {
- if (resultCode == Activity.RESULT_OK && data != null) importTunnel(data.data)
+ if (resultCode == Activity.RESULT_OK && data != null) coroutineScope.launch { importTunnel(data.data) }
return
}
IntentIntegrator.REQUEST_CODE -> {
diff --git a/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt b/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt
index 3d48082d..fa844e19 100644
--- a/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt
+++ b/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt
@@ -47,7 +47,7 @@ class TunnelManager(private val configStore: ConfigStore) : BaseObservable() {
private val tunnelMap: ObservableSortedKeyedArrayList<String, ObservableTunnel> = ObservableSortedKeyedArrayList(TunnelComparator)
private var haveLoaded = false
- private fun addToList(name: String, config: Config?, state: Tunnel.State): ObservableTunnel? {
+ private fun addToList(name: String, config: Config?, state: Tunnel.State): ObservableTunnel {
val tunnel = ObservableTunnel(this, name, config, state)
tunnelMap.add(tunnel)
return tunnel
@@ -61,6 +61,15 @@ class TunnelManager(private val configStore: ConfigStore) : BaseObservable() {
return getAsyncWorker().supplyAsync { configStore.create(name, config!!) }.thenApply { addToList(name, it, Tunnel.State.DOWN) }
}
+ suspend fun createAsync(name: String, config: Config?): ObservableTunnel {
+ if (Tunnel.isNameInvalid(name))
+ throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name))
+ if (tunnelMap.containsKey(name))
+ throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name))
+ val newConfig = configStore.create(name, config!!)
+ return withContext(Dispatchers.Main) { addToList(name, newConfig, Tunnel.State.DOWN) }
+ }
+
suspend fun delete(tunnel: ObservableTunnel) {
val originalState = tunnel.state
val wasLastUsed = tunnel == lastUsedTunnel