diff options
Diffstat (limited to 'ui/src/main/java/com/wireguard/android/model/TunnelManager.kt')
-rw-r--r-- | ui/src/main/java/com/wireguard/android/model/TunnelManager.kt | 255 |
1 files changed, 255 insertions, 0 deletions
diff --git a/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt b/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt new file mode 100644 index 00000000..d7c1391f --- /dev/null +++ b/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt @@ -0,0 +1,255 @@ +/* + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package com.wireguard.android.model + +import android.content.BroadcastReceiver +import android.content.Context +import android.content.Intent +import android.os.Build +import android.util.Log +import android.widget.Toast +import androidx.databinding.BaseObservable +import androidx.databinding.Bindable +import com.wireguard.android.Application.Companion.get +import com.wireguard.android.Application.Companion.getBackend +import com.wireguard.android.Application.Companion.getTunnelManager +import com.wireguard.android.BR +import com.wireguard.android.R +import com.wireguard.android.backend.Statistics +import com.wireguard.android.backend.Tunnel +import com.wireguard.android.configStore.ConfigStore +import com.wireguard.android.databinding.ObservableSortedKeyedArrayList +import com.wireguard.android.util.ErrorMessages +import com.wireguard.android.util.UserKnobs +import com.wireguard.android.util.applicationScope +import com.wireguard.config.Config +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext + +/** + * Maintains and mediates changes to the set of available WireGuard tunnels, + */ +class TunnelManager(private val configStore: ConfigStore) : BaseObservable() { + private val tunnels = CompletableDeferred<ObservableSortedKeyedArrayList<String, ObservableTunnel>>() + private val context: Context = get() + private val tunnelMap: ObservableSortedKeyedArrayList<String, ObservableTunnel> = ObservableSortedKeyedArrayList(TunnelComparator) + private var haveLoaded = false + + private fun addToList(name: String, config: Config?, state: Tunnel.State): ObservableTunnel { + val tunnel = ObservableTunnel(this, name, config, state) + tunnelMap.add(tunnel) + return tunnel + } + + suspend fun getTunnels(): ObservableSortedKeyedArrayList<String, ObservableTunnel> = tunnels.await() + + suspend fun create(name: String, config: Config?): ObservableTunnel = withContext(Dispatchers.Main.immediate) { + if (Tunnel.isNameInvalid(name)) + throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name)) + if (tunnelMap.containsKey(name)) + throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name)) + addToList(name, withContext(Dispatchers.IO) { configStore.create(name, config!!) }, Tunnel.State.DOWN) + } + + suspend fun delete(tunnel: ObservableTunnel) = withContext(Dispatchers.Main.immediate) { + val originalState = tunnel.state + val wasLastUsed = tunnel == lastUsedTunnel + // Make sure nothing touches the tunnel. + if (wasLastUsed) + lastUsedTunnel = null + tunnelMap.remove(tunnel) + try { + if (originalState == Tunnel.State.UP) + withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) } + try { + withContext(Dispatchers.IO) { configStore.delete(tunnel.name) } + } catch (e: Throwable) { + if (originalState == Tunnel.State.UP) + withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) } + throw e + } + } catch (e: Throwable) { + // Failure, put the tunnel back. + tunnelMap.add(tunnel) + if (wasLastUsed) + lastUsedTunnel = tunnel + throw e + } + } + + @get:Bindable + var lastUsedTunnel: ObservableTunnel? = null + private set(value) { + if (value == field) return + field = value + notifyPropertyChanged(BR.lastUsedTunnel) + applicationScope.launch { UserKnobs.setLastUsedTunnel(value?.name) } + } + + suspend fun getTunnelConfig(tunnel: ObservableTunnel): Config = withContext(Dispatchers.Main.immediate) { + tunnel.onConfigChanged(withContext(Dispatchers.IO) { configStore.load(tunnel.name) })!! + } + + fun onCreate() { + applicationScope.launch { + try { + onTunnelsLoaded(withContext(Dispatchers.IO) { configStore.enumerate() }, withContext(Dispatchers.IO) { getBackend().runningTunnelNames }) + } catch (e: Throwable) { + Log.e(TAG, Log.getStackTraceString(e)) + } + } + } + + private fun onTunnelsLoaded(present: Iterable<String>, running: Collection<String>) { + for (name in present) + addToList(name, null, if (running.contains(name)) Tunnel.State.UP else Tunnel.State.DOWN) + applicationScope.launch { + val lastUsedName = UserKnobs.lastUsedTunnel.first() + if (lastUsedName != null) + lastUsedTunnel = tunnelMap[lastUsedName] + haveLoaded = true + restoreState(true) + tunnels.complete(tunnelMap) + } + } + + private fun refreshTunnelStates() { + applicationScope.launch { + try { + val running = withContext(Dispatchers.IO) { getBackend().runningTunnelNames } + for (tunnel in tunnelMap) + tunnel.onStateChanged(if (running.contains(tunnel.name)) Tunnel.State.UP else Tunnel.State.DOWN) + } catch (e: Throwable) { + Log.e(TAG, Log.getStackTraceString(e)) + } + } + } + + suspend fun restoreState(force: Boolean) { + if (!haveLoaded || (!force && !UserKnobs.restoreOnBoot.first())) + return + val previouslyRunning = UserKnobs.runningTunnels.first() + if (previouslyRunning.isEmpty()) return + withContext(Dispatchers.IO) { + try { + tunnelMap.filter { previouslyRunning.contains(it.name) }.map { async(Dispatchers.IO + SupervisorJob()) { setTunnelState(it, Tunnel.State.UP) } } + .awaitAll() + } catch (e: Throwable) { + Log.e(TAG, Log.getStackTraceString(e)) + } + } + } + + suspend fun saveState() { + UserKnobs.setRunningTunnels(tunnelMap.filter { it.state == Tunnel.State.UP }.map { it.name }.toSet()) + } + + suspend fun setTunnelConfig(tunnel: ObservableTunnel, config: Config): Config = withContext(Dispatchers.Main.immediate) { + tunnel.onConfigChanged(withContext(Dispatchers.IO) { + getBackend().setState(tunnel, tunnel.state, config) + configStore.save(tunnel.name, config) + })!! + } + + suspend fun setTunnelName(tunnel: ObservableTunnel, name: String): String = withContext(Dispatchers.Main.immediate) { + if (Tunnel.isNameInvalid(name)) + throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name)) + if (tunnelMap.containsKey(name)) { + throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name)) + } + val originalState = tunnel.state + val wasLastUsed = tunnel == lastUsedTunnel + // Make sure nothing touches the tunnel. + if (wasLastUsed) + lastUsedTunnel = null + tunnelMap.remove(tunnel) + var throwable: Throwable? = null + var newName: String? = null + try { + if (originalState == Tunnel.State.UP) + withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) } + withContext(Dispatchers.IO) { configStore.rename(tunnel.name, name) } + newName = tunnel.onNameChanged(name) + if (originalState == Tunnel.State.UP) + withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) } + } catch (e: Throwable) { + throwable = e + // On failure, we don't know what state the tunnel might be in. Fix that. + getTunnelState(tunnel) + } + // Add the tunnel back to the manager, under whatever name it thinks it has. + tunnelMap.add(tunnel) + if (wasLastUsed) + lastUsedTunnel = tunnel + if (throwable != null) + throw throwable + newName!! + } + + suspend fun setTunnelState(tunnel: ObservableTunnel, state: Tunnel.State): Tunnel.State = withContext(Dispatchers.Main.immediate) { + var newState = tunnel.state + var throwable: Throwable? = null + try { + newState = withContext(Dispatchers.IO) { getBackend().setState(tunnel, state, tunnel.getConfigAsync()) } + if (newState == Tunnel.State.UP) + lastUsedTunnel = tunnel + } catch (e: Throwable) { + throwable = e + } + tunnel.onStateChanged(newState) + saveState() + if (throwable != null) + throw throwable + newState + } + + class IntentReceiver : BroadcastReceiver() { + override fun onReceive(context: Context, intent: Intent?) { + applicationScope.launch { + val manager = getTunnelManager() + if (intent == null) return@launch + val action = intent.action ?: return@launch + if ("com.wireguard.android.action.REFRESH_TUNNEL_STATES" == action) { + manager.refreshTunnelStates() + return@launch + } + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.M || !UserKnobs.allowRemoteControlIntents.first()) + return@launch + val state: Tunnel.State + state = when (action) { + "com.wireguard.android.action.SET_TUNNEL_UP" -> Tunnel.State.UP + "com.wireguard.android.action.SET_TUNNEL_DOWN" -> Tunnel.State.DOWN + else -> return@launch + } + val tunnelName = intent.getStringExtra("tunnel") ?: return@launch + val tunnels = manager.getTunnels() + val tunnel = tunnels[tunnelName] ?: return@launch + try { + manager.setTunnelState(tunnel, state) + } catch (e: Throwable) { + Toast.makeText(context, ErrorMessages[e], Toast.LENGTH_LONG).show() + } + } + } + } + + suspend fun getTunnelState(tunnel: ObservableTunnel): Tunnel.State = withContext(Dispatchers.Main.immediate) { + tunnel.onStateChanged(withContext(Dispatchers.IO) { getBackend().getState(tunnel) }) + } + + suspend fun getTunnelStatistics(tunnel: ObservableTunnel): Statistics = withContext(Dispatchers.Main.immediate) { + tunnel.onStatisticsChanged(withContext(Dispatchers.IO) { getBackend().getStatistics(tunnel) })!! + } + + companion object { + private const val TAG = "WireGuard/TunnelManager" + } +} |