diff options
Diffstat (limited to 'ui/src/main/java/com/wireguard')
66 files changed, 6000 insertions, 1731 deletions
diff --git a/ui/src/main/java/com/wireguard/android/Application.kt b/ui/src/main/java/com/wireguard/android/Application.kt index d533028f..74eaccf8 100644 --- a/ui/src/main/java/com/wireguard/android/Application.kt +++ b/ui/src/main/java/com/wireguard/android/Application.kt @@ -1,43 +1,51 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android import android.content.Context import android.content.Intent -import android.content.SharedPreferences -import android.content.SharedPreferences.OnSharedPreferenceChangeListener -import android.os.AsyncTask import android.os.Build -import android.os.Handler -import android.os.Looper import android.os.StrictMode +import android.os.StrictMode.ThreadPolicy import android.os.StrictMode.VmPolicy import android.util.Log import androidx.appcompat.app.AppCompatDelegate -import androidx.preference.PreferenceManager +import androidx.datastore.core.DataStore +import androidx.datastore.preferences.core.PreferenceDataStoreFactory +import androidx.datastore.preferences.core.Preferences +import androidx.datastore.preferences.preferencesDataStoreFile +import com.google.android.material.color.DynamicColors import com.wireguard.android.backend.Backend import com.wireguard.android.backend.GoBackend import com.wireguard.android.backend.WgQuickBackend import com.wireguard.android.configStore.FileConfigStore import com.wireguard.android.model.TunnelManager -import com.wireguard.android.util.AsyncWorker -import com.wireguard.android.util.ExceptionLoggers -import com.wireguard.android.util.ModuleLoader +import com.wireguard.android.updater.Updater import com.wireguard.android.util.RootShell import com.wireguard.android.util.ToolsInstaller -import java9.util.concurrent.CompletableFuture +import com.wireguard.android.util.UserKnobs +import com.wireguard.android.util.applicationScope +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.cancel +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.launchIn +import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking import java.lang.ref.WeakReference import java.util.Locale -class Application : android.app.Application(), OnSharedPreferenceChangeListener { - private val futureBackend = CompletableFuture<Backend>() - private lateinit var asyncWorker: AsyncWorker +class Application : android.app.Application() { + private val futureBackend = CompletableDeferred<Backend>() + private val coroutineScope = CoroutineScope(Job() + Dispatchers.Main.immediate) private var backend: Backend? = null - private lateinit var moduleLoader: ModuleLoader private lateinit var rootShell: RootShell - private lateinit var sharedPreferences: SharedPreferences + private lateinit var preferencesDataStore: DataStore<Preferences> private lateinit var toolsInstaller: ToolsInstaller private lateinit var tunnelManager: TunnelManager @@ -51,38 +59,73 @@ class Application : android.app.Application(), OnSharedPreferenceChangeListener startActivity(intent) System.exit(0) } - if (BuildConfig.DEBUG) { - StrictMode.setVmPolicy(VmPolicy.Builder().detectAll().penaltyLog().build()) + } + + private suspend fun determineBackend(): Backend { + var backend: Backend? = null + if (UserKnobs.enableKernelModule.first() && WgQuickBackend.hasKernelSupport()) { + try { + rootShell.start() + val wgQuickBackend = WgQuickBackend(applicationContext, rootShell, toolsInstaller) + wgQuickBackend.setMultipleTunnels(UserKnobs.multipleTunnels.first()) + backend = wgQuickBackend + UserKnobs.multipleTunnels.onEach { + wgQuickBackend.setMultipleTunnels(it) + }.launchIn(coroutineScope) + } catch (ignored: Exception) { + } + } + if (backend == null) { + backend = GoBackend(applicationContext) + GoBackend.setAlwaysOnCallback { get().applicationScope.launch { get().tunnelManager.restoreState(true) } } } + return backend } override fun onCreate() { Log.i(TAG, USER_AGENT) super.onCreate() - asyncWorker = AsyncWorker(AsyncTask.SERIAL_EXECUTOR, Handler(Looper.getMainLooper())) + DynamicColors.applyToActivitiesIfAvailable(this) rootShell = RootShell(applicationContext) toolsInstaller = ToolsInstaller(applicationContext, rootShell) - moduleLoader = ModuleLoader(applicationContext, rootShell, USER_AGENT) - sharedPreferences = PreferenceManager.getDefaultSharedPreferences(applicationContext) + preferencesDataStore = PreferenceDataStoreFactory.create { applicationContext.preferencesDataStoreFile("settings") } if (Build.VERSION.SDK_INT < Build.VERSION_CODES.Q) { - AppCompatDelegate.setDefaultNightMode( - if (sharedPreferences.getBoolean("dark_theme", false)) AppCompatDelegate.MODE_NIGHT_YES else AppCompatDelegate.MODE_NIGHT_NO) + runBlocking { + AppCompatDelegate.setDefaultNightMode(if (UserKnobs.darkTheme.first()) AppCompatDelegate.MODE_NIGHT_YES else AppCompatDelegate.MODE_NIGHT_NO) + } + UserKnobs.darkTheme.onEach { + val newMode = if (it) { + AppCompatDelegate.MODE_NIGHT_YES + } else { + AppCompatDelegate.MODE_NIGHT_NO + } + if (AppCompatDelegate.getDefaultNightMode() != newMode) { + AppCompatDelegate.setDefaultNightMode(newMode) + } + }.launchIn(coroutineScope) } else { AppCompatDelegate.setDefaultNightMode(AppCompatDelegate.MODE_NIGHT_FOLLOW_SYSTEM) } tunnelManager = TunnelManager(FileConfigStore(applicationContext)) tunnelManager.onCreate() - asyncWorker.supplyAsync(Companion::getBackend).thenAccept { futureBackend.complete(it) } - sharedPreferences.registerOnSharedPreferenceChangeListener(this) - } + coroutineScope.launch(Dispatchers.IO) { + try { + backend = determineBackend() + futureBackend.complete(backend!!) + } catch (e: Throwable) { + Log.e(TAG, Log.getStackTraceString(e)) + } + } + Updater.monitorForUpdates() - override fun onSharedPreferenceChanged(sharedPreferences: SharedPreferences, key: String) { - if ("multiple_tunnels" == key && backend != null && backend is WgQuickBackend) - (backend as WgQuickBackend).setMultipleTunnels(sharedPreferences.getBoolean(key, false)) + if (BuildConfig.DEBUG) { + StrictMode.setVmPolicy(VmPolicy.Builder().detectAll().penaltyLog().build()) + StrictMode.setThreadPolicy(ThreadPolicy.Builder().detectAll().penaltyLog().build()) + } } override fun onTerminate() { - sharedPreferences.unregisterOnSharedPreferenceChangeListener(this) + coroutineScope.cancel() super.onTerminate() } @@ -91,66 +134,21 @@ class Application : android.app.Application(), OnSharedPreferenceChangeListener private const val TAG = "WireGuard/Application" private lateinit var weakSelf: WeakReference<Application> - @JvmStatic fun get(): Application { return weakSelf.get()!! } - @JvmStatic - fun getAsyncWorker() = get().asyncWorker - - @JvmStatic - fun getBackend(): Backend { - val app = get() - synchronized(app.futureBackend) { - if (app.backend == null) { - var backend: Backend? = null - var didStartRootShell = false - if (!ModuleLoader.isModuleLoaded() && app.moduleLoader.moduleMightExist()) { - try { - app.rootShell.start() - didStartRootShell = true - app.moduleLoader.loadModule() - } catch (ignored: Exception) { - } - } - if (!app.sharedPreferences.getBoolean("disable_kernel_module", false) && ModuleLoader.isModuleLoaded()) { - try { - if (!didStartRootShell) - app.rootShell.start() - val wgQuickBackend = WgQuickBackend(app.applicationContext, app.rootShell, app.toolsInstaller) - wgQuickBackend.setMultipleTunnels(app.sharedPreferences.getBoolean("multiple_tunnels", false)) - backend = wgQuickBackend - } catch (ignored: Exception) { - } - } - if (backend == null) { - backend = GoBackend(app.applicationContext) - GoBackend.setAlwaysOnCallback { get().tunnelManager.restoreState(true).whenComplete(ExceptionLoggers.D) } - } - app.backend = backend - } - return app.backend!! - } - } - - @JvmStatic - fun getBackendAsync() = get().futureBackend + suspend fun getBackend() = get().futureBackend.await() - @JvmStatic - fun getModuleLoader() = get().moduleLoader - - @JvmStatic fun getRootShell() = get().rootShell - @JvmStatic - fun getSharedPreferences() = get().sharedPreferences + fun getPreferencesDataStore() = get().preferencesDataStore - @JvmStatic fun getToolsInstaller() = get().toolsInstaller - @JvmStatic fun getTunnelManager() = get().tunnelManager + + fun getCoroutineScope() = get().coroutineScope } init { diff --git a/ui/src/main/java/com/wireguard/android/BootShutdownReceiver.kt b/ui/src/main/java/com/wireguard/android/BootShutdownReceiver.kt index 41aff76d..59769df4 100644 --- a/ui/src/main/java/com/wireguard/android/BootShutdownReceiver.kt +++ b/ui/src/main/java/com/wireguard/android/BootShutdownReceiver.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android @@ -8,19 +8,19 @@ import android.content.BroadcastReceiver import android.content.Context import android.content.Intent import android.util.Log -import com.wireguard.android.backend.Backend import com.wireguard.android.backend.WgQuickBackend -import com.wireguard.android.util.ExceptionLoggers +import com.wireguard.android.util.applicationScope +import kotlinx.coroutines.launch class BootShutdownReceiver : BroadcastReceiver() { override fun onReceive(context: Context, intent: Intent) { - Application.getBackendAsync().thenAccept { backend: Backend? -> - if (backend !is WgQuickBackend) return@thenAccept - val action = intent.action ?: return@thenAccept + val action = intent.action ?: return + applicationScope.launch { + if (Application.getBackend() !is WgQuickBackend) return@launch val tunnelManager = Application.getTunnelManager() if (Intent.ACTION_BOOT_COMPLETED == action) { Log.i(TAG, "Broadcast receiver restoring state (boot)") - tunnelManager.restoreState(false).whenComplete(ExceptionLoggers.D) + tunnelManager.restoreState(false) } else if (Intent.ACTION_SHUTDOWN == action) { Log.i(TAG, "Broadcast receiver saving state (shutdown)") tunnelManager.saveState() diff --git a/ui/src/main/java/com/wireguard/android/QuickTileService.kt b/ui/src/main/java/com/wireguard/android/QuickTileService.kt index 5099668e..a8650b78 100644 --- a/ui/src/main/java/com/wireguard/android/QuickTileService.kt +++ b/ui/src/main/java/com/wireguard/android/QuickTileService.kt @@ -1,15 +1,18 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android +import android.app.PendingIntent import android.content.Intent import android.graphics.Bitmap import android.graphics.Canvas import android.graphics.drawable.Icon +import android.net.Uri import android.os.Build import android.os.IBinder +import android.provider.Settings import android.service.quicksettings.Tile import android.service.quicksettings.TileService import android.util.Log @@ -20,7 +23,9 @@ import com.wireguard.android.activity.MainActivity import com.wireguard.android.activity.TunnelToggleActivity import com.wireguard.android.backend.Tunnel import com.wireguard.android.model.ObservableTunnel +import com.wireguard.android.util.applicationScope import com.wireguard.android.widget.SlashDrawable +import kotlinx.coroutines.launch /** * Service that maintains the application's custom Quick Settings tile. This service is bound by the @@ -40,38 +45,66 @@ class QuickTileService : TileService() { var ret: IBinder? = null try { ret = super.onBind(intent) - } catch (e: Exception) { + } catch (e: Throwable) { Log.d(TAG, "Failed to bind to TileService", e) } return ret } override fun onClick() { - if (tunnel != null) { - unlockAndRun { - val tile = qsTile - if (tile != null) { - tile.icon = if (tile.icon == iconOn) iconOff else iconOn - tile.updateTile() - } - tunnel!!.setStateAsync(Tunnel.State.TOGGLE).whenComplete { _, t -> - if (t == null) { - updateTile() + applicationScope.launch { + if (tunnel == null) { + Application.getTunnelManager().getTunnels() + updateTile() + } + when (val tunnel = tunnel) { + null -> { + Log.d(TAG, "No tunnel set, so launching main activity") + val intent = Intent(this@QuickTileService, MainActivity::class.java) + intent.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK) + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.UPSIDE_DOWN_CAKE) { + startActivityAndCollapse(PendingIntent.getActivity(this@QuickTileService, 0, intent, PendingIntent.FLAG_IMMUTABLE)) } else { - val toggleIntent = Intent(this, TunnelToggleActivity::class.java) - toggleIntent.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK) - startActivity(toggleIntent) + @Suppress("DEPRECATION") + startActivityAndCollapse(intent) + } + } + + else -> { + unlockAndRun { + applicationScope.launch { + try { + tunnel.setStateAsync(Tunnel.State.TOGGLE) + updateTile() + } catch (e: Throwable) { + Log.d(TAG, "Failed to set state, so falling back", e) + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.UPSIDE_DOWN_CAKE && !Settings.canDrawOverlays(this@QuickTileService)) { + Log.d(TAG, "Need overlay permissions") + val permissionIntent = Intent(Settings.ACTION_MANAGE_OVERLAY_PERMISSION, Uri.parse("package:$packageName")) + permissionIntent.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK) + startActivityAndCollapse( + PendingIntent.getActivity( + this@QuickTileService, + 0, + permissionIntent, + PendingIntent.FLAG_IMMUTABLE + ) + ) + return@launch + } + val toggleIntent = Intent(this@QuickTileService, TunnelToggleActivity::class.java) + toggleIntent.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK) + startActivity(toggleIntent) + } + } } } } - } else { - val intent = Intent(this, MainActivity::class.java) - intent.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK) - startActivityAndCollapse(intent) } } override fun onCreate() { + isAdded = true if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) { iconOn = Icon.createWithResource(this, R.drawable.ic_tile) iconOff = iconOn @@ -81,55 +114,64 @@ class QuickTileService : TileService() { icon.setAnimationEnabled(false) /* Unfortunately we can't have animations, since Icons are marshaled. */ icon.setSlashed(false) var b = Bitmap.createBitmap(icon.intrinsicWidth, icon.intrinsicHeight, Bitmap.Config.ARGB_8888) - ?: return var c = Canvas(b) icon.setBounds(0, 0, c.width, c.height) icon.draw(c) iconOn = Icon.createWithBitmap(b) icon.setSlashed(true) b = Bitmap.createBitmap(icon.intrinsicWidth, icon.intrinsicHeight, Bitmap.Config.ARGB_8888) - ?: return c = Canvas(b) icon.setBounds(0, 0, c.width, c.height) icon.draw(c) iconOff = Icon.createWithBitmap(b) } + override fun onDestroy() { + super.onDestroy() + isAdded = false + } + override fun onStartListening() { Application.getTunnelManager().addOnPropertyChangedCallback(onTunnelChangedCallback) - if (tunnel != null) tunnel!!.addOnPropertyChangedCallback(onStateChangedCallback) + tunnel?.addOnPropertyChangedCallback(onStateChangedCallback) updateTile() } override fun onStopListening() { - if (tunnel != null) tunnel!!.removeOnPropertyChangedCallback(onStateChangedCallback) + tunnel?.removeOnPropertyChangedCallback(onStateChangedCallback) Application.getTunnelManager().removeOnPropertyChangedCallback(onTunnelChangedCallback) } + override fun onTileAdded() { + isAdded = true + } + + override fun onTileRemoved() { + isAdded = false + } + private fun updateTile() { // Update the tunnel. val newTunnel = Application.getTunnelManager().lastUsedTunnel if (newTunnel != tunnel) { - if (tunnel != null) tunnel!!.removeOnPropertyChangedCallback(onStateChangedCallback) + tunnel?.removeOnPropertyChangedCallback(onStateChangedCallback) tunnel = newTunnel - if (tunnel != null) tunnel!!.addOnPropertyChangedCallback(onStateChangedCallback) + tunnel?.addOnPropertyChangedCallback(onStateChangedCallback) } // Update the tile contents. - val label: String - val state: Int - val tile = qsTile - if (tunnel != null) { - label = tunnel!!.name - state = if (tunnel!!.state == Tunnel.State.UP) Tile.STATE_ACTIVE else Tile.STATE_INACTIVE - } else { - label = getString(R.string.app_name) - state = Tile.STATE_INACTIVE - } - if (tile == null) return - tile.label = label - if (tile.state != state) { - tile.icon = if (state == Tile.STATE_ACTIVE) iconOn else iconOff - tile.state = state + val tile = qsTile ?: return + + when (val tunnel = tunnel) { + null -> { + tile.label = getString(R.string.app_name) + tile.state = Tile.STATE_INACTIVE + tile.icon = iconOff + } + else -> { + tile.label = tunnel.name + tile.state = if (tunnel.state == Tunnel.State.UP) Tile.STATE_ACTIVE else Tile.STATE_INACTIVE + tile.icon = if (tunnel.state == Tunnel.State.UP) iconOn else iconOff + } } tile.updateTile() } @@ -140,19 +182,23 @@ class QuickTileService : TileService() { sender.removeOnPropertyChangedCallback(this) return } - if (propertyId != 0 && propertyId != BR.state) return + if (propertyId != 0 && propertyId != BR.state) + return updateTile() } } private inner class OnTunnelChangedCallback : OnPropertyChangedCallback() { override fun onPropertyChanged(sender: Observable, propertyId: Int) { - if (propertyId != 0 && propertyId != BR.lastUsedTunnel) return + if (propertyId != 0 && propertyId != BR.lastUsedTunnel) + return updateTile() } } companion object { private const val TAG = "WireGuard/QuickTileService" + var isAdded: Boolean = false + private set } } diff --git a/ui/src/main/java/com/wireguard/android/activity/BaseActivity.kt b/ui/src/main/java/com/wireguard/android/activity/BaseActivity.kt index 14ab0bdb..5ff11062 100644 --- a/ui/src/main/java/com/wireguard/android/activity/BaseActivity.kt +++ b/ui/src/main/java/com/wireguard/android/activity/BaseActivity.kt @@ -1,27 +1,36 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.activity import android.os.Bundle +import androidx.appcompat.app.AppCompatActivity import androidx.databinding.CallbackRegistry import androidx.databinding.CallbackRegistry.NotifierCallback +import androidx.lifecycle.lifecycleScope import com.wireguard.android.Application import com.wireguard.android.model.ObservableTunnel +import kotlinx.coroutines.launch /** * Base class for activities that need to remember the currently-selected tunnel. */ -abstract class BaseActivity : ThemeChangeAwareActivity() { +abstract class BaseActivity : AppCompatActivity() { private val selectionChangeRegistry = SelectionChangeRegistry() + private var created = false var selectedTunnel: ObservableTunnel? = null set(value) { val oldTunnel = field if (oldTunnel == value) return field = value - onSelectedTunnelChanged(oldTunnel, value) - selectionChangeRegistry.notifyCallbacks(oldTunnel, 0, value) + if (created) { + if (!onSelectedTunnelChanged(oldTunnel, value)) { + field = oldTunnel + } else { + selectionChangeRegistry.notifyCallbacks(oldTunnel, 0, value) + } + } } fun addOnSelectedTunnelChangedListener(listener: OnSelectedTunnelChangedListener) { @@ -29,6 +38,8 @@ abstract class BaseActivity : ThemeChangeAwareActivity() { } override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + // Restore the saved tunnel if there is one; otherwise grab it from the arguments. val savedTunnelName = when { savedInstanceState != null -> savedInstanceState.getString(KEY_SELECTED_TUNNEL) @@ -36,13 +47,16 @@ abstract class BaseActivity : ThemeChangeAwareActivity() { else -> null } if (savedTunnelName != null) { - Application.getTunnelManager() - .tunnels - .thenAccept { selectedTunnel = it[savedTunnelName] } + lifecycleScope.launch { + val tunnel = Application.getTunnelManager().getTunnels()[savedTunnelName] + if (tunnel == null) + created = true + selectedTunnel = tunnel + created = true + } + } else { + created = true } - - // The selected tunnel must be set before the superclass method recreates fragments. - super.onCreate(savedInstanceState) } override fun onSaveInstanceState(outState: Bundle) { @@ -50,9 +64,11 @@ abstract class BaseActivity : ThemeChangeAwareActivity() { super.onSaveInstanceState(outState) } - protected abstract fun onSelectedTunnelChanged(oldTunnel: ObservableTunnel?, newTunnel: ObservableTunnel?) + protected abstract fun onSelectedTunnelChanged(oldTunnel: ObservableTunnel?, newTunnel: ObservableTunnel?): Boolean + fun removeOnSelectedTunnelChangedListener( - listener: OnSelectedTunnelChangedListener) { + listener: OnSelectedTunnelChangedListener + ) { selectionChangeRegistry.remove(listener) } @@ -62,17 +78,17 @@ abstract class BaseActivity : ThemeChangeAwareActivity() { private class SelectionChangeNotifier : NotifierCallback<OnSelectedTunnelChangedListener, ObservableTunnel, ObservableTunnel>() { override fun onNotifyCallback( - listener: OnSelectedTunnelChangedListener, - oldTunnel: ObservableTunnel?, - ignored: Int, - newTunnel: ObservableTunnel? + listener: OnSelectedTunnelChangedListener, + oldTunnel: ObservableTunnel?, + ignored: Int, + newTunnel: ObservableTunnel? ) { listener.onSelectedTunnelChanged(oldTunnel, newTunnel) } } private class SelectionChangeRegistry : - CallbackRegistry<OnSelectedTunnelChangedListener, ObservableTunnel, ObservableTunnel>(SelectionChangeNotifier()) + CallbackRegistry<OnSelectedTunnelChangedListener, ObservableTunnel, ObservableTunnel>(SelectionChangeNotifier()) companion object { private const val KEY_SELECTED_TUNNEL = "selected_tunnel" diff --git a/ui/src/main/java/com/wireguard/android/activity/LogViewerActivity.kt b/ui/src/main/java/com/wireguard/android/activity/LogViewerActivity.kt index 87fdc236..195e9592 100644 --- a/ui/src/main/java/com/wireguard/android/activity/LogViewerActivity.kt +++ b/ui/src/main/java/com/wireguard/android/activity/LogViewerActivity.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2020 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ @@ -19,13 +19,18 @@ import android.text.Spannable import android.text.SpannableString import android.text.style.ForegroundColorSpan import android.text.style.StyleSpan +import android.util.Log import android.view.LayoutInflater import android.view.Menu import android.view.MenuItem import android.view.View import android.view.ViewGroup +import androidx.activity.result.contract.ActivityResultContracts import androidx.appcompat.app.AppCompatActivity +import androidx.collection.CircularArray import androidx.core.app.ShareCompat +import androidx.core.content.res.ResourcesCompat +import androidx.lifecycle.lifecycleScope import androidx.recyclerview.widget.DividerItemDecoration import androidx.recyclerview.widget.LinearLayoutManager import androidx.recyclerview.widget.RecyclerView @@ -36,13 +41,9 @@ import com.wireguard.android.R import com.wireguard.android.databinding.LogViewerActivityBinding import com.wireguard.android.util.DownloadsFileSaver import com.wireguard.android.util.ErrorMessages -import com.wireguard.android.widget.EdgeToEdge.setUpFAB -import com.wireguard.android.widget.EdgeToEdge.setUpRoot -import com.wireguard.android.widget.EdgeToEdge.setUpScrollingContent +import com.wireguard.android.util.resolveAttribute import com.wireguard.crypto.KeyPair -import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.cancel import kotlinx.coroutines.launch import kotlinx.coroutines.withContext import java.io.BufferedReader @@ -60,33 +61,26 @@ import java.util.regex.Matcher import java.util.regex.Pattern class LogViewerActivity : AppCompatActivity() { - private lateinit var binding: LogViewerActivityBinding private lateinit var logAdapter: LogEntryAdapter - private var logLines = arrayListOf<LogLine>() - private var rawLogLines = StringBuffer() + private var logLines = CircularArray<LogLine>() + private var rawLogLines = CircularArray<String>() private var recyclerView: RecyclerView? = null private var saveButton: MenuItem? = null - private val coroutineScope = CoroutineScope(Dispatchers.Default) private val year by lazy { val yearFormatter: DateFormat = SimpleDateFormat("yyyy", Locale.US) yearFormatter.format(Date()) } - @Suppress("Deprecation") - private val defaultColor by lazy { resources.getColor(R.color.primary_text_color) } + private val defaultColor by lazy { resolveAttribute(com.google.android.material.R.attr.colorOnSurface) } - @Suppress("Deprecation") - private val debugColor by lazy { resources.getColor(R.color.debug_tag_color) } + private val debugColor by lazy { ResourcesCompat.getColor(resources, R.color.debug_tag_color, theme) } - @Suppress("Deprecation") - private val errorColor by lazy { resources.getColor(R.color.error_tag_color) } + private val errorColor by lazy { ResourcesCompat.getColor(resources, R.color.error_tag_color, theme) } - @Suppress("Deprecation") - private val infoColor by lazy { resources.getColor(R.color.info_tag_color) } + private val infoColor by lazy { ResourcesCompat.getColor(resources, R.color.info_tag_color, theme) } - @Suppress("Deprecation") - private val warningColor by lazy { resources.getColor(R.color.warning_tag_color) } + private val warningColor by lazy { ResourcesCompat.getColor(resources, R.color.warning_tag_color, theme) } private var lastUri: Uri? = null @@ -103,9 +97,6 @@ class LogViewerActivity : AppCompatActivity() { binding = LogViewerActivityBinding.inflate(layoutInflater) setContentView(binding.root) supportActionBar?.setDisplayHomeAsUpEnabled(true) - setUpFAB(binding.shareFab) - setUpRoot(binding.root) - setUpScrollingContent(binding.recyclerView, binding.shareFab) logAdapter = LogEntryAdapter() binding.recyclerView.apply { recyclerView = this @@ -114,35 +105,34 @@ class LogViewerActivity : AppCompatActivity() { addItemDecoration(DividerItemDecoration(context, LinearLayoutManager.VERTICAL)) } - coroutineScope.launch { streamingLog() } + lifecycleScope.launch(Dispatchers.IO) { streamingLog() } - binding.shareFab.setOnClickListener { + val revokeLastActivityResultLauncher = registerForActivityResult(ActivityResultContracts.StartActivityForResult()) { revokeLastUri() - val key = KeyPair().privateKey.toHex() - LOGS[key] = rawLogLines.toString().toByteArray(Charsets.UTF_8) - lastUri = Uri.parse("content://${BuildConfig.APPLICATION_ID}.exported-log/$key") - val shareIntent = ShareCompat.IntentBuilder.from(this) + } + + binding.shareFab.setOnClickListener { + lifecycleScope.launch { + revokeLastUri() + val key = KeyPair().privateKey.toHex() + LOGS[key] = rawLogBytes() + lastUri = Uri.parse("content://${BuildConfig.APPLICATION_ID}.exported-log/$key") + val shareIntent = ShareCompat.IntentBuilder(this@LogViewerActivity) .setType("text/plain") .setSubject(getString(R.string.log_export_subject)) .setStream(lastUri) .setChooserTitle(R.string.log_export_title) .createChooserIntent() .addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION) - grantUriPermission("android", lastUri, Intent.FLAG_GRANT_READ_URI_PERMISSION) - startActivityForResult(shareIntent, SHARE_ACTIVITY_REQUEST) - } - } - - override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) { - if (requestCode == SHARE_ACTIVITY_REQUEST) { - revokeLastUri() + grantUriPermission("android", lastUri, Intent.FLAG_GRANT_READ_URI_PERMISSION) + revokeLastActivityResultLauncher.launch(shareIntent) + } } - super.onActivityResult(requestCode, resultCode, data) } - override fun onCreateOptionsMenu(menu: Menu?): Boolean { + override fun onCreateOptionsMenu(menu: Menu): Boolean { menuInflater.inflate(R.menu.log_viewer, menu) - saveButton = menu?.findItem(R.id.save_log) + saveButton = menu.findItem(R.id.save_log) return true } @@ -152,94 +142,126 @@ class LogViewerActivity : AppCompatActivity() { finish() true } + R.id.save_log -> { - coroutineScope.launch { saveLog() } + saveButton?.isEnabled = false + lifecycleScope.launch { saveLog() } true } + else -> super.onOptionsItemSelected(item) } } - override fun onDestroy() { - super.onDestroy() - coroutineScope.cancel() + private val downloadsFileSaver = DownloadsFileSaver(this) + + private suspend fun rawLogBytes(): ByteArray { + val builder = StringBuilder() + withContext(Dispatchers.IO) { + for (i in 0 until rawLogLines.size()) { + builder.append(rawLogLines[i]) + builder.append('\n') + } + } + return builder.toString().toByteArray(Charsets.UTF_8) } private suspend fun saveLog() { - val context = this - withContext(Dispatchers.Main) { - saveButton?.isEnabled = false - withContext(Dispatchers.IO) { - var exception: Throwable? = null - var outputFile: DownloadsFileSaver.DownloadsFile? = null - try { - outputFile = DownloadsFileSaver.save(context, "wireguard-log.txt", "text/plain", true) - outputFile.outputStream.use { - it.write(rawLogLines.toString().toByteArray(Charsets.UTF_8)) - } - } catch (e: Throwable) { - outputFile?.delete() - exception = e - } - withContext(Dispatchers.Main) { - saveButton?.isEnabled = true - Snackbar.make(findViewById(android.R.id.content), - if (exception == null) getString(R.string.log_export_success, outputFile?.fileName) - else getString(R.string.log_export_error, ErrorMessages[exception]), - if (exception == null) Snackbar.LENGTH_SHORT else Snackbar.LENGTH_LONG) - .setAnchorView(binding.shareFab) - .show() - } + var exception: Throwable? = null + var outputFile: DownloadsFileSaver.DownloadsFile? = null + withContext(Dispatchers.IO) { + try { + outputFile = downloadsFileSaver.save("wireguard-log.txt", "text/plain", true) + outputFile?.outputStream?.write(rawLogBytes()) + } catch (e: Throwable) { + outputFile?.delete() + exception = e } } + saveButton?.isEnabled = true + if (outputFile == null) + return + Snackbar.make( + findViewById(android.R.id.content), + if (exception == null) getString(R.string.log_export_success, outputFile?.fileName) + else getString(R.string.log_export_error, ErrorMessages[exception]), + if (exception == null) Snackbar.LENGTH_SHORT else Snackbar.LENGTH_LONG + ) + .setAnchorView(binding.shareFab) + .show() } private suspend fun streamingLog() = withContext(Dispatchers.IO) { val builder = ProcessBuilder().command("logcat", "-b", "all", "-v", "threadtime", "*:V") builder.environment()["LC_ALL"] = "C" - val process = try { - builder.start() - } catch (e: IOException) { - e.printStackTrace() - return@withContext - } - val stdout = BufferedReader(InputStreamReader(process!!.inputStream, StandardCharsets.UTF_8)) - var haveScrolled = false - val start = System.nanoTime() - var startPeriod = start - while (true) { - val line = stdout.readLine() ?: break - rawLogLines.append(line) - rawLogLines.append('\n') - val logLine = parseLine(line) - withContext(Dispatchers.Main) { + var process: Process? = null + try { + process = try { + builder.start() + } catch (e: IOException) { + Log.e(TAG, Log.getStackTraceString(e)) + return@withContext + } + val stdout = BufferedReader(InputStreamReader(process!!.inputStream, StandardCharsets.UTF_8)) + + var posStart = 0 + var timeLastNotify = System.nanoTime() + var priorModified = false + val bufferedLogLines = arrayListOf<LogLine>() + var timeout = 1000000000L / 2 // The timeout is initially small so that the view gets populated immediately. + val MAX_LINES = (1 shl 16) - 1 + val MAX_BUFFERED_LINES = (1 shl 14) - 1 + + while (true) { + val line = stdout.readLine() ?: break + if (rawLogLines.size() >= MAX_LINES) + rawLogLines.popFirst() + rawLogLines.addLast(line) + val logLine = parseLine(line) if (logLine != null) { - recyclerView?.let { - val shouldScroll = haveScrolled && !it.canScrollVertically(1) - logLines.add(logLine) - if (haveScrolled) logAdapter.notifyDataSetChanged() - if (shouldScroll) - it.scrollToPosition(logLines.size - 1) - } + bufferedLogLines.add(logLine) } else { - /* I'd prefer for the next line to be: - * logLines.lastOrNull()?.msg += "\n$line" - * However, as of writing, that causes the kotlin compiler to freak out and crash, spewing bytecode. - */ - logLines.lastOrNull()?.apply { msg += "\n$line" } - if (haveScrolled) logAdapter.notifyDataSetChanged() + if (bufferedLogLines.isNotEmpty()) { + bufferedLogLines.last().msg += "\n$line" + } else if (!logLines.isEmpty()) { + logLines[logLines.size() - 1].msg += "\n$line" + priorModified = true + } } - if (!haveScrolled) { - val end = System.nanoTime() - val scroll = (end - start) > 1000000000L * 2.5 || !stdout.ready() - if (logLines.isNotEmpty() && (scroll || (end - startPeriod) > 1000000000L / 4)) { - logAdapter.notifyDataSetChanged() - recyclerView?.scrollToPosition(logLines.size - 1) - startPeriod = end + val timeNow = System.nanoTime() + if (bufferedLogLines.size < MAX_BUFFERED_LINES && (timeNow - timeLastNotify) < timeout && stdout.ready()) + continue + timeout = 1000000000L * 5 / 2 // Increase the timeout after the initial view has something in it. + timeLastNotify = timeNow + + withContext(Dispatchers.Main.immediate) { + val isScrolledToBottomAlready = recyclerView?.canScrollVertically(1) == false + if (priorModified) { + logAdapter.notifyItemChanged(posStart - 1) + priorModified = false + } + val fullLen = logLines.size() + bufferedLogLines.size + if (fullLen >= MAX_LINES) { + val numToRemove = fullLen - MAX_LINES + 1 + logLines.removeFromStart(numToRemove) + logAdapter.notifyItemRangeRemoved(0, numToRemove) + posStart -= numToRemove + + } + for (bufferedLine in bufferedLogLines) { + logLines.addLast(bufferedLine) + } + bufferedLogLines.clear() + logAdapter.notifyItemRangeInserted(posStart, logLines.size() - posStart) + posStart = logLines.size() + + if (isScrolledToBottomAlready) { + recyclerView?.scrollToPosition(logLines.size() - 1) } - if (scroll) haveScrolled = true } } + } finally { + process?.destroy() } } @@ -269,9 +291,10 @@ class LogViewerActivity : AppCompatActivity() { * * <pre>05-26 11:02:36.886 5689 5689 D AndroidRuntime: CheckJNI is OFF.</pre> */ - private val THREADTIME_LINE: Pattern = Pattern.compile("^(\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}.\\d{3})(?:\\s+[0-9A-Za-z]+)?\\s+(\\d+)\\s+(\\d+)\\s+([A-Z])\\s+(.+?)\\s*: (.*)$") + private val THREADTIME_LINE: Pattern = + Pattern.compile("^(\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}.\\d{3})(?:\\s+[0-9A-Za-z]+)?\\s+(\\d+)\\s+(\\d+)\\s+([A-Z])\\s+(.+?)\\s*: (.*)$") private val LOGS: MutableMap<String, ByteArray> = ConcurrentHashMap() - private const val SHARE_ACTIVITY_REQUEST = 49133 + private const val TAG = "WireGuard/LogViewerActivity" } private inner class LogEntryAdapter : RecyclerView.Adapter<LogEntryAdapter.ViewHolder>() { @@ -280,7 +303,7 @@ class LogViewerActivity : AppCompatActivity() { private fun levelToColor(level: String): Int { return when (level) { - "D" -> debugColor + "V", "D" -> debugColor "E" -> errorColor "I" -> infoColor "W" -> warningColor @@ -288,11 +311,11 @@ class LogViewerActivity : AppCompatActivity() { } } - override fun getItemCount() = logLines.size + override fun getItemCount() = logLines.size() override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): ViewHolder { val view = LayoutInflater.from(parent.context) - .inflate(R.layout.log_viewer_entry, parent, false) + .inflate(R.layout.log_viewer_entry, parent, false) return ViewHolder(view) } @@ -303,8 +326,10 @@ class LogViewerActivity : AppCompatActivity() { else SpannableString("${line.tag}: ${line.msg}").apply { setSpan(StyleSpan(BOLD), 0, "${line.tag}:".length, Spannable.SPAN_EXCLUSIVE_EXCLUSIVE) - setSpan(ForegroundColorSpan(levelToColor(line.level)), - 0, "${line.tag}:".length, Spannable.SPAN_EXCLUSIVE_EXCLUSIVE) + setSpan( + ForegroundColorSpan(levelToColor(line.level)), + 0, "${line.tag}:".length, Spannable.SPAN_EXCLUSIVE_EXCLUSIVE + ) } holder.layout.apply { findViewById<MaterialTextView>(R.id.log_date).text = line.time.toString() @@ -326,11 +351,11 @@ class LogViewerActivity : AppCompatActivity() { override fun insert(uri: Uri, values: ContentValues?): Uri? = null override fun query(uri: Uri, projection: Array<out String>?, selection: String?, selectionArgs: Array<out String>?, sortOrder: String?): Cursor? = - logForUri(uri)?.let { - val m = MatrixCursor(arrayOf(android.provider.OpenableColumns.DISPLAY_NAME, android.provider.OpenableColumns.SIZE), 1) - m.addRow(arrayOf("wireguard-log.txt", it.size.toLong())) - m - } + logForUri(uri)?.let { + val m = MatrixCursor(arrayOf(android.provider.OpenableColumns.DISPLAY_NAME, android.provider.OpenableColumns.SIZE), 1) + m.addRow(arrayOf<Any>("wireguard-log.txt", it.size.toLong())) + m + } override fun onCreate(): Boolean = true @@ -340,7 +365,8 @@ class LogViewerActivity : AppCompatActivity() { override fun getType(uri: Uri): String? = logForUri(uri)?.let { "text/plain" } - override fun getStreamTypes(uri: Uri, mimeTypeFilter: String): Array<String>? = getType(uri)?.let { if (compareMimeTypes(it, mimeTypeFilter)) arrayOf(it) else null } + override fun getStreamTypes(uri: Uri, mimeTypeFilter: String): Array<String>? = + getType(uri)?.let { if (compareMimeTypes(it, mimeTypeFilter)) arrayOf(it) else null } override fun openFile(uri: Uri, mode: String): ParcelFileDescriptor? { if (mode != "r") return null @@ -348,7 +374,7 @@ class LogViewerActivity : AppCompatActivity() { return openPipeHelper(uri, "text/plain", null, log) { output, _, _, _, l -> try { FileOutputStream(output.fileDescriptor).write(l!!) - } catch (_: Exception) { + } catch (_: Throwable) { } } } diff --git a/ui/src/main/java/com/wireguard/android/activity/MainActivity.kt b/ui/src/main/java/com/wireguard/android/activity/MainActivity.kt index f567e763..087ca08e 100644 --- a/ui/src/main/java/com/wireguard/android/activity/MainActivity.kt +++ b/ui/src/main/java/com/wireguard/android/activity/MainActivity.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.activity @@ -9,11 +9,12 @@ import android.os.Bundle import android.view.Menu import android.view.MenuItem import android.view.View +import androidx.activity.OnBackPressedCallback +import androidx.activity.addCallback import androidx.appcompat.app.ActionBar -import androidx.core.view.ViewCompat -import androidx.core.view.WindowInsetsCompat import androidx.fragment.app.FragmentManager import androidx.fragment.app.FragmentTransaction +import androidx.fragment.app.commit import com.wireguard.android.R import com.wireguard.android.fragment.TunnelDetailFragment import com.wireguard.android.fragment.TunnelEditorFragment @@ -27,27 +28,29 @@ import com.wireguard.android.model.ObservableTunnel class MainActivity : BaseActivity(), FragmentManager.OnBackStackChangedListener { private var actionBar: ActionBar? = null private var isTwoPaneLayout = false + private var backPressedCallback: OnBackPressedCallback? = null - override fun onBackPressed() { + private fun handleBackPressed() { val backStackEntries = supportFragmentManager.backStackEntryCount // If the two-pane layout does not have an editor open, going back should exit the app. if (isTwoPaneLayout && backStackEntries <= 1) { finish() return } - // Deselect the current tunnel on navigating back from the detail pane to the one-pane list. - if (!isTwoPaneLayout && backStackEntries == 1) { + + if (backStackEntries >= 1) supportFragmentManager.popBackStack() + + // Deselect the current tunnel on navigating back from the detail pane to the one-pane list. + if (backStackEntries == 1) selectedTunnel = null - return - } - super.onBackPressed() } override fun onBackStackChanged() { + val backStackEntries = supportFragmentManager.backStackEntryCount + backPressedCallback?.isEnabled = backStackEntries >= 1 if (actionBar == null) return // Do not show the home menu when the two-pane layout is at the detail view (see above). - val backStackEntries = supportFragmentManager.backStackEntryCount val minBackStackEntries = if (isTwoPaneLayout) 2 else 1 actionBar!!.setDisplayHomeAsUpEnabled(backStackEntries >= minBackStackEntries) } @@ -58,17 +61,8 @@ class MainActivity : BaseActivity(), FragmentManager.OnBackStackChangedListener actionBar = supportActionBar isTwoPaneLayout = findViewById<View?>(R.id.master_detail_wrapper) != null supportFragmentManager.addOnBackStackChangedListener(this) + backPressedCallback = onBackPressedDispatcher.addCallback(this) { handleBackPressed() } onBackStackChanged() - // Dispatch insets on back stack change - // This is required to ensure replaced fragments are also able to consume insets - findViewById<View>(R.id.main_activity_container).setOnApplyWindowInsetsListener { _, insets -> - supportFragmentManager.addOnBackStackChangedListener { - supportFragmentManager.fragments.forEach { - ViewCompat.dispatchApplyWindowInsets(it.requireView(), WindowInsetsCompat.toWindowInsetsCompat(insets)) - } - } - insets - } } override fun onCreateOptionsMenu(menu: Menu): Boolean { @@ -80,15 +74,16 @@ class MainActivity : BaseActivity(), FragmentManager.OnBackStackChangedListener return when (item.itemId) { android.R.id.home -> { // The back arrow in the action bar should act the same as the back button. - onBackPressed() + onBackPressedDispatcher.onBackPressed() true } + R.id.menu_action_edit -> { - supportFragmentManager.beginTransaction() - .replace(R.id.detail_container, TunnelEditorFragment()) - .setTransition(FragmentTransaction.TRANSIT_FRAGMENT_FADE) - .addToBackStack(null) - .commit() + supportFragmentManager.commit { + replace(if (isTwoPaneLayout) R.id.detail_container else R.id.list_detail_container, TunnelEditorFragment()) + setTransition(FragmentTransaction.TRANSIT_FRAGMENT_FADE) + addToBackStack(null) + } true } // This menu item is handled by the editor fragment. @@ -97,18 +92,25 @@ class MainActivity : BaseActivity(), FragmentManager.OnBackStackChangedListener startActivity(Intent(this, SettingsActivity::class.java)) true } + else -> super.onOptionsItemSelected(item) } } - override fun onSelectedTunnelChanged(oldTunnel: ObservableTunnel?, - newTunnel: ObservableTunnel?) { + override fun onSelectedTunnelChanged( + oldTunnel: ObservableTunnel?, + newTunnel: ObservableTunnel? + ): Boolean { val fragmentManager = supportFragmentManager + if (fragmentManager.isStateSaved) { + return false + } + val backStackEntries = fragmentManager.backStackEntryCount if (newTunnel == null) { // Clear everything off the back stack (all editors and detail fragments). fragmentManager.popBackStackImmediate(0, FragmentManager.POP_BACK_STACK_INCLUSIVE) - return + return true } if (backStackEntries == 2) { // Pop the editor off the back stack to reveal the detail fragment. Use the immediate @@ -116,11 +118,12 @@ class MainActivity : BaseActivity(), FragmentManager.OnBackStackChangedListener fragmentManager.popBackStackImmediate() } else if (backStackEntries == 0) { // Create and show a new detail fragment. - fragmentManager.beginTransaction() - .add(R.id.detail_container, TunnelDetailFragment()) - .setTransition(FragmentTransaction.TRANSIT_FRAGMENT_FADE) - .addToBackStack(null) - .commit() + fragmentManager.commit { + add(if (isTwoPaneLayout) R.id.detail_container else R.id.list_detail_container, TunnelDetailFragment()) + setTransition(FragmentTransaction.TRANSIT_FRAGMENT_FADE) + addToBackStack(null) + } } + return true } } diff --git a/ui/src/main/java/com/wireguard/android/activity/SettingsActivity.kt b/ui/src/main/java/com/wireguard/android/activity/SettingsActivity.kt index 103b6b44..5ce71792 100644 --- a/ui/src/main/java/com/wireguard/android/activity/SettingsActivity.kt +++ b/ui/src/main/java/com/wireguard/android/activity/SettingsActivity.kt @@ -1,58 +1,41 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.activity import android.content.Intent -import android.content.pm.PackageManager import android.os.Build import android.os.Bundle -import android.util.SparseArray +import android.view.LayoutInflater import android.view.MenuItem -import androidx.core.app.ActivityCompat -import androidx.core.content.ContextCompat +import android.view.View +import android.view.ViewGroup +import androidx.appcompat.app.AppCompatActivity +import androidx.fragment.app.commit +import androidx.lifecycle.lifecycleScope import androidx.preference.Preference import androidx.preference.PreferenceFragmentCompat import com.wireguard.android.Application +import com.wireguard.android.QuickTileService import com.wireguard.android.R import com.wireguard.android.backend.WgQuickBackend -import com.wireguard.android.util.ModuleLoader -import java.util.ArrayList -import java.util.Arrays +import com.wireguard.android.preference.PreferencesPreferenceDataStore +import com.wireguard.android.util.AdminKnobs +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext /** * Interface for changing application-global persistent settings. */ -class SettingsActivity : ThemeChangeAwareActivity() { - private val permissionRequestCallbacks = SparseArray<(permissions: Array<String>, granted: IntArray) -> Unit>() - private var permissionRequestCounter = 0 - - fun ensurePermissions(permissions: Array<String>, cb: (permissions: Array<String>, granted: IntArray) -> Unit) { - val needPermissions: MutableList<String> = ArrayList(permissions.size) - permissions.forEach { - if (ContextCompat.checkSelfPermission(this, it) != PackageManager.PERMISSION_GRANTED) { - needPermissions.add(it) - } - } - if (needPermissions.isEmpty()) { - val granted = IntArray(permissions.size) - Arrays.fill(granted, PackageManager.PERMISSION_GRANTED) - cb.invoke(permissions, granted) - return - } - val idx = permissionRequestCounter++ - permissionRequestCallbacks.put(idx, cb) - ActivityCompat.requestPermissions(this, - needPermissions.toTypedArray(), idx) - } - +class SettingsActivity : AppCompatActivity() { override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) if (supportFragmentManager.findFragmentById(android.R.id.content) == null) { - supportFragmentManager.beginTransaction() - .add(android.R.id.content, SettingsFragment()) - .commit() + supportFragmentManager.commit { + add(android.R.id.content, SettingsFragment()) + } } } @@ -64,20 +47,26 @@ class SettingsActivity : ThemeChangeAwareActivity() { return super.onOptionsItemSelected(item) } - override fun onRequestPermissionsResult(requestCode: Int, - permissions: Array<String>, - grantResults: IntArray) { - val f = permissionRequestCallbacks[requestCode] - if (f != null) { - permissionRequestCallbacks.remove(requestCode) - f.invoke(permissions, grantResults) + class SettingsFragment : PreferenceFragmentCompat() { + + // Since this is pretty much abandoned by androidx, it never got updated for proper EdgeToEdge support, + // which is enabled everywhere for API 35. So handle the insets manually here. + override fun onCreateView(inflater: LayoutInflater, container: ViewGroup?, savedInstanceState: Bundle?): View { + val view = super.onCreateView(inflater, container, savedInstanceState) + view.fitsSystemWindows = true + return view } - } - class SettingsFragment : PreferenceFragmentCompat() { override fun onCreatePreferences(savedInstanceState: Bundle?, key: String?) { + preferenceManager.preferenceDataStore = PreferencesPreferenceDataStore(lifecycleScope, Application.getPreferencesDataStore()) addPreferencesFromResource(R.xml.preferences) - preferenceScreen.initialExpandedChildrenCount = 4 + preferenceScreen.initialExpandedChildrenCount = 5 + + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.TIRAMISU || QuickTileService.isAdded) { + val quickTile = preferenceManager.findPreference<Preference>("quick_tile") + quickTile?.parent?.removePreference(quickTile) + --preferenceScreen.initialExpandedChildrenCount + } if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) { val darkTheme = preferenceManager.findPreference<Preference>("dark_theme") darkTheme?.parent?.removePreference(darkTheme) @@ -87,14 +76,18 @@ class SettingsActivity : ThemeChangeAwareActivity() { val remoteApps = preferenceManager.findPreference<Preference>("allow_remote_control_intents") remoteApps?.parent?.removePreference(remoteApps) } + if (AdminKnobs.disableConfigExport) { + val zipExporter = preferenceManager.findPreference<Preference>("zip_exporter") + zipExporter?.parent?.removePreference(zipExporter) + } val wgQuickOnlyPrefs = arrayOf( - preferenceManager.findPreference("tools_installer"), - preferenceManager.findPreference("restore_on_boot"), - preferenceManager.findPreference<Preference>("multiple_tunnels") + preferenceManager.findPreference("tools_installer"), + preferenceManager.findPreference("restore_on_boot"), + preferenceManager.findPreference<Preference>("multiple_tunnels") ).filterNotNull() wgQuickOnlyPrefs.forEach { it.isVisible = false } - Application.getBackendAsync().thenAccept { backend -> - if (backend is WgQuickBackend) { + lifecycleScope.launch { + if (Application.getBackend() is WgQuickBackend) { ++preferenceScreen.initialExpandedChildrenCount wgQuickOnlyPrefs.forEach { it.isVisible = true } } else { @@ -105,19 +98,19 @@ class SettingsActivity : ThemeChangeAwareActivity() { startActivity(Intent(requireContext(), LogViewerActivity::class.java)) true } - val moduleInstaller = preferenceManager.findPreference<Preference>("module_downloader") - val kernelModuleDisabler = preferenceManager.findPreference<Preference>("kernel_module_disabler") - moduleInstaller?.isVisible = false - if (ModuleLoader.isModuleLoaded()) { - moduleInstaller?.parent?.removePreference(moduleInstaller) - } else { - kernelModuleDisabler?.parent?.removePreference(kernelModuleDisabler) - Application.getAsyncWorker().runAsync(Application.getRootShell()::start).whenComplete { _, e -> - if (e == null) - moduleInstaller?.isVisible = true - else - moduleInstaller?.parent?.removePreference(moduleInstaller) + val kernelModuleEnabler = preferenceManager.findPreference<Preference>("kernel_module_enabler") + if (WgQuickBackend.hasKernelSupport()) { + lifecycleScope.launch { + if (Application.getBackend() !is WgQuickBackend) { + try { + withContext(Dispatchers.IO) { Application.getRootShell().start() } + } catch (_: Throwable) { + kernelModuleEnabler?.parent?.removePreference(kernelModuleEnabler) + } + } } + } else { + kernelModuleEnabler?.parent?.removePreference(kernelModuleEnabler) } } } diff --git a/ui/src/main/java/com/wireguard/android/activity/ThemeChangeAwareActivity.kt b/ui/src/main/java/com/wireguard/android/activity/ThemeChangeAwareActivity.kt deleted file mode 100644 index bd124cbc..00000000 --- a/ui/src/main/java/com/wireguard/android/activity/ThemeChangeAwareActivity.kt +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package com.wireguard.android.activity - -import android.content.SharedPreferences -import android.content.SharedPreferences.OnSharedPreferenceChangeListener -import android.os.Build -import android.os.Bundle -import androidx.appcompat.app.AppCompatActivity -import androidx.appcompat.app.AppCompatDelegate -import com.wireguard.android.Application - -abstract class ThemeChangeAwareActivity : AppCompatActivity(), OnSharedPreferenceChangeListener { - override fun onCreate(savedInstanceState: Bundle?) { - super.onCreate(savedInstanceState) - if (Build.VERSION.SDK_INT < Build.VERSION_CODES.Q) { - Application.getSharedPreferences().registerOnSharedPreferenceChangeListener(this) - } - } - - override fun onDestroy() { - if (Build.VERSION.SDK_INT < Build.VERSION_CODES.Q) { - Application.getSharedPreferences().unregisterOnSharedPreferenceChangeListener(this) - } - super.onDestroy() - } - - override fun onSharedPreferenceChanged(sharedPreferences: SharedPreferences, key: String) { - when (key) { - "dark_theme" -> { - AppCompatDelegate.setDefaultNightMode(if (sharedPreferences.getBoolean(key, false)) { - AppCompatDelegate.MODE_NIGHT_YES - } else { - AppCompatDelegate.MODE_NIGHT_NO - }) - recreate() - } - } - } -} diff --git a/ui/src/main/java/com/wireguard/android/activity/TunnelCreatorActivity.kt b/ui/src/main/java/com/wireguard/android/activity/TunnelCreatorActivity.kt index d3d4d4f9..8d5f4cfa 100644 --- a/ui/src/main/java/com/wireguard/android/activity/TunnelCreatorActivity.kt +++ b/ui/src/main/java/com/wireguard/android/activity/TunnelCreatorActivity.kt @@ -1,11 +1,11 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.activity import android.os.Bundle -import com.wireguard.android.fragment.TunnelEditorFragment +import com.wireguard.android.R import com.wireguard.android.model.ObservableTunnel /** @@ -14,14 +14,11 @@ import com.wireguard.android.model.ObservableTunnel class TunnelCreatorActivity : BaseActivity() { override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) - if (supportFragmentManager.findFragmentById(android.R.id.content) == null) { - supportFragmentManager.beginTransaction() - .add(android.R.id.content, TunnelEditorFragment()) - .commit() - } + setContentView(R.layout.tunnel_creator_activity) } - override fun onSelectedTunnelChanged(oldTunnel: ObservableTunnel?, newTunnel: ObservableTunnel?) { + override fun onSelectedTunnelChanged(oldTunnel: ObservableTunnel?, newTunnel: ObservableTunnel?): Boolean { finish() + return true } } diff --git a/ui/src/main/java/com/wireguard/android/activity/TunnelToggleActivity.kt b/ui/src/main/java/com/wireguard/android/activity/TunnelToggleActivity.kt index 44d81c01..dfc1f5b8 100644 --- a/ui/src/main/java/com/wireguard/android/activity/TunnelToggleActivity.kt +++ b/ui/src/main/java/com/wireguard/android/activity/TunnelToggleActivity.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.activity @@ -10,32 +10,58 @@ import android.os.Bundle import android.service.quicksettings.TileService import android.util.Log import android.widget.Toast +import androidx.activity.result.contract.ActivityResultContracts import androidx.annotation.RequiresApi import androidx.appcompat.app.AppCompatActivity +import androidx.lifecycle.lifecycleScope import com.wireguard.android.Application import com.wireguard.android.QuickTileService import com.wireguard.android.R +import com.wireguard.android.backend.GoBackend import com.wireguard.android.backend.Tunnel import com.wireguard.android.util.ErrorMessages +import kotlinx.coroutines.launch @RequiresApi(Build.VERSION_CODES.N) class TunnelToggleActivity : AppCompatActivity() { - override fun onCreate(savedInstanceState: Bundle?) { - super.onCreate(savedInstanceState) + private val permissionActivityResultLauncher = + registerForActivityResult(ActivityResultContracts.StartActivityForResult()) { toggleTunnelWithPermissionsResult() } + + private fun toggleTunnelWithPermissionsResult() { val tunnel = Application.getTunnelManager().lastUsedTunnel ?: return - tunnel.setStateAsync(Tunnel.State.TOGGLE).whenComplete { _, t -> - TileService.requestListeningState(this, ComponentName(this, QuickTileService::class.java)) - onToggleFinished(t) + lifecycleScope.launch { + try { + tunnel.setStateAsync(Tunnel.State.TOGGLE) + } catch (e: Throwable) { + TileService.requestListeningState(this@TunnelToggleActivity, ComponentName(this@TunnelToggleActivity, QuickTileService::class.java)) + val error = ErrorMessages[e] + val message = getString(R.string.toggle_error, error) + Log.e(TAG, message, e) + Toast.makeText(this@TunnelToggleActivity, message, Toast.LENGTH_LONG).show() + finishAffinity() + return@launch + } + TileService.requestListeningState(this@TunnelToggleActivity, ComponentName(this@TunnelToggleActivity, QuickTileService::class.java)) finishAffinity() } } - private fun onToggleFinished(throwable: Throwable?) { - if (throwable == null) return - val error = ErrorMessages[throwable] - val message = getString(R.string.toggle_error, error) - Log.e(TAG, message, throwable) - Toast.makeText(this, message, Toast.LENGTH_LONG).show() + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + lifecycleScope.launch { + if (Application.getBackend() is GoBackend) { + try { + val intent = GoBackend.VpnService.prepare(this@TunnelToggleActivity) + if (intent != null) { + permissionActivityResultLauncher.launch(intent) + return@launch + } + } catch (e: Exception) { + Toast.makeText(this@TunnelToggleActivity, ErrorMessages[e], Toast.LENGTH_LONG).show() + } + } + toggleTunnelWithPermissionsResult() + } } companion object { diff --git a/ui/src/main/java/com/wireguard/android/activity/TvMainActivity.kt b/ui/src/main/java/com/wireguard/android/activity/TvMainActivity.kt new file mode 100644 index 00000000..3084d314 --- /dev/null +++ b/ui/src/main/java/com/wireguard/android/activity/TvMainActivity.kt @@ -0,0 +1,448 @@ +/* + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.activity + +import android.Manifest +import android.content.ActivityNotFoundException +import android.content.Context +import android.content.Intent +import android.content.pm.PackageManager +import android.net.Uri +import android.os.Build +import android.os.Bundle +import android.os.Environment +import android.os.storage.StorageManager +import android.os.storage.StorageVolume +import android.util.Log +import android.view.View +import android.widget.Toast +import androidx.activity.addCallback +import androidx.activity.result.contract.ActivityResultContracts +import androidx.appcompat.app.AppCompatActivity +import androidx.appcompat.app.AppCompatDelegate +import androidx.core.content.ContextCompat +import androidx.core.content.getSystemService +import androidx.core.view.forEach +import androidx.databinding.DataBindingUtil +import androidx.databinding.Observable +import androidx.databinding.ObservableBoolean +import androidx.databinding.ObservableField +import androidx.lifecycle.lifecycleScope +import androidx.recyclerview.widget.GridLayoutManager +import androidx.recyclerview.widget.GridLayoutManager.SpanSizeLookup +import com.google.android.material.dialog.MaterialAlertDialogBuilder +import com.wireguard.android.Application +import com.wireguard.android.R +import com.wireguard.android.backend.GoBackend +import com.wireguard.android.backend.Tunnel +import com.wireguard.android.databinding.Keyed +import com.wireguard.android.databinding.ObservableKeyedArrayList +import com.wireguard.android.databinding.ObservableKeyedRecyclerViewAdapter +import com.wireguard.android.databinding.TvActivityBinding +import com.wireguard.android.databinding.TvFileListItemBinding +import com.wireguard.android.databinding.TvTunnelListItemBinding +import com.wireguard.android.model.ObservableTunnel +import com.wireguard.android.util.ErrorMessages +import com.wireguard.android.util.QuantityFormatter +import com.wireguard.android.util.TunnelImporter +import com.wireguard.android.util.UserKnobs +import com.wireguard.android.util.applicationScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import java.io.File + +class TvMainActivity : AppCompatActivity() { + private val tunnelFileImportResultLauncher = registerForActivityResult(object : ActivityResultContracts.OpenDocument() { + override fun createIntent(context: Context, input: Array<String>): Intent { + val intent = super.createIntent(context, input) + + /* AndroidTV now comes with stubs that do nothing but display a Toast less helpful than + * what we can do, so detect this and throw an exception that we can catch later. */ + val activitiesToResolveIntent = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) { + context.packageManager.queryIntentActivities(intent, PackageManager.ResolveInfoFlags.of(PackageManager.MATCH_DEFAULT_ONLY.toLong())) + } else { + @Suppress("DEPRECATION") + context.packageManager.queryIntentActivities(intent, PackageManager.MATCH_DEFAULT_ONLY) + } + if (activitiesToResolveIntent.all { + val name = it.activityInfo.packageName + name.startsWith("com.google.android.tv.frameworkpackagestubs") || name.startsWith("com.android.tv.frameworkpackagestubs") + }) { + throw ActivityNotFoundException() + } + return intent + } + }) { data -> + if (data == null) return@registerForActivityResult + lifecycleScope.launch { + TunnelImporter.importTunnel(contentResolver, data) { + Toast.makeText(this@TvMainActivity, it, Toast.LENGTH_LONG).show() + } + } + } + private var pendingTunnel: ObservableTunnel? = null + private val permissionActivityResultLauncher = registerForActivityResult(ActivityResultContracts.StartActivityForResult()) { + val tunnel = pendingTunnel + if (tunnel != null) + setTunnelStateWithPermissionsResult(tunnel) + pendingTunnel = null + } + + private fun setTunnelStateWithPermissionsResult(tunnel: ObservableTunnel) { + lifecycleScope.launch { + try { + tunnel.setStateAsync(Tunnel.State.TOGGLE) + } catch (e: Throwable) { + val error = ErrorMessages[e] + val message = getString(R.string.error_up, error) + Toast.makeText(this@TvMainActivity, message, Toast.LENGTH_LONG).show() + Log.e(TAG, message, e) + } + updateStats() + } + } + + private lateinit var binding: TvActivityBinding + private val isDeleting = ObservableBoolean() + private val files = ObservableKeyedArrayList<String, KeyedFile>() + private val filesRoot = ObservableField("") + + override fun onCreate(savedInstanceState: Bundle?) { + if (AppCompatDelegate.getDefaultNightMode() != AppCompatDelegate.MODE_NIGHT_YES) { + AppCompatDelegate.setDefaultNightMode(AppCompatDelegate.MODE_NIGHT_YES) + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.Q) { + applicationScope.launch { + UserKnobs.setDarkTheme(true) + } + } + } + super.onCreate(savedInstanceState) + binding = TvActivityBinding.inflate(layoutInflater) + lifecycleScope.launch { + binding.tunnels = Application.getTunnelManager().getTunnels() + if (binding.tunnels?.isEmpty() == true) + binding.importButton.requestFocus() + else + binding.tunnelList.requestFocus() + } + binding.isDeleting = isDeleting + binding.files = files + binding.filesRoot = filesRoot + val gridManager = binding.tunnelList.layoutManager as GridLayoutManager + gridManager.spanSizeLookup = SlatedSpanSizeLookup(gridManager) + binding.tunnelRowConfigurationHandler = object : ObservableKeyedRecyclerViewAdapter.RowConfigurationHandler<TvTunnelListItemBinding, ObservableTunnel> { + override fun onConfigureRow(binding: TvTunnelListItemBinding, item: ObservableTunnel, position: Int) { + binding.isDeleting = isDeleting + binding.isFocused = ObservableBoolean() + binding.root.setOnFocusChangeListener { _, focused -> + binding.isFocused?.set(focused) + } + binding.root.setOnClickListener { + lifecycleScope.launch { + if (isDeleting.get()) { + try { + item.deleteAsync() + if (this@TvMainActivity.binding.tunnels?.isEmpty() != false) + isDeleting.set(false) + } catch (e: Throwable) { + val error = ErrorMessages[e] + val message = getString(R.string.config_delete_error, error) + Toast.makeText(this@TvMainActivity, message, Toast.LENGTH_LONG).show() + Log.e(TAG, message, e) + } + } else { + if (Application.getBackend() is GoBackend) { + val intent = GoBackend.VpnService.prepare(binding.root.context) + if (intent != null) { + pendingTunnel = item + permissionActivityResultLauncher.launch(intent) + return@launch + } + } + setTunnelStateWithPermissionsResult(item) + } + } + } + } + } + + binding.filesRowConfigurationHandler = object : ObservableKeyedRecyclerViewAdapter.RowConfigurationHandler<TvFileListItemBinding, KeyedFile> { + override fun onConfigureRow(binding: TvFileListItemBinding, item: KeyedFile, position: Int) { + binding.root.setOnClickListener { + if (item.file.isDirectory) + navigateTo(item.file) + else { + val uri = Uri.fromFile(item.file) + files.clear() + filesRoot.set("") + lifecycleScope.launch { + TunnelImporter.importTunnel(contentResolver, uri) { + Toast.makeText(this@TvMainActivity, it, Toast.LENGTH_LONG).show() + } + } + runOnUiThread { + this@TvMainActivity.binding.tunnelList.requestFocus() + } + } + } + } + } + + binding.importButton.setOnClickListener { + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.Q) { + if (filesRoot.get()?.isEmpty() != false) { + navigateTo(File("/")) + runOnUiThread { + binding.filesList.requestFocus() + } + } else { + files.clear() + filesRoot.set("") + runOnUiThread { + binding.tunnelList.requestFocus() + } + } + } else { + try { + tunnelFileImportResultLauncher.launch(arrayOf("*/*")) + } catch (_: Throwable) { + MaterialAlertDialogBuilder(binding.root.context).setMessage(R.string.tv_no_file_picker).setCancelable(false) + .setPositiveButton(android.R.string.ok) { _, _ -> + try { + startActivity(Intent(Intent.ACTION_VIEW).apply { + data = Uri.parse("https://play.google.com/store/apps/details?id=com.cxinventor.file.explorer") + setPackage("com.android.vending") + }) + } catch (_: Throwable) { + } + }.show() + } + } + } + + binding.deleteButton.setOnClickListener { + isDeleting.set(!isDeleting.get()) + runOnUiThread { + binding.tunnelList.requestFocus() + } + } + + val backPressedCallback = onBackPressedDispatcher.addCallback(this) { handleBackPressed() } + val updateBackPressedCallback = object : Observable.OnPropertyChangedCallback() { + override fun onPropertyChanged(sender: Observable?, propertyId: Int) { + backPressedCallback.isEnabled = isDeleting.get() || filesRoot.get()?.isNotEmpty() == true + } + } + isDeleting.addOnPropertyChangedCallback(updateBackPressedCallback) + filesRoot.addOnPropertyChangedCallback(updateBackPressedCallback) + backPressedCallback.isEnabled = false + + binding.executePendingBindings() + setContentView(binding.root) + + lifecycleScope.launch { + while (true) { + updateStats() + delay(1000) + } + } + } + + private var pendingNavigation: File? = null + private val permissionRequestPermissionLauncher = registerForActivityResult(ActivityResultContracts.RequestPermission()) { + val to = pendingNavigation + if (it && to != null) + navigateTo(to) + pendingNavigation = null + } + + private var cachedRoots: Collection<KeyedFile>? = null + + private suspend fun makeStorageRoots(): Collection<KeyedFile> = withContext(Dispatchers.IO) { + cachedRoots?.let { return@withContext it } + val list = HashSet<KeyedFile>() + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { + val storageManager: StorageManager = getSystemService() ?: return@withContext list + list.addAll(storageManager.storageVolumes.mapNotNull { volume -> + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.R) { + volume.directory?.let { KeyedFile(it, volume.getDescription(this@TvMainActivity)) } + } else { + KeyedFile((StorageVolume::class.java.getMethod("getPathFile").invoke(volume) as File), volume.getDescription(this@TvMainActivity)) + } + }) + } else { + @Suppress("DEPRECATION") + list.add(KeyedFile(Environment.getExternalStorageDirectory())) + try { + File("/storage").listFiles()?.forEach { + if (!it.isDirectory) return@forEach + try { + if (Environment.isExternalStorageRemovable(it)) { + list.add(KeyedFile(it)) + } + } catch (_: Throwable) { + } + } + } catch (_: Throwable) { + } + } + cachedRoots = list + list + } + + private fun isBelowCachedRoots(maybeChild: File): Boolean { + val cachedRoots = cachedRoots ?: return true + for (root in cachedRoots) { + if (maybeChild.canonicalPath.startsWith(root.file.canonicalPath)) + return false + } + return true + } + + private fun navigateTo(directory: File) { + require(Build.VERSION.SDK_INT < Build.VERSION_CODES.Q) + + if (ContextCompat.checkSelfPermission(this, Manifest.permission.READ_EXTERNAL_STORAGE) != PackageManager.PERMISSION_GRANTED) { + pendingNavigation = directory + permissionRequestPermissionLauncher.launch(Manifest.permission.READ_EXTERNAL_STORAGE) + return + } + + lifecycleScope.launch { + if (isBelowCachedRoots(directory)) { + val roots = makeStorageRoots() + if (roots.count() == 1) { + navigateTo(roots.first().file) + return@launch + } + files.clear() + files.addAll(roots) + filesRoot.set(getString(R.string.tv_select_a_storage_drive)) + return@launch + } + + val newFiles = withContext(Dispatchers.IO) { + val newFiles = ArrayList<KeyedFile>() + try { + directory.parentFile?.let { + newFiles.add(KeyedFile(it, "../")) + } + val listing = directory.listFiles() ?: return@withContext null + listing.forEach { + if (it.extension == "conf" || it.extension == "zip" || it.isDirectory) + newFiles.add(KeyedFile(it)) + } + newFiles.sortWith { a, b -> + if (a.file.isDirectory && !b.file.isDirectory) -1 + else if (!a.file.isDirectory && b.file.isDirectory) 1 + else a.file.compareTo(b.file) + } + } catch (e: Throwable) { + Log.e(TAG, Log.getStackTraceString(e)) + } + newFiles + } + if (newFiles?.isEmpty() != false) + return@launch + files.clear() + files.addAll(newFiles) + filesRoot.set(directory.canonicalPath) + } + } + + private fun handleBackPressed() { + when { + isDeleting.get() -> { + isDeleting.set(false) + runOnUiThread { + binding.tunnelList.requestFocus() + } + } + + filesRoot.get()?.isNotEmpty() == true -> { + files.clear() + filesRoot.set("") + runOnUiThread { + binding.tunnelList.requestFocus() + } + } + } + } + + private suspend fun updateStats() { + binding.tunnelList.forEach { viewItem -> + val listItem = DataBindingUtil.findBinding<TvTunnelListItemBinding>(viewItem) + ?: return@forEach + try { + val tunnel = listItem.item!! + if (tunnel.state != Tunnel.State.UP || isDeleting.get()) { + throw Exception() + } + val statistics = tunnel.getStatisticsAsync() + val rx = statistics.totalRx() + val tx = statistics.totalTx() + listItem.tunnelTransfer.text = getString(R.string.transfer_rx_tx, QuantityFormatter.formatBytes(rx), QuantityFormatter.formatBytes(tx)) + listItem.tunnelTransfer.visibility = View.VISIBLE + } catch (_: Throwable) { + listItem.tunnelTransfer.visibility = View.GONE + listItem.tunnelTransfer.text = "" + } + } + } + + class KeyedFile(val file: File, private val forcedKey: String? = null) : Keyed<String> { + override val key: String + get() = forcedKey ?: if (file.isDirectory) "${file.name}/" else file.name + } + + private class SlatedSpanSizeLookup(private val gridManager: GridLayoutManager) : SpanSizeLookup() { + private val originalHeight = gridManager.spanCount + private var newWidth = 0 + private lateinit var sizeMap: Array<IntArray?> + + private fun emptyUnderIndex(index: Int, size: Int): Int { + sizeMap[size - 1]?.let { return it[index] } + val sizes = IntArray(size) + val oh = originalHeight + val nw = newWidth + var empties = 0 + for (i in 0 until size) { + val ox = (i + empties) / oh + val oy = (i + empties) % oh + var empty = 0 + for (j in oy + 1 until oh) { + val ni = nw * j + ox + if (ni < size) + break + empty++ + } + empties += empty + sizes[i] = empty + } + sizeMap[size - 1] = sizes + return sizes[index] + } + + override fun getSpanSize(position: Int): Int { + if (newWidth == 0) { + val child = gridManager.getChildAt(0) ?: return 1 + if (child.width == 0) return 1 + newWidth = gridManager.width / child.width + sizeMap = Array(originalHeight * newWidth - 1) { null } + } + val total = gridManager.itemCount + if (total >= originalHeight * newWidth || total == 0) + return 1 + return emptyUnderIndex(position, total) + 1 + } + } + + companion object { + private const val TAG = "WireGuard/TvMainActivity" + } +} diff --git a/ui/src/main/java/com/wireguard/android/configStore/ConfigStore.kt b/ui/src/main/java/com/wireguard/android/configStore/ConfigStore.kt index 6c381b69..45f38600 100644 --- a/ui/src/main/java/com/wireguard/android/configStore/ConfigStore.kt +++ b/ui/src/main/java/com/wireguard/android/configStore/ConfigStore.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.configStore diff --git a/ui/src/main/java/com/wireguard/android/configStore/FileConfigStore.kt b/ui/src/main/java/com/wireguard/android/configStore/FileConfigStore.kt index 6fe2bece..98b738e1 100644 --- a/ui/src/main/java/com/wireguard/android/configStore/FileConfigStore.kt +++ b/ui/src/main/java/com/wireguard/android/configStore/FileConfigStore.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.configStore @@ -40,9 +40,9 @@ class FileConfigStore(private val context: Context) : ConfigStore { override fun enumerate(): Set<String> { return context.fileList() - .filter { it.endsWith(".conf") } - .map { it.substring(0, it.length - ".conf".length) } - .toSet() + .filter { it.endsWith(".conf") } + .map { it.substring(0, it.length - ".conf".length) } + .toSet() } private fun fileFor(name: String): File { diff --git a/ui/src/main/java/com/wireguard/android/databinding/BindingAdapters.kt b/ui/src/main/java/com/wireguard/android/databinding/BindingAdapters.kt index 055c2f06..df3bd08b 100644 --- a/ui/src/main/java/com/wireguard/android/databinding/BindingAdapters.kt +++ b/ui/src/main/java/com/wireguard/android/databinding/BindingAdapters.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.databinding @@ -23,10 +23,11 @@ import com.wireguard.android.R import com.wireguard.android.databinding.ObservableKeyedRecyclerViewAdapter.RowConfigurationHandler import com.wireguard.android.widget.ToggleSwitch import com.wireguard.android.widget.ToggleSwitch.OnBeforeCheckedChangeListener +import com.wireguard.android.widget.TvCardView import com.wireguard.config.Attribute import com.wireguard.config.InetNetwork -import java9.util.Optional import java.net.InetAddress +import java.util.Optional /** * Static methods for use by generated code in the Android data binding library. @@ -46,9 +47,11 @@ object BindingAdapters { @JvmStatic @BindingAdapter("items", "layout", "fragment") - fun <E> setItems(view: LinearLayout, - oldList: ObservableList<E>?, oldLayoutId: Int, @Suppress("UNUSED_PARAMETER") oldFragment: Fragment?, - newList: ObservableList<E>?, newLayoutId: Int, newFragment: Fragment?) { + fun <E> setItems( + view: LinearLayout, + oldList: ObservableList<E>?, oldLayoutId: Int, @Suppress("UNUSED_PARAMETER") oldFragment: Fragment?, + newList: ObservableList<E>?, newLayoutId: Int, newFragment: Fragment? + ) { if (oldList === newList && oldLayoutId == newLayoutId) return var listener: ItemChangeListener<E>? = ListenerUtil.getListener(view, R.id.item_change_listener) @@ -72,9 +75,11 @@ object BindingAdapters { @JvmStatic @BindingAdapter("items", "layout") - fun <E> setItems(view: LinearLayout, - oldList: Iterable<E>?, oldLayoutId: Int, - newList: Iterable<E>?, newLayoutId: Int) { + fun <E> setItems( + view: LinearLayout, + oldList: Iterable<E>?, oldLayoutId: Int, + newList: Iterable<E>?, newLayoutId: Int + ) { if (oldList === newList && oldLayoutId == newLayoutId) return view.removeAllViews() @@ -92,11 +97,13 @@ object BindingAdapters { @JvmStatic @BindingAdapter(requireAll = false, value = ["items", "layout", "configurationHandler"]) - fun <K, E : Keyed<out K>> setItems(view: RecyclerView, - oldList: ObservableKeyedArrayList<K, E>?, oldLayoutId: Int, - @Suppress("UNUSED_PARAMETER") oldRowConfigurationHandler: RowConfigurationHandler<*, *>?, - newList: ObservableKeyedArrayList<K, E>?, newLayoutId: Int, - newRowConfigurationHandler: RowConfigurationHandler<*, *>?) { + fun <K, E : Keyed<out K>> setItems( + view: RecyclerView, + oldList: ObservableKeyedArrayList<K, E>?, oldLayoutId: Int, + @Suppress("UNUSED_PARAMETER") oldRowConfigurationHandler: RowConfigurationHandler<*, *>?, + newList: ObservableKeyedArrayList<K, E>?, newLayoutId: Int, + newRowConfigurationHandler: RowConfigurationHandler<*, *>? + ) { if (view.layoutManager == null) view.layoutManager = LinearLayoutManager(view.context, RecyclerView.VERTICAL, false) if (oldList === newList && oldLayoutId == newLayoutId) @@ -122,16 +129,20 @@ object BindingAdapters { @JvmStatic @BindingAdapter("onBeforeCheckedChanged") - fun setOnBeforeCheckedChanged(view: ToggleSwitch, - listener: OnBeforeCheckedChangeListener?) { + fun setOnBeforeCheckedChanged( + view: ToggleSwitch, + listener: OnBeforeCheckedChangeListener? + ) { view.setOnBeforeCheckedChangeListener(listener) } @JvmStatic @BindingAdapter("onFocusChange") - fun setOnFocusChange(view: EditText, - listener: View.OnFocusChangeListener?) { - view.setOnFocusChangeListener(listener) + fun setOnFocusChange( + view: EditText, + listener: View.OnFocusChangeListener? + ) { + view.onFocusChangeListener = listener } @JvmStatic @@ -153,13 +164,31 @@ object BindingAdapters { } @JvmStatic + @BindingAdapter("android:text") + fun setStringSetText(view: TextView, strings: Iterable<String?>?) { + view.text = if (strings != null) Attribute.join(strings) else "" + } + + @JvmStatic fun tryParseInt(s: String?): Int { if (s == null) return 0 return try { Integer.parseInt(s) - } catch (_: Exception) { + } catch (_: Throwable) { 0 } } + + @JvmStatic + @BindingAdapter("isUp") + fun setIsUp(card: TvCardView, up: Boolean) { + card.isUp = up + } + + @JvmStatic + @BindingAdapter("isDeleting") + fun setIsDeleting(card: TvCardView, deleting: Boolean) { + card.isDeleting = deleting + } } diff --git a/ui/src/main/java/com/wireguard/android/databinding/ItemChangeListener.kt b/ui/src/main/java/com/wireguard/android/databinding/ItemChangeListener.kt index 29784a75..84ec3ed8 100644 --- a/ui/src/main/java/com/wireguard/android/databinding/ItemChangeListener.kt +++ b/ui/src/main/java/com/wireguard/android/databinding/ItemChangeListener.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.databinding @@ -61,8 +61,10 @@ internal class ItemChangeListener<T>(private val container: ViewGroup, private v } } - override fun onItemRangeChanged(sender: ObservableList<T>, positionStart: Int, - itemCount: Int) { + override fun onItemRangeChanged( + sender: ObservableList<T>, positionStart: Int, + itemCount: Int + ) { val listener = weakListener.get() if (listener != null) { for (i in positionStart until positionStart + itemCount) { @@ -75,8 +77,10 @@ internal class ItemChangeListener<T>(private val container: ViewGroup, private v } } - override fun onItemRangeInserted(sender: ObservableList<T>, positionStart: Int, - itemCount: Int) { + override fun onItemRangeInserted( + sender: ObservableList<T>, positionStart: Int, + itemCount: Int + ) { val listener = weakListener.get() if (listener != null) { for (i in positionStart until positionStart + itemCount) @@ -86,8 +90,10 @@ internal class ItemChangeListener<T>(private val container: ViewGroup, private v } } - override fun onItemRangeMoved(sender: ObservableList<T>, fromPosition: Int, - toPosition: Int, itemCount: Int) { + override fun onItemRangeMoved( + sender: ObservableList<T>, fromPosition: Int, + toPosition: Int, itemCount: Int + ) { val listener = weakListener.get() if (listener != null) { val views = arrayOfNulls<View>(itemCount) @@ -99,8 +105,10 @@ internal class ItemChangeListener<T>(private val container: ViewGroup, private v } } - override fun onItemRangeRemoved(sender: ObservableList<T>, positionStart: Int, - itemCount: Int) { + override fun onItemRangeRemoved( + sender: ObservableList<T>, positionStart: Int, + itemCount: Int + ) { val listener = weakListener.get() if (listener != null) { listener.container.removeViews(positionStart, itemCount) diff --git a/ui/src/main/java/com/wireguard/android/databinding/Keyed.kt b/ui/src/main/java/com/wireguard/android/databinding/Keyed.kt index dd03e4c9..fc4ee357 100644 --- a/ui/src/main/java/com/wireguard/android/databinding/Keyed.kt +++ b/ui/src/main/java/com/wireguard/android/databinding/Keyed.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2020 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.databinding diff --git a/ui/src/main/java/com/wireguard/android/databinding/ObservableKeyedArrayList.kt b/ui/src/main/java/com/wireguard/android/databinding/ObservableKeyedArrayList.kt index c00f553c..4d6c3a21 100644 --- a/ui/src/main/java/com/wireguard/android/databinding/ObservableKeyedArrayList.kt +++ b/ui/src/main/java/com/wireguard/android/databinding/ObservableKeyedArrayList.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2020 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.databinding diff --git a/ui/src/main/java/com/wireguard/android/databinding/ObservableKeyedRecyclerViewAdapter.kt b/ui/src/main/java/com/wireguard/android/databinding/ObservableKeyedRecyclerViewAdapter.kt index 29d09fbc..a9ef4913 100644 --- a/ui/src/main/java/com/wireguard/android/databinding/ObservableKeyedRecyclerViewAdapter.kt +++ b/ui/src/main/java/com/wireguard/android/databinding/ObservableKeyedRecyclerViewAdapter.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.databinding diff --git a/ui/src/main/java/com/wireguard/android/databinding/ObservableSortedKeyedArrayList.kt b/ui/src/main/java/com/wireguard/android/databinding/ObservableSortedKeyedArrayList.kt index 98e9e915..d6c039f2 100644 --- a/ui/src/main/java/com/wireguard/android/databinding/ObservableSortedKeyedArrayList.kt +++ b/ui/src/main/java/com/wireguard/android/databinding/ObservableSortedKeyedArrayList.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2020 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.databinding diff --git a/ui/src/main/java/com/wireguard/android/fragment/AddTunnelsSheet.kt b/ui/src/main/java/com/wireguard/android/fragment/AddTunnelsSheet.kt index 17bdac04..f077cbae 100644 --- a/ui/src/main/java/com/wireguard/android/fragment/AddTunnelsSheet.kt +++ b/ui/src/main/java/com/wireguard/android/fragment/AddTunnelsSheet.kt @@ -1,10 +1,10 @@ /* - * Copyright © 2020 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.fragment -import android.content.Intent +import android.content.pm.PackageManager import android.graphics.drawable.GradientDrawable import android.os.Bundle import android.view.LayoutInflater @@ -12,13 +12,12 @@ import android.view.View import android.view.ViewGroup import android.view.ViewTreeObserver import android.widget.FrameLayout +import androidx.core.os.bundleOf +import androidx.fragment.app.setFragmentResult import com.google.android.material.bottomsheet.BottomSheetBehavior import com.google.android.material.bottomsheet.BottomSheetDialog import com.google.android.material.bottomsheet.BottomSheetDialogFragment -import com.google.zxing.integration.android.IntentIntegrator import com.wireguard.android.R -import com.wireguard.android.activity.TunnelCreatorActivity -import com.wireguard.android.util.requireTargetFragment import com.wireguard.android.util.resolveAttribute class AddTunnelsSheet : BottomSheetDialogFragment() { @@ -35,13 +34,15 @@ class AddTunnelsSheet : BottomSheetDialogFragment() { } } - override fun getTheme(): Int { - return R.style.BottomSheetDialogTheme - } - override fun onCreateView(inflater: LayoutInflater, container: ViewGroup?, savedInstanceState: Bundle?): View? { if (savedInstanceState != null) dismiss() - return inflater.inflate(R.layout.add_tunnels_bottom_sheet, container, false) + val view = inflater.inflate(R.layout.add_tunnels_bottom_sheet, container, false) + if (activity?.packageManager?.hasSystemFeature(PackageManager.FEATURE_CAMERA_ANY) != true) { + val qrcode = view.findViewById<View>(R.id.create_from_qrcode) + qrcode.isEnabled = false + qrcode.visibility = View.GONE + } + return view } override fun onViewCreated(view: View, savedInstanceState: Bundle?) { @@ -71,7 +72,7 @@ class AddTunnelsSheet : BottomSheetDialogFragment() { } }) val gradientDrawable = GradientDrawable().apply { - setColor(requireContext().resolveAttribute(R.attr.colorBackground)) + setColor(requireContext().resolveAttribute(com.google.android.material.R.attr.colorSurface)) } view.background = gradientDrawable } @@ -82,23 +83,22 @@ class AddTunnelsSheet : BottomSheetDialogFragment() { } private fun onRequestCreateConfig() { - startActivity(Intent(activity, TunnelCreatorActivity::class.java)) + setFragmentResult(REQUEST_KEY_NEW_TUNNEL, bundleOf(REQUEST_METHOD to REQUEST_CREATE)) } private fun onRequestImportConfig() { - val intent = Intent(Intent.ACTION_GET_CONTENT).apply { - addCategory(Intent.CATEGORY_OPENABLE) - type = "*/*" - } - requireTargetFragment().startActivityForResult(intent, TunnelListFragment.REQUEST_IMPORT) + setFragmentResult(REQUEST_KEY_NEW_TUNNEL, bundleOf(REQUEST_METHOD to REQUEST_IMPORT)) } private fun onRequestScanQRCode() { - val integrator = IntentIntegrator.forSupportFragment(requireTargetFragment()).apply { - setOrientationLocked(false) - setBeepEnabled(false) - setPrompt(getString(R.string.qr_code_hint)) - } - integrator.initiateScan(listOf(IntentIntegrator.QR_CODE)) + setFragmentResult(REQUEST_KEY_NEW_TUNNEL, bundleOf(REQUEST_METHOD to REQUEST_SCAN)) + } + + companion object { + const val REQUEST_KEY_NEW_TUNNEL = "request_new_tunnel" + const val REQUEST_METHOD = "request_method" + const val REQUEST_CREATE = "request_create" + const val REQUEST_IMPORT = "request_import" + const val REQUEST_SCAN = "request_scan" } } diff --git a/ui/src/main/java/com/wireguard/android/fragment/AppListDialogFragment.kt b/ui/src/main/java/com/wireguard/android/fragment/AppListDialogFragment.kt index 35bd3ce9..692dd809 100644 --- a/ui/src/main/java/com/wireguard/android/fragment/AppListDialogFragment.kt +++ b/ui/src/main/java/com/wireguard/android/fragment/AppListDialogFragment.kt @@ -1,27 +1,35 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.fragment +import android.Manifest import android.app.Dialog -import android.content.Intent +import android.content.pm.PackageInfo +import android.content.pm.PackageManager +import android.content.pm.PackageManager.PackageInfoFlags +import android.os.Build import android.os.Bundle import android.widget.Button import android.widget.Toast import androidx.appcompat.app.AlertDialog +import androidx.core.os.bundleOf import androidx.databinding.Observable import androidx.fragment.app.DialogFragment -import androidx.fragment.app.Fragment +import androidx.fragment.app.setFragmentResult +import androidx.lifecycle.lifecycleScope +import com.google.android.material.dialog.MaterialAlertDialogBuilder import com.google.android.material.tabs.TabLayout -import com.wireguard.android.Application import com.wireguard.android.BR import com.wireguard.android.R import com.wireguard.android.databinding.AppListDialogFragmentBinding import com.wireguard.android.databinding.ObservableKeyedArrayList import com.wireguard.android.model.ApplicationData import com.wireguard.android.util.ErrorMessages -import com.wireguard.android.util.requireTargetFragment +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext class AppListDialogFragment : DialogFragment() { private val appData = ObservableKeyedArrayList<String, ApplicationData>() @@ -33,44 +41,57 @@ class AppListDialogFragment : DialogFragment() { private fun loadData() { val activity = activity ?: return val pm = activity.packageManager - Application.getAsyncWorker().supplyAsync<List<ApplicationData>> { - val launcherIntent = Intent(Intent.ACTION_MAIN, null) - launcherIntent.addCategory(Intent.CATEGORY_LAUNCHER) - val resolveInfos = pm.queryIntentActivities(launcherIntent, 0) - val applicationData: MutableList<ApplicationData> = ArrayList() - resolveInfos.forEach { - val packageName = it.activityInfo.packageName - val appData = ApplicationData(it.loadIcon(pm), it.loadLabel(pm).toString(), packageName, currentlySelectedApps.contains(packageName)) - applicationData.add(appData) - appData.addOnPropertyChangedCallback(object : Observable.OnPropertyChangedCallback() { - override fun onPropertyChanged(sender: Observable?, propertyId: Int) { - if (propertyId == BR.selected) - setButtonText() + lifecycleScope.launch(Dispatchers.Default) { + try { + val applicationData: MutableList<ApplicationData> = ArrayList() + withContext(Dispatchers.IO) { + val packageInfos = getPackagesHoldingPermissions(pm, arrayOf(Manifest.permission.INTERNET)) + packageInfos.forEach { + val packageName = it.packageName + val appInfo = it.applicationInfo ?: return@forEach + val appData = + ApplicationData(appInfo.loadIcon(pm), appInfo.loadLabel(pm).toString(), packageName, currentlySelectedApps.contains(packageName)) + applicationData.add(appData) + appData.addOnPropertyChangedCallback(object : Observable.OnPropertyChangedCallback() { + override fun onPropertyChanged(sender: Observable?, propertyId: Int) { + if (propertyId == BR.selected) + setButtonText() + } + }) } - }) - } - applicationData.sortWith(compareBy(String.CASE_INSENSITIVE_ORDER) { it.name }) - applicationData - }.whenComplete { data, throwable -> - if (data != null) { - appData.clear() - appData.addAll(data) - } else { - val error = ErrorMessages[throwable] - val message = activity.getString(R.string.error_fetching_apps, error) - Toast.makeText(activity, message, Toast.LENGTH_LONG).show() - dismissAllowingStateLoss() + } + applicationData.sortWith(compareBy(String.CASE_INSENSITIVE_ORDER) { it.name }) + withContext(Dispatchers.Main.immediate) { + appData.clear() + appData.addAll(applicationData) + setButtonText() + } + } catch (e: Throwable) { + withContext(Dispatchers.Main.immediate) { + val error = ErrorMessages[e] + val message = activity.getString(R.string.error_fetching_apps, error) + Toast.makeText(activity, message, Toast.LENGTH_LONG).show() + dismissAllowingStateLoss() + } } } } override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) - require(requireTargetFragment() is AppSelectionListener) { "${requireTargetFragment()} must implement AppSelectionListener" } currentlySelectedApps = (arguments?.getStringArrayList(KEY_SELECTED_APPS) ?: emptyList()) initiallyExcluded = arguments?.getBoolean(KEY_IS_EXCLUDED) ?: true } + private fun getPackagesHoldingPermissions(pm: PackageManager, permissions: Array<String>): List<PackageInfo> { + return if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) { + pm.getPackagesHoldingPermissions(permissions, PackageInfoFlags.of(0L)) + } else { + @Suppress("DEPRECATION") + pm.getPackagesHoldingPermissions(permissions, 0) + } + } + private fun setButtonText() { val numSelected = appData.count { it.isSelected } button?.text = if (numSelected == 0) @@ -83,7 +104,7 @@ class AppListDialogFragment : DialogFragment() { } override fun onCreateDialog(savedInstanceState: Bundle?): Dialog { - val alertDialogBuilder = AlertDialog.Builder(requireActivity()) + val alertDialogBuilder = MaterialAlertDialogBuilder(requireActivity()) val binding = AppListDialogFragmentBinding.inflate(requireActivity().layoutInflater, null, false) binding.executePendingBindings() alertDialogBuilder.setView(binding.root) @@ -123,23 +144,25 @@ class AppListDialogFragment : DialogFragment() { selectedApps.add(data.packageName) } } - (requireTargetFragment() as AppSelectionListener).onSelectedAppsSelected(selectedApps, tabs?.selectedTabPosition == 0) + setFragmentResult( + REQUEST_SELECTION, bundleOf( + KEY_SELECTED_APPS to selectedApps.toTypedArray(), + KEY_IS_EXCLUDED to (tabs?.selectedTabPosition == 0) + ) + ) dismiss() } - interface AppSelectionListener { - fun onSelectedAppsSelected(selectedApps: List<String>, isExcluded: Boolean) - } - companion object { - private const val KEY_SELECTED_APPS = "selected_apps" - private const val KEY_IS_EXCLUDED = "is_excluded" - fun <T> newInstance(selectedApps: ArrayList<String?>?, isExcluded: Boolean, target: T): AppListDialogFragment where T : Fragment?, T : AppSelectionListener? { + const val KEY_SELECTED_APPS = "selected_apps" + const val KEY_IS_EXCLUDED = "is_excluded" + const val REQUEST_SELECTION = "request_selection" + + fun newInstance(selectedApps: ArrayList<String?>?, isExcluded: Boolean): AppListDialogFragment { val extras = Bundle() extras.putStringArrayList(KEY_SELECTED_APPS, selectedApps) extras.putBoolean(KEY_IS_EXCLUDED, isExcluded) val fragment = AppListDialogFragment() - fragment.setTargetFragment(target, 0) fragment.arguments = extras return fragment } diff --git a/ui/src/main/java/com/wireguard/android/fragment/BaseFragment.kt b/ui/src/main/java/com/wireguard/android/fragment/BaseFragment.kt index 82802623..2e551f83 100644 --- a/ui/src/main/java/com/wireguard/android/fragment/BaseFragment.kt +++ b/ui/src/main/java/com/wireguard/android/fragment/BaseFragment.kt @@ -1,66 +1,60 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.fragment import android.content.Context -import android.content.Intent import android.util.Log import android.view.View import android.widget.Toast +import androidx.activity.result.contract.ActivityResultContracts import androidx.databinding.DataBindingUtil import androidx.databinding.ViewDataBinding import androidx.fragment.app.Fragment +import androidx.lifecycle.lifecycleScope import com.google.android.material.snackbar.Snackbar import com.wireguard.android.Application import com.wireguard.android.R import com.wireguard.android.activity.BaseActivity import com.wireguard.android.activity.BaseActivity.OnSelectedTunnelChangedListener -import com.wireguard.android.backend.Backend import com.wireguard.android.backend.GoBackend import com.wireguard.android.backend.Tunnel import com.wireguard.android.databinding.TunnelDetailFragmentBinding import com.wireguard.android.databinding.TunnelListItemBinding import com.wireguard.android.model.ObservableTunnel import com.wireguard.android.util.ErrorMessages +import kotlinx.coroutines.launch /** * Base class for fragments that need to know the currently-selected tunnel. Only does anything when * attached to a `BaseActivity`. */ abstract class BaseFragment : Fragment(), OnSelectedTunnelChangedListener { - private var baseActivity: BaseActivity? = null private var pendingTunnel: ObservableTunnel? = null private var pendingTunnelUp: Boolean? = null + private val permissionActivityResultLauncher = registerForActivityResult(ActivityResultContracts.StartActivityForResult()) { + val tunnel = pendingTunnel + val checked = pendingTunnelUp + if (tunnel != null && checked != null) + setTunnelStateWithPermissionsResult(tunnel, checked) + pendingTunnel = null + pendingTunnelUp = null + } + protected var selectedTunnel: ObservableTunnel? - get() = baseActivity?.selectedTunnel + get() = (activity as? BaseActivity)?.selectedTunnel protected set(tunnel) { - baseActivity?.selectedTunnel = tunnel + (activity as? BaseActivity)?.selectedTunnel = tunnel } - override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) { - super.onActivityResult(requestCode, resultCode, data) - if (requestCode == REQUEST_CODE_VPN_PERMISSION) { - if (pendingTunnel != null && pendingTunnelUp != null) setTunnelStateWithPermissionsResult(pendingTunnel!!, pendingTunnelUp!!) - pendingTunnel = null - pendingTunnelUp = null - } - } - override fun onAttach(context: Context) { super.onAttach(context) - if (context is BaseActivity) { - baseActivity = context - baseActivity?.addOnSelectedTunnelChangedListener(this) - } else { - baseActivity = null - } + (activity as? BaseActivity)?.addOnSelectedTunnelChangedListener(this) } override fun onDetach() { - baseActivity?.removeOnSelectedTunnelChangedListener(this) - baseActivity = null + (activity as? BaseActivity)?.removeOnSelectedTunnelChangedListener(this) super.onDetach() } @@ -70,14 +64,23 @@ abstract class BaseFragment : Fragment(), OnSelectedTunnelChangedListener { is TunnelListItemBinding -> binding.item else -> return } ?: return - Application.getBackendAsync().thenAccept { backend: Backend? -> - if (backend is GoBackend) { - val intent = GoBackend.VpnService.prepare(view.context) - if (intent != null) { - pendingTunnel = tunnel - pendingTunnelUp = checked - startActivityForResult(intent, REQUEST_CODE_VPN_PERMISSION) - return@thenAccept + val activity = activity ?: return + activity.lifecycleScope.launch { + if (Application.getBackend() is GoBackend) { + try { + val intent = GoBackend.VpnService.prepare(activity) + if (intent != null) { + pendingTunnel = tunnel + pendingTunnelUp = checked + permissionActivityResultLauncher.launch(intent) + return@launch + } + } catch (e: Throwable) { + val message = activity.getString(R.string.error_prepare, ErrorMessages[e]) + Snackbar.make(view, message, Snackbar.LENGTH_LONG) + .setAnchorView(view.findViewById(R.id.create_fab)) + .show() + Log.e(TAG, message, e) } } setTunnelStateWithPermissionsResult(tunnel, checked) @@ -85,24 +88,27 @@ abstract class BaseFragment : Fragment(), OnSelectedTunnelChangedListener { } private fun setTunnelStateWithPermissionsResult(tunnel: ObservableTunnel, checked: Boolean) { - tunnel.setStateAsync(Tunnel.State.of(checked)).whenComplete { _, throwable -> - if (throwable == null) return@whenComplete - val error = ErrorMessages[throwable] - val messageResId = if (checked) R.string.error_up else R.string.error_down - val message = requireContext().getString(messageResId, error) - val view = view - if (view != null) - Snackbar.make(view, message, Snackbar.LENGTH_LONG) - .setAnchorView(view.findViewById<View>(R.id.create_fab)) + val activity = activity ?: return + activity.lifecycleScope.launch { + try { + tunnel.setStateAsync(Tunnel.State.of(checked)) + } catch (e: Throwable) { + val error = ErrorMessages[e] + val messageResId = if (checked) R.string.error_up else R.string.error_down + val message = activity.getString(messageResId, error) + val view = view + if (view != null) + Snackbar.make(view, message, Snackbar.LENGTH_LONG) + .setAnchorView(view.findViewById(R.id.create_fab)) .show() - else - Toast.makeText(requireContext(), message, Toast.LENGTH_LONG).show() - Log.e(TAG, message, throwable) + else + Toast.makeText(activity, message, Toast.LENGTH_LONG).show() + Log.e(TAG, message, e) + } } } companion object { - private const val REQUEST_CODE_VPN_PERMISSION = 23491 private const val TAG = "WireGuard/BaseFragment" } } diff --git a/ui/src/main/java/com/wireguard/android/fragment/ConfigNamingDialogFragment.kt b/ui/src/main/java/com/wireguard/android/fragment/ConfigNamingDialogFragment.kt index d1b01944..23da3fca 100644 --- a/ui/src/main/java/com/wireguard/android/fragment/ConfigNamingDialogFragment.kt +++ b/ui/src/main/java/com/wireguard/android/fragment/ConfigNamingDialogFragment.kt @@ -1,21 +1,21 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.fragment import android.app.Dialog -import android.content.Context -import android.content.DialogInterface import android.os.Bundle -import android.view.inputmethod.InputMethodManager -import androidx.appcompat.app.AlertDialog +import android.view.WindowManager import androidx.fragment.app.DialogFragment +import androidx.lifecycle.lifecycleScope +import com.google.android.material.dialog.MaterialAlertDialogBuilder import com.wireguard.android.Application import com.wireguard.android.R import com.wireguard.android.databinding.ConfigNamingDialogFragmentBinding import com.wireguard.config.BadConfigException import com.wireguard.config.Config +import kotlinx.coroutines.launch import java.io.ByteArrayInputStream import java.io.IOException import java.nio.charset.StandardCharsets @@ -23,33 +23,28 @@ import java.nio.charset.StandardCharsets class ConfigNamingDialogFragment : DialogFragment() { private var binding: ConfigNamingDialogFragmentBinding? = null private var config: Config? = null - private var imm: InputMethodManager? = null private fun createTunnelAndDismiss() { - binding?.let { - val name = it.tunnelNameText.text.toString() - Application.getTunnelManager().create(name, config).whenComplete { tunnel, throwable -> - if (tunnel != null) { - dismiss() - } else { - it.tunnelNameTextLayout.error = throwable.message - } + val binding = binding ?: return + val activity = activity ?: return + val name = binding.tunnelNameText.text.toString() + activity.lifecycleScope.launch { + try { + Application.getTunnelManager().create(name, config) + dismiss() + } catch (e: Throwable) { + binding.tunnelNameTextLayout.error = e.message } } } - override fun dismiss() { - setKeyboardVisible(false) - super.dismiss() - } - override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) val configText = requireArguments().getString(KEY_CONFIG_TEXT) val configBytes = configText!!.toByteArray(StandardCharsets.UTF_8) config = try { Config.parse(ByteArrayInputStream(configBytes)) - } catch (e: Exception) { + } catch (e: Throwable) { when (e) { is BadConfigException, is IOException -> throw IllegalArgumentException("Invalid config passed to ${javaClass.simpleName}", e) else -> throw e @@ -59,40 +54,23 @@ class ConfigNamingDialogFragment : DialogFragment() { override fun onCreateDialog(savedInstanceState: Bundle?): Dialog { val activity = requireActivity() - imm = activity.getSystemService(Context.INPUT_METHOD_SERVICE) as InputMethodManager - val alertDialogBuilder = AlertDialog.Builder(activity) + val alertDialogBuilder = MaterialAlertDialogBuilder(activity) alertDialogBuilder.setTitle(R.string.import_from_qr_code) binding = ConfigNamingDialogFragmentBinding.inflate(activity.layoutInflater, null, false) binding?.apply { executePendingBindings() alertDialogBuilder.setView(root) } - alertDialogBuilder.setPositiveButton(R.string.create_tunnel, null) + alertDialogBuilder.setPositiveButton(R.string.create_tunnel) { _, _ -> createTunnelAndDismiss() } alertDialogBuilder.setNegativeButton(R.string.cancel) { _, _ -> dismiss() } - return alertDialogBuilder.create() - } - - override fun onResume() { - super.onResume() - val dialog = dialog as AlertDialog? - if (dialog != null) { - dialog.getButton(DialogInterface.BUTTON_POSITIVE).setOnClickListener { createTunnelAndDismiss() } - setKeyboardVisible(true) - } - } - - private fun setKeyboardVisible(visible: Boolean) { - if (visible) { - imm!!.toggleSoftInput(InputMethodManager.SHOW_FORCED, 0) - } else if (binding != null) { - imm!!.hideSoftInputFromWindow(binding!!.tunnelNameText.windowToken, 0) - } + val dialog = alertDialogBuilder.create() + dialog.window?.setSoftInputMode(WindowManager.LayoutParams.SOFT_INPUT_STATE_ALWAYS_VISIBLE) + return dialog } companion object { private const val KEY_CONFIG_TEXT = "config_text" - @JvmStatic fun newInstance(configText: String?): ConfigNamingDialogFragment { val extras = Bundle() extras.putString(KEY_CONFIG_TEXT, configText) diff --git a/ui/src/main/java/com/wireguard/android/fragment/TunnelDetailFragment.kt b/ui/src/main/java/com/wireguard/android/fragment/TunnelDetailFragment.kt index 2b5a4ba6..7731391d 100644 --- a/ui/src/main/java/com/wireguard/android/fragment/TunnelDetailFragment.kt +++ b/ui/src/main/java/com/wireguard/android/fragment/TunnelDetailFragment.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.fragment @@ -8,57 +8,51 @@ import android.os.Bundle import android.view.LayoutInflater import android.view.Menu import android.view.MenuInflater +import android.view.MenuItem import android.view.View import android.view.ViewGroup +import androidx.core.view.MenuProvider import androidx.databinding.DataBindingUtil +import androidx.lifecycle.Lifecycle +import androidx.lifecycle.lifecycleScope import com.wireguard.android.R import com.wireguard.android.backend.Tunnel import com.wireguard.android.databinding.TunnelDetailFragmentBinding import com.wireguard.android.databinding.TunnelDetailPeerBinding import com.wireguard.android.model.ObservableTunnel -import com.wireguard.android.widget.EdgeToEdge.setUpRoot -import com.wireguard.android.widget.EdgeToEdge.setUpScrollingContent -import java.util.Timer -import java.util.TimerTask +import com.wireguard.android.util.QuantityFormatter +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch /** * Fragment that shows details about a specific tunnel. */ -class TunnelDetailFragment : BaseFragment() { +class TunnelDetailFragment : BaseFragment(), MenuProvider { private var binding: TunnelDetailFragmentBinding? = null private var lastState = Tunnel.State.TOGGLE - private var timer: Timer? = null + private var timerActive = true - private fun formatBytes(bytes: Long): String { - val context = requireContext() - return when { - bytes < 1024 -> context.getString(R.string.transfer_bytes, bytes) - bytes < 1024 * 1024 -> context.getString(R.string.transfer_kibibytes, bytes / 1024.0) - bytes < 1024 * 1024 * 1024 -> context.getString(R.string.transfer_mibibytes, bytes / (1024.0 * 1024.0)) - bytes < 1024 * 1024 * 1024 * 1024L -> context.getString(R.string.transfer_gibibytes, bytes / (1024.0 * 1024.0 * 1024.0)) - else -> context.getString(R.string.transfer_tibibytes, bytes / (1024.0 * 1024.0 * 1024.0) / 1024.0) - } - } - - override fun onCreate(savedInstanceState: Bundle?) { - super.onCreate(savedInstanceState) - setHasOptionsMenu(true) + override fun onMenuItemSelected(menuItem: MenuItem): Boolean { + return false } - override fun onCreateOptionsMenu(menu: Menu, inflater: MenuInflater) { - inflater.inflate(R.menu.tunnel_detail, menu) + override fun onCreateMenu(menu: Menu, menuInflater: MenuInflater) { + menuInflater.inflate(R.menu.tunnel_detail, menu) } - override fun onCreateView(inflater: LayoutInflater, container: ViewGroup?, - savedInstanceState: Bundle?): View? { + override fun onCreateView( + inflater: LayoutInflater, container: ViewGroup?, + savedInstanceState: Bundle? + ): View? { super.onCreateView(inflater, container, savedInstanceState) binding = TunnelDetailFragmentBinding.inflate(inflater, container, false) - binding?.apply { - executePendingBindings() - setUpRoot(root as ViewGroup) - setUpScrollingContent(root as ViewGroup, null) - } - return binding!!.root + binding?.executePendingBindings() + return binding?.root + } + + override fun onViewCreated(view: View, savedInstanceState: Bundle?) { + super.onViewCreated(view, savedInstanceState) + requireActivity().addMenuProvider(this, viewLifecycleOwner, Lifecycle.State.RESUMED) } override fun onDestroyView() { @@ -68,28 +62,36 @@ class TunnelDetailFragment : BaseFragment() { override fun onResume() { super.onResume() - timer = Timer() - timer!!.scheduleAtFixedRate(object : TimerTask() { - override fun run() { + timerActive = true + lifecycleScope.launch { + while (timerActive) { updateStats() + delay(1000) } - }, 0, 1000) + } } override fun onSelectedTunnelChanged(oldTunnel: ObservableTunnel?, newTunnel: ObservableTunnel?) { - binding ?: return - binding!!.tunnel = newTunnel - if (newTunnel == null) binding!!.config = null else newTunnel.configAsync.thenAccept { config -> binding!!.config = config } + val binding = binding ?: return + binding.tunnel = newTunnel + if (newTunnel == null) { + binding.config = null + } else { + lifecycleScope.launch { + try { + binding.config = newTunnel.getConfigAsync() + } catch (_: Throwable) { + binding.config = null + } + } + } lastState = Tunnel.State.TOGGLE - updateStats() + lifecycleScope.launch { updateStats() } } override fun onStop() { + timerActive = false super.onStop() - if (timer != null) { - timer!!.cancel() - timer = null - } } override fun onViewStateRestored(savedInstanceState: Bundle?) { @@ -99,36 +101,49 @@ class TunnelDetailFragment : BaseFragment() { super.onViewStateRestored(savedInstanceState) } - private fun updateStats() { - if (binding == null || !isResumed) return - val tunnel = binding!!.tunnel ?: return + private suspend fun updateStats() { + val binding = binding ?: return + val tunnel = binding.tunnel ?: return + if (!isResumed) return val state = tunnel.state if (state != Tunnel.State.UP && lastState == state) return lastState = state - tunnel.statisticsAsync.whenComplete { statistics, throwable -> - if (throwable != null) { - for (i in 0 until binding!!.peersLayout.childCount) { - val peer: TunnelDetailPeerBinding = DataBindingUtil.getBinding(binding!!.peersLayout.getChildAt(i)) - ?: continue - peer.transferLabel.visibility = View.GONE - peer.transferText.visibility = View.GONE - } - return@whenComplete - } - for (i in 0 until binding!!.peersLayout.childCount) { - val peer: TunnelDetailPeerBinding = DataBindingUtil.getBinding(binding!!.peersLayout.getChildAt(i)) - ?: continue + try { + val statistics = tunnel.getStatisticsAsync() + for (i in 0 until binding.peersLayout.childCount) { + val peer: TunnelDetailPeerBinding = DataBindingUtil.getBinding(binding.peersLayout.getChildAt(i)) + ?: continue val publicKey = peer.item!!.publicKey - val rx = statistics.peerRx(publicKey) - val tx = statistics.peerTx(publicKey) - if (rx == 0L && tx == 0L) { + val peerStats = statistics.peer(publicKey) + if (peerStats == null || (peerStats.rxBytes == 0L && peerStats.txBytes == 0L)) { peer.transferLabel.visibility = View.GONE peer.transferText.visibility = View.GONE - continue + } else { + peer.transferText.text = getString( + R.string.transfer_rx_tx, + QuantityFormatter.formatBytes(peerStats.rxBytes), + QuantityFormatter.formatBytes(peerStats.txBytes) + ) + peer.transferLabel.visibility = View.VISIBLE + peer.transferText.visibility = View.VISIBLE } - peer.transferText.text = requireContext().getString(R.string.transfer_rx_tx, formatBytes(rx), formatBytes(tx)) - peer.transferLabel.visibility = View.VISIBLE - peer.transferText.visibility = View.VISIBLE + if (peerStats == null || peerStats.latestHandshakeEpochMillis == 0L) { + peer.latestHandshakeLabel.visibility = View.GONE + peer.latestHandshakeText.visibility = View.GONE + } else { + peer.latestHandshakeText.text = QuantityFormatter.formatEpochAgo(peerStats.latestHandshakeEpochMillis) + peer.latestHandshakeLabel.visibility = View.VISIBLE + peer.latestHandshakeText.visibility = View.VISIBLE + } + } + } catch (e: Throwable) { + for (i in 0 until binding.peersLayout.childCount) { + val peer: TunnelDetailPeerBinding = DataBindingUtil.getBinding(binding.peersLayout.getChildAt(i)) + ?: continue + peer.transferLabel.visibility = View.GONE + peer.transferText.visibility = View.GONE + peer.latestHandshakeLabel.visibility = View.GONE + peer.latestHandshakeText.visibility = View.GONE } } } diff --git a/ui/src/main/java/com/wireguard/android/fragment/TunnelEditorFragment.kt b/ui/src/main/java/com/wireguard/android/fragment/TunnelEditorFragment.kt index dc1b8aa2..f5d28ad5 100644 --- a/ui/src/main/java/com/wireguard/android/fragment/TunnelEditorFragment.kt +++ b/ui/src/main/java/com/wireguard/android/fragment/TunnelEditorFragment.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.fragment @@ -18,118 +18,106 @@ import android.view.WindowManager import android.view.inputmethod.InputMethodManager import android.widget.EditText import android.widget.Toast +import androidx.core.os.BundleCompat +import androidx.core.view.MenuProvider +import androidx.lifecycle.Lifecycle +import androidx.lifecycle.lifecycleScope import com.google.android.material.snackbar.Snackbar import com.wireguard.android.Application import com.wireguard.android.R import com.wireguard.android.backend.Tunnel import com.wireguard.android.databinding.TunnelEditorFragmentBinding -import com.wireguard.android.fragment.AppListDialogFragment.AppSelectionListener import com.wireguard.android.model.ObservableTunnel +import com.wireguard.android.util.AdminKnobs import com.wireguard.android.util.BiometricAuthenticator import com.wireguard.android.util.ErrorMessages import com.wireguard.android.viewmodel.ConfigProxy -import com.wireguard.android.widget.EdgeToEdge.setUpRoot -import com.wireguard.android.widget.EdgeToEdge.setUpScrollingContent import com.wireguard.config.Config +import kotlinx.coroutines.launch /** * Fragment for editing a WireGuard configuration. */ -class TunnelEditorFragment : BaseFragment(), AppSelectionListener { +class TunnelEditorFragment : BaseFragment(), MenuProvider { private var haveShownKeys = false private var binding: TunnelEditorFragmentBinding? = null private var tunnel: ObservableTunnel? = null + private fun onConfigLoaded(config: Config) { binding?.config = ConfigProxy(config) } private fun onConfigSaved(savedTunnel: Tunnel, throwable: Throwable?) { - val message: String + val ctx = activity ?: Application.get() if (throwable == null) { - message = getString(R.string.config_save_success, savedTunnel.name) + val message = ctx.getString(R.string.config_save_success, savedTunnel.name) Log.d(TAG, message) - Toast.makeText(requireContext(), message, Toast.LENGTH_SHORT).show() + Toast.makeText(ctx, message, Toast.LENGTH_SHORT).show() onFinished() } else { val error = ErrorMessages[throwable] - message = getString(R.string.config_save_error, savedTunnel.name, error) + val message = ctx.getString(R.string.config_save_error, savedTunnel.name, error) Log.e(TAG, message, throwable) - binding?.let { - Snackbar.make(it.mainContainer, message, Snackbar.LENGTH_LONG).show() - } + val binding = binding + if (binding != null) + Snackbar.make(binding.mainContainer, message, Snackbar.LENGTH_LONG).show() + else + Toast.makeText(ctx, message, Toast.LENGTH_SHORT).show() } } - override fun onCreate(savedInstanceState: Bundle?) { - super.onCreate(savedInstanceState) - setHasOptionsMenu(true) - } - - override fun onCreateOptionsMenu(menu: Menu, inflater: MenuInflater) { - inflater.inflate(R.menu.config_editor, menu) + override fun onCreateMenu(menu: Menu, menuInflater: MenuInflater) { + menuInflater.inflate(R.menu.config_editor, menu) } - override fun onCreateView(inflater: LayoutInflater, container: ViewGroup?, - savedInstanceState: Bundle?): View? { + override fun onCreateView( + inflater: LayoutInflater, container: ViewGroup?, + savedInstanceState: Bundle? + ): View? { super.onCreateView(inflater, container, savedInstanceState) binding = TunnelEditorFragmentBinding.inflate(inflater, container, false) binding?.apply { executePendingBindings() - setUpRoot(root as ViewGroup) - setUpScrollingContent(mainContainer, null) privateKeyTextLayout.setEndIconOnClickListener { config?.`interface`?.generateKeyPair() } } return binding?.root } + override fun onViewCreated(view: View, savedInstanceState: Bundle?) { + super.onViewCreated(view, savedInstanceState) + requireActivity().addMenuProvider(this, viewLifecycleOwner, Lifecycle.State.RESUMED) + } + override fun onDestroyView() { activity?.window?.clearFlags(WindowManager.LayoutParams.FLAG_SECURE) binding = null super.onDestroyView() } - override fun onSelectedAppsSelected(selectedApps: List<String>, isExcluded: Boolean) { - requireNotNull(binding) { "Tried to set excluded/included apps while no view was loaded" } - if (isExcluded) { - binding!!.config!!.`interface`.includedApplications.clear() - binding!!.config!!.`interface`.excludedApplications.apply { - clear() - addAll(selectedApps) - } - } else { - binding!!.config!!.`interface`.excludedApplications.clear() - binding!!.config!!.`interface`.includedApplications.apply { - clear() - addAll(selectedApps) - } - } - } - private fun onFinished() { // Hide the keyboard; it rarely goes away on its own. val activity = activity ?: return val focusedView = activity.currentFocus if (focusedView != null) { val inputManager = activity.getSystemService(Context.INPUT_METHOD_SERVICE) as? InputMethodManager - inputManager?.hideSoftInputFromWindow(focusedView.windowToken, - InputMethodManager.HIDE_NOT_ALWAYS) - } - // Tell the activity to finish itself or go back to the detail view. - activity.runOnUiThread { - // TODO(smaeul): Remove this hack when fixing the Config ViewModel - // The selected tunnel has to actually change, but we have to remember this one. - val savedTunnel = tunnel - if (savedTunnel === selectedTunnel) selectedTunnel = null - selectedTunnel = savedTunnel + inputManager?.hideSoftInputFromWindow( + focusedView.windowToken, + InputMethodManager.HIDE_NOT_ALWAYS + ) } + parentFragmentManager.popBackStackImmediate() + + // If we just made a new one, save it to select the details page. + if (selectedTunnel != tunnel) + selectedTunnel = tunnel } - override fun onOptionsItemSelected(item: MenuItem): Boolean { - if (item.itemId == R.id.menu_action_save) { + override fun onMenuItemSelected(menuItem: MenuItem): Boolean { + if (menuItem.itemId == R.id.menu_action_save) { binding ?: return false val newConfig = try { binding!!.config!!.resolve() - } catch (e: Exception) { + } catch (e: Throwable) { val error = ErrorMessages[e] val tunnelName = if (tunnel == null) binding!!.name else tunnel!!.name val message = getString(R.string.config_save_error, tunnelName, error) @@ -137,25 +125,43 @@ class TunnelEditorFragment : BaseFragment(), AppSelectionListener { Snackbar.make(binding!!.mainContainer, error, Snackbar.LENGTH_LONG).show() return false } - when { - tunnel == null -> { - Log.d(TAG, "Attempting to create new tunnel " + binding!!.name) - val manager = Application.getTunnelManager() - manager.create(binding!!.name!!, newConfig).whenComplete(this::onTunnelCreated) - } - tunnel!!.name != binding!!.name -> { - Log.d(TAG, "Attempting to rename tunnel to " + binding!!.name) - tunnel!!.setNameAsync(binding!!.name!!).whenComplete { _, t -> onTunnelRenamed(tunnel!!, newConfig, t) } - } - else -> { - Log.d(TAG, "Attempting to save config of " + tunnel!!.name) - tunnel!!.setConfigAsync(newConfig) - .whenComplete { _, t -> onConfigSaved(tunnel!!, t) } + val activity = requireActivity() + activity.lifecycleScope.launch { + when { + tunnel == null -> { + Log.d(TAG, "Attempting to create new tunnel " + binding!!.name) + val manager = Application.getTunnelManager() + try { + onTunnelCreated(manager.create(binding!!.name!!, newConfig), null) + } catch (e: Throwable) { + onTunnelCreated(null, e) + } + } + + tunnel!!.name != binding!!.name -> { + Log.d(TAG, "Attempting to rename tunnel to " + binding!!.name) + try { + tunnel!!.setNameAsync(binding!!.name!!) + onTunnelRenamed(tunnel!!, newConfig, null) + } catch (e: Throwable) { + onTunnelRenamed(tunnel!!, newConfig, e) + } + } + + else -> { + Log.d(TAG, "Attempting to save config of " + tunnel!!.name) + try { + tunnel!!.setConfigAsync(newConfig) + onConfigSaved(tunnel!!, null) + } catch (e: Throwable) { + onConfigSaved(tunnel!!, e) + } + } } } return true } - return super.onOptionsItemSelected(item) + return false } @Suppress("UNUSED_PARAMETER") @@ -168,8 +174,26 @@ class TunnelEditorFragment : BaseFragment(), AppSelectionListener { if (selectedApps.isNotEmpty()) isExcluded = false } - val fragment = AppListDialogFragment.newInstance(selectedApps, isExcluded, this) - fragment.show(parentFragmentManager, null) + val fragment = AppListDialogFragment.newInstance(selectedApps, isExcluded) + childFragmentManager.setFragmentResultListener(AppListDialogFragment.REQUEST_SELECTION, viewLifecycleOwner) { _, bundle -> + requireNotNull(binding) { "Tried to set excluded/included apps while no view was loaded" } + val newSelections = requireNotNull(bundle.getStringArray(AppListDialogFragment.KEY_SELECTED_APPS)) + val excluded = requireNotNull(bundle.getBoolean(AppListDialogFragment.KEY_IS_EXCLUDED)) + if (excluded) { + binding!!.config!!.`interface`.includedApplications.clear() + binding!!.config!!.`interface`.excludedApplications.apply { + clear() + addAll(newSelections) + } + } else { + binding!!.config!!.`interface`.excludedApplications.clear() + binding!!.config!!.`interface`.includedApplications.apply { + clear() + addAll(newSelections) + } + } + } + fragment.show(childFragmentManager, null) } } @@ -179,53 +203,71 @@ class TunnelEditorFragment : BaseFragment(), AppSelectionListener { super.onSaveInstanceState(outState) } - override fun onSelectedTunnelChanged(oldTunnel: ObservableTunnel?, - newTunnel: ObservableTunnel?) { + override fun onSelectedTunnelChanged( + oldTunnel: ObservableTunnel?, + newTunnel: ObservableTunnel? + ) { tunnel = newTunnel if (binding == null) return binding!!.config = ConfigProxy() if (tunnel != null) { binding!!.name = tunnel!!.name - tunnel!!.configAsync.thenAccept(this::onConfigLoaded) + lifecycleScope.launch { + try { + onConfigLoaded(tunnel!!.getConfigAsync()) + } catch (_: Throwable) { + } + } } else { binding!!.name = "" } } - private fun onTunnelCreated(newTunnel: ObservableTunnel, throwable: Throwable?) { - val message: String + private fun onTunnelCreated(newTunnel: ObservableTunnel?, throwable: Throwable?) { + val ctx = activity ?: Application.get() if (throwable == null) { tunnel = newTunnel - message = getString(R.string.tunnel_create_success, tunnel!!.name) + val message = ctx.getString(R.string.tunnel_create_success, tunnel!!.name) Log.d(TAG, message) - Toast.makeText(requireContext(), message, Toast.LENGTH_SHORT).show() + Toast.makeText(ctx, message, Toast.LENGTH_SHORT).show() onFinished() } else { val error = ErrorMessages[throwable] - message = getString(R.string.tunnel_create_error, error) + val message = ctx.getString(R.string.tunnel_create_error, error) Log.e(TAG, message, throwable) - binding?.let { - Snackbar.make(it.mainContainer, message, Snackbar.LENGTH_LONG).show() - } + val binding = binding + if (binding != null) + Snackbar.make(binding.mainContainer, message, Snackbar.LENGTH_LONG).show() + else + Toast.makeText(ctx, message, Toast.LENGTH_SHORT).show() } } - private fun onTunnelRenamed(renamedTunnel: ObservableTunnel, newConfig: Config, - throwable: Throwable?) { - val message: String + private suspend fun onTunnelRenamed( + renamedTunnel: ObservableTunnel, newConfig: Config, + throwable: Throwable? + ) { + val ctx = activity ?: Application.get() if (throwable == null) { - message = getString(R.string.tunnel_rename_success, renamedTunnel.name) + val message = ctx.getString(R.string.tunnel_rename_success, renamedTunnel.name) Log.d(TAG, message) // Now save the rest of configuration changes. Log.d(TAG, "Attempting to save config of renamed tunnel " + tunnel!!.name) - renamedTunnel.setConfigAsync(newConfig).whenComplete { _, t -> onConfigSaved(renamedTunnel, t) } + try { + renamedTunnel.setConfigAsync(newConfig) + onConfigSaved(renamedTunnel, null) + } catch (e: Throwable) { + onConfigSaved(renamedTunnel, e) + } } else { val error = ErrorMessages[throwable] - message = getString(R.string.tunnel_rename_error, error) + val message = ctx.getString(R.string.tunnel_rename_error, error) Log.e(TAG, message, throwable) - binding?.let { - Snackbar.make(it.mainContainer, message, Snackbar.LENGTH_LONG).show() - } + val binding = binding + if (binding != null) + Snackbar.make(binding.mainContainer, message, Snackbar.LENGTH_LONG).show() + else + Toast.makeText(ctx, message, Toast.LENGTH_SHORT).show() } } @@ -236,7 +278,7 @@ class TunnelEditorFragment : BaseFragment(), AppSelectionListener { onSelectedTunnelChanged(null, selectedTunnel) } else { tunnel = selectedTunnel - val config: ConfigProxy = savedInstanceState.getParcelable(KEY_LOCAL_CONFIG)!! + val config = BundleCompat.getParcelable(savedInstanceState, KEY_LOCAL_CONFIG, ConfigProxy::class.java)!! val originalName = savedInstanceState.getString(KEY_ORIGINAL_NAME) if (tunnel != null && tunnel!!.name != originalName) onSelectedTunnelChanged(null, tunnel) else binding!!.config = config } @@ -252,6 +294,7 @@ class TunnelEditorFragment : BaseFragment(), AppSelectionListener { val edit = view as? EditText ?: return if (edit.inputType == InputType.TYPE_TEXT_FLAG_NO_SUGGESTIONS or InputType.TYPE_TEXT_VARIATION_VISIBLE_PASSWORD) return if (!haveShownKeys && edit.text.isNotEmpty()) { + if (AdminKnobs.disableConfigExport) return showingAuthenticator = true BiometricAuthenticator.authenticate(R.string.biometric_prompt_private_key_title, this) { showingAuthenticator = false @@ -260,13 +303,16 @@ class TunnelEditorFragment : BaseFragment(), AppSelectionListener { haveShownKeys = true showPrivateKey(edit) } + is BiometricAuthenticator.Result.Failure -> { Snackbar.make( - binding!!.mainContainer, - it.message, - Snackbar.LENGTH_SHORT + binding!!.mainContainer, + it.message, + Snackbar.LENGTH_SHORT ).show() } + + is BiometricAuthenticator.Result.Cancelled -> {} } } } else { diff --git a/ui/src/main/java/com/wireguard/android/fragment/TunnelListFragment.kt b/ui/src/main/java/com/wireguard/android/fragment/TunnelListFragment.kt index 7af5e06b..119b6afe 100644 --- a/ui/src/main/java/com/wireguard/android/fragment/TunnelListFragment.kt +++ b/ui/src/main/java/com/wireguard/android/fragment/TunnelListFragment.kt @@ -1,16 +1,12 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.fragment -import android.annotation.SuppressLint -import android.app.Activity import android.content.Intent import android.content.res.Resources -import android.net.Uri import android.os.Bundle -import android.provider.OpenableColumns import android.util.Log import android.view.LayoutInflater import android.view.Menu @@ -19,33 +15,33 @@ import android.view.View import android.view.ViewGroup import android.view.animation.Animation import android.view.animation.AnimationUtils +import android.widget.Toast +import androidx.activity.OnBackPressedCallback +import androidx.activity.addCallback +import androidx.activity.result.contract.ActivityResultContracts import androidx.appcompat.app.AppCompatActivity import androidx.appcompat.view.ActionMode +import androidx.lifecycle.lifecycleScope import com.google.android.material.snackbar.Snackbar -import com.google.zxing.integration.android.IntentIntegrator +import com.google.zxing.qrcode.QRCodeReader +import com.journeyapps.barcodescanner.ScanContract +import com.journeyapps.barcodescanner.ScanOptions import com.wireguard.android.Application import com.wireguard.android.R +import com.wireguard.android.activity.TunnelCreatorActivity import com.wireguard.android.databinding.ObservableKeyedRecyclerViewAdapter.RowConfigurationHandler import com.wireguard.android.databinding.TunnelListFragmentBinding import com.wireguard.android.databinding.TunnelListItemBinding -import com.wireguard.android.fragment.ConfigNamingDialogFragment.Companion.newInstance import com.wireguard.android.model.ObservableTunnel +import com.wireguard.android.updater.SnackbarUpdateShower import com.wireguard.android.util.ErrorMessages -import com.wireguard.android.widget.EdgeToEdge.setUpFAB -import com.wireguard.android.widget.EdgeToEdge.setUpRoot -import com.wireguard.android.widget.EdgeToEdge.setUpScrollingContent +import com.wireguard.android.util.QrCodeFromFileScanner +import com.wireguard.android.util.TunnelImporter import com.wireguard.android.widget.MultiselectableRelativeLayout -import com.wireguard.config.Config -import java9.util.concurrent.CompletableFuture -import java.io.BufferedReader -import java.io.ByteArrayInputStream -import java.io.InputStreamReader -import java.nio.charset.StandardCharsets -import java.util.ArrayList -import java.util.HashSet -import java.util.Locale -import java.util.zip.ZipEntry -import java.util.zip.ZipInputStream +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.launch /** * Fragment containing a list of known WireGuard tunnels. It allows creating and deleting tunnels. @@ -53,123 +49,42 @@ import java.util.zip.ZipInputStream class TunnelListFragment : BaseFragment() { private val actionModeListener = ActionModeListener() private var actionMode: ActionMode? = null + private var backPressedCallback: OnBackPressedCallback? = null private var binding: TunnelListFragmentBinding? = null - private fun importTunnel(configText: String) { - try { - // Ensure the config text is parseable before proceeding… - Config.parse(ByteArrayInputStream(configText.toByteArray(StandardCharsets.UTF_8))) - - // Config text is valid, now create the tunnel… - newInstance(configText).show(parentFragmentManager, null) - } catch (e: Exception) { - onTunnelImportFinished(emptyList(), listOf<Throwable>(e)) - } - } - - private fun importTunnel(uri: Uri?) { - val activity = activity - if (activity == null || uri == null) { - return - } - val contentResolver = activity.contentResolver - - val futureTunnels = ArrayList<CompletableFuture<ObservableTunnel>>() - val throwables = ArrayList<Throwable>() - Application.getAsyncWorker().supplyAsync { - val columns = arrayOf(OpenableColumns.DISPLAY_NAME) - var name = "" - contentResolver.query(uri, columns, null, null, null)?.use { cursor -> - if (cursor.moveToFirst() && !cursor.isNull(0)) { - name = cursor.getString(0) - } - } - if (name.isEmpty()) { - name = Uri.decode(uri.lastPathSegment) - } - var idx = name.lastIndexOf('/') - if (idx >= 0) { - require(idx < name.length - 1) { resources.getString(R.string.illegal_filename_error, name) } - name = name.substring(idx + 1) - } - val isZip = name.toLowerCase(Locale.ROOT).endsWith(".zip") - if (name.toLowerCase(Locale.ROOT).endsWith(".conf")) { - name = name.substring(0, name.length - ".conf".length) - } else { - require(isZip) { resources.getString(R.string.bad_extension_error) } - } - - if (isZip) { - ZipInputStream(contentResolver.openInputStream(uri)).use { zip -> - val reader = BufferedReader(InputStreamReader(zip, StandardCharsets.UTF_8)) - var entry: ZipEntry? - while (true) { - entry = zip.nextEntry ?: break - name = entry.name - idx = name.lastIndexOf('/') - if (idx >= 0) { - if (idx >= name.length - 1) { - continue - } - name = name.substring(name.lastIndexOf('/') + 1) - } - if (name.toLowerCase(Locale.ROOT).endsWith(".conf")) { - name = name.substring(0, name.length - ".conf".length) - } else { - continue - } - try { - Config.parse(reader) - } catch (e: Exception) { - throwables.add(e) - null - }?.let { - futureTunnels.add(Application.getTunnelManager().create(name, it).toCompletableFuture()) - } - } + private val tunnelFileImportResultLauncher = registerForActivityResult(ActivityResultContracts.GetContent()) { data -> + if (data == null) return@registerForActivityResult + val activity = activity ?: return@registerForActivityResult + val contentResolver = activity.contentResolver ?: return@registerForActivityResult + activity.lifecycleScope.launch { + if (QrCodeFromFileScanner.validContentType(contentResolver, data)) { + try { + val qrCodeFromFileScanner = QrCodeFromFileScanner(contentResolver, QRCodeReader()) + val result = qrCodeFromFileScanner.scan(data) + TunnelImporter.importTunnel(parentFragmentManager, result.text) { showSnackbar(it) } + } catch (e: Exception) { + val error = ErrorMessages[e] + val message = Application.get().resources.getString(R.string.import_error, error) + Log.e(TAG, message, e) + showSnackbar(message) } } else { - futureTunnels.add( - Application.getTunnelManager().create( - name, - Config.parse(contentResolver.openInputStream(uri)!!) - ).toCompletableFuture() - ) + TunnelImporter.importTunnel(contentResolver, data) { showSnackbar(it) } } + } + } - if (futureTunnels.isEmpty()) { - if (throwables.size == 1) { - throw throwables[0] - } else { - require(throwables.isNotEmpty()) { resources.getString(R.string.no_configs_error) } - } - } - CompletableFuture.allOf(*futureTunnels.toTypedArray()) - }.whenComplete { future, exception -> - if (exception != null) { - onTunnelImportFinished(emptyList(), listOf(exception)) - } else { - future.whenComplete { _, _ -> - val tunnels = mutableListOf<ObservableTunnel>() - for (futureTunnel in futureTunnels) { - val tunnel: ObservableTunnel? = try { - futureTunnel.getNow(null) - } catch (e: Exception) { - throwables.add(e) - null - } - - if (tunnel != null) { - tunnels.add(tunnel) - } - } - onTunnelImportFinished(tunnels, throwables) - } - } + private val qrImportResultLauncher = registerForActivityResult(ScanContract()) { result -> + val qrCode = result.contents + val activity = activity + if (qrCode != null && activity != null) { + activity.lifecycleScope.launch { TunnelImporter.importTunnel(parentFragmentManager, qrCode) { showSnackbar(it) } } } } - override fun onActivityCreated(savedInstanceState: Bundle?) { - super.onActivityCreated(savedInstanceState) + private val snackbarUpdateShower = SnackbarUpdateShower(this) + + override fun onViewCreated(view: View, savedInstanceState: Bundle?) { + super.onViewCreated(view, savedInstanceState) if (savedInstanceState != null) { val checkedItems = savedInstanceState.getIntegerArrayList(CHECKED_ITEMS) if (checkedItems != null) { @@ -178,40 +93,46 @@ class TunnelListFragment : BaseFragment() { } } - override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) { - when (requestCode) { - REQUEST_IMPORT -> { - if (resultCode == Activity.RESULT_OK && data != null) importTunnel(data.data) - return - } - IntentIntegrator.REQUEST_CODE -> { - val result = IntentIntegrator.parseActivityResult(requestCode, resultCode, data) - if (result != null && result.contents != null) { - importTunnel(result.contents) - } - return - } - else -> super.onActivityResult(requestCode, resultCode, data) - } - } - - @SuppressLint("ClickableViewAccessibility") - override fun onCreateView(inflater: LayoutInflater, container: ViewGroup?, - savedInstanceState: Bundle?): View? { + override fun onCreateView( + inflater: LayoutInflater, container: ViewGroup?, + savedInstanceState: Bundle? + ): View? { super.onCreateView(inflater, container, savedInstanceState) binding = TunnelListFragmentBinding.inflate(inflater, container, false) + val bottomSheet = AddTunnelsSheet() binding?.apply { createFab.setOnClickListener { - val bottomSheet = AddTunnelsSheet() - bottomSheet.setTargetFragment(fragment, REQUEST_TARGET_FRAGMENT) - bottomSheet.show(parentFragmentManager, "BOTTOM_SHEET") + if (childFragmentManager.findFragmentByTag("BOTTOM_SHEET") != null) + return@setOnClickListener + childFragmentManager.setFragmentResultListener(AddTunnelsSheet.REQUEST_KEY_NEW_TUNNEL, viewLifecycleOwner) { _, bundle -> + when (bundle.getString(AddTunnelsSheet.REQUEST_METHOD)) { + AddTunnelsSheet.REQUEST_CREATE -> { + startActivity(Intent(requireActivity(), TunnelCreatorActivity::class.java)) + } + + AddTunnelsSheet.REQUEST_IMPORT -> { + tunnelFileImportResultLauncher.launch("*/*") + } + + AddTunnelsSheet.REQUEST_SCAN -> { + qrImportResultLauncher.launch( + ScanOptions() + .setOrientationLocked(false) + .setBeepEnabled(false) + .setPrompt(getString(R.string.qr_code_hint)) + ) + } + } + } + bottomSheet.showNow(childFragmentManager, "BOTTOM_SHEET") } executePendingBindings() - setUpRoot(root as ViewGroup) - setUpFAB(createFab) - setUpScrollingContent(tunnelList, createFab) + snackbarUpdateShower.attach(mainContainer, createFab) } - return binding!!.root + backPressedCallback = requireActivity().onBackPressedDispatcher.addCallback(this) { actionMode?.finish() } + backPressedCallback?.isEnabled = false + + return binding?.root } override fun onDestroyView() { @@ -226,53 +147,34 @@ class TunnelListFragment : BaseFragment() { override fun onSelectedTunnelChanged(oldTunnel: ObservableTunnel?, newTunnel: ObservableTunnel?) { binding ?: return - Application.getTunnelManager().tunnels.thenAccept { tunnels -> - if (newTunnel != null) viewForTunnel(newTunnel, tunnels).setSingleSelected(true) - if (oldTunnel != null) viewForTunnel(oldTunnel, tunnels).setSingleSelected(false) + lifecycleScope.launch { + val tunnels = Application.getTunnelManager().getTunnels() + if (newTunnel != null) viewForTunnel(newTunnel, tunnels)?.setSingleSelected(true) + if (oldTunnel != null) viewForTunnel(oldTunnel, tunnels)?.setSingleSelected(false) } } private fun onTunnelDeletionFinished(count: Int, throwable: Throwable?) { val message: String + val ctx = activity ?: Application.get() if (throwable == null) { - message = resources.getQuantityString(R.plurals.delete_success, count, count) + message = ctx.resources.getQuantityString(R.plurals.delete_success, count, count) } else { val error = ErrorMessages[throwable] - message = resources.getQuantityString(R.plurals.delete_error, count, count, error) + message = ctx.resources.getQuantityString(R.plurals.delete_error, count, count, error) Log.e(TAG, message, throwable) } showSnackbar(message) } - private fun onTunnelImportFinished(tunnels: List<ObservableTunnel>, throwables: Collection<Throwable>) { - var message = "" - for (throwable in throwables) { - val error = ErrorMessages[throwable] - message = getString(R.string.import_error, error) - Log.e(TAG, message, throwable) - } - if (tunnels.size == 1 && throwables.isEmpty()) - message = getString(R.string.import_success, tunnels[0].name) - else if (tunnels.isEmpty() && throwables.size == 1) - else if (throwables.isEmpty()) - message = resources.getQuantityString(R.plurals.import_total_success, - tunnels.size, tunnels.size) - else if (!throwables.isEmpty()) - message = resources.getQuantityString(R.plurals.import_partial_success, - tunnels.size + throwables.size, - tunnels.size, tunnels.size + throwables.size) - showSnackbar(message) - } - override fun onViewStateRestored(savedInstanceState: Bundle?) { super.onViewStateRestored(savedInstanceState) binding ?: return binding!!.fragment = this - Application.getTunnelManager().tunnels.thenAccept { tunnels -> binding!!.tunnels = tunnels } - val parent = this + lifecycleScope.launch { binding!!.tunnels = Application.getTunnelManager().getTunnels() } binding!!.rowConfigurationHandler = object : RowConfigurationHandler<TunnelListItemBinding, ObservableTunnel> { override fun onConfigureRow(binding: TunnelListItemBinding, item: ObservableTunnel, position: Int) { - binding.fragment = parent + binding.fragment = this@TunnelListFragment binding.root.setOnClickListener { if (actionMode == null) { selectedTunnel = item @@ -293,15 +195,17 @@ class TunnelListFragment : BaseFragment() { } private fun showSnackbar(message: CharSequence) { - binding?.let { - Snackbar.make(it.mainContainer, message, Snackbar.LENGTH_LONG) - .setAnchorView(it.createFab) - .show() - } + val binding = binding + if (binding != null) + Snackbar.make(binding.mainContainer, message, Snackbar.LENGTH_LONG) + .setAnchorView(binding.createFab) + .show() + else + Toast.makeText(activity ?: Application.get(), message, Toast.LENGTH_SHORT).show() } - private fun viewForTunnel(tunnel: ObservableTunnel, tunnels: List<*>): MultiselectableRelativeLayout { - return binding!!.tunnelList.findViewHolderForAdapterPosition(tunnels.indexOf(tunnel))!!.itemView as MultiselectableRelativeLayout + private fun viewForTunnel(tunnel: ObservableTunnel, tunnels: List<*>): MultiselectableRelativeLayout? { + return binding?.tunnelList?.findViewHolderForAdapterPosition(tunnels.indexOf(tunnel))?.itemView as? MultiselectableRelativeLayout } private inner class ActionModeListener : ActionMode.Callback { @@ -315,38 +219,46 @@ class TunnelListFragment : BaseFragment() { override fun onActionItemClicked(mode: ActionMode, item: MenuItem): Boolean { return when (item.itemId) { R.id.menu_action_delete -> { + val activity = activity ?: return true val copyCheckedItems = HashSet(checkedItems) binding?.createFab?.apply { visibility = View.VISIBLE scaleX = 1f scaleY = 1f } - Application.getTunnelManager().tunnels.thenAccept { tunnels -> - val tunnelsToDelete = ArrayList<ObservableTunnel>() - for (position in copyCheckedItems) tunnelsToDelete.add(tunnels[position]) - val futures = tunnelsToDelete.map { it.delete().toCompletableFuture() }.toTypedArray() - CompletableFuture.allOf(*futures) - .thenApply { futures.size } - .whenComplete(this@TunnelListFragment::onTunnelDeletionFinished) + activity.lifecycleScope.launch { + try { + val tunnels = Application.getTunnelManager().getTunnels() + val tunnelsToDelete = ArrayList<ObservableTunnel>() + for (position in copyCheckedItems) tunnelsToDelete.add(tunnels[position]) + val futures = tunnelsToDelete.map { async(SupervisorJob()) { it.deleteAsync() } } + onTunnelDeletionFinished(futures.awaitAll().size, null) + } catch (e: Throwable) { + onTunnelDeletionFinished(0, e) + } } checkedItems.clear() mode.finish() true } + R.id.menu_action_select_all -> { - Application.getTunnelManager().tunnels.thenAccept { tunnels -> + lifecycleScope.launch { + val tunnels = Application.getTunnelManager().getTunnels() for (i in 0 until tunnels.size) { setItemChecked(i, true) } } true } + else -> false } } override fun onCreateActionMode(mode: ActionMode, menu: Menu): Boolean { actionMode = mode + backPressedCallback?.isEnabled = true if (activity != null) { resources = activity!!.resources } @@ -358,10 +270,11 @@ class TunnelListFragment : BaseFragment() { override fun onDestroyActionMode(mode: ActionMode) { actionMode = null + backPressedCallback?.isEnabled = false resources = null animateFab(binding?.createFab, true) checkedItems.clear() - binding!!.tunnelList.adapter!!.notifyDataSetChanged() + binding?.tunnelList?.adapter?.notifyDataSetChanged() } override fun onPrepareActionMode(mode: ActionMode, menu: Menu): Boolean { @@ -377,7 +290,7 @@ class TunnelListFragment : BaseFragment() { } val adapter = if (binding == null) null else binding!!.tunnelList.adapter if (actionMode == null && !checkedItems.isEmpty() && activity != null) { - (activity as AppCompatActivity?)!!.startSupportActionMode(this) + (activity as AppCompatActivity).startSupportActionMode(this) } else if (actionMode != null && checkedItems.isEmpty()) { actionMode!!.finish() } @@ -404,7 +317,7 @@ class TunnelListFragment : BaseFragment() { private fun animateFab(view: View?, show: Boolean) { view ?: return val animation = AnimationUtils.loadAnimation( - context, if (show) R.anim.scale_up else R.anim.scale_down + context, if (show) R.anim.scale_up else R.anim.scale_down ) animation.setAnimationListener(object : Animation.AnimationListener { override fun onAnimationRepeat(animation: Animation?) { @@ -423,8 +336,6 @@ class TunnelListFragment : BaseFragment() { } companion object { - const val REQUEST_IMPORT = 1 - private const val REQUEST_TARGET_FRAGMENT = 2 private const val CHECKED_ITEMS = "CHECKED_ITEMS" private const val TAG = "WireGuard/TunnelListFragment" } diff --git a/ui/src/main/java/com/wireguard/android/model/ApplicationData.kt b/ui/src/main/java/com/wireguard/android/model/ApplicationData.kt index e0961f04..e6b5705a 100644 --- a/ui/src/main/java/com/wireguard/android/model/ApplicationData.kt +++ b/ui/src/main/java/com/wireguard/android/model/ApplicationData.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.model diff --git a/ui/src/main/java/com/wireguard/android/model/ObservableTunnel.kt b/ui/src/main/java/com/wireguard/android/model/ObservableTunnel.kt index f8691cbb..227c1291 100644 --- a/ui/src/main/java/com/wireguard/android/model/ObservableTunnel.kt +++ b/ui/src/main/java/com/wireguard/android/model/ObservableTunnel.kt @@ -1,28 +1,30 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.model +import android.util.Log import androidx.databinding.BaseObservable import androidx.databinding.Bindable import com.wireguard.android.BR import com.wireguard.android.backend.Statistics import com.wireguard.android.backend.Tunnel import com.wireguard.android.databinding.Keyed -import com.wireguard.android.util.ExceptionLoggers +import com.wireguard.android.util.applicationScope import com.wireguard.config.Config -import java9.util.concurrent.CompletableFuture -import java9.util.concurrent.CompletionStage +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext /** * Encapsulates the volatile and nonvolatile state of a WireGuard tunnel. */ class ObservableTunnel internal constructor( - private val manager: TunnelManager, - private var name: String, - config: Config?, - state: Tunnel.State + private val manager: TunnelManager, + private var name: String, + config: Config?, + state: Tunnel.State ) : BaseObservable(), Keyed<String>, Tunnel { override val key get() = name @@ -30,10 +32,12 @@ class ObservableTunnel internal constructor( @Bindable override fun getName() = name - fun setNameAsync(name: String): CompletionStage<String> = if (name != this.name) - manager.setTunnelName(this, name) - else - CompletableFuture.completedFuture(this.name) + suspend fun setNameAsync(name: String): String = withContext(Dispatchers.Main.immediate) { + if (name != this@ObservableTunnel.name) + manager.setTunnelName(this@ObservableTunnel, name) + else + this@ObservableTunnel.name + } fun onNameChanged(name: String): String { this.name = name @@ -57,31 +61,42 @@ class ObservableTunnel internal constructor( return state } - fun setStateAsync(state: Tunnel.State): CompletionStage<Tunnel.State> = if (state != this.state) - manager.setTunnelState(this, state) - else - CompletableFuture.completedFuture(this.state) + suspend fun setStateAsync(state: Tunnel.State): Tunnel.State = withContext(Dispatchers.Main.immediate) { + if (state != this@ObservableTunnel.state) + manager.setTunnelState(this@ObservableTunnel, state) + else + this@ObservableTunnel.state + } @get:Bindable var config = config get() { if (field == null) - manager.getTunnelConfig(this).whenComplete(ExceptionLoggers.E) + // Opportunistically fetch this if we don't have a cached one, and rely on data bindings to update it eventually + applicationScope.launch { + try { + manager.getTunnelConfig(this@ObservableTunnel) + } catch (e: Throwable) { + Log.e(TAG, Log.getStackTraceString(e)) + } + } return field } private set - val configAsync: CompletionStage<Config> - get() = if (config == null) - manager.getTunnelConfig(this) - else - CompletableFuture.completedFuture(config) + suspend fun getConfigAsync(): Config = withContext(Dispatchers.Main.immediate) { + config ?: manager.getTunnelConfig(this@ObservableTunnel) + } - fun setConfigAsync(config: Config): CompletionStage<Config> = if (config != this.config) - manager.setTunnelConfig(this, config) - else - CompletableFuture.completedFuture(this.config) + suspend fun setConfigAsync(config: Config): Config = withContext(Dispatchers.Main.immediate) { + this@ObservableTunnel.config.let { + if (config != it) + manager.setTunnelConfig(this@ObservableTunnel, config) + else + it + } + } fun onConfigChanged(config: Config?): Config? { this.config = config @@ -94,16 +109,26 @@ class ObservableTunnel internal constructor( var statistics: Statistics? = null get() { if (field == null || field?.isStale != false) - manager.getTunnelStatistics(this).whenComplete(ExceptionLoggers.E) + // Opportunistically fetch this if we don't have a cached one, and rely on data bindings to update it eventually + applicationScope.launch { + try { + manager.getTunnelStatistics(this@ObservableTunnel) + } catch (e: Throwable) { + Log.e(TAG, Log.getStackTraceString(e)) + } + } return field } private set - val statisticsAsync: CompletionStage<Statistics> - get() = if (statistics == null || statistics?.isStale != false) - manager.getTunnelStatistics(this) - else - CompletableFuture.completedFuture(statistics) + suspend fun getStatisticsAsync(): Statistics = withContext(Dispatchers.Main.immediate) { + statistics.let { + if (it == null || it.isStale) + manager.getTunnelStatistics(this@ObservableTunnel) + else + it + } + } fun onStatisticsChanged(statistics: Statistics?): Statistics? { this.statistics = statistics @@ -112,5 +137,10 @@ class ObservableTunnel internal constructor( } - fun delete(): CompletionStage<Void> = manager.delete(this) + suspend fun deleteAsync() = manager.delete(this) + + + companion object { + private const val TAG = "WireGuard/ObservableTunnel" + } } diff --git a/ui/src/main/java/com/wireguard/android/model/TunnelComparator.kt b/ui/src/main/java/com/wireguard/android/model/TunnelComparator.kt index 9fb96cab..3be1019a 100644 --- a/ui/src/main/java/com/wireguard/android/model/TunnelComparator.kt +++ b/ui/src/main/java/com/wireguard/android/model/TunnelComparator.kt @@ -1,12 +1,10 @@ /* - * Copyright © 2020 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.model -import java.util.Locale - object TunnelComparator : Comparator<String> { private class NaturalSortString(originalString: String) { class NaturalSortToken(val maybeString: String?, val maybeNumber: Int?) : Comparable<NaturalSortToken> { @@ -29,7 +27,7 @@ object TunnelComparator : Comparator<String> { val tokens: MutableList<NaturalSortToken> = ArrayList() init { - for (s in NATURAL_SORT_DIGIT_FINDER.findAll(originalString.split(WHITESPACE_FINDER).joinToString(" ").toLowerCase(Locale.ENGLISH))) { + for (s in NATURAL_SORT_DIGIT_FINDER.findAll(originalString.split(WHITESPACE_FINDER).joinToString(" ").lowercase())) { try { val n = s.value.toInt() tokens.add(NaturalSortToken(null, n)) diff --git a/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt b/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt index 5091ed3b..d7c1391f 100644 --- a/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt +++ b/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt @@ -1,20 +1,19 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.model -import android.annotation.SuppressLint import android.content.BroadcastReceiver import android.content.Context import android.content.Intent import android.os.Build +import android.util.Log +import android.widget.Toast import androidx.databinding.BaseObservable import androidx.databinding.Bindable import com.wireguard.android.Application.Companion.get -import com.wireguard.android.Application.Companion.getAsyncWorker import com.wireguard.android.Application.Companion.getBackend -import com.wireguard.android.Application.Companion.getSharedPreferences import com.wireguard.android.Application.Companion.getTunnelManager import com.wireguard.android.BR import com.wireguard.android.R @@ -22,145 +21,149 @@ import com.wireguard.android.backend.Statistics import com.wireguard.android.backend.Tunnel import com.wireguard.android.configStore.ConfigStore import com.wireguard.android.databinding.ObservableSortedKeyedArrayList -import com.wireguard.android.util.ExceptionLoggers +import com.wireguard.android.util.ErrorMessages +import com.wireguard.android.util.UserKnobs +import com.wireguard.android.util.applicationScope import com.wireguard.config.Config -import java9.util.concurrent.CompletableFuture -import java9.util.concurrent.CompletionStage -import java.util.ArrayList +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext /** * Maintains and mediates changes to the set of available WireGuard tunnels, */ class TunnelManager(private val configStore: ConfigStore) : BaseObservable() { - val tunnels = CompletableFuture<ObservableSortedKeyedArrayList<String, ObservableTunnel>>() + private val tunnels = CompletableDeferred<ObservableSortedKeyedArrayList<String, ObservableTunnel>>() private val context: Context = get() - private val delayedLoadRestoreTunnels = ArrayList<CompletableFuture<Void>>() private val tunnelMap: ObservableSortedKeyedArrayList<String, ObservableTunnel> = ObservableSortedKeyedArrayList(TunnelComparator) private var haveLoaded = false - private fun addToList(name: String, config: Config?, state: Tunnel.State): ObservableTunnel? { + private fun addToList(name: String, config: Config?, state: Tunnel.State): ObservableTunnel { val tunnel = ObservableTunnel(this, name, config, state) tunnelMap.add(tunnel) return tunnel } - fun create(name: String, config: Config?): CompletionStage<ObservableTunnel> { + suspend fun getTunnels(): ObservableSortedKeyedArrayList<String, ObservableTunnel> = tunnels.await() + + suspend fun create(name: String, config: Config?): ObservableTunnel = withContext(Dispatchers.Main.immediate) { if (Tunnel.isNameInvalid(name)) - return CompletableFuture.failedFuture(IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name))) + throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name)) if (tunnelMap.containsKey(name)) - return CompletableFuture.failedFuture(IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name))) - return getAsyncWorker().supplyAsync { configStore.create(name, config!!) }.thenApply { addToList(name, it, Tunnel.State.DOWN) } + throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name)) + addToList(name, withContext(Dispatchers.IO) { configStore.create(name, config!!) }, Tunnel.State.DOWN) } - fun delete(tunnel: ObservableTunnel): CompletionStage<Void> { + suspend fun delete(tunnel: ObservableTunnel) = withContext(Dispatchers.Main.immediate) { val originalState = tunnel.state val wasLastUsed = tunnel == lastUsedTunnel // Make sure nothing touches the tunnel. if (wasLastUsed) lastUsedTunnel = null tunnelMap.remove(tunnel) - return getAsyncWorker().runAsync { + try { if (originalState == Tunnel.State.UP) - getBackend().setState(tunnel, Tunnel.State.DOWN, null) + withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) } try { - configStore.delete(tunnel.name) - } catch (e: Exception) { + withContext(Dispatchers.IO) { configStore.delete(tunnel.name) } + } catch (e: Throwable) { if (originalState == Tunnel.State.UP) - getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) + withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) } throw e } - }.whenComplete { _, e -> - if (e == null) - return@whenComplete + } catch (e: Throwable) { // Failure, put the tunnel back. tunnelMap.add(tunnel) if (wasLastUsed) lastUsedTunnel = tunnel + throw e } } @get:Bindable - @SuppressLint("ApplySharedPref") var lastUsedTunnel: ObservableTunnel? = null private set(value) { if (value == field) return field = value notifyPropertyChanged(BR.lastUsedTunnel) - if (value != null) - getSharedPreferences().edit().putString(KEY_LAST_USED_TUNNEL, value.name).commit() - else - getSharedPreferences().edit().remove(KEY_LAST_USED_TUNNEL).commit() + applicationScope.launch { UserKnobs.setLastUsedTunnel(value?.name) } } - fun getTunnelConfig(tunnel: ObservableTunnel): CompletionStage<Config> = getAsyncWorker() - .supplyAsync { configStore.load(tunnel.name) }.thenApply(tunnel::onConfigChanged) - + suspend fun getTunnelConfig(tunnel: ObservableTunnel): Config = withContext(Dispatchers.Main.immediate) { + tunnel.onConfigChanged(withContext(Dispatchers.IO) { configStore.load(tunnel.name) })!! + } fun onCreate() { - getAsyncWorker().supplyAsync { configStore.enumerate() } - .thenAcceptBoth(getAsyncWorker().supplyAsync { getBackend().runningTunnelNames }, this::onTunnelsLoaded) - .whenComplete(ExceptionLoggers.E) + applicationScope.launch { + try { + onTunnelsLoaded(withContext(Dispatchers.IO) { configStore.enumerate() }, withContext(Dispatchers.IO) { getBackend().runningTunnelNames }) + } catch (e: Throwable) { + Log.e(TAG, Log.getStackTraceString(e)) + } + } } private fun onTunnelsLoaded(present: Iterable<String>, running: Collection<String>) { for (name in present) addToList(name, null, if (running.contains(name)) Tunnel.State.UP else Tunnel.State.DOWN) - val lastUsedName = getSharedPreferences().getString(KEY_LAST_USED_TUNNEL, null) - if (lastUsedName != null) - lastUsedTunnel = tunnelMap[lastUsedName] - var toComplete: Array<CompletableFuture<Void>> - synchronized(delayedLoadRestoreTunnels) { + applicationScope.launch { + val lastUsedName = UserKnobs.lastUsedTunnel.first() + if (lastUsedName != null) + lastUsedTunnel = tunnelMap[lastUsedName] haveLoaded = true - toComplete = delayedLoadRestoreTunnels.toTypedArray() - delayedLoadRestoreTunnels.clear() + restoreState(true) + tunnels.complete(tunnelMap) } - restoreState(true).whenComplete { v: Void?, t: Throwable? -> - for (f in toComplete) { - if (t == null) - f.complete(v) - else - f.completeExceptionally(t) - } - } - tunnels.complete(tunnelMap) } - fun refreshTunnelStates() { - getAsyncWorker().supplyAsync { getBackend().runningTunnelNames } - .thenAccept { running: Set<String> -> for (tunnel in tunnelMap) tunnel.onStateChanged(if (running.contains(tunnel.name)) Tunnel.State.UP else Tunnel.State.DOWN) } - .whenComplete(ExceptionLoggers.E) + private fun refreshTunnelStates() { + applicationScope.launch { + try { + val running = withContext(Dispatchers.IO) { getBackend().runningTunnelNames } + for (tunnel in tunnelMap) + tunnel.onStateChanged(if (running.contains(tunnel.name)) Tunnel.State.UP else Tunnel.State.DOWN) + } catch (e: Throwable) { + Log.e(TAG, Log.getStackTraceString(e)) + } + } } - fun restoreState(force: Boolean): CompletionStage<Void> { - if (!force && !getSharedPreferences().getBoolean(KEY_RESTORE_ON_BOOT, false)) - return CompletableFuture.completedFuture(null) - synchronized(delayedLoadRestoreTunnels) { - if (!haveLoaded) { - val f = CompletableFuture<Void>() - delayedLoadRestoreTunnels.add(f) - return f + suspend fun restoreState(force: Boolean) { + if (!haveLoaded || (!force && !UserKnobs.restoreOnBoot.first())) + return + val previouslyRunning = UserKnobs.runningTunnels.first() + if (previouslyRunning.isEmpty()) return + withContext(Dispatchers.IO) { + try { + tunnelMap.filter { previouslyRunning.contains(it.name) }.map { async(Dispatchers.IO + SupervisorJob()) { setTunnelState(it, Tunnel.State.UP) } } + .awaitAll() + } catch (e: Throwable) { + Log.e(TAG, Log.getStackTraceString(e)) } } - val previouslyRunning = getSharedPreferences().getStringSet(KEY_RUNNING_TUNNELS, null) - ?: return CompletableFuture.completedFuture(null) - return CompletableFuture.allOf(*tunnelMap.filter { previouslyRunning.contains(it.name) }.map { setTunnelState(it, Tunnel.State.UP).toCompletableFuture() }.toTypedArray()) } - @SuppressLint("ApplySharedPref") - fun saveState() { - getSharedPreferences().edit().putStringSet(KEY_RUNNING_TUNNELS, tunnelMap.filter { it.state == Tunnel.State.UP }.map { it.name }.toSet()).commit() + suspend fun saveState() { + UserKnobs.setRunningTunnels(tunnelMap.filter { it.state == Tunnel.State.UP }.map { it.name }.toSet()) } - fun setTunnelConfig(tunnel: ObservableTunnel, config: Config): CompletionStage<Config> = getAsyncWorker().supplyAsync { - getBackend().setState(tunnel, tunnel.state, config) - configStore.save(tunnel.name, config) - }.thenApply { tunnel.onConfigChanged(it) } + suspend fun setTunnelConfig(tunnel: ObservableTunnel, config: Config): Config = withContext(Dispatchers.Main.immediate) { + tunnel.onConfigChanged(withContext(Dispatchers.IO) { + getBackend().setState(tunnel, tunnel.state, config) + configStore.save(tunnel.name, config) + })!! + } - fun setTunnelName(tunnel: ObservableTunnel, name: String): CompletionStage<String> { + suspend fun setTunnelName(tunnel: ObservableTunnel, name: String): String = withContext(Dispatchers.Main.immediate) { if (Tunnel.isNameInvalid(name)) - return CompletableFuture.failedFuture(IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name))) + throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name)) if (tunnelMap.containsKey(name)) { - return CompletableFuture.failedFuture(IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name))) + throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name)) } val originalState = tunnel.state val wasLastUsed = tunnel == lastUsedTunnel @@ -168,69 +171,85 @@ class TunnelManager(private val configStore: ConfigStore) : BaseObservable() { if (wasLastUsed) lastUsedTunnel = null tunnelMap.remove(tunnel) - return getAsyncWorker().supplyAsync { + var throwable: Throwable? = null + var newName: String? = null + try { if (originalState == Tunnel.State.UP) - getBackend().setState(tunnel, Tunnel.State.DOWN, null) - configStore.rename(tunnel.name, name) - val newName = tunnel.onNameChanged(name) + withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) } + withContext(Dispatchers.IO) { configStore.rename(tunnel.name, name) } + newName = tunnel.onNameChanged(name) if (originalState == Tunnel.State.UP) - getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) - newName - }.whenComplete { _, e -> + withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) } + } catch (e: Throwable) { + throwable = e // On failure, we don't know what state the tunnel might be in. Fix that. - if (e != null) - getTunnelState(tunnel) - // Add the tunnel back to the manager, under whatever name it thinks it has. - tunnelMap.add(tunnel) - if (wasLastUsed) - lastUsedTunnel = tunnel + getTunnelState(tunnel) } + // Add the tunnel back to the manager, under whatever name it thinks it has. + tunnelMap.add(tunnel) + if (wasLastUsed) + lastUsedTunnel = tunnel + if (throwable != null) + throw throwable + newName!! } - fun setTunnelState(tunnel: ObservableTunnel, state: Tunnel.State): CompletionStage<Tunnel.State> = tunnel.configAsync - .thenCompose { getAsyncWorker().supplyAsync { getBackend().setState(tunnel, state, it) } } - .whenComplete { newState, e -> - // Ensure onStateChanged is always called (failure or not), and with the correct state. - tunnel.onStateChanged(if (e == null) newState else tunnel.state) - if (e == null && newState == Tunnel.State.UP) - lastUsedTunnel = tunnel - saveState() - } + suspend fun setTunnelState(tunnel: ObservableTunnel, state: Tunnel.State): Tunnel.State = withContext(Dispatchers.Main.immediate) { + var newState = tunnel.state + var throwable: Throwable? = null + try { + newState = withContext(Dispatchers.IO) { getBackend().setState(tunnel, state, tunnel.getConfigAsync()) } + if (newState == Tunnel.State.UP) + lastUsedTunnel = tunnel + } catch (e: Throwable) { + throwable = e + } + tunnel.onStateChanged(newState) + saveState() + if (throwable != null) + throw throwable + newState + } class IntentReceiver : BroadcastReceiver() { override fun onReceive(context: Context, intent: Intent?) { - val manager = getTunnelManager() - if (intent == null) return - val action = intent.action ?: return - if ("com.wireguard.android.action.REFRESH_TUNNEL_STATES" == action) { - manager.refreshTunnelStates() - return - } - if (Build.VERSION.SDK_INT < Build.VERSION_CODES.M || !getSharedPreferences().getBoolean("allow_remote_control_intents", false)) - return - val state: Tunnel.State - state = when (action) { - "com.wireguard.android.action.SET_TUNNEL_UP" -> Tunnel.State.UP - "com.wireguard.android.action.SET_TUNNEL_DOWN" -> Tunnel.State.DOWN - else -> return - } - val tunnelName = intent.getStringExtra("tunnel") ?: return - manager.tunnels.thenAccept { - val tunnel = it[tunnelName] ?: return@thenAccept - manager.setTunnelState(tunnel, state) + applicationScope.launch { + val manager = getTunnelManager() + if (intent == null) return@launch + val action = intent.action ?: return@launch + if ("com.wireguard.android.action.REFRESH_TUNNEL_STATES" == action) { + manager.refreshTunnelStates() + return@launch + } + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.M || !UserKnobs.allowRemoteControlIntents.first()) + return@launch + val state: Tunnel.State + state = when (action) { + "com.wireguard.android.action.SET_TUNNEL_UP" -> Tunnel.State.UP + "com.wireguard.android.action.SET_TUNNEL_DOWN" -> Tunnel.State.DOWN + else -> return@launch + } + val tunnelName = intent.getStringExtra("tunnel") ?: return@launch + val tunnels = manager.getTunnels() + val tunnel = tunnels[tunnelName] ?: return@launch + try { + manager.setTunnelState(tunnel, state) + } catch (e: Throwable) { + Toast.makeText(context, ErrorMessages[e], Toast.LENGTH_LONG).show() + } } } } - fun getTunnelState(tunnel: ObservableTunnel): CompletionStage<Tunnel.State> = getAsyncWorker() - .supplyAsync { getBackend().getState(tunnel) }.thenApply(tunnel::onStateChanged) + suspend fun getTunnelState(tunnel: ObservableTunnel): Tunnel.State = withContext(Dispatchers.Main.immediate) { + tunnel.onStateChanged(withContext(Dispatchers.IO) { getBackend().getState(tunnel) }) + } - fun getTunnelStatistics(tunnel: ObservableTunnel): CompletionStage<Statistics> = getAsyncWorker() - .supplyAsync { getBackend().getStatistics(tunnel) }.thenApply(tunnel::onStatisticsChanged) + suspend fun getTunnelStatistics(tunnel: ObservableTunnel): Statistics = withContext(Dispatchers.Main.immediate) { + tunnel.onStatisticsChanged(withContext(Dispatchers.IO) { getBackend().getStatistics(tunnel) })!! + } companion object { - private const val KEY_LAST_USED_TUNNEL = "last_used_tunnel" - private const val KEY_RESTORE_ON_BOOT = "restore_on_boot" - private const val KEY_RUNNING_TUNNELS = "enabled_configs" + private const val TAG = "WireGuard/TunnelManager" } } diff --git a/ui/src/main/java/com/wireguard/android/preference/DonatePreference.kt b/ui/src/main/java/com/wireguard/android/preference/DonatePreference.kt new file mode 100644 index 00000000..2f66a2ca --- /dev/null +++ b/ui/src/main/java/com/wireguard/android/preference/DonatePreference.kt @@ -0,0 +1,43 @@ +/* + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.preference + +import android.content.Context +import android.content.Intent +import android.net.Uri +import android.util.AttributeSet +import android.widget.Toast +import androidx.preference.Preference +import com.google.android.material.dialog.MaterialAlertDialogBuilder +import com.wireguard.android.R +import com.wireguard.android.updater.Updater +import com.wireguard.android.util.ErrorMessages +import androidx.core.net.toUri + +class DonatePreference(context: Context, attrs: AttributeSet?) : Preference(context, attrs) { + override fun getSummary() = context.getString(R.string.donate_summary) + + override fun getTitle() = context.getString(R.string.donate_title) + + override fun onClick() { + /* Google Play Store forbids links to our donation page. */ + if (Updater.installerIsGooglePlay(context)) { + MaterialAlertDialogBuilder(context) + .setTitle(R.string.donate_title) + .setMessage(R.string.donate_google_play_disappointment) + .show() + return + } + + val intent = Intent(Intent.ACTION_VIEW) + intent.data = "https://www.wireguard.com/donations/".toUri() + try { + context.startActivity(intent) + } catch (e: Throwable) { + Toast.makeText(context, ErrorMessages[e], Toast.LENGTH_SHORT).show() + } + } +} diff --git a/ui/src/main/java/com/wireguard/android/preference/KernelModuleDisablerPreference.kt b/ui/src/main/java/com/wireguard/android/preference/KernelModuleDisablerPreference.kt deleted file mode 100644 index 1479d7b6..00000000 --- a/ui/src/main/java/com/wireguard/android/preference/KernelModuleDisablerPreference.kt +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright © 2020 WireGuard LLC. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package com.wireguard.android.preference - -import android.annotation.SuppressLint -import android.content.Context -import android.content.Intent -import android.util.AttributeSet -import androidx.preference.Preference -import com.wireguard.android.Application -import com.wireguard.android.R -import com.wireguard.android.activity.SettingsActivity -import com.wireguard.android.backend.Tunnel -import com.wireguard.android.backend.WgQuickBackend -import java9.util.concurrent.CompletableFuture -import kotlin.system.exitProcess - -class KernelModuleDisablerPreference(context: Context, attrs: AttributeSet?) : Preference(context, attrs) { - private var state = State.UNKNOWN - - init { - isVisible = false - Application.getBackendAsync().thenAccept { backend -> - setState(if (backend is WgQuickBackend) State.ENABLED else State.DISABLED) - } - } - - override fun getSummary() = if (state == State.UNKNOWN) "" else context.getString(state.summaryResourceId) - - override fun getTitle() = if (state == State.UNKNOWN) "" else context.getString(state.titleResourceId) - - @SuppressLint("ApplySharedPref") - override fun onClick() { - if (state == State.DISABLED) { - setState(State.ENABLING) - Application.getSharedPreferences().edit().putBoolean("disable_kernel_module", false).commit() - } else if (state == State.ENABLED) { - setState(State.DISABLING) - Application.getSharedPreferences().edit().putBoolean("disable_kernel_module", true).commit() - } - Application.getAsyncWorker().runAsync { - Application.getTunnelManager().tunnels.thenApply { observableTunnels -> - val downings = observableTunnels.map { it.setStateAsync(Tunnel.State.DOWN).toCompletableFuture() }.toTypedArray() - CompletableFuture.allOf(*downings).thenRun { - val restartIntent = Intent(context, SettingsActivity::class.java) - restartIntent.addFlags(Intent.FLAG_ACTIVITY_CLEAR_TOP) - restartIntent.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK) - Application.get().startActivity(restartIntent) - exitProcess(0) - } - }.join() - } - } - - private fun setState(state: State) { - if (this.state == state) return - this.state = state - if (isEnabled != state.shouldEnableView) isEnabled = state.shouldEnableView - if (isVisible != state.visible) isVisible = state.visible - notifyChanged() - } - - private enum class State(val titleResourceId: Int, val summaryResourceId: Int, val shouldEnableView: Boolean, val visible: Boolean) { - UNKNOWN(0, 0, false, false), - ENABLED(R.string.module_disabler_enabled_title, R.string.module_disabler_enabled_summary, true, true), - DISABLED(R.string.module_disabler_disabled_title, R.string.module_disabler_disabled_summary, true, true), - ENABLING(R.string.module_disabler_disabled_title, R.string.success_application_will_restart, false, true), - DISABLING(R.string.module_disabler_enabled_title, R.string.success_application_will_restart, false, true); - } -} diff --git a/ui/src/main/java/com/wireguard/android/preference/KernelModuleEnablerPreference.kt b/ui/src/main/java/com/wireguard/android/preference/KernelModuleEnablerPreference.kt new file mode 100644 index 00000000..3d1c27f1 --- /dev/null +++ b/ui/src/main/java/com/wireguard/android/preference/KernelModuleEnablerPreference.kt @@ -0,0 +1,88 @@ +/* + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package com.wireguard.android.preference + +import android.content.Context +import android.content.Intent +import android.util.AttributeSet +import android.util.Log +import androidx.lifecycle.lifecycleScope +import androidx.preference.Preference +import com.wireguard.android.Application +import com.wireguard.android.R +import com.wireguard.android.activity.SettingsActivity +import com.wireguard.android.backend.Tunnel +import com.wireguard.android.backend.WgQuickBackend +import com.wireguard.android.util.UserKnobs +import com.wireguard.android.util.activity +import com.wireguard.android.util.lifecycleScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import kotlin.system.exitProcess + +class KernelModuleEnablerPreference(context: Context, attrs: AttributeSet?) : Preference(context, attrs) { + private var state = State.UNKNOWN + + init { + isVisible = false + lifecycleScope.launch { + setState(if (Application.getBackend() is WgQuickBackend) State.ENABLED else State.DISABLED) + } + } + + override fun getSummary() = if (state == State.UNKNOWN) "" else context.getString(state.summaryResourceId) + + override fun getTitle() = if (state == State.UNKNOWN) "" else context.getString(state.titleResourceId) + + override fun onClick() { + activity.lifecycleScope.launch { + if (state == State.DISABLED) { + setState(State.ENABLING) + UserKnobs.setEnableKernelModule(true) + } else if (state == State.ENABLED) { + setState(State.DISABLING) + UserKnobs.setEnableKernelModule(false) + } + val observableTunnels = Application.getTunnelManager().getTunnels() + val downings = observableTunnels.map { async(SupervisorJob()) { it.setStateAsync(Tunnel.State.DOWN) } } + try { + downings.awaitAll() + withContext(Dispatchers.IO) { + val restartIntent = Intent(context, SettingsActivity::class.java) + restartIntent.addFlags(Intent.FLAG_ACTIVITY_CLEAR_TOP) + restartIntent.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK) + Application.get().startActivity(restartIntent) + exitProcess(0) + } + } catch (e: Throwable) { + Log.e(TAG, Log.getStackTraceString(e)) + } + } + } + + private fun setState(state: State) { + if (this.state == state) return + this.state = state + if (isEnabled != state.shouldEnableView) isEnabled = state.shouldEnableView + if (isVisible != state.visible) isVisible = state.visible + notifyChanged() + } + + private enum class State(val titleResourceId: Int, val summaryResourceId: Int, val shouldEnableView: Boolean, val visible: Boolean) { + UNKNOWN(0, 0, false, false), + ENABLED(R.string.module_enabler_enabled_title, R.string.module_enabler_enabled_summary, true, true), + DISABLED(R.string.module_enabler_disabled_title, R.string.module_enabler_disabled_summary, true, true), + ENABLING(R.string.module_enabler_disabled_title, R.string.success_application_will_restart, false, true), + DISABLING(R.string.module_enabler_enabled_title, R.string.success_application_will_restart, false, true); + } + + companion object { + private const val TAG = "WireGuard/KernelModuleEnablerPreference" + } +} diff --git a/ui/src/main/java/com/wireguard/android/preference/ModuleDownloaderPreference.kt b/ui/src/main/java/com/wireguard/android/preference/ModuleDownloaderPreference.kt deleted file mode 100644 index 055ed449..00000000 --- a/ui/src/main/java/com/wireguard/android/preference/ModuleDownloaderPreference.kt +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright © 2019 WireGuard LLC. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package com.wireguard.android.preference - -import android.annotation.SuppressLint -import android.content.Context -import android.content.Intent -import android.system.OsConstants -import android.util.AttributeSet -import android.widget.Toast -import androidx.preference.Preference -import com.wireguard.android.Application -import com.wireguard.android.R -import com.wireguard.android.activity.SettingsActivity -import com.wireguard.android.util.ErrorMessages -import kotlin.system.exitProcess - -class ModuleDownloaderPreference(context: Context, attrs: AttributeSet?) : Preference(context, attrs) { - private var state = State.INITIAL - - override fun getSummary() = context.getString(state.messageResourceId) - - override fun getTitle() = context.getString(R.string.module_installer_title) - - override fun onClick() { - setState(State.WORKING) - Application.getAsyncWorker().supplyAsync(Application.getModuleLoader()::download).whenComplete(this::onDownloadResult) - } - - @SuppressLint("ApplySharedPref") - private fun onDownloadResult(result: Int, throwable: Throwable?) { - when { - throwable != null -> { - setState(State.FAILURE) - Toast.makeText(context, ErrorMessages[throwable], Toast.LENGTH_LONG).show() - } - result == OsConstants.ENOENT -> setState(State.NOTFOUND) - result == OsConstants.EXIT_SUCCESS -> { - setState(State.SUCCESS) - Application.getSharedPreferences().edit().remove("disable_kernel_module").commit() - Application.getAsyncWorker().runAsync { - val restartIntent = Intent(context, SettingsActivity::class.java) - restartIntent.addFlags(Intent.FLAG_ACTIVITY_CLEAR_TOP) - restartIntent.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK) - Application.get().startActivity(restartIntent) - exitProcess(0) - } - } - else -> setState(State.FAILURE) - } - } - - private fun setState(state: State) { - if (this.state == state) return - this.state = state - if (isEnabled != state.shouldEnableView) isEnabled = state.shouldEnableView - notifyChanged() - } - - private enum class State(val messageResourceId: Int, val shouldEnableView: Boolean) { - INITIAL(R.string.module_installer_initial, true), - FAILURE(R.string.module_installer_error, true), - WORKING(R.string.module_installer_working, false), - SUCCESS(R.string.success_application_will_restart, false), - NOTFOUND(R.string.module_installer_not_found, false); - } -} diff --git a/ui/src/main/java/com/wireguard/android/preference/PreferencesPreferenceDataStore.kt b/ui/src/main/java/com/wireguard/android/preference/PreferencesPreferenceDataStore.kt new file mode 100644 index 00000000..e2fc51e3 --- /dev/null +++ b/ui/src/main/java/com/wireguard/android/preference/PreferencesPreferenceDataStore.kt @@ -0,0 +1,135 @@ +/* + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.preference + +import androidx.datastore.core.DataStore +import androidx.datastore.preferences.core.Preferences +import androidx.datastore.preferences.core.booleanPreferencesKey +import androidx.datastore.preferences.core.edit +import androidx.datastore.preferences.core.floatPreferencesKey +import androidx.datastore.preferences.core.intPreferencesKey +import androidx.datastore.preferences.core.longPreferencesKey +import androidx.datastore.preferences.core.stringPreferencesKey +import androidx.datastore.preferences.core.stringSetPreferencesKey +import androidx.preference.PreferenceDataStore +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking + +class PreferencesPreferenceDataStore(private val coroutineScope: CoroutineScope, private val dataStore: DataStore<Preferences>) : PreferenceDataStore() { + override fun putString(key: String?, value: String?) { + if (key == null) return + val pk = stringPreferencesKey(key) + coroutineScope.launch { + dataStore.edit { + if (value == null) it.remove(pk) + else it[pk] = value + } + } + } + + override fun putStringSet(key: String?, values: Set<String?>?) { + if (key == null) return + val pk = stringSetPreferencesKey(key) + val filteredValues = values?.filterNotNull()?.toSet() + coroutineScope.launch { + dataStore.edit { + if (filteredValues == null || filteredValues.isEmpty()) it.remove(pk) + else it[pk] = filteredValues + } + } + } + + override fun putInt(key: String?, value: Int) { + if (key == null) return + val pk = intPreferencesKey(key) + coroutineScope.launch { + dataStore.edit { + it[pk] = value + } + } + } + + override fun putLong(key: String?, value: Long) { + if (key == null) return + val pk = longPreferencesKey(key) + coroutineScope.launch { + dataStore.edit { + it[pk] = value + } + } + } + + override fun putFloat(key: String?, value: Float) { + if (key == null) return + val pk = floatPreferencesKey(key) + coroutineScope.launch { + dataStore.edit { + it[pk] = value + } + } + } + + override fun putBoolean(key: String?, value: Boolean) { + if (key == null) return + val pk = booleanPreferencesKey(key) + coroutineScope.launch { + dataStore.edit { + it[pk] = value + } + } + } + + override fun getString(key: String?, defValue: String?): String? { + if (key == null) return defValue + val pk = stringPreferencesKey(key) + return runBlocking { + dataStore.data.map { it[pk] ?: defValue }.first() + } + } + + override fun getStringSet(key: String?, defValues: Set<String?>?): Set<String?>? { + if (key == null) return defValues + val pk = stringSetPreferencesKey(key) + return runBlocking { + dataStore.data.map { it[pk] ?: defValues }.first() + } + } + + override fun getInt(key: String?, defValue: Int): Int { + if (key == null) return defValue + val pk = intPreferencesKey(key) + return runBlocking { + dataStore.data.map { it[pk] ?: defValue }.first() + } + } + + override fun getLong(key: String?, defValue: Long): Long { + if (key == null) return defValue + val pk = longPreferencesKey(key) + return runBlocking { + dataStore.data.map { it[pk] ?: defValue }.first() + } + } + + override fun getFloat(key: String?, defValue: Float): Float { + if (key == null) return defValue + val pk = floatPreferencesKey(key) + return runBlocking { + dataStore.data.map { it[pk] ?: defValue }.first() + } + } + + override fun getBoolean(key: String?, defValue: Boolean): Boolean { + if (key == null) return defValue + val pk = booleanPreferencesKey(key) + return runBlocking { + dataStore.data.map { it[pk] ?: defValue }.first() + } + } +} diff --git a/ui/src/main/java/com/wireguard/android/preference/QuickTilePreference.kt b/ui/src/main/java/com/wireguard/android/preference/QuickTilePreference.kt new file mode 100644 index 00000000..458b9f9a --- /dev/null +++ b/ui/src/main/java/com/wireguard/android/preference/QuickTilePreference.kt @@ -0,0 +1,50 @@ +/* + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.preference + +import android.app.StatusBarManager +import android.content.ComponentName +import android.content.Context +import android.graphics.drawable.Icon +import android.os.Build +import android.util.AttributeSet +import android.widget.Toast +import androidx.annotation.RequiresApi +import androidx.preference.Preference +import com.wireguard.android.QuickTileService +import com.wireguard.android.R + +@RequiresApi(Build.VERSION_CODES.TIRAMISU) +class QuickTilePreference(context: Context, attrs: AttributeSet?) : Preference(context, attrs) { + override fun getSummary() = context.getString(R.string.quick_settings_tile_add_summary) + + override fun getTitle() = context.getString(R.string.quick_settings_tile_add_title) + + override fun onClick() { + val statusBarManager = context.getSystemService(StatusBarManager::class.java) + statusBarManager.requestAddTileService( + ComponentName(context, QuickTileService::class.java), + context.getString(R.string.quick_settings_tile_action), + Icon.createWithResource(context, R.drawable.ic_tile), + context.mainExecutor + ) { + when (it) { + StatusBarManager.TILE_ADD_REQUEST_RESULT_TILE_ALREADY_ADDED, + StatusBarManager.TILE_ADD_REQUEST_RESULT_TILE_ADDED -> { + parent?.removePreference(this) + --preferenceManager.preferenceScreen.initialExpandedChildrenCount + } + StatusBarManager.TILE_ADD_REQUEST_ERROR_MISMATCHED_PACKAGE, + StatusBarManager.TILE_ADD_REQUEST_ERROR_REQUEST_IN_PROGRESS, + StatusBarManager.TILE_ADD_REQUEST_ERROR_BAD_COMPONENT, + StatusBarManager.TILE_ADD_REQUEST_ERROR_NOT_CURRENT_USER, + StatusBarManager.TILE_ADD_REQUEST_ERROR_APP_NOT_IN_FOREGROUND, + StatusBarManager.TILE_ADD_REQUEST_ERROR_NO_STATUS_BAR_SERVICE -> + Toast.makeText(context, context.getString(R.string.quick_settings_tile_add_failure, it), Toast.LENGTH_SHORT).show() + } + } + } +} diff --git a/ui/src/main/java/com/wireguard/android/preference/ToolsInstallerPreference.kt b/ui/src/main/java/com/wireguard/android/preference/ToolsInstallerPreference.kt index f7dd932d..b22048b5 100644 --- a/ui/src/main/java/com/wireguard/android/preference/ToolsInstallerPreference.kt +++ b/ui/src/main/java/com/wireguard/android/preference/ToolsInstallerPreference.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.preference @@ -10,6 +10,10 @@ import androidx.preference.Preference import com.wireguard.android.Application import com.wireguard.android.R import com.wireguard.android.util.ToolsInstaller +import com.wireguard.android.util.lifecycleScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext /** * Preference implementing a button that asynchronously runs `ToolsInstaller` and displays the @@ -17,37 +21,41 @@ import com.wireguard.android.util.ToolsInstaller */ class ToolsInstallerPreference(context: Context, attrs: AttributeSet?) : Preference(context, attrs) { private var state = State.INITIAL - override fun getSummary() = context.getString(state.messageResourceId) override fun getTitle() = context.getString(R.string.tools_installer_title) override fun onAttached() { super.onAttached() - Application.getAsyncWorker().supplyAsync(Application.getToolsInstaller()::areInstalled).whenComplete(this::onCheckResult) - } - - private fun onCheckResult(state: Int, throwable: Throwable?) { - when { - throwable != null || state == ToolsInstaller.ERROR -> setState(State.INITIAL) - state and ToolsInstaller.YES == ToolsInstaller.YES -> setState(State.ALREADY) - state and (ToolsInstaller.MAGISK or ToolsInstaller.NO) == ToolsInstaller.MAGISK or ToolsInstaller.NO -> setState(State.INITIAL_MAGISK) - state and (ToolsInstaller.SYSTEM or ToolsInstaller.NO) == ToolsInstaller.SYSTEM or ToolsInstaller.NO -> setState(State.INITIAL_SYSTEM) - else -> setState(State.INITIAL) + lifecycleScope.launch { + try { + val state = withContext(Dispatchers.IO) { Application.getToolsInstaller().areInstalled() } + when { + state == ToolsInstaller.ERROR -> setState(State.INITIAL) + state and ToolsInstaller.YES == ToolsInstaller.YES -> setState(State.ALREADY) + state and (ToolsInstaller.MAGISK or ToolsInstaller.NO) == ToolsInstaller.MAGISK or ToolsInstaller.NO -> setState(State.INITIAL_MAGISK) + state and (ToolsInstaller.SYSTEM or ToolsInstaller.NO) == ToolsInstaller.SYSTEM or ToolsInstaller.NO -> setState(State.INITIAL_SYSTEM) + else -> setState(State.INITIAL) + } + } catch (_: Throwable) { + setState(State.INITIAL) + } } } override fun onClick() { setState(State.WORKING) - Application.getAsyncWorker().supplyAsync { Application.getToolsInstaller().install() }.whenComplete { result: Int, throwable: Throwable? -> onInstallResult(result, throwable) } - } - - private fun onInstallResult(result: Int, throwable: Throwable?) { - when { - throwable != null -> setState(State.FAILURE) - result and (ToolsInstaller.YES or ToolsInstaller.MAGISK) == ToolsInstaller.YES or ToolsInstaller.MAGISK -> setState(State.SUCCESS_MAGISK) - result and (ToolsInstaller.YES or ToolsInstaller.SYSTEM) == ToolsInstaller.YES or ToolsInstaller.SYSTEM -> setState(State.SUCCESS_SYSTEM) - else -> setState(State.FAILURE) + lifecycleScope.launch { + try { + val result = withContext(Dispatchers.IO) { Application.getToolsInstaller().install() } + when { + result and (ToolsInstaller.YES or ToolsInstaller.MAGISK) == ToolsInstaller.YES or ToolsInstaller.MAGISK -> setState(State.SUCCESS_MAGISK) + result and (ToolsInstaller.YES or ToolsInstaller.SYSTEM) == ToolsInstaller.YES or ToolsInstaller.SYSTEM -> setState(State.SUCCESS_SYSTEM) + else -> setState(State.FAILURE) + } + } catch (_: Throwable) { + setState(State.FAILURE) + } } } diff --git a/ui/src/main/java/com/wireguard/android/preference/VersionPreference.kt b/ui/src/main/java/com/wireguard/android/preference/VersionPreference.kt index 0734df45..3850482b 100644 --- a/ui/src/main/java/com/wireguard/android/preference/VersionPreference.kt +++ b/ui/src/main/java/com/wireguard/android/preference/VersionPreference.kt @@ -1,14 +1,14 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.preference -import android.content.ActivityNotFoundException import android.content.Context import android.content.Intent import android.net.Uri import android.util.AttributeSet +import android.widget.Toast import androidx.preference.Preference import com.wireguard.android.Application import com.wireguard.android.BuildConfig @@ -16,7 +16,11 @@ import com.wireguard.android.R import com.wireguard.android.backend.Backend import com.wireguard.android.backend.GoBackend import com.wireguard.android.backend.WgQuickBackend -import java.util.Locale +import com.wireguard.android.util.ErrorMessages +import com.wireguard.android.util.lifecycleScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext class VersionPreference(context: Context, attrs: AttributeSet?) : Preference(context, attrs) { private var versionSummary: String? = null @@ -30,7 +34,8 @@ class VersionPreference(context: Context, attrs: AttributeSet?) : Preference(con intent.data = Uri.parse("https://www.wireguard.com/") try { context.startActivity(intent) - } catch (_: ActivityNotFoundException) { + } catch (e: Throwable) { + Toast.makeText(context, ErrorMessages[e], Toast.LENGTH_SHORT).show() } } @@ -43,15 +48,16 @@ class VersionPreference(context: Context, attrs: AttributeSet?) : Preference(con } init { - Application.getBackendAsync().thenAccept { backend -> - versionSummary = getContext().getString(R.string.version_summary_checking, getBackendPrettyName(context, backend).toLowerCase(Locale.ENGLISH)) - Application.getAsyncWorker().supplyAsync(backend::getVersion).whenComplete { version, exception -> - versionSummary = if (exception == null) - getContext().getString(R.string.version_summary, getBackendPrettyName(context, backend), version) - else - getContext().getString(R.string.version_summary_unknown, getBackendPrettyName(context, backend).toLowerCase(Locale.ENGLISH)) - notifyChanged() + lifecycleScope.launch { + val backend = Application.getBackend() + versionSummary = getContext().getString(R.string.version_summary_checking, getBackendPrettyName(context, backend).lowercase()) + notifyChanged() + versionSummary = try { + getContext().getString(R.string.version_summary, getBackendPrettyName(context, backend), withContext(Dispatchers.IO) { backend.version }) + } catch (_: Throwable) { + getContext().getString(R.string.version_summary_unknown, getBackendPrettyName(context, backend).lowercase()) } + notifyChanged() } } } diff --git a/ui/src/main/java/com/wireguard/android/preference/ZipExporterPreference.kt b/ui/src/main/java/com/wireguard/android/preference/ZipExporterPreference.kt index cdd25134..52701157 100644 --- a/ui/src/main/java/com/wireguard/android/preference/ZipExporterPreference.kt +++ b/ui/src/main/java/com/wireguard/android/preference/ZipExporterPreference.kt @@ -1,24 +1,28 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.preference -import android.Manifest import android.content.Context -import android.content.pm.PackageManager import android.util.AttributeSet import android.util.Log import androidx.preference.Preference import com.google.android.material.snackbar.Snackbar import com.wireguard.android.Application import com.wireguard.android.R -import com.wireguard.android.model.ObservableTunnel +import com.wireguard.android.util.AdminKnobs import com.wireguard.android.util.BiometricAuthenticator import com.wireguard.android.util.DownloadsFileSaver import com.wireguard.android.util.ErrorMessages -import com.wireguard.android.util.FragmentUtils -import java9.util.concurrent.CompletableFuture +import com.wireguard.android.util.activity +import com.wireguard.android.util.lifecycleScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext import java.nio.charset.StandardCharsets import java.util.zip.ZipEntry import java.util.zip.ZipOutputStream @@ -28,80 +32,77 @@ import java.util.zip.ZipOutputStream */ class ZipExporterPreference(context: Context, attrs: AttributeSet?) : Preference(context, attrs) { private var exportedFilePath: String? = null + private val downloadsFileSaver = DownloadsFileSaver(activity) private fun exportZip() { - Application.getTunnelManager().tunnels.thenAccept(this::exportZip) - } - - private fun exportZip(tunnels: List<ObservableTunnel>) { - val futureConfigs = tunnels.map { it.configAsync.toCompletableFuture() }.toTypedArray() - if (futureConfigs.isEmpty()) { - exportZipComplete(null, IllegalArgumentException( - context.getString(R.string.no_tunnels_error))) - return - } - CompletableFuture.allOf(*futureConfigs) - .whenComplete { _, exception -> - Application.getAsyncWorker().supplyAsync { - if (exception != null) throw exception - val outputFile = DownloadsFileSaver.save(context, "wireguard-export.zip", "application/zip", true) - try { - ZipOutputStream(outputFile.outputStream).use { zip -> - for (i in futureConfigs.indices) { - zip.putNextEntry(ZipEntry(tunnels[i].name + ".conf")) - zip.write(futureConfigs[i].getNow(null)!!.toWgQuickString().toByteArray(StandardCharsets.UTF_8)) - } - zip.closeEntry() + lifecycleScope.launch { + val tunnels = Application.getTunnelManager().getTunnels() + try { + exportedFilePath = withContext(Dispatchers.IO) { + val configs = tunnels.map { async(SupervisorJob()) { it.getConfigAsync() } }.awaitAll() + if (configs.isEmpty()) { + throw IllegalArgumentException(context.getString(R.string.no_tunnels_error)) + } + val outputFile = downloadsFileSaver.save("wireguard-export.zip", "application/zip", true) + if (outputFile == null) { + withContext(Dispatchers.Main.immediate) { + isEnabled = true + } + return@withContext null + } + try { + ZipOutputStream(outputFile.outputStream).use { zip -> + for (i in configs.indices) { + zip.putNextEntry(ZipEntry(tunnels[i].name + ".conf")) + zip.write(configs[i].toWgQuickString().toByteArray(StandardCharsets.UTF_8)) } - } catch (e: Exception) { - outputFile.delete() - throw e + zip.closeEntry() } - outputFile.fileName - }.whenComplete(this::exportZipComplete) + } catch (e: Throwable) { + outputFile.delete() + throw e + } + outputFile.fileName } - } - - private fun exportZipComplete(filePath: String?, throwable: Throwable?) { - if (throwable != null) { - val error = ErrorMessages[throwable] - val message = context.getString(R.string.zip_export_error, error) - Log.e(TAG, message, throwable) - Snackbar.make( - FragmentUtils.getPrefActivity(this).findViewById(android.R.id.content), - message, Snackbar.LENGTH_LONG).show() - isEnabled = true - } else { - exportedFilePath = filePath - notifyChanged() + notifyChanged() + } catch (e: Throwable) { + val error = ErrorMessages[e] + val message = context.getString(R.string.zip_export_error, error) + Log.e(TAG, message, e) + Snackbar.make( + activity.findViewById(android.R.id.content), + message, Snackbar.LENGTH_LONG + ).show() + isEnabled = true + } } } - override fun getSummary() = if (exportedFilePath == null) context.getString(R.string.zip_export_summary) else context.getString(R.string.zip_export_success, exportedFilePath) + override fun getSummary() = + if (exportedFilePath == null) context.getString(R.string.zip_export_summary) else context.getString(R.string.zip_export_success, exportedFilePath) override fun getTitle() = context.getString(R.string.zip_export_title) override fun onClick() { - val prefActivity = FragmentUtils.getPrefActivity(this) - val fragment = prefActivity.supportFragmentManager.fragments.first() + if (AdminKnobs.disableConfigExport) return + val fragment = activity.supportFragmentManager.fragments.first() BiometricAuthenticator.authenticate(R.string.biometric_prompt_zip_exporter_title, fragment) { when (it) { // When we have successful authentication, or when there is no biometric hardware available. is BiometricAuthenticator.Result.Success, is BiometricAuthenticator.Result.HardwareUnavailableOrDisabled -> { - prefActivity.ensurePermissions(arrayOf(Manifest.permission.WRITE_EXTERNAL_STORAGE)) { _, grantResults -> - if (grantResults.isNotEmpty() && grantResults[0] == PackageManager.PERMISSION_GRANTED) { - isEnabled = false - exportZip() - } - } + isEnabled = false + exportZip() } + is BiometricAuthenticator.Result.Failure -> { Snackbar.make( - prefActivity.findViewById(android.R.id.content), - it.message, - Snackbar.LENGTH_SHORT + activity.findViewById(android.R.id.content), + it.message, + Snackbar.LENGTH_SHORT ).show() } + + is BiometricAuthenticator.Result.Cancelled -> {} } } } diff --git a/ui/src/main/java/com/wireguard/android/updater/Ed25519.java b/ui/src/main/java/com/wireguard/android/updater/Ed25519.java new file mode 100644 index 00000000..44e99b86 --- /dev/null +++ b/ui/src/main/java/com/wireguard/android/updater/Ed25519.java @@ -0,0 +1,2507 @@ +/* + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. + * Copyright 2017 Google Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.updater; + +import java.math.BigInteger; +import java.security.GeneralSecurityException; +import java.security.MessageDigest; +import java.util.Arrays; + +/** + * Implementation of Ed25519 signature verification. + * + * <p>This implementation is based on the ed25519/ref10 implementation in NaCl.</p> + * + * <p>It implements this twisted Edwards curve: + * + * <pre> + * -x^2 + y^2 = 1 + (-121665 / 121666 mod 2^255-19)*x^2*y^2 + * </pre> + * + * @see <a href="https://eprint.iacr.org/2008/013.pdf">Bernstein D.J., Birkner P., Joye M., Lange + * T., Peters C. (2008) Twisted Edwards Curves</a> + * @see <a href="https://eprint.iacr.org/2008/522.pdf">Hisil H., Wong K.KH., Carter G., Dawson E. + * (2008) Twisted Edwards Curves Revisited</a> + */ +final class Ed25519 { + + // d = -121665 / 121666 mod 2^255-19 + private static final long[] D; + // 2d + private static final long[] D2; + // 2^((p-1)/4) mod p where p = 2^255-19 + private static final long[] SQRTM1; + + /** + * Base point for the Edwards twisted curve = (x, 4/5) and its exponentiations. B_TABLE[i][j] = + * (j+1)*256^i*B for i in [0, 32) and j in [0, 8). Base point B = B_TABLE[0][0] + */ + private static final CachedXYT[][] B_TABLE; + private static final CachedXYT[] B2; + + private static final BigInteger P_BI = + BigInteger.valueOf(2).pow(255).subtract(BigInteger.valueOf(19)); + private static final BigInteger D_BI = + BigInteger.valueOf(-121665).multiply(BigInteger.valueOf(121666).modInverse(P_BI)).mod(P_BI); + private static final BigInteger D2_BI = BigInteger.valueOf(2).multiply(D_BI).mod(P_BI); + private static final BigInteger SQRTM1_BI = + BigInteger.valueOf(2).modPow(P_BI.subtract(BigInteger.ONE).divide(BigInteger.valueOf(4)), P_BI); + + private Ed25519() { + } + + private static class Point { + private BigInteger x; + private BigInteger y; + } + + private static BigInteger recoverX(BigInteger y) { + // x^2 = (y^2 - 1) / (d * y^2 + 1) mod 2^255-19 + BigInteger xx = + y.pow(2) + .subtract(BigInteger.ONE) + .multiply(D_BI.multiply(y.pow(2)).add(BigInteger.ONE).modInverse(P_BI)); + BigInteger x = xx.modPow(P_BI.add(BigInteger.valueOf(3)).divide(BigInteger.valueOf(8)), P_BI); + if (!x.pow(2).subtract(xx).mod(P_BI).equals(BigInteger.ZERO)) { + x = x.multiply(SQRTM1_BI).mod(P_BI); + } + if (x.testBit(0)) { + x = P_BI.subtract(x); + } + return x; + } + + private static Point edwards(Point a, Point b) { + Point o = new Point(); + BigInteger xxyy = D_BI.multiply(a.x.multiply(b.x).multiply(a.y).multiply(b.y)).mod(P_BI); + o.x = + (a.x.multiply(b.y).add(b.x.multiply(a.y))) + .multiply(BigInteger.ONE.add(xxyy).modInverse(P_BI)) + .mod(P_BI); + o.y = + (a.y.multiply(b.y).add(a.x.multiply(b.x))) + .multiply(BigInteger.ONE.subtract(xxyy).modInverse(P_BI)) + .mod(P_BI); + return o; + } + + private static byte[] toLittleEndian(BigInteger n) { + byte[] b = new byte[32]; + byte[] nBytes = n.toByteArray(); + System.arraycopy(nBytes, 0, b, 32 - nBytes.length, nBytes.length); + for (int i = 0; i < b.length / 2; i++) { + byte t = b[i]; + b[i] = b[b.length - i - 1]; + b[b.length - i - 1] = t; + } + return b; + } + + private static CachedXYT getCachedXYT(Point p) { + return new CachedXYT( + Field25519.expand(toLittleEndian(p.y.add(p.x).mod(P_BI))), + Field25519.expand(toLittleEndian(p.y.subtract(p.x).mod(P_BI))), + Field25519.expand(toLittleEndian(D2_BI.multiply(p.x).multiply(p.y).mod(P_BI)))); + } + + static { + Point b = new Point(); + b.y = BigInteger.valueOf(4).multiply(BigInteger.valueOf(5).modInverse(P_BI)).mod(P_BI); + b.x = recoverX(b.y); + + D = Field25519.expand(toLittleEndian(D_BI)); + D2 = Field25519.expand(toLittleEndian(D2_BI)); + SQRTM1 = Field25519.expand(toLittleEndian(SQRTM1_BI)); + + Point bi = b; + B_TABLE = new CachedXYT[32][8]; + for (int i = 0; i < 32; i++) { + Point bij = bi; + for (int j = 0; j < 8; j++) { + B_TABLE[i][j] = getCachedXYT(bij); + bij = edwards(bij, bi); + } + for (int j = 0; j < 8; j++) { + bi = edwards(bi, bi); + } + } + bi = b; + Point b2 = edwards(b, b); + B2 = new CachedXYT[8]; + for (int i = 0; i < 8; i++) { + B2[i] = getCachedXYT(bi); + bi = edwards(bi, b2); + } + } + + private static final int PUBLIC_KEY_LEN = Field25519.FIELD_LEN; + private static final int SIGNATURE_LEN = Field25519.FIELD_LEN * 2; + + /** + * Defines field 25519 function based on <a + * href="https://github.com/agl/curve25519-donna/blob/master/curve25519-donna.c">curve25519-donna C + * implementation</a> (mostly identical). + * + * <p>Field elements are written as an array of signed, 64-bit limbs (an array of longs), least + * significant first. The value of the field element is: + * + * <pre> + * x[0] + 2^26·x[1] + 2^51·x[2] + 2^77·x[3] + 2^102·x[4] + 2^128·x[5] + 2^153·x[6] + 2^179·x[7] + + * 2^204·x[8] + 2^230·x[9], + * </pre> + * + * <p>i.e. the limbs are 26, 25, 26, 25, ... bits wide. + */ + private static final class Field25519 { + /** + * During Field25519 computation, the mixed radix representation may be in different forms: + * <ul> + * <li> Reduced-size form: the array has size at most 10. + * <li> Non-reduced-size form: the array is not reduced modulo 2^255 - 19 and has size at most + * 19. + * </ul> + * <p> + * TODO(quannguyen): + * <ul> + * <li> Clarify ill-defined terminologies. + * <li> The reduction procedure is different from DJB's paper + * (http://cr.yp.to/ecdh/curve25519-20060209.pdf). The coefficients after reducing degree and + * reducing coefficients aren't guaranteed to be in range {-2^25, ..., 2^25}. We should check to + * see what's going on. + * <li> Consider using method mult() everywhere and making product() private. + * </ul> + */ + + static final int FIELD_LEN = 32; + static final int LIMB_CNT = 10; + private static final long TWO_TO_25 = 1 << 25; + private static final long TWO_TO_26 = TWO_TO_25 << 1; + + private static final int[] EXPAND_START = {0, 3, 6, 9, 12, 16, 19, 22, 25, 28}; + private static final int[] EXPAND_SHIFT = {0, 2, 3, 5, 6, 0, 1, 3, 4, 6}; + private static final int[] MASK = {0x3ffffff, 0x1ffffff}; + private static final int[] SHIFT = {26, 25}; + + /** + * Sums two numbers: output = in1 + in2 + * <p> + * On entry: in1, in2 are in reduced-size form. + */ + static void sum(long[] output, long[] in1, long[] in2) { + for (int i = 0; i < LIMB_CNT; i++) { + output[i] = in1[i] + in2[i]; + } + } + + /** + * Sums two numbers: output += in + * <p> + * On entry: in is in reduced-size form. + */ + static void sum(long[] output, long[] in) { + sum(output, output, in); + } + + /** + * Find the difference of two numbers: output = in1 - in2 + * (note the order of the arguments!). + * <p> + * On entry: in1, in2 are in reduced-size form. + */ + static void sub(long[] output, long[] in1, long[] in2) { + for (int i = 0; i < LIMB_CNT; i++) { + output[i] = in1[i] - in2[i]; + } + } + + /** + * Find the difference of two numbers: output = in - output + * (note the order of the arguments!). + * <p> + * On entry: in, output are in reduced-size form. + */ + static void sub(long[] output, long[] in) { + sub(output, in, output); + } + + /** + * Multiply a number by a scalar: output = in * scalar + */ + static void scalarProduct(long[] output, long[] in, long scalar) { + for (int i = 0; i < LIMB_CNT; i++) { + output[i] = in[i] * scalar; + } + } + + /** + * Multiply two numbers: out = in2 * in + * <p> + * output must be distinct to both inputs. The inputs are reduced coefficient form, + * the output is not. + * <p> + * out[x] <= 14 * the largest product of the input limbs. + */ + static void product(long[] out, long[] in2, long[] in) { + out[0] = in2[0] * in[0]; + out[1] = in2[0] * in[1] + + in2[1] * in[0]; + out[2] = 2 * in2[1] * in[1] + + in2[0] * in[2] + + in2[2] * in[0]; + out[3] = in2[1] * in[2] + + in2[2] * in[1] + + in2[0] * in[3] + + in2[3] * in[0]; + out[4] = in2[2] * in[2] + + 2 * (in2[1] * in[3] + in2[3] * in[1]) + + in2[0] * in[4] + + in2[4] * in[0]; + out[5] = in2[2] * in[3] + + in2[3] * in[2] + + in2[1] * in[4] + + in2[4] * in[1] + + in2[0] * in[5] + + in2[5] * in[0]; + out[6] = 2 * (in2[3] * in[3] + in2[1] * in[5] + in2[5] * in[1]) + + in2[2] * in[4] + + in2[4] * in[2] + + in2[0] * in[6] + + in2[6] * in[0]; + out[7] = in2[3] * in[4] + + in2[4] * in[3] + + in2[2] * in[5] + + in2[5] * in[2] + + in2[1] * in[6] + + in2[6] * in[1] + + in2[0] * in[7] + + in2[7] * in[0]; + out[8] = in2[4] * in[4] + + 2 * (in2[3] * in[5] + in2[5] * in[3] + in2[1] * in[7] + in2[7] * in[1]) + + in2[2] * in[6] + + in2[6] * in[2] + + in2[0] * in[8] + + in2[8] * in[0]; + out[9] = in2[4] * in[5] + + in2[5] * in[4] + + in2[3] * in[6] + + in2[6] * in[3] + + in2[2] * in[7] + + in2[7] * in[2] + + in2[1] * in[8] + + in2[8] * in[1] + + in2[0] * in[9] + + in2[9] * in[0]; + out[10] = + 2 * (in2[5] * in[5] + in2[3] * in[7] + in2[7] * in[3] + in2[1] * in[9] + in2[9] * in[1]) + + in2[4] * in[6] + + in2[6] * in[4] + + in2[2] * in[8] + + in2[8] * in[2]; + out[11] = in2[5] * in[6] + + in2[6] * in[5] + + in2[4] * in[7] + + in2[7] * in[4] + + in2[3] * in[8] + + in2[8] * in[3] + + in2[2] * in[9] + + in2[9] * in[2]; + out[12] = in2[6] * in[6] + + 2 * (in2[5] * in[7] + in2[7] * in[5] + in2[3] * in[9] + in2[9] * in[3]) + + in2[4] * in[8] + + in2[8] * in[4]; + out[13] = in2[6] * in[7] + + in2[7] * in[6] + + in2[5] * in[8] + + in2[8] * in[5] + + in2[4] * in[9] + + in2[9] * in[4]; + out[14] = 2 * (in2[7] * in[7] + in2[5] * in[9] + in2[9] * in[5]) + + in2[6] * in[8] + + in2[8] * in[6]; + out[15] = in2[7] * in[8] + + in2[8] * in[7] + + in2[6] * in[9] + + in2[9] * in[6]; + out[16] = in2[8] * in[8] + + 2 * (in2[7] * in[9] + in2[9] * in[7]); + out[17] = in2[8] * in[9] + + in2[9] * in[8]; + out[18] = 2 * in2[9] * in[9]; + } + + /** + * Reduce a field element by calling reduceSizeByModularReduction and reduceCoefficients. + * + * @param input An input array of any length. If the array has 19 elements, it will be used as + * temporary buffer and its contents changed. + * @param output An output array of size LIMB_CNT. After the call |output[i]| < 2^26 will hold. + */ + static void reduce(long[] input, long[] output) { + long[] tmp; + if (input.length == 19) { + tmp = input; + } else { + tmp = new long[19]; + System.arraycopy(input, 0, tmp, 0, input.length); + } + reduceSizeByModularReduction(tmp); + reduceCoefficients(tmp); + System.arraycopy(tmp, 0, output, 0, LIMB_CNT); + } + + /** + * Reduce a long form to a reduced-size form by taking the input mod 2^255 - 19. + * <p> + * On entry: |output[i]| < 14*2^54 + * On exit: |output[0..8]| < 280*2^54 + */ + static void reduceSizeByModularReduction(long[] output) { + // The coefficients x[10], x[11],..., x[18] are eliminated by reduction modulo 2^255 - 19. + // For example, the coefficient x[18] is multiplied by 19 and added to the coefficient x[8]. + // + // Each of these shifts and adds ends up multiplying the value by 19. + // + // For output[0..8], the absolute entry value is < 14*2^54 and we add, at most, 19*14*2^54 thus, + // on exit, |output[0..8]| < 280*2^54. + output[8] += output[18] << 4; + output[8] += output[18] << 1; + output[8] += output[18]; + output[7] += output[17] << 4; + output[7] += output[17] << 1; + output[7] += output[17]; + output[6] += output[16] << 4; + output[6] += output[16] << 1; + output[6] += output[16]; + output[5] += output[15] << 4; + output[5] += output[15] << 1; + output[5] += output[15]; + output[4] += output[14] << 4; + output[4] += output[14] << 1; + output[4] += output[14]; + output[3] += output[13] << 4; + output[3] += output[13] << 1; + output[3] += output[13]; + output[2] += output[12] << 4; + output[2] += output[12] << 1; + output[2] += output[12]; + output[1] += output[11] << 4; + output[1] += output[11] << 1; + output[1] += output[11]; + output[0] += output[10] << 4; + output[0] += output[10] << 1; + output[0] += output[10]; + } + + /** + * Reduce all coefficients of the short form input so that |x| < 2^26. + * <p> + * On entry: |output[i]| < 280*2^54 + */ + static void reduceCoefficients(long[] output) { + output[10] = 0; + + for (int i = 0; i < LIMB_CNT; i += 2) { + long over = output[i] / TWO_TO_26; + // The entry condition (that |output[i]| < 280*2^54) means that over is, at most, 280*2^28 in + // the first iteration of this loop. This is added to the next limb and we can approximate the + // resulting bound of that limb by 281*2^54. + output[i] -= over << 26; + output[i + 1] += over; + + // For the first iteration, |output[i+1]| < 281*2^54, thus |over| < 281*2^29. When this is + // added to the next limb, the resulting bound can be approximated as 281*2^54. + // + // For subsequent iterations of the loop, 281*2^54 remains a conservative bound and no + // overflow occurs. + over = output[i + 1] / TWO_TO_25; + output[i + 1] -= over << 25; + output[i + 2] += over; + } + // Now |output[10]| < 281*2^29 and all other coefficients are reduced. + output[0] += output[10] << 4; + output[0] += output[10] << 1; + output[0] += output[10]; + + output[10] = 0; + // Now output[1..9] are reduced, and |output[0]| < 2^26 + 19*281*2^29 so |over| will be no more + // than 2^16. + long over = output[0] / TWO_TO_26; + output[0] -= over << 26; + output[1] += over; + // Now output[0,2..9] are reduced, and |output[1]| < 2^25 + 2^16 < 2^26. The bound on + // |output[1]| is sufficient to meet our needs. + } + + /** + * A helpful wrapper around {@ref Field25519#product}: output = in * in2. + * <p> + * On entry: |in[i]| < 2^27 and |in2[i]| < 2^27. + * <p> + * The output is reduced degree (indeed, one need only provide storage for 10 limbs) and + * |output[i]| < 2^26. + */ + static void mult(long[] output, long[] in, long[] in2) { + long[] t = new long[19]; + product(t, in, in2); + // |t[i]| < 2^26 + reduce(t, output); + } + + /** + * Square a number: out = in**2 + * <p> + * output must be distinct from the input. The inputs are reduced coefficient form, the output is + * not. + * <p> + * out[x] <= 14 * the largest product of the input limbs. + */ + private static void squareInner(long[] out, long[] in) { + out[0] = in[0] * in[0]; + out[1] = 2 * in[0] * in[1]; + out[2] = 2 * (in[1] * in[1] + in[0] * in[2]); + out[3] = 2 * (in[1] * in[2] + in[0] * in[3]); + out[4] = in[2] * in[2] + + 4 * in[1] * in[3] + + 2 * in[0] * in[4]; + out[5] = 2 * (in[2] * in[3] + in[1] * in[4] + in[0] * in[5]); + out[6] = 2 * (in[3] * in[3] + in[2] * in[4] + in[0] * in[6] + 2 * in[1] * in[5]); + out[7] = 2 * (in[3] * in[4] + in[2] * in[5] + in[1] * in[6] + in[0] * in[7]); + out[8] = in[4] * in[4] + + 2 * (in[2] * in[6] + in[0] * in[8] + 2 * (in[1] * in[7] + in[3] * in[5])); + out[9] = 2 * (in[4] * in[5] + in[3] * in[6] + in[2] * in[7] + in[1] * in[8] + in[0] * in[9]); + out[10] = 2 * (in[5] * in[5] + + in[4] * in[6] + + in[2] * in[8] + + 2 * (in[3] * in[7] + in[1] * in[9])); + out[11] = 2 * (in[5] * in[6] + in[4] * in[7] + in[3] * in[8] + in[2] * in[9]); + out[12] = in[6] * in[6] + + 2 * (in[4] * in[8] + 2 * (in[5] * in[7] + in[3] * in[9])); + out[13] = 2 * (in[6] * in[7] + in[5] * in[8] + in[4] * in[9]); + out[14] = 2 * (in[7] * in[7] + in[6] * in[8] + 2 * in[5] * in[9]); + out[15] = 2 * (in[7] * in[8] + in[6] * in[9]); + out[16] = in[8] * in[8] + 4 * in[7] * in[9]; + out[17] = 2 * in[8] * in[9]; + out[18] = 2 * in[9] * in[9]; + } + + /** + * Returns in^2. + * <p> + * On entry: The |in| argument is in reduced coefficients form and |in[i]| < 2^27. + * <p> + * On exit: The |output| argument is in reduced coefficients form (indeed, one need only provide + * storage for 10 limbs) and |out[i]| < 2^26. + */ + static void square(long[] output, long[] in) { + long[] t = new long[19]; + squareInner(t, in); + // |t[i]| < 14*2^54 because the largest product of two limbs will be < 2^(27+27) and SquareInner + // adds together, at most, 14 of those products. + reduce(t, output); + } + + /** + * Takes a little-endian, 32-byte number and expands it into mixed radix form. + */ + static long[] expand(byte[] input) { + long[] output = new long[LIMB_CNT]; + for (int i = 0; i < LIMB_CNT; i++) { + output[i] = ((((long) (input[EXPAND_START[i]] & 0xff)) + | ((long) (input[EXPAND_START[i] + 1] & 0xff)) << 8 + | ((long) (input[EXPAND_START[i] + 2] & 0xff)) << 16 + | ((long) (input[EXPAND_START[i] + 3] & 0xff)) << 24) >> EXPAND_SHIFT[i]) & MASK[i & 1]; + } + return output; + } + + /** + * Takes a fully reduced mixed radix form number and contract it into a little-endian, 32-byte + * array. + * <p> + * On entry: |input_limbs[i]| < 2^26 + */ + @SuppressWarnings("NarrowingCompoundAssignment") + static byte[] contract(long[] inputLimbs) { + long[] input = Arrays.copyOf(inputLimbs, LIMB_CNT); + for (int j = 0; j < 2; j++) { + for (int i = 0; i < 9; i++) { + // This calculation is a time-invariant way to make input[i] non-negative by borrowing + // from the next-larger limb. + int carry = -(int) ((input[i] & (input[i] >> 31)) >> SHIFT[i & 1]); + input[i] = input[i] + (carry << SHIFT[i & 1]); + input[i + 1] -= carry; + } + + // There's no greater limb for input[9] to borrow from, but we can multiply by 19 and borrow + // from input[0], which is valid mod 2^255-19. + { + int carry = -(int) ((input[9] & (input[9] >> 31)) >> 25); + input[9] += (carry << 25); + input[0] -= (carry * 19); + } + + // After the first iteration, input[1..9] are non-negative and fit within 25 or 26 bits, + // depending on position. However, input[0] may be negative. + } + + // The first borrow-propagation pass above ended with every limb except (possibly) input[0] + // non-negative. + // + // If input[0] was negative after the first pass, then it was because of a carry from input[9]. + // On entry, input[9] < 2^26 so the carry was, at most, one, since (2**26-1) >> 25 = 1. Thus + // input[0] >= -19. + // + // In the second pass, each limb is decreased by at most one. Thus the second borrow-propagation + // pass could only have wrapped around to decrease input[0] again if the first pass left + // input[0] negative *and* input[1] through input[9] were all zero. In that case, input[1] is + // now 2^25 - 1, and this last borrow-propagation step will leave input[1] non-negative. + { + int carry = -(int) ((input[0] & (input[0] >> 31)) >> 26); + input[0] += (carry << 26); + input[1] -= carry; + } + + // All input[i] are now non-negative. However, there might be values between 2^25 and 2^26 in a + // limb which is, nominally, 25 bits wide. + for (int j = 0; j < 2; j++) { + for (int i = 0; i < 9; i++) { + int carry = (int) (input[i] >> SHIFT[i & 1]); + input[i] &= MASK[i & 1]; + input[i + 1] += carry; + } + } + + { + int carry = (int) (input[9] >> 25); + input[9] &= 0x1ffffff; + input[0] += 19 * carry; + } + + // If the first carry-chain pass, just above, ended up with a carry from input[9], and that + // caused input[0] to be out-of-bounds, then input[0] was < 2^26 + 2*19, because the carry was, + // at most, two. + // + // If the second pass carried from input[9] again then input[0] is < 2*19 and the input[9] -> + // input[0] carry didn't push input[0] out of bounds. + + // It still remains the case that input might be between 2^255-19 and 2^255. In this case, + // input[1..9] must take their maximum value and input[0] must be >= (2^255-19) & 0x3ffffff, + // which is 0x3ffffed. + int mask = gte((int) input[0], 0x3ffffed); + for (int i = 1; i < LIMB_CNT; i++) { + mask &= eq((int) input[i], MASK[i & 1]); + } + + // mask is either 0xffffffff (if input >= 2^255-19) and zero otherwise. Thus this conditionally + // subtracts 2^255-19. + input[0] -= mask & 0x3ffffed; + input[1] -= mask & 0x1ffffff; + for (int i = 2; i < LIMB_CNT; i += 2) { + input[i] -= mask & 0x3ffffff; + input[i + 1] -= mask & 0x1ffffff; + } + + for (int i = 0; i < LIMB_CNT; i++) { + input[i] <<= EXPAND_SHIFT[i]; + } + byte[] output = new byte[FIELD_LEN]; + for (int i = 0; i < LIMB_CNT; i++) { + output[EXPAND_START[i]] |= input[i] & 0xff; + output[EXPAND_START[i] + 1] |= (input[i] >> 8) & 0xff; + output[EXPAND_START[i] + 2] |= (input[i] >> 16) & 0xff; + output[EXPAND_START[i] + 3] |= (input[i] >> 24) & 0xff; + } + return output; + } + + /** + * Computes inverse of z = z(2^255 - 21) + * <p> + * Shamelessly copied from agl's code which was shamelessly copied from djb's code. Only the + * comment format and the variable namings are different from those. + */ + static void inverse(long[] out, long[] z) { + long[] z2 = new long[Field25519.LIMB_CNT]; + long[] z9 = new long[Field25519.LIMB_CNT]; + long[] z11 = new long[Field25519.LIMB_CNT]; + long[] z2To5Minus1 = new long[Field25519.LIMB_CNT]; + long[] z2To10Minus1 = new long[Field25519.LIMB_CNT]; + long[] z2To20Minus1 = new long[Field25519.LIMB_CNT]; + long[] z2To50Minus1 = new long[Field25519.LIMB_CNT]; + long[] z2To100Minus1 = new long[Field25519.LIMB_CNT]; + long[] t0 = new long[Field25519.LIMB_CNT]; + long[] t1 = new long[Field25519.LIMB_CNT]; + + square(z2, z); // 2 + square(t1, z2); // 4 + square(t0, t1); // 8 + mult(z9, t0, z); // 9 + mult(z11, z9, z2); // 11 + square(t0, z11); // 22 + mult(z2To5Minus1, t0, z9); // 2^5 - 2^0 = 31 + + square(t0, z2To5Minus1); // 2^6 - 2^1 + square(t1, t0); // 2^7 - 2^2 + square(t0, t1); // 2^8 - 2^3 + square(t1, t0); // 2^9 - 2^4 + square(t0, t1); // 2^10 - 2^5 + mult(z2To10Minus1, t0, z2To5Minus1); // 2^10 - 2^0 + + square(t0, z2To10Minus1); // 2^11 - 2^1 + square(t1, t0); // 2^12 - 2^2 + for (int i = 2; i < 10; i += 2) { // 2^20 - 2^10 + square(t0, t1); + square(t1, t0); + } + mult(z2To20Minus1, t1, z2To10Minus1); // 2^20 - 2^0 + + square(t0, z2To20Minus1); // 2^21 - 2^1 + square(t1, t0); // 2^22 - 2^2 + for (int i = 2; i < 20; i += 2) { // 2^40 - 2^20 + square(t0, t1); + square(t1, t0); + } + mult(t0, t1, z2To20Minus1); // 2^40 - 2^0 + + square(t1, t0); // 2^41 - 2^1 + square(t0, t1); // 2^42 - 2^2 + for (int i = 2; i < 10; i += 2) { // 2^50 - 2^10 + square(t1, t0); + square(t0, t1); + } + mult(z2To50Minus1, t0, z2To10Minus1); // 2^50 - 2^0 + + square(t0, z2To50Minus1); // 2^51 - 2^1 + square(t1, t0); // 2^52 - 2^2 + for (int i = 2; i < 50; i += 2) { // 2^100 - 2^50 + square(t0, t1); + square(t1, t0); + } + mult(z2To100Minus1, t1, z2To50Minus1); // 2^100 - 2^0 + + square(t1, z2To100Minus1); // 2^101 - 2^1 + square(t0, t1); // 2^102 - 2^2 + for (int i = 2; i < 100; i += 2) { // 2^200 - 2^100 + square(t1, t0); + square(t0, t1); + } + mult(t1, t0, z2To100Minus1); // 2^200 - 2^0 + + square(t0, t1); // 2^201 - 2^1 + square(t1, t0); // 2^202 - 2^2 + for (int i = 2; i < 50; i += 2) { // 2^250 - 2^50 + square(t0, t1); + square(t1, t0); + } + mult(t0, t1, z2To50Minus1); // 2^250 - 2^0 + + square(t1, t0); // 2^251 - 2^1 + square(t0, t1); // 2^252 - 2^2 + square(t1, t0); // 2^253 - 2^3 + square(t0, t1); // 2^254 - 2^4 + square(t1, t0); // 2^255 - 2^5 + mult(out, t1, z11); // 2^255 - 21 + } + + + /** + * Returns 0xffffffff iff a == b and zero otherwise. + */ + private static int eq(int a, int b) { + a = ~(a ^ b); + a &= a << 16; + a &= a << 8; + a &= a << 4; + a &= a << 2; + a &= a << 1; + return a >> 31; + } + + /** + * returns 0xffffffff if a >= b and zero otherwise, where a and b are both non-negative. + */ + private static int gte(int a, int b) { + a -= b; + // a >= 0 iff a >= b. + return ~(a >> 31); + } + } + + // (x = 0, y = 1) point + private static final CachedXYT CACHED_NEUTRAL = new CachedXYT( + new long[]{1, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + new long[]{1, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + new long[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + private static final PartialXYZT NEUTRAL = new PartialXYZT( + new XYZ(new long[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + new long[]{1, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + new long[]{1, 0, 0, 0, 0, 0, 0, 0, 0, 0}), + new long[]{1, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + + /** + * Projective point representation (X:Y:Z) satisfying x = X/Z, y = Y/Z + * <p> + * Note that this is referred as ge_p2 in ref10 impl. + * Also note that x = X, y = Y and z = Z below following Java coding style. + * <p> + * See + * Koyama K., Tsuruoka Y. (1993) Speeding up Elliptic Cryptosystems by Using a Signed Binary + * Window Method. + * <p> + * https://hyperelliptic.org/EFD/g1p/auto-twisted-projective.html + */ + private static class XYZ { + + final long[] x; + final long[] y; + final long[] z; + + XYZ() { + this(new long[Field25519.LIMB_CNT], new long[Field25519.LIMB_CNT], new long[Field25519.LIMB_CNT]); + } + + XYZ(long[] x, long[] y, long[] z) { + this.x = x; + this.y = y; + this.z = z; + } + + XYZ(XYZ xyz) { + x = Arrays.copyOf(xyz.x, Field25519.LIMB_CNT); + y = Arrays.copyOf(xyz.y, Field25519.LIMB_CNT); + z = Arrays.copyOf(xyz.z, Field25519.LIMB_CNT); + } + + XYZ(PartialXYZT partialXYZT) { + this(); + fromPartialXYZT(this, partialXYZT); + } + + /** + * ge_p1p1_to_p2.c + */ + static XYZ fromPartialXYZT(XYZ out, PartialXYZT in) { + Field25519.mult(out.x, in.xyz.x, in.t); + Field25519.mult(out.y, in.xyz.y, in.xyz.z); + Field25519.mult(out.z, in.xyz.z, in.t); + return out; + } + + /** + * Encodes this point to bytes. + */ + byte[] toBytes() { + long[] recip = new long[Field25519.LIMB_CNT]; + long[] x = new long[Field25519.LIMB_CNT]; + long[] y = new long[Field25519.LIMB_CNT]; + Field25519.inverse(recip, z); + Field25519.mult(x, this.x, recip); + Field25519.mult(y, this.y, recip); + byte[] s = Field25519.contract(y); + s[31] = (byte) (s[31] ^ (getLsb(x) << 7)); + return s; + } + + + /** + * Best effort fix-timing array comparison. + * + * @return true if two arrays are equal. + */ + private static boolean bytesEqual(final byte[] x, final byte[] y) { + if (x == null || y == null) { + return false; + } + if (x.length != y.length) { + return false; + } + int res = 0; + for (int i = 0; i < x.length; i++) { + res |= x[i] ^ y[i]; + } + return res == 0; + } + + /** + * Checks that the point is on curve + */ + boolean isOnCurve() { + long[] x2 = new long[Field25519.LIMB_CNT]; + Field25519.square(x2, x); + long[] y2 = new long[Field25519.LIMB_CNT]; + Field25519.square(y2, y); + long[] z2 = new long[Field25519.LIMB_CNT]; + Field25519.square(z2, z); + long[] z4 = new long[Field25519.LIMB_CNT]; + Field25519.square(z4, z2); + long[] lhs = new long[Field25519.LIMB_CNT]; + // lhs = y^2 - x^2 + Field25519.sub(lhs, y2, x2); + // lhs = z^2 * (y2 - x2) + Field25519.mult(lhs, lhs, z2); + long[] rhs = new long[Field25519.LIMB_CNT]; + // rhs = x^2 * y^2 + Field25519.mult(rhs, x2, y2); + // rhs = D * x^2 * y^2 + Field25519.mult(rhs, rhs, D); + // rhs = z^4 + D * x^2 * y^2 + Field25519.sum(rhs, z4); + // Field25519.mult reduces its output, but Field25519.sum does not, so we have to manually + // reduce it here. + Field25519.reduce(rhs, rhs); + // z^2 (y^2 - x^2) == z^4 + D * x^2 * y^2 + return bytesEqual(Field25519.contract(lhs), Field25519.contract(rhs)); + } + } + + /** + * Represents extended projective point representation (X:Y:Z:T) satisfying x = X/Z, y = Y/Z, + * XY = ZT + * <p> + * Note that this is referred as ge_p3 in ref10 impl. + * Also note that t = T below following Java coding style. + * <p> + * See + * Hisil H., Wong K.KH., Carter G., Dawson E. (2008) Twisted Edwards Curves Revisited. + * <p> + * https://hyperelliptic.org/EFD/g1p/auto-twisted-extended.html + */ + private static class XYZT { + + final XYZ xyz; + final long[] t; + + XYZT() { + this(new XYZ(), new long[Field25519.LIMB_CNT]); + } + + XYZT(XYZ xyz, long[] t) { + this.xyz = xyz; + this.t = t; + } + + XYZT(PartialXYZT partialXYZT) { + this(); + fromPartialXYZT(this, partialXYZT); + } + + /** + * ge_p1p1_to_p2.c + */ + private static XYZT fromPartialXYZT(XYZT out, PartialXYZT in) { + Field25519.mult(out.xyz.x, in.xyz.x, in.t); + Field25519.mult(out.xyz.y, in.xyz.y, in.xyz.z); + Field25519.mult(out.xyz.z, in.xyz.z, in.t); + Field25519.mult(out.t, in.xyz.x, in.xyz.y); + return out; + } + + /** + * Decodes {@code s} into an extented projective point. + * See Section 5.1.3 Decoding in https://tools.ietf.org/html/rfc8032#section-5.1.3 + */ + private static XYZT fromBytesNegateVarTime(byte[] s) throws GeneralSecurityException { + long[] x = new long[Field25519.LIMB_CNT]; + long[] y = Field25519.expand(s); + long[] z = new long[Field25519.LIMB_CNT]; + z[0] = 1; + long[] t = new long[Field25519.LIMB_CNT]; + long[] u = new long[Field25519.LIMB_CNT]; + long[] v = new long[Field25519.LIMB_CNT]; + long[] vxx = new long[Field25519.LIMB_CNT]; + long[] check = new long[Field25519.LIMB_CNT]; + Field25519.square(u, y); + Field25519.mult(v, u, D); + Field25519.sub(u, u, z); // u = y^2 - 1 + Field25519.sum(v, v, z); // v = dy^2 + 1 + + long[] v3 = new long[Field25519.LIMB_CNT]; + Field25519.square(v3, v); + Field25519.mult(v3, v3, v); // v3 = v^3 + Field25519.square(x, v3); + Field25519.mult(x, x, v); + Field25519.mult(x, x, u); // x = uv^7 + + pow2252m3(x, x); // x = (uv^7)^((q-5)/8) + Field25519.mult(x, x, v3); + Field25519.mult(x, x, u); // x = uv^3(uv^7)^((q-5)/8) + + Field25519.square(vxx, x); + Field25519.mult(vxx, vxx, v); + Field25519.sub(check, vxx, u); // vx^2-u + if (isNonZeroVarTime(check)) { + Field25519.sum(check, vxx, u); // vx^2+u + if (isNonZeroVarTime(check)) { + throw new GeneralSecurityException("Cannot convert given bytes to extended projective " + + "coordinates. No square root exists for modulo 2^255-19"); + } + Field25519.mult(x, x, SQRTM1); + } + + if (!isNonZeroVarTime(x) && (s[31] & 0xff) >> 7 != 0) { + throw new GeneralSecurityException("Cannot convert given bytes to extended projective " + + "coordinates. Computed x is zero and encoded x's least significant bit is not zero"); + } + if (getLsb(x) == ((s[31] & 0xff) >> 7)) { + neg(x, x); + } + + Field25519.mult(t, x, y); + return new XYZT(new XYZ(x, y, z), t); + } + } + + /** + * Partial projective point representation ((X:Z),(Y:T)) satisfying x=X/Z, y=Y/T + * <p> + * Note that this is referred as complete form in the original ref10 impl (ge_p1p1). + * Also note that t = T below following Java coding style. + * <p> + * Although this has the same types as XYZT, it is redefined to have its own type so that it is + * readable and 1:1 corresponds to ref10 impl. + * <p> + * Can be converted to XYZT as follows: + * X1 = X * T = x * Z * T = x * Z1 + * Y1 = Y * Z = y * T * Z = y * Z1 + * Z1 = Z * T = Z * T + * T1 = X * Y = x * Z * y * T = x * y * Z1 = X1Y1 / Z1 + */ + private static class PartialXYZT { + + final XYZ xyz; + final long[] t; + + PartialXYZT() { + this(new XYZ(), new long[Field25519.LIMB_CNT]); + } + + PartialXYZT(XYZ xyz, long[] t) { + this.xyz = xyz; + this.t = t; + } + + PartialXYZT(PartialXYZT other) { + xyz = new XYZ(other.xyz); + t = Arrays.copyOf(other.t, Field25519.LIMB_CNT); + } + } + + /** + * Corresponds to the caching mentioned in the last paragraph of Section 3.1 of + * Hisil H., Wong K.KH., Carter G., Dawson E. (2008) Twisted Edwards Curves Revisited. + * with Z = 1. + */ + private static class CachedXYT { + + final long[] yPlusX; + final long[] yMinusX; + final long[] t2d; + + /** + * Creates a cached XYZT with Z = 1 + * + * @param yPlusX y + x + * @param yMinusX y - x + * @param t2d 2d * xy + */ + CachedXYT(long[] yPlusX, long[] yMinusX, long[] t2d) { + this.yPlusX = yPlusX; + this.yMinusX = yMinusX; + this.t2d = t2d; + } + + CachedXYT(CachedXYT other) { + yPlusX = Arrays.copyOf(other.yPlusX, Field25519.LIMB_CNT); + yMinusX = Arrays.copyOf(other.yMinusX, Field25519.LIMB_CNT); + t2d = Arrays.copyOf(other.t2d, Field25519.LIMB_CNT); + } + + // z is one implicitly, so this just copies {@code in} to {@code output}. + void multByZ(long[] output, long[] in) { + System.arraycopy(in, 0, output, 0, Field25519.LIMB_CNT); + } + + /** + * If icopy is 1, copies {@code other} into this point. Time invariant wrt to icopy value. + */ + void copyConditional(CachedXYT other, int icopy) { + copyConditional(yPlusX, other.yPlusX, icopy); + copyConditional(yMinusX, other.yMinusX, icopy); + copyConditional(t2d, other.t2d, icopy); + } + + /** + * Conditionally copies a reduced-form limb arrays {@code b} into {@code a} if {@code icopy} is 1, + * but leave {@code a} unchanged if 'iswap' is 0. Runs in data-invariant time to avoid + * side-channel attacks. + * + * <p>NOTE that this function requires that {@code icopy} be 1 or 0; other values give wrong + * results. Also, the two limb arrays must be in reduced-coefficient, reduced-degree form: the + * values in a[10..19] or b[10..19] aren't swapped, and all all values in a[0..9],b[0..9] must + * have magnitude less than Integer.MAX_VALUE. + */ + static void copyConditional(long[] a, long[] b, int icopy) { + int copy = -icopy; + for (int i = 0; i < Field25519.LIMB_CNT; i++) { + int x = copy & (((int) a[i]) ^ ((int) b[i])); + a[i] = ((int) a[i]) ^ x; + } + } + } + + private static class CachedXYZT extends CachedXYT { + + private final long[] z; + + CachedXYZT() { + this(new long[Field25519.LIMB_CNT], new long[Field25519.LIMB_CNT], new long[Field25519.LIMB_CNT], new long[Field25519.LIMB_CNT]); + } + + /** + * ge_p3_to_cached.c + */ + CachedXYZT(XYZT xyzt) { + this(); + Field25519.sum(yPlusX, xyzt.xyz.y, xyzt.xyz.x); + Field25519.sub(yMinusX, xyzt.xyz.y, xyzt.xyz.x); + System.arraycopy(xyzt.xyz.z, 0, z, 0, Field25519.LIMB_CNT); + Field25519.mult(t2d, xyzt.t, D2); + } + + /** + * Creates a cached XYZT + * + * @param yPlusX Y + X + * @param yMinusX Y - X + * @param z Z + * @param t2d 2d * (XY/Z) + */ + CachedXYZT(long[] yPlusX, long[] yMinusX, long[] z, long[] t2d) { + super(yPlusX, yMinusX, t2d); + this.z = z; + } + + @Override + public void multByZ(long[] output, long[] in) { + Field25519.mult(output, in, z); + } + } + + /** + * Addition defined in Section 3.1 of + * Hisil H., Wong K.KH., Carter G., Dawson E. (2008) Twisted Edwards Curves Revisited. + * <p> + * Please note that this is a partial of the operation listed there leaving out the final + * conversion from PartialXYZT to XYZT. + * + * @param extended extended projective point input + * @param cached cached projective point input + */ + private static void add(PartialXYZT partialXYZT, XYZT extended, CachedXYT cached) { + long[] t = new long[Field25519.LIMB_CNT]; + + // Y1 + X1 + Field25519.sum(partialXYZT.xyz.x, extended.xyz.y, extended.xyz.x); + + // Y1 - X1 + Field25519.sub(partialXYZT.xyz.y, extended.xyz.y, extended.xyz.x); + + // A = (Y1 - X1) * (Y2 - X2) + Field25519.mult(partialXYZT.xyz.y, partialXYZT.xyz.y, cached.yMinusX); + + // B = (Y1 + X1) * (Y2 + X2) + Field25519.mult(partialXYZT.xyz.z, partialXYZT.xyz.x, cached.yPlusX); + + // C = T1 * 2d * T2 = 2d * T1 * T2 (2d is written as k in the paper) + Field25519.mult(partialXYZT.t, extended.t, cached.t2d); + + // Z1 * Z2 + cached.multByZ(partialXYZT.xyz.x, extended.xyz.z); + + // D = 2 * Z1 * Z2 + Field25519.sum(t, partialXYZT.xyz.x, partialXYZT.xyz.x); + + // X3 = B - A + Field25519.sub(partialXYZT.xyz.x, partialXYZT.xyz.z, partialXYZT.xyz.y); + + // Y3 = B + A + Field25519.sum(partialXYZT.xyz.y, partialXYZT.xyz.z, partialXYZT.xyz.y); + + // Z3 = D + C + Field25519.sum(partialXYZT.xyz.z, t, partialXYZT.t); + + // T3 = D - C + Field25519.sub(partialXYZT.t, t, partialXYZT.t); + } + + /** + * Based on the addition defined in Section 3.1 of + * Hisil H., Wong K.KH., Carter G., Dawson E. (2008) Twisted Edwards Curves Revisited. + * <p> + * Please note that this is a partial of the operation listed there leaving out the final + * conversion from PartialXYZT to XYZT. + * + * @param extended extended projective point input + * @param cached cached projective point input + */ + private static void sub(PartialXYZT partialXYZT, XYZT extended, CachedXYT cached) { + long[] t = new long[Field25519.LIMB_CNT]; + + // Y1 + X1 + Field25519.sum(partialXYZT.xyz.x, extended.xyz.y, extended.xyz.x); + + // Y1 - X1 + Field25519.sub(partialXYZT.xyz.y, extended.xyz.y, extended.xyz.x); + + // A = (Y1 - X1) * (Y2 + X2) + Field25519.mult(partialXYZT.xyz.y, partialXYZT.xyz.y, cached.yPlusX); + + // B = (Y1 + X1) * (Y2 - X2) + Field25519.mult(partialXYZT.xyz.z, partialXYZT.xyz.x, cached.yMinusX); + + // C = T1 * 2d * T2 = 2d * T1 * T2 (2d is written as k in the paper) + Field25519.mult(partialXYZT.t, extended.t, cached.t2d); + + // Z1 * Z2 + cached.multByZ(partialXYZT.xyz.x, extended.xyz.z); + + // D = 2 * Z1 * Z2 + Field25519.sum(t, partialXYZT.xyz.x, partialXYZT.xyz.x); + + // X3 = B - A + Field25519.sub(partialXYZT.xyz.x, partialXYZT.xyz.z, partialXYZT.xyz.y); + + // Y3 = B + A + Field25519.sum(partialXYZT.xyz.y, partialXYZT.xyz.z, partialXYZT.xyz.y); + + // Z3 = D - C + Field25519.sub(partialXYZT.xyz.z, t, partialXYZT.t); + + // T3 = D + C + Field25519.sum(partialXYZT.t, t, partialXYZT.t); + } + + /** + * Doubles {@code p} and puts the result into this PartialXYZT. + * <p> + * This is based on the addition defined in formula 7 in Section 3.3 of + * Hisil H., Wong K.KH., Carter G., Dawson E. (2008) Twisted Edwards Curves Revisited. + * <p> + * Please note that this is a partial of the operation listed there leaving out the final + * conversion from PartialXYZT to XYZT and also this fixes a typo in calculation of Y3 and T3 in + * the paper, H should be replaced with A+B. + */ + private static void doubleXYZ(PartialXYZT partialXYZT, XYZ p) { + long[] t0 = new long[Field25519.LIMB_CNT]; + + // XX = X1^2 + Field25519.square(partialXYZT.xyz.x, p.x); + + // YY = Y1^2 + Field25519.square(partialXYZT.xyz.z, p.y); + + // B' = Z1^2 + Field25519.square(partialXYZT.t, p.z); + + // B = 2 * B' + Field25519.sum(partialXYZT.t, partialXYZT.t, partialXYZT.t); + + // A = X1 + Y1 + Field25519.sum(partialXYZT.xyz.y, p.x, p.y); + + // AA = A^2 + Field25519.square(t0, partialXYZT.xyz.y); + + // Y3 = YY + XX + Field25519.sum(partialXYZT.xyz.y, partialXYZT.xyz.z, partialXYZT.xyz.x); + + // Z3 = YY - XX + Field25519.sub(partialXYZT.xyz.z, partialXYZT.xyz.z, partialXYZT.xyz.x); + + // X3 = AA - Y3 + Field25519.sub(partialXYZT.xyz.x, t0, partialXYZT.xyz.y); + + // T3 = B - Z3 + Field25519.sub(partialXYZT.t, partialXYZT.t, partialXYZT.xyz.z); + } + + /** + * Doubles {@code p} and puts the result into this PartialXYZT. + */ + private static void doubleXYZT(PartialXYZT partialXYZT, XYZT p) { + doubleXYZ(partialXYZT, p.xyz); + } + + /** + * Compares two byte values in constant time. + */ + private static int eq(int a, int b) { + int r = ~(a ^ b) & 0xff; + r &= r << 4; + r &= r << 2; + r &= r << 1; + return (r >> 7) & 1; + } + + /** + * This is a constant time operation where point b*B*256^pos is stored in {@code t}. + * When b is 0, t remains the same (i.e., neutral point). + * <p> + * Although B_TABLE[32][8] (B_TABLE[i][j] = (j+1)*B*256^i) has j values in [0, 7], the select + * method negates the corresponding point if b is negative (which is straight forward in elliptic + * curves by just negating y coordinate). Therefore we can get multiples of B with the half of + * memory requirements. + * + * @param t neutral element (i.e., point 0), also serves as output. + * @param pos in B[pos][j] = (j+1)*B*256^pos + * @param b value in [-8, 8] range. + */ + private static void select(CachedXYT t, int pos, byte b) { + int bnegative = (b & 0xff) >> 7; + int babs = b - (((-bnegative) & b) << 1); + + t.copyConditional(B_TABLE[pos][0], eq(babs, 1)); + t.copyConditional(B_TABLE[pos][1], eq(babs, 2)); + t.copyConditional(B_TABLE[pos][2], eq(babs, 3)); + t.copyConditional(B_TABLE[pos][3], eq(babs, 4)); + t.copyConditional(B_TABLE[pos][4], eq(babs, 5)); + t.copyConditional(B_TABLE[pos][5], eq(babs, 6)); + t.copyConditional(B_TABLE[pos][6], eq(babs, 7)); + t.copyConditional(B_TABLE[pos][7], eq(babs, 8)); + + long[] yPlusX = Arrays.copyOf(t.yMinusX, Field25519.LIMB_CNT); + long[] yMinusX = Arrays.copyOf(t.yPlusX, Field25519.LIMB_CNT); + long[] t2d = Arrays.copyOf(t.t2d, Field25519.LIMB_CNT); + neg(t2d, t2d); + CachedXYT minust = new CachedXYT(yPlusX, yMinusX, t2d); + t.copyConditional(minust, bnegative); + } + + /** + * Computes {@code a}*B + * where a = a[0]+256*a[1]+...+256^31 a[31] and + * B is the Ed25519 base point (x,4/5) with x positive. + * <p> + * Preconditions: + * a[31] <= 127 + * + * @throws IllegalStateException iff there is arithmetic error. + */ + @SuppressWarnings("NarrowingCompoundAssignment") + private static XYZ scalarMultWithBase(byte[] a) { + byte[] e = new byte[2 * Field25519.FIELD_LEN]; + for (int i = 0; i < Field25519.FIELD_LEN; i++) { + e[2 * i + 0] = (byte) (((a[i] & 0xff) >> 0) & 0xf); + e[2 * i + 1] = (byte) (((a[i] & 0xff) >> 4) & 0xf); + } + // each e[i] is between 0 and 15 + // e[63] is between 0 and 7 + + // Rewrite e in a way that each e[i] is in [-8, 8]. + // This can be done since a[63] is in [0, 7], the carry-over onto the most significant byte + // a[63] can be at most 1. + int carry = 0; + for (int i = 0; i < e.length - 1; i++) { + e[i] += carry; + carry = e[i] + 8; + carry >>= 4; + e[i] -= carry << 4; + } + e[e.length - 1] += carry; + + PartialXYZT ret = new PartialXYZT(NEUTRAL); + XYZT xyzt = new XYZT(); + // Although B_TABLE's i can be at most 31 (stores only 32 4bit multiples of B) and we have 64 + // 4bit values in e array, the below for loop adds cached values by iterating e by two in odd + // indices. After the result, we can double the result point 4 times to shift the multiplication + // scalar by 4 bits. + for (int i = 1; i < e.length; i += 2) { + CachedXYT t = new CachedXYT(CACHED_NEUTRAL); + select(t, i / 2, e[i]); + add(ret, XYZT.fromPartialXYZT(xyzt, ret), t); + } + + // Doubles the result 4 times to shift the multiplication scalar 4 bits to get the actual result + // for the odd indices in e. + XYZ xyz = new XYZ(); + doubleXYZ(ret, XYZ.fromPartialXYZT(xyz, ret)); + doubleXYZ(ret, XYZ.fromPartialXYZT(xyz, ret)); + doubleXYZ(ret, XYZ.fromPartialXYZT(xyz, ret)); + doubleXYZ(ret, XYZ.fromPartialXYZT(xyz, ret)); + + // Add multiples of B for even indices of e. + for (int i = 0; i < e.length; i += 2) { + CachedXYT t = new CachedXYT(CACHED_NEUTRAL); + select(t, i / 2, e[i]); + add(ret, XYZT.fromPartialXYZT(xyzt, ret), t); + } + + // This check is to protect against flaws, i.e. if there is a computation error through a + // faulty CPU or if the implementation contains a bug. + XYZ result = new XYZ(ret); + if (!result.isOnCurve()) { + throw new IllegalStateException("arithmetic error in scalar multiplication"); + } + return result; + } + + @SuppressWarnings("NarrowingCompoundAssignment") + private static byte[] slide(byte[] a) { + byte[] r = new byte[256]; + // Writes each bit in a[0..31] into r[0..255]: + // a = a[0]+256*a[1]+...+256^31*a[31] is equal to + // r = r[0]+2*r[1]+...+2^255*r[255] + for (int i = 0; i < 256; i++) { + r[i] = (byte) (1 & ((a[i >> 3] & 0xff) >> (i & 7))); + } + + // Transforms r[i] as odd values in [-15, 15] + for (int i = 0; i < 256; i++) { + if (r[i] != 0) { + for (int b = 1; b <= 6 && i + b < 256; b++) { + if (r[i + b] != 0) { + if (r[i] + (r[i + b] << b) <= 15) { + r[i] += r[i + b] << b; + r[i + b] = 0; + } else if (r[i] - (r[i + b] << b) >= -15) { + r[i] -= r[i + b] << b; + for (int k = i + b; k < 256; k++) { + if (r[k] == 0) { + r[k] = 1; + break; + } + r[k] = 0; + } + } else { + break; + } + } + } + } + } + return r; + } + + /** + * Computes {@code a}*{@code pointA}+{@code b}*B + * where a = a[0]+256*a[1]+...+256^31*a[31]. + * and b = b[0]+256*b[1]+...+256^31*b[31]. + * B is the Ed25519 base point (x,4/5) with x positive. + * <p> + * Note that execution time varies based on the input since this will only be used in verification + * of signatures. + */ + private static XYZ doubleScalarMultVarTime(byte[] a, XYZT pointA, byte[] b) { + // pointA, 3*pointA, 5*pointA, 7*pointA, 9*pointA, 11*pointA, 13*pointA, 15*pointA + CachedXYZT[] pointAArray = new CachedXYZT[8]; + pointAArray[0] = new CachedXYZT(pointA); + PartialXYZT t = new PartialXYZT(); + doubleXYZT(t, pointA); + XYZT doubleA = new XYZT(t); + for (int i = 1; i < pointAArray.length; i++) { + add(t, doubleA, pointAArray[i - 1]); + pointAArray[i] = new CachedXYZT(new XYZT(t)); + } + + byte[] aSlide = slide(a); + byte[] bSlide = slide(b); + t = new PartialXYZT(NEUTRAL); + XYZT u = new XYZT(); + int i = 255; + for (; i >= 0; i--) { + if (aSlide[i] != 0 || bSlide[i] != 0) { + break; + } + } + for (; i >= 0; i--) { + doubleXYZ(t, new XYZ(t)); + if (aSlide[i] > 0) { + add(t, XYZT.fromPartialXYZT(u, t), pointAArray[aSlide[i] / 2]); + } else if (aSlide[i] < 0) { + sub(t, XYZT.fromPartialXYZT(u, t), pointAArray[-aSlide[i] / 2]); + } + if (bSlide[i] > 0) { + add(t, XYZT.fromPartialXYZT(u, t), B2[bSlide[i] / 2]); + } else if (bSlide[i] < 0) { + sub(t, XYZT.fromPartialXYZT(u, t), B2[-bSlide[i] / 2]); + } + } + + return new XYZ(t); + } + + /** + * Returns true if {@code in} is nonzero. + * <p> + * Note that execution time might depend on the input {@code in}. + */ + private static boolean isNonZeroVarTime(long[] in) { + long[] inCopy = new long[in.length + 1]; + System.arraycopy(in, 0, inCopy, 0, in.length); + Field25519.reduceCoefficients(inCopy); + byte[] bytes = Field25519.contract(inCopy); + for (byte b : bytes) { + if (b != 0) { + return true; + } + } + return false; + } + + /** + * Returns the least significant bit of {@code in}. + */ + private static int getLsb(long[] in) { + return Field25519.contract(in)[0] & 1; + } + + /** + * Negates all values in {@code in} and store it in {@code out}. + */ + private static void neg(long[] out, long[] in) { + for (int i = 0; i < in.length; i++) { + out[i] = -in[i]; + } + } + + /** + * Computes {@code in}^(2^252-3) mod 2^255-19 and puts the result in {@code out}. + */ + private static void pow2252m3(long[] out, long[] in) { + long[] t0 = new long[Field25519.LIMB_CNT]; + long[] t1 = new long[Field25519.LIMB_CNT]; + long[] t2 = new long[Field25519.LIMB_CNT]; + + // z2 = z1^2^1 + Field25519.square(t0, in); + + // z8 = z2^2^2 + Field25519.square(t1, t0); + for (int i = 1; i < 2; i++) { + Field25519.square(t1, t1); + } + + // z9 = z1*z8 + Field25519.mult(t1, in, t1); + + // z11 = z2*z9 + Field25519.mult(t0, t0, t1); + + // z22 = z11^2^1 + Field25519.square(t0, t0); + + // z_5_0 = z9*z22 + Field25519.mult(t0, t1, t0); + + // z_10_5 = z_5_0^2^5 + Field25519.square(t1, t0); + for (int i = 1; i < 5; i++) { + Field25519.square(t1, t1); + } + + // z_10_0 = z_10_5*z_5_0 + Field25519.mult(t0, t1, t0); + + // z_20_10 = z_10_0^2^10 + Field25519.square(t1, t0); + for (int i = 1; i < 10; i++) { + Field25519.square(t1, t1); + } + + // z_20_0 = z_20_10*z_10_0 + Field25519.mult(t1, t1, t0); + + // z_40_20 = z_20_0^2^20 + Field25519.square(t2, t1); + for (int i = 1; i < 20; i++) { + Field25519.square(t2, t2); + } + + // z_40_0 = z_40_20*z_20_0 + Field25519.mult(t1, t2, t1); + + // z_50_10 = z_40_0^2^10 + Field25519.square(t1, t1); + for (int i = 1; i < 10; i++) { + Field25519.square(t1, t1); + } + + // z_50_0 = z_50_10*z_10_0 + Field25519.mult(t0, t1, t0); + + // z_100_50 = z_50_0^2^50 + Field25519.square(t1, t0); + for (int i = 1; i < 50; i++) { + Field25519.square(t1, t1); + } + + // z_100_0 = z_100_50*z_50_0 + Field25519.mult(t1, t1, t0); + + // z_200_100 = z_100_0^2^100 + Field25519.square(t2, t1); + for (int i = 1; i < 100; i++) { + Field25519.square(t2, t2); + } + + // z_200_0 = z_200_100*z_100_0 + Field25519.mult(t1, t2, t1); + + // z_250_50 = z_200_0^2^50 + Field25519.square(t1, t1); + for (int i = 1; i < 50; i++) { + Field25519.square(t1, t1); + } + + // z_250_0 = z_250_50*z_50_0 + Field25519.mult(t0, t1, t0); + + // z_252_2 = z_250_0^2^2 + Field25519.square(t0, t0); + for (int i = 1; i < 2; i++) { + Field25519.square(t0, t0); + } + + // z_252_3 = z_252_2*z1 + Field25519.mult(out, t0, in); + } + + /** + * Returns 3 bytes of {@code in} starting from {@code idx} in Little-Endian format. + */ + private static long load3(byte[] in, int idx) { + long result; + result = (long) in[idx] & 0xff; + result |= (long) (in[idx + 1] & 0xff) << 8; + result |= (long) (in[idx + 2] & 0xff) << 16; + return result; + } + + /** + * Returns 4 bytes of {@code in} starting from {@code idx} in Little-Endian format. + */ + private static long load4(byte[] in, int idx) { + long result = load3(in, idx); + result |= (long) (in[idx + 3] & 0xff) << 24; + return result; + } + + /** + * Input: + * s[0]+256*s[1]+...+256^63*s[63] = s + * <p> + * Output: + * s[0]+256*s[1]+...+256^31*s[31] = s mod l + * where l = 2^252 + 27742317777372353535851937790883648493. + * Overwrites s in place. + */ + private static void reduce(byte[] s) { + // Observation: + // 2^252 mod l is equivalent to -27742317777372353535851937790883648493 mod l + // Let m = -27742317777372353535851937790883648493 + // Thus a*2^252+b mod l is equivalent to a*m+b mod l + // + // First s is divided into chunks of 21 bits as follows: + // s0+2^21*s1+2^42*s3+...+2^462*s23 = s[0]+256*s[1]+...+256^63*s[63] + long s0 = 2097151 & load3(s, 0); + long s1 = 2097151 & (load4(s, 2) >> 5); + long s2 = 2097151 & (load3(s, 5) >> 2); + long s3 = 2097151 & (load4(s, 7) >> 7); + long s4 = 2097151 & (load4(s, 10) >> 4); + long s5 = 2097151 & (load3(s, 13) >> 1); + long s6 = 2097151 & (load4(s, 15) >> 6); + long s7 = 2097151 & (load3(s, 18) >> 3); + long s8 = 2097151 & load3(s, 21); + long s9 = 2097151 & (load4(s, 23) >> 5); + long s10 = 2097151 & (load3(s, 26) >> 2); + long s11 = 2097151 & (load4(s, 28) >> 7); + long s12 = 2097151 & (load4(s, 31) >> 4); + long s13 = 2097151 & (load3(s, 34) >> 1); + long s14 = 2097151 & (load4(s, 36) >> 6); + long s15 = 2097151 & (load3(s, 39) >> 3); + long s16 = 2097151 & load3(s, 42); + long s17 = 2097151 & (load4(s, 44) >> 5); + long s18 = 2097151 & (load3(s, 47) >> 2); + long s19 = 2097151 & (load4(s, 49) >> 7); + long s20 = 2097151 & (load4(s, 52) >> 4); + long s21 = 2097151 & (load3(s, 55) >> 1); + long s22 = 2097151 & (load4(s, 57) >> 6); + long s23 = (load4(s, 60) >> 3); + long carry0; + long carry1; + long carry2; + long carry3; + long carry4; + long carry5; + long carry6; + long carry7; + long carry8; + long carry9; + long carry10; + long carry11; + long carry12; + long carry13; + long carry14; + long carry15; + long carry16; + + // s23*2^462 = s23*2^210*2^252 is equivalent to s23*2^210*m in mod l + // As m is a 125 bit number, the result needs to scattered to 6 limbs (125/21 ceil is 6) + // starting from s11 (s11*2^210) + // m = [666643, 470296, 654183, -997805, 136657, -683901] in 21-bit limbs + s11 += s23 * 666643; + s12 += s23 * 470296; + s13 += s23 * 654183; + s14 -= s23 * 997805; + s15 += s23 * 136657; + s16 -= s23 * 683901; + // s23 = 0; + + s10 += s22 * 666643; + s11 += s22 * 470296; + s12 += s22 * 654183; + s13 -= s22 * 997805; + s14 += s22 * 136657; + s15 -= s22 * 683901; + // s22 = 0; + + s9 += s21 * 666643; + s10 += s21 * 470296; + s11 += s21 * 654183; + s12 -= s21 * 997805; + s13 += s21 * 136657; + s14 -= s21 * 683901; + // s21 = 0; + + s8 += s20 * 666643; + s9 += s20 * 470296; + s10 += s20 * 654183; + s11 -= s20 * 997805; + s12 += s20 * 136657; + s13 -= s20 * 683901; + // s20 = 0; + + s7 += s19 * 666643; + s8 += s19 * 470296; + s9 += s19 * 654183; + s10 -= s19 * 997805; + s11 += s19 * 136657; + s12 -= s19 * 683901; + // s19 = 0; + + s6 += s18 * 666643; + s7 += s18 * 470296; + s8 += s18 * 654183; + s9 -= s18 * 997805; + s10 += s18 * 136657; + s11 -= s18 * 683901; + // s18 = 0; + + // Reduce the bit length of limbs from s6 to s15 to 21-bits. + carry6 = (s6 + (1 << 20)) >> 21; + s7 += carry6; + s6 -= carry6 << 21; + carry8 = (s8 + (1 << 20)) >> 21; + s9 += carry8; + s8 -= carry8 << 21; + carry10 = (s10 + (1 << 20)) >> 21; + s11 += carry10; + s10 -= carry10 << 21; + carry12 = (s12 + (1 << 20)) >> 21; + s13 += carry12; + s12 -= carry12 << 21; + carry14 = (s14 + (1 << 20)) >> 21; + s15 += carry14; + s14 -= carry14 << 21; + carry16 = (s16 + (1 << 20)) >> 21; + s17 += carry16; + s16 -= carry16 << 21; + + carry7 = (s7 + (1 << 20)) >> 21; + s8 += carry7; + s7 -= carry7 << 21; + carry9 = (s9 + (1 << 20)) >> 21; + s10 += carry9; + s9 -= carry9 << 21; + carry11 = (s11 + (1 << 20)) >> 21; + s12 += carry11; + s11 -= carry11 << 21; + carry13 = (s13 + (1 << 20)) >> 21; + s14 += carry13; + s13 -= carry13 << 21; + carry15 = (s15 + (1 << 20)) >> 21; + s16 += carry15; + s15 -= carry15 << 21; + + // Resume reduction where we left off. + s5 += s17 * 666643; + s6 += s17 * 470296; + s7 += s17 * 654183; + s8 -= s17 * 997805; + s9 += s17 * 136657; + s10 -= s17 * 683901; + // s17 = 0; + + s4 += s16 * 666643; + s5 += s16 * 470296; + s6 += s16 * 654183; + s7 -= s16 * 997805; + s8 += s16 * 136657; + s9 -= s16 * 683901; + // s16 = 0; + + s3 += s15 * 666643; + s4 += s15 * 470296; + s5 += s15 * 654183; + s6 -= s15 * 997805; + s7 += s15 * 136657; + s8 -= s15 * 683901; + // s15 = 0; + + s2 += s14 * 666643; + s3 += s14 * 470296; + s4 += s14 * 654183; + s5 -= s14 * 997805; + s6 += s14 * 136657; + s7 -= s14 * 683901; + // s14 = 0; + + s1 += s13 * 666643; + s2 += s13 * 470296; + s3 += s13 * 654183; + s4 -= s13 * 997805; + s5 += s13 * 136657; + s6 -= s13 * 683901; + // s13 = 0; + + s0 += s12 * 666643; + s1 += s12 * 470296; + s2 += s12 * 654183; + s3 -= s12 * 997805; + s4 += s12 * 136657; + s5 -= s12 * 683901; + s12 = 0; + + // Reduce the range of limbs from s0 to s11 to 21-bits. + carry0 = (s0 + (1 << 20)) >> 21; + s1 += carry0; + s0 -= carry0 << 21; + carry2 = (s2 + (1 << 20)) >> 21; + s3 += carry2; + s2 -= carry2 << 21; + carry4 = (s4 + (1 << 20)) >> 21; + s5 += carry4; + s4 -= carry4 << 21; + carry6 = (s6 + (1 << 20)) >> 21; + s7 += carry6; + s6 -= carry6 << 21; + carry8 = (s8 + (1 << 20)) >> 21; + s9 += carry8; + s8 -= carry8 << 21; + carry10 = (s10 + (1 << 20)) >> 21; + s11 += carry10; + s10 -= carry10 << 21; + + carry1 = (s1 + (1 << 20)) >> 21; + s2 += carry1; + s1 -= carry1 << 21; + carry3 = (s3 + (1 << 20)) >> 21; + s4 += carry3; + s3 -= carry3 << 21; + carry5 = (s5 + (1 << 20)) >> 21; + s6 += carry5; + s5 -= carry5 << 21; + carry7 = (s7 + (1 << 20)) >> 21; + s8 += carry7; + s7 -= carry7 << 21; + carry9 = (s9 + (1 << 20)) >> 21; + s10 += carry9; + s9 -= carry9 << 21; + carry11 = (s11 + (1 << 20)) >> 21; + s12 += carry11; + s11 -= carry11 << 21; + + s0 += s12 * 666643; + s1 += s12 * 470296; + s2 += s12 * 654183; + s3 -= s12 * 997805; + s4 += s12 * 136657; + s5 -= s12 * 683901; + s12 = 0; + + // Carry chain reduction to propagate excess bits from s0 to s5 to the most significant limbs. + carry0 = s0 >> 21; + s1 += carry0; + s0 -= carry0 << 21; + carry1 = s1 >> 21; + s2 += carry1; + s1 -= carry1 << 21; + carry2 = s2 >> 21; + s3 += carry2; + s2 -= carry2 << 21; + carry3 = s3 >> 21; + s4 += carry3; + s3 -= carry3 << 21; + carry4 = s4 >> 21; + s5 += carry4; + s4 -= carry4 << 21; + carry5 = s5 >> 21; + s6 += carry5; + s5 -= carry5 << 21; + carry6 = s6 >> 21; + s7 += carry6; + s6 -= carry6 << 21; + carry7 = s7 >> 21; + s8 += carry7; + s7 -= carry7 << 21; + carry8 = s8 >> 21; + s9 += carry8; + s8 -= carry8 << 21; + carry9 = s9 >> 21; + s10 += carry9; + s9 -= carry9 << 21; + carry10 = s10 >> 21; + s11 += carry10; + s10 -= carry10 << 21; + carry11 = s11 >> 21; + s12 += carry11; + s11 -= carry11 << 21; + + // Do one last reduction as s12 might be 1. + s0 += s12 * 666643; + s1 += s12 * 470296; + s2 += s12 * 654183; + s3 -= s12 * 997805; + s4 += s12 * 136657; + s5 -= s12 * 683901; + // s12 = 0; + + carry0 = s0 >> 21; + s1 += carry0; + s0 -= carry0 << 21; + carry1 = s1 >> 21; + s2 += carry1; + s1 -= carry1 << 21; + carry2 = s2 >> 21; + s3 += carry2; + s2 -= carry2 << 21; + carry3 = s3 >> 21; + s4 += carry3; + s3 -= carry3 << 21; + carry4 = s4 >> 21; + s5 += carry4; + s4 -= carry4 << 21; + carry5 = s5 >> 21; + s6 += carry5; + s5 -= carry5 << 21; + carry6 = s6 >> 21; + s7 += carry6; + s6 -= carry6 << 21; + carry7 = s7 >> 21; + s8 += carry7; + s7 -= carry7 << 21; + carry8 = s8 >> 21; + s9 += carry8; + s8 -= carry8 << 21; + carry9 = s9 >> 21; + s10 += carry9; + s9 -= carry9 << 21; + carry10 = s10 >> 21; + s11 += carry10; + s10 -= carry10 << 21; + + // Serialize the result into the s. + s[0] = (byte) s0; + s[1] = (byte) (s0 >> 8); + s[2] = (byte) ((s0 >> 16) | (s1 << 5)); + s[3] = (byte) (s1 >> 3); + s[4] = (byte) (s1 >> 11); + s[5] = (byte) ((s1 >> 19) | (s2 << 2)); + s[6] = (byte) (s2 >> 6); + s[7] = (byte) ((s2 >> 14) | (s3 << 7)); + s[8] = (byte) (s3 >> 1); + s[9] = (byte) (s3 >> 9); + s[10] = (byte) ((s3 >> 17) | (s4 << 4)); + s[11] = (byte) (s4 >> 4); + s[12] = (byte) (s4 >> 12); + s[13] = (byte) ((s4 >> 20) | (s5 << 1)); + s[14] = (byte) (s5 >> 7); + s[15] = (byte) ((s5 >> 15) | (s6 << 6)); + s[16] = (byte) (s6 >> 2); + s[17] = (byte) (s6 >> 10); + s[18] = (byte) ((s6 >> 18) | (s7 << 3)); + s[19] = (byte) (s7 >> 5); + s[20] = (byte) (s7 >> 13); + s[21] = (byte) s8; + s[22] = (byte) (s8 >> 8); + s[23] = (byte) ((s8 >> 16) | (s9 << 5)); + s[24] = (byte) (s9 >> 3); + s[25] = (byte) (s9 >> 11); + s[26] = (byte) ((s9 >> 19) | (s10 << 2)); + s[27] = (byte) (s10 >> 6); + s[28] = (byte) ((s10 >> 14) | (s11 << 7)); + s[29] = (byte) (s11 >> 1); + s[30] = (byte) (s11 >> 9); + s[31] = (byte) (s11 >> 17); + } + + /** + * Input: + * a[0]+256*a[1]+...+256^31*a[31] = a + * b[0]+256*b[1]+...+256^31*b[31] = b + * c[0]+256*c[1]+...+256^31*c[31] = c + * <p> + * Output: + * s[0]+256*s[1]+...+256^31*s[31] = (ab+c) mod l + * where l = 2^252 + 27742317777372353535851937790883648493. + */ + private static void mulAdd(byte[] s, byte[] a, byte[] b, byte[] c) { + // This is very similar to Ed25519.reduce, the difference in here is that it computes ab+c + // See Ed25519.reduce for related comments. + long a0 = 2097151 & load3(a, 0); + long a1 = 2097151 & (load4(a, 2) >> 5); + long a2 = 2097151 & (load3(a, 5) >> 2); + long a3 = 2097151 & (load4(a, 7) >> 7); + long a4 = 2097151 & (load4(a, 10) >> 4); + long a5 = 2097151 & (load3(a, 13) >> 1); + long a6 = 2097151 & (load4(a, 15) >> 6); + long a7 = 2097151 & (load3(a, 18) >> 3); + long a8 = 2097151 & load3(a, 21); + long a9 = 2097151 & (load4(a, 23) >> 5); + long a10 = 2097151 & (load3(a, 26) >> 2); + long a11 = (load4(a, 28) >> 7); + long b0 = 2097151 & load3(b, 0); + long b1 = 2097151 & (load4(b, 2) >> 5); + long b2 = 2097151 & (load3(b, 5) >> 2); + long b3 = 2097151 & (load4(b, 7) >> 7); + long b4 = 2097151 & (load4(b, 10) >> 4); + long b5 = 2097151 & (load3(b, 13) >> 1); + long b6 = 2097151 & (load4(b, 15) >> 6); + long b7 = 2097151 & (load3(b, 18) >> 3); + long b8 = 2097151 & load3(b, 21); + long b9 = 2097151 & (load4(b, 23) >> 5); + long b10 = 2097151 & (load3(b, 26) >> 2); + long b11 = (load4(b, 28) >> 7); + long c0 = 2097151 & load3(c, 0); + long c1 = 2097151 & (load4(c, 2) >> 5); + long c2 = 2097151 & (load3(c, 5) >> 2); + long c3 = 2097151 & (load4(c, 7) >> 7); + long c4 = 2097151 & (load4(c, 10) >> 4); + long c5 = 2097151 & (load3(c, 13) >> 1); + long c6 = 2097151 & (load4(c, 15) >> 6); + long c7 = 2097151 & (load3(c, 18) >> 3); + long c8 = 2097151 & load3(c, 21); + long c9 = 2097151 & (load4(c, 23) >> 5); + long c10 = 2097151 & (load3(c, 26) >> 2); + long c11 = (load4(c, 28) >> 7); + long s0; + long s1; + long s2; + long s3; + long s4; + long s5; + long s6; + long s7; + long s8; + long s9; + long s10; + long s11; + long s12; + long s13; + long s14; + long s15; + long s16; + long s17; + long s18; + long s19; + long s20; + long s21; + long s22; + long s23; + long carry0; + long carry1; + long carry2; + long carry3; + long carry4; + long carry5; + long carry6; + long carry7; + long carry8; + long carry9; + long carry10; + long carry11; + long carry12; + long carry13; + long carry14; + long carry15; + long carry16; + long carry17; + long carry18; + long carry19; + long carry20; + long carry21; + long carry22; + + s0 = c0 + a0 * b0; + s1 = c1 + a0 * b1 + a1 * b0; + s2 = c2 + a0 * b2 + a1 * b1 + a2 * b0; + s3 = c3 + a0 * b3 + a1 * b2 + a2 * b1 + a3 * b0; + s4 = c4 + a0 * b4 + a1 * b3 + a2 * b2 + a3 * b1 + a4 * b0; + s5 = c5 + a0 * b5 + a1 * b4 + a2 * b3 + a3 * b2 + a4 * b1 + a5 * b0; + s6 = c6 + a0 * b6 + a1 * b5 + a2 * b4 + a3 * b3 + a4 * b2 + a5 * b1 + a6 * b0; + s7 = c7 + a0 * b7 + a1 * b6 + a2 * b5 + a3 * b4 + a4 * b3 + a5 * b2 + a6 * b1 + a7 * b0; + s8 = c8 + a0 * b8 + a1 * b7 + a2 * b6 + a3 * b5 + a4 * b4 + a5 * b3 + a6 * b2 + a7 * b1 + + a8 * b0; + s9 = c9 + a0 * b9 + a1 * b8 + a2 * b7 + a3 * b6 + a4 * b5 + a5 * b4 + a6 * b3 + a7 * b2 + + a8 * b1 + a9 * b0; + s10 = c10 + a0 * b10 + a1 * b9 + a2 * b8 + a3 * b7 + a4 * b6 + a5 * b5 + a6 * b4 + a7 * b3 + + a8 * b2 + a9 * b1 + a10 * b0; + s11 = c11 + a0 * b11 + a1 * b10 + a2 * b9 + a3 * b8 + a4 * b7 + a5 * b6 + a6 * b5 + a7 * b4 + + a8 * b3 + a9 * b2 + a10 * b1 + a11 * b0; + s12 = a1 * b11 + a2 * b10 + a3 * b9 + a4 * b8 + a5 * b7 + a6 * b6 + a7 * b5 + a8 * b4 + a9 * b3 + + a10 * b2 + a11 * b1; + s13 = a2 * b11 + a3 * b10 + a4 * b9 + a5 * b8 + a6 * b7 + a7 * b6 + a8 * b5 + a9 * b4 + a10 * b3 + + a11 * b2; + s14 = a3 * b11 + a4 * b10 + a5 * b9 + a6 * b8 + a7 * b7 + a8 * b6 + a9 * b5 + a10 * b4 + + a11 * b3; + s15 = a4 * b11 + a5 * b10 + a6 * b9 + a7 * b8 + a8 * b7 + a9 * b6 + a10 * b5 + a11 * b4; + s16 = a5 * b11 + a6 * b10 + a7 * b9 + a8 * b8 + a9 * b7 + a10 * b6 + a11 * b5; + s17 = a6 * b11 + a7 * b10 + a8 * b9 + a9 * b8 + a10 * b7 + a11 * b6; + s18 = a7 * b11 + a8 * b10 + a9 * b9 + a10 * b8 + a11 * b7; + s19 = a8 * b11 + a9 * b10 + a10 * b9 + a11 * b8; + s20 = a9 * b11 + a10 * b10 + a11 * b9; + s21 = a10 * b11 + a11 * b10; + s22 = a11 * b11; + s23 = 0; + + carry0 = (s0 + (1 << 20)) >> 21; + s1 += carry0; + s0 -= carry0 << 21; + carry2 = (s2 + (1 << 20)) >> 21; + s3 += carry2; + s2 -= carry2 << 21; + carry4 = (s4 + (1 << 20)) >> 21; + s5 += carry4; + s4 -= carry4 << 21; + carry6 = (s6 + (1 << 20)) >> 21; + s7 += carry6; + s6 -= carry6 << 21; + carry8 = (s8 + (1 << 20)) >> 21; + s9 += carry8; + s8 -= carry8 << 21; + carry10 = (s10 + (1 << 20)) >> 21; + s11 += carry10; + s10 -= carry10 << 21; + carry12 = (s12 + (1 << 20)) >> 21; + s13 += carry12; + s12 -= carry12 << 21; + carry14 = (s14 + (1 << 20)) >> 21; + s15 += carry14; + s14 -= carry14 << 21; + carry16 = (s16 + (1 << 20)) >> 21; + s17 += carry16; + s16 -= carry16 << 21; + carry18 = (s18 + (1 << 20)) >> 21; + s19 += carry18; + s18 -= carry18 << 21; + carry20 = (s20 + (1 << 20)) >> 21; + s21 += carry20; + s20 -= carry20 << 21; + carry22 = (s22 + (1 << 20)) >> 21; + s23 += carry22; + s22 -= carry22 << 21; + + carry1 = (s1 + (1 << 20)) >> 21; + s2 += carry1; + s1 -= carry1 << 21; + carry3 = (s3 + (1 << 20)) >> 21; + s4 += carry3; + s3 -= carry3 << 21; + carry5 = (s5 + (1 << 20)) >> 21; + s6 += carry5; + s5 -= carry5 << 21; + carry7 = (s7 + (1 << 20)) >> 21; + s8 += carry7; + s7 -= carry7 << 21; + carry9 = (s9 + (1 << 20)) >> 21; + s10 += carry9; + s9 -= carry9 << 21; + carry11 = (s11 + (1 << 20)) >> 21; + s12 += carry11; + s11 -= carry11 << 21; + carry13 = (s13 + (1 << 20)) >> 21; + s14 += carry13; + s13 -= carry13 << 21; + carry15 = (s15 + (1 << 20)) >> 21; + s16 += carry15; + s15 -= carry15 << 21; + carry17 = (s17 + (1 << 20)) >> 21; + s18 += carry17; + s17 -= carry17 << 21; + carry19 = (s19 + (1 << 20)) >> 21; + s20 += carry19; + s19 -= carry19 << 21; + carry21 = (s21 + (1 << 20)) >> 21; + s22 += carry21; + s21 -= carry21 << 21; + + s11 += s23 * 666643; + s12 += s23 * 470296; + s13 += s23 * 654183; + s14 -= s23 * 997805; + s15 += s23 * 136657; + s16 -= s23 * 683901; + // s23 = 0; + + s10 += s22 * 666643; + s11 += s22 * 470296; + s12 += s22 * 654183; + s13 -= s22 * 997805; + s14 += s22 * 136657; + s15 -= s22 * 683901; + // s22 = 0; + + s9 += s21 * 666643; + s10 += s21 * 470296; + s11 += s21 * 654183; + s12 -= s21 * 997805; + s13 += s21 * 136657; + s14 -= s21 * 683901; + // s21 = 0; + + s8 += s20 * 666643; + s9 += s20 * 470296; + s10 += s20 * 654183; + s11 -= s20 * 997805; + s12 += s20 * 136657; + s13 -= s20 * 683901; + // s20 = 0; + + s7 += s19 * 666643; + s8 += s19 * 470296; + s9 += s19 * 654183; + s10 -= s19 * 997805; + s11 += s19 * 136657; + s12 -= s19 * 683901; + // s19 = 0; + + s6 += s18 * 666643; + s7 += s18 * 470296; + s8 += s18 * 654183; + s9 -= s18 * 997805; + s10 += s18 * 136657; + s11 -= s18 * 683901; + // s18 = 0; + + carry6 = (s6 + (1 << 20)) >> 21; + s7 += carry6; + s6 -= carry6 << 21; + carry8 = (s8 + (1 << 20)) >> 21; + s9 += carry8; + s8 -= carry8 << 21; + carry10 = (s10 + (1 << 20)) >> 21; + s11 += carry10; + s10 -= carry10 << 21; + carry12 = (s12 + (1 << 20)) >> 21; + s13 += carry12; + s12 -= carry12 << 21; + carry14 = (s14 + (1 << 20)) >> 21; + s15 += carry14; + s14 -= carry14 << 21; + carry16 = (s16 + (1 << 20)) >> 21; + s17 += carry16; + s16 -= carry16 << 21; + + carry7 = (s7 + (1 << 20)) >> 21; + s8 += carry7; + s7 -= carry7 << 21; + carry9 = (s9 + (1 << 20)) >> 21; + s10 += carry9; + s9 -= carry9 << 21; + carry11 = (s11 + (1 << 20)) >> 21; + s12 += carry11; + s11 -= carry11 << 21; + carry13 = (s13 + (1 << 20)) >> 21; + s14 += carry13; + s13 -= carry13 << 21; + carry15 = (s15 + (1 << 20)) >> 21; + s16 += carry15; + s15 -= carry15 << 21; + + s5 += s17 * 666643; + s6 += s17 * 470296; + s7 += s17 * 654183; + s8 -= s17 * 997805; + s9 += s17 * 136657; + s10 -= s17 * 683901; + // s17 = 0; + + s4 += s16 * 666643; + s5 += s16 * 470296; + s6 += s16 * 654183; + s7 -= s16 * 997805; + s8 += s16 * 136657; + s9 -= s16 * 683901; + // s16 = 0; + + s3 += s15 * 666643; + s4 += s15 * 470296; + s5 += s15 * 654183; + s6 -= s15 * 997805; + s7 += s15 * 136657; + s8 -= s15 * 683901; + // s15 = 0; + + s2 += s14 * 666643; + s3 += s14 * 470296; + s4 += s14 * 654183; + s5 -= s14 * 997805; + s6 += s14 * 136657; + s7 -= s14 * 683901; + // s14 = 0; + + s1 += s13 * 666643; + s2 += s13 * 470296; + s3 += s13 * 654183; + s4 -= s13 * 997805; + s5 += s13 * 136657; + s6 -= s13 * 683901; + // s13 = 0; + + s0 += s12 * 666643; + s1 += s12 * 470296; + s2 += s12 * 654183; + s3 -= s12 * 997805; + s4 += s12 * 136657; + s5 -= s12 * 683901; + s12 = 0; + + carry0 = (s0 + (1 << 20)) >> 21; + s1 += carry0; + s0 -= carry0 << 21; + carry2 = (s2 + (1 << 20)) >> 21; + s3 += carry2; + s2 -= carry2 << 21; + carry4 = (s4 + (1 << 20)) >> 21; + s5 += carry4; + s4 -= carry4 << 21; + carry6 = (s6 + (1 << 20)) >> 21; + s7 += carry6; + s6 -= carry6 << 21; + carry8 = (s8 + (1 << 20)) >> 21; + s9 += carry8; + s8 -= carry8 << 21; + carry10 = (s10 + (1 << 20)) >> 21; + s11 += carry10; + s10 -= carry10 << 21; + + carry1 = (s1 + (1 << 20)) >> 21; + s2 += carry1; + s1 -= carry1 << 21; + carry3 = (s3 + (1 << 20)) >> 21; + s4 += carry3; + s3 -= carry3 << 21; + carry5 = (s5 + (1 << 20)) >> 21; + s6 += carry5; + s5 -= carry5 << 21; + carry7 = (s7 + (1 << 20)) >> 21; + s8 += carry7; + s7 -= carry7 << 21; + carry9 = (s9 + (1 << 20)) >> 21; + s10 += carry9; + s9 -= carry9 << 21; + carry11 = (s11 + (1 << 20)) >> 21; + s12 += carry11; + s11 -= carry11 << 21; + + s0 += s12 * 666643; + s1 += s12 * 470296; + s2 += s12 * 654183; + s3 -= s12 * 997805; + s4 += s12 * 136657; + s5 -= s12 * 683901; + s12 = 0; + + carry0 = s0 >> 21; + s1 += carry0; + s0 -= carry0 << 21; + carry1 = s1 >> 21; + s2 += carry1; + s1 -= carry1 << 21; + carry2 = s2 >> 21; + s3 += carry2; + s2 -= carry2 << 21; + carry3 = s3 >> 21; + s4 += carry3; + s3 -= carry3 << 21; + carry4 = s4 >> 21; + s5 += carry4; + s4 -= carry4 << 21; + carry5 = s5 >> 21; + s6 += carry5; + s5 -= carry5 << 21; + carry6 = s6 >> 21; + s7 += carry6; + s6 -= carry6 << 21; + carry7 = s7 >> 21; + s8 += carry7; + s7 -= carry7 << 21; + carry8 = s8 >> 21; + s9 += carry8; + s8 -= carry8 << 21; + carry9 = s9 >> 21; + s10 += carry9; + s9 -= carry9 << 21; + carry10 = s10 >> 21; + s11 += carry10; + s10 -= carry10 << 21; + carry11 = s11 >> 21; + s12 += carry11; + s11 -= carry11 << 21; + + s0 += s12 * 666643; + s1 += s12 * 470296; + s2 += s12 * 654183; + s3 -= s12 * 997805; + s4 += s12 * 136657; + s5 -= s12 * 683901; + // s12 = 0; + + carry0 = s0 >> 21; + s1 += carry0; + s0 -= carry0 << 21; + carry1 = s1 >> 21; + s2 += carry1; + s1 -= carry1 << 21; + carry2 = s2 >> 21; + s3 += carry2; + s2 -= carry2 << 21; + carry3 = s3 >> 21; + s4 += carry3; + s3 -= carry3 << 21; + carry4 = s4 >> 21; + s5 += carry4; + s4 -= carry4 << 21; + carry5 = s5 >> 21; + s6 += carry5; + s5 -= carry5 << 21; + carry6 = s6 >> 21; + s7 += carry6; + s6 -= carry6 << 21; + carry7 = s7 >> 21; + s8 += carry7; + s7 -= carry7 << 21; + carry8 = s8 >> 21; + s9 += carry8; + s8 -= carry8 << 21; + carry9 = s9 >> 21; + s10 += carry9; + s9 -= carry9 << 21; + carry10 = s10 >> 21; + s11 += carry10; + s10 -= carry10 << 21; + + s[0] = (byte) s0; + s[1] = (byte) (s0 >> 8); + s[2] = (byte) ((s0 >> 16) | (s1 << 5)); + s[3] = (byte) (s1 >> 3); + s[4] = (byte) (s1 >> 11); + s[5] = (byte) ((s1 >> 19) | (s2 << 2)); + s[6] = (byte) (s2 >> 6); + s[7] = (byte) ((s2 >> 14) | (s3 << 7)); + s[8] = (byte) (s3 >> 1); + s[9] = (byte) (s3 >> 9); + s[10] = (byte) ((s3 >> 17) | (s4 << 4)); + s[11] = (byte) (s4 >> 4); + s[12] = (byte) (s4 >> 12); + s[13] = (byte) ((s4 >> 20) | (s5 << 1)); + s[14] = (byte) (s5 >> 7); + s[15] = (byte) ((s5 >> 15) | (s6 << 6)); + s[16] = (byte) (s6 >> 2); + s[17] = (byte) (s6 >> 10); + s[18] = (byte) ((s6 >> 18) | (s7 << 3)); + s[19] = (byte) (s7 >> 5); + s[20] = (byte) (s7 >> 13); + s[21] = (byte) s8; + s[22] = (byte) (s8 >> 8); + s[23] = (byte) ((s8 >> 16) | (s9 << 5)); + s[24] = (byte) (s9 >> 3); + s[25] = (byte) (s9 >> 11); + s[26] = (byte) ((s9 >> 19) | (s10 << 2)); + s[27] = (byte) (s10 >> 6); + s[28] = (byte) ((s10 >> 14) | (s11 << 7)); + s[29] = (byte) (s11 >> 1); + s[30] = (byte) (s11 >> 9); + s[31] = (byte) (s11 >> 17); + } + + // The order of the generator as unsigned bytes in little endian order. + // (2^252 + 0x14def9dea2f79cd65812631a5cf5d3ed, cf. RFC 7748) + private static final byte[] GROUP_ORDER = { + (byte) 0xed, (byte) 0xd3, (byte) 0xf5, (byte) 0x5c, + (byte) 0x1a, (byte) 0x63, (byte) 0x12, (byte) 0x58, + (byte) 0xd6, (byte) 0x9c, (byte) 0xf7, (byte) 0xa2, + (byte) 0xde, (byte) 0xf9, (byte) 0xde, (byte) 0x14, + (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, + (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, + (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, + (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x10}; + + // Checks whether s represents an integer smaller than the order of the group. + // This is needed to ensure that EdDSA signatures are non-malleable, as failing to check + // the range of S allows to modify signatures (cf. RFC 8032, Section 5.2.7 and Section 8.4.) + // @param s an integer in little-endian order. + private static boolean isSmallerThanGroupOrder(byte[] s) { + for (int j = Field25519.FIELD_LEN - 1; j >= 0; j--) { + // compare unsigned bytes + int a = s[j] & 0xff; + int b = GROUP_ORDER[j] & 0xff; + if (a != b) { + return a < b; + } + } + return false; + } + + /** + * Returns true if the EdDSA {@code signature} with {@code message}, can be verified with + * {@code publicKey}. + */ + public static boolean verify(final byte[] message, final byte[] signature, + final byte[] publicKey) { + try { + if (signature.length != SIGNATURE_LEN) { + return false; + } + if (publicKey.length != PUBLIC_KEY_LEN) { + return false; + } + byte[] s = Arrays.copyOfRange(signature, Field25519.FIELD_LEN, SIGNATURE_LEN); + if (!isSmallerThanGroupOrder(s)) { + return false; + } + MessageDigest digest = MessageDigest.getInstance("SHA-512"); + digest.update(signature, 0, Field25519.FIELD_LEN); + digest.update(publicKey); + digest.update(message); + byte[] h = digest.digest(); + reduce(h); + + XYZT negPublicKey = XYZT.fromBytesNegateVarTime(publicKey); + XYZ xyz = doubleScalarMultVarTime(h, negPublicKey, s); + byte[] expectedR = xyz.toBytes(); + for (int i = 0; i < Field25519.FIELD_LEN; i++) { + if (expectedR[i] != signature[i]) { + return false; + } + } + return true; + } catch (final GeneralSecurityException ignored) { + return false; + } + } +} diff --git a/ui/src/main/java/com/wireguard/android/updater/SnackbarUpdateShower.kt b/ui/src/main/java/com/wireguard/android/updater/SnackbarUpdateShower.kt new file mode 100644 index 00000000..e6134991 --- /dev/null +++ b/ui/src/main/java/com/wireguard/android/updater/SnackbarUpdateShower.kt @@ -0,0 +1,173 @@ +/* + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.updater + +import android.content.Intent +import android.net.Uri +import android.view.View +import android.widget.Toast +import androidx.activity.result.contract.ActivityResultContracts +import androidx.fragment.app.Fragment +import androidx.lifecycle.lifecycleScope +import com.google.android.material.dialog.MaterialAlertDialogBuilder +import com.google.android.material.snackbar.BaseTransientBottomBar +import com.google.android.material.snackbar.Snackbar +import com.wireguard.android.R +import com.wireguard.android.util.ErrorMessages +import com.wireguard.android.util.QuantityFormatter +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.launchIn +import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.launch +import kotlin.time.Duration.Companion.seconds + +class SnackbarUpdateShower(private val fragment: Fragment) { + private var lastUserIntervention: Updater.Progress.NeedsUserIntervention? = null + private val intentLauncher = fragment.registerForActivityResult(ActivityResultContracts.StartActivityForResult()) { + lastUserIntervention?.markAsDone() + } + + private class SwapableSnackbar(fragment: Fragment, view: View, anchor: View?) { + private val actionSnackbar = makeSnackbar(fragment, view, anchor) + private val statusSnackbar = makeSnackbar(fragment, view, anchor) + private var showingAction: Boolean = false + private var showingStatus: Boolean = false + + private fun makeSnackbar(fragment: Fragment, view: View, anchor: View?): Snackbar { + val snackbar = Snackbar.make(fragment.requireContext(), view, "", Snackbar.LENGTH_INDEFINITE) + if (anchor != null) + snackbar.anchorView = anchor + snackbar.setTextMaxLines(6) + snackbar.behavior = object : BaseTransientBottomBar.Behavior() { + override fun canSwipeDismissView(child: View): Boolean { + return false + } + } + snackbar.addCallback(object : BaseTransientBottomBar.BaseCallback<Snackbar>() { + override fun onDismissed(snackbar: Snackbar?, @DismissEvent event: Int) { + super.onDismissed(snackbar, event) + if (event == DISMISS_EVENT_MANUAL || event == DISMISS_EVENT_ACTION || + (snackbar == actionSnackbar && !showingAction) || (snackbar == statusSnackbar && !showingStatus) + ) + return + fragment.lifecycleScope.launch { + delay(5.seconds) + snackbar?.show() + } + } + }) + return snackbar + } + + fun showAction(text: String, action: String, listener: View.OnClickListener) { + if (showingStatus) { + showingStatus = false + statusSnackbar.dismiss() + } + actionSnackbar.setText(text) + actionSnackbar.setAction(action, listener) + if (!showingAction) { + actionSnackbar.show() + showingAction = true + } + } + + fun showText(text: String) { + if (showingAction) { + showingAction = false + actionSnackbar.dismiss() + } + statusSnackbar.setText(text) + if (!showingStatus) { + statusSnackbar.show() + showingStatus = true + } + } + + fun dismiss() { + actionSnackbar.dismiss() + statusSnackbar.dismiss() + showingAction = false + showingStatus = false + } + } + + fun attach(view: View, anchor: View?) { + val snackbar = SwapableSnackbar(fragment, view, anchor) + val context = fragment.requireContext() + + Updater.state.onEach { progress -> + when (progress) { + is Updater.Progress.Complete -> + snackbar.dismiss() + + is Updater.Progress.Available -> + snackbar.showAction(context.getString(R.string.updater_avalable), context.getString(R.string.updater_action)) { + progress.update() + } + + is Updater.Progress.NeedsUserIntervention -> { + lastUserIntervention = progress + intentLauncher.launch(progress.intent) + } + + is Updater.Progress.Installing -> + snackbar.showText(context.getString(R.string.updater_installing)) + + is Updater.Progress.Rechecking -> + snackbar.showText(context.getString(R.string.updater_rechecking)) + + is Updater.Progress.Downloading -> { + if (progress.bytesTotal != 0UL) { + snackbar.showText( + context.getString( + R.string.updater_download_progress, + QuantityFormatter.formatBytes(progress.bytesDownloaded.toLong()), + QuantityFormatter.formatBytes(progress.bytesTotal.toLong()), + progress.bytesDownloaded.toFloat() * 100.0 / progress.bytesTotal.toFloat() + ) + ) + } else { + snackbar.showText( + context.getString( + R.string.updater_download_progress_nototal, + QuantityFormatter.formatBytes(progress.bytesDownloaded.toLong()) + ) + ) + } + } + + is Updater.Progress.Failure -> { + snackbar.showText(context.getString(R.string.updater_failure, ErrorMessages[progress.error])) + delay(5.seconds) + progress.retry() + } + + is Updater.Progress.Corrupt -> { + MaterialAlertDialogBuilder(context) + .setTitle(R.string.updater_corrupt_title) + .setMessage(R.string.updater_corrupt_message) + .setPositiveButton(R.string.updater_corrupt_navigate) { _, _ -> + val intent = Intent(Intent.ACTION_VIEW) + intent.data = Uri.parse(progress.downloadUrl) + try { + context.startActivity(intent) + } catch (e: Throwable) { + Toast.makeText(context, ErrorMessages[e], Toast.LENGTH_SHORT).show() + } + }.setCancelable(false).setOnDismissListener { + val intent = Intent(Intent.ACTION_MAIN) + intent.addCategory(Intent.CATEGORY_HOME) + intent.addFlags(Intent.FLAG_ACTIVITY_CLEAR_TASK) + intent.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK) + context.startActivity(intent) + System.exit(0) + }.show() + } + } + }.launchIn(fragment.lifecycleScope) + } +}
\ No newline at end of file diff --git a/ui/src/main/java/com/wireguard/android/updater/Updater.kt b/ui/src/main/java/com/wireguard/android/updater/Updater.kt new file mode 100644 index 00000000..46d0fe34 --- /dev/null +++ b/ui/src/main/java/com/wireguard/android/updater/Updater.kt @@ -0,0 +1,451 @@ +/* + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package com.wireguard.android.updater + +import android.Manifest +import android.app.PendingIntent +import android.content.BroadcastReceiver +import android.content.Context +import android.content.Intent +import android.content.IntentFilter +import android.content.pm.PackageInstaller +import android.content.pm.PackageManager +import android.os.Build +import android.util.Base64 +import android.util.Log +import androidx.core.content.ContextCompat +import androidx.core.content.IntentCompat +import com.wireguard.android.Application +import com.wireguard.android.BuildConfig +import com.wireguard.android.activity.MainActivity +import com.wireguard.android.util.UserKnobs +import com.wireguard.android.util.applicationScope +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.firstOrNull +import kotlinx.coroutines.flow.launchIn +import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import java.io.IOException +import java.net.HttpURLConnection +import java.net.URL +import java.nio.charset.StandardCharsets +import java.security.InvalidKeyException +import java.security.InvalidParameterException +import java.security.MessageDigest +import java.util.UUID +import kotlin.math.max +import kotlin.time.Duration.Companion.minutes +import kotlin.time.Duration.Companion.seconds + +object Updater { + private const val TAG = "WireGuard/Updater" + private const val UPDATE_URL_FMT = "https://download.wireguard.com/android-client/%s" + private const val APK_NAME_PREFIX = BuildConfig.APPLICATION_ID + "-" + private const val APK_NAME_SUFFIX = ".apk" + private const val LATEST_FILE = "latest.sig" + private const val RELEASE_PUBLIC_KEY_BASE64 = "RWTAzwGRYr3EC9px0Ia3fbttz8WcVN6wrOwWp2delz4el6SI8XmkKSMp" + private val CURRENT_VERSION by lazy { Version(BuildConfig.VERSION_NAME) } + + private val updaterScope = CoroutineScope(Job() + Dispatchers.IO) + + private fun installer(context: Context): String = try { + val packageName = context.packageName + val pm = context.packageManager + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.R) { + pm.getInstallSourceInfo(packageName).installingPackageName ?: "" + } else { + @Suppress("DEPRECATION") + pm.getInstallerPackageName(packageName) ?: "" + } + } catch (_: Throwable) { + "" + } + + fun installerIsGooglePlay(context: Context): Boolean = installer(context) == "com.android.vending" + + sealed class Progress { + object Complete : Progress() + class Available(val version: String) : Progress() { + fun update() { + applicationScope.launch { + UserKnobs.setUpdaterNewerVersionConsented(version) + } + } + } + + object Rechecking : Progress() + class Downloading(val bytesDownloaded: ULong, val bytesTotal: ULong) : Progress() + object Installing : Progress() + class NeedsUserIntervention(val intent: Intent, private val id: Int) : Progress() { + + private suspend fun installerActive(): Boolean { + if (mutableState.firstOrNull() != this@NeedsUserIntervention) + return true + try { + if (Application.get().packageManager.packageInstaller.getSessionInfo(id)?.isActive == true) + return true + } catch (_: SecurityException) { + return true + } + return false + } + + fun markAsDone() { + applicationScope.launch { + if (installerActive()) + return@launch + delay(7.seconds) + if (installerActive()) + return@launch + emitProgress(Failure(Exception("Ignored by user"))) + } + } + } + + class Failure(val error: Throwable) : Progress() { + fun retry() { + updaterScope.launch { + downloadAndUpdateWrapErrors() + } + } + } + + class Corrupt(private val betterFile: String?) : Progress() { + val downloadUrl: String + get() = UPDATE_URL_FMT.format(betterFile ?: "") + } + } + + private val mutableState = MutableStateFlow<Progress>(Progress.Complete) + val state = mutableState.asStateFlow() + + private suspend fun emitProgress(progress: Progress, force: Boolean = false) { + if (force || mutableState.firstOrNull()?.javaClass != progress.javaClass) + mutableState.emit(progress) + } + + private class Sha256Digest(hex: String) { + val bytes: ByteArray + + init { + if (hex.length != 64) + throw InvalidParameterException("SHA256 hashes must be 32 bytes long") + bytes = hex.chunked(2).map { it.toInt(16).toByte() }.toByteArray() + } + } + + @OptIn(ExperimentalUnsignedTypes::class) + private class Version(version: String) : Comparable<Version> { + val parts: ULongArray + + init { + val strParts = version.split(".") + if (strParts.isEmpty()) + throw InvalidParameterException("Version has no parts") + parts = ULongArray(strParts.size) + for (i in parts.indices) { + parts[i] = strParts[i].toULong() + } + } + + override fun toString(): String { + return parts.joinToString(".") + } + + override fun compareTo(other: Version): Int { + for (i in 0 until max(parts.size, other.parts.size)) { + val lhsPart = if (i < parts.size) parts[i] else 0UL + val rhsPart = if (i < other.parts.size) other.parts[i] else 0UL + if (lhsPart > rhsPart) + return 1 + else if (lhsPart < rhsPart) + return -1 + } + return 0 + } + } + + private class Update(val fileName: String, val version: Version, val hash: Sha256Digest) + + private fun versionOfFile(name: String): Version? { + if (!name.startsWith(APK_NAME_PREFIX) || !name.endsWith(APK_NAME_SUFFIX)) + return null + return try { + Version(name.substring(APK_NAME_PREFIX.length, name.length - APK_NAME_SUFFIX.length)) + } catch (_: Throwable) { + null + } + } + + private fun verifySignedFileList(signifyDigest: String): List<Update> { + val updates = ArrayList<Update>(1) + val publicKeyBytes = Base64.decode(RELEASE_PUBLIC_KEY_BASE64, Base64.DEFAULT) + if (publicKeyBytes == null || publicKeyBytes.size != 32 + 10 || publicKeyBytes[0] != 'E'.code.toByte() || publicKeyBytes[1] != 'd'.code.toByte()) + throw InvalidKeyException("Invalid public key") + val lines = signifyDigest.split("\n", limit = 3) + if (lines.size != 3) + throw InvalidParameterException("Invalid signature format: too few lines") + if (!lines[0].startsWith("untrusted comment: ")) + throw InvalidParameterException("Invalid signature format: missing comment") + val signatureBytes = Base64.decode(lines[1], Base64.DEFAULT) + if (signatureBytes == null || signatureBytes.size != 64 + 10) + throw InvalidParameterException("Invalid signature format: wrong sized or missing signature") + for (i in 0..9) { + if (signatureBytes[i] != publicKeyBytes[i]) + throw InvalidParameterException("Invalid signature format: wrong signer") + } + if (!Ed25519.verify( + lines[2].toByteArray(StandardCharsets.UTF_8), + signatureBytes.sliceArray(10 until 10 + 64), + publicKeyBytes.sliceArray(10 until 10 + 32) + ) + ) + throw SecurityException("Invalid signature") + for (line in lines[2].split("\n").dropLastWhile { it.isEmpty() }) { + val components = line.split(" ", limit = 2) + if (components.size != 2) + throw InvalidParameterException("Invalid file list format: too few components") + /* If version is null, it's not a file we understand, but still a legitimate entry, so don't throw. */ + val version = versionOfFile(components[1]) ?: continue + updates.add(Update(components[1], version, Sha256Digest(components[0]))) + } + return updates + } + + private fun checkForUpdates(): Update? { + val connection = URL(UPDATE_URL_FMT.format(LATEST_FILE)).openConnection() as HttpURLConnection + connection.setRequestProperty("User-Agent", Application.USER_AGENT) + connection.connect() + if (connection.responseCode != HttpURLConnection.HTTP_OK) + throw IOException(connection.responseMessage) + var fileListBytes = ByteArray(1024 * 512 /* 512 KiB */) + connection.inputStream.use { + val len = it.read(fileListBytes) + if (len <= 0) + throw IOException("File list is empty") + fileListBytes = fileListBytes.sliceArray(0 until len) + } + return verifySignedFileList(fileListBytes.decodeToString()).maxByOrNull { it.version } + } + + private suspend fun downloadAndUpdate() = withContext(Dispatchers.IO) { + val receiver = InstallReceiver() + val context = Application.get().applicationContext + val pendingIntent = withContext(Dispatchers.Main) { + ContextCompat.registerReceiver(context, receiver, IntentFilter(receiver.sessionId), ContextCompat.RECEIVER_NOT_EXPORTED) + PendingIntent.getBroadcast( + context, + 0, + Intent(receiver.sessionId).setPackage(context.packageName), + PendingIntent.FLAG_UPDATE_CURRENT or PendingIntent.FLAG_MUTABLE + ) + } + + emitProgress(Progress.Rechecking) + val update = checkForUpdates() + if (update == null || update.version <= CURRENT_VERSION) { + emitProgress(Progress.Complete) + return@withContext + } + + emitProgress(Progress.Downloading(0UL, 0UL), true) + val connection = URL(UPDATE_URL_FMT.format(update.fileName)).openConnection() as HttpURLConnection + connection.setRequestProperty("User-Agent", Application.USER_AGENT) + connection.connect() + if (connection.responseCode != HttpURLConnection.HTTP_OK) + throw IOException("Update could not be fetched: ${connection.responseCode}") + + var downloadedByteLen: ULong = 0UL + val totalByteLen = (if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) connection.contentLengthLong else connection.contentLength).toLong().toULong() + val fileBytes = ByteArray(1024 * 32 /* 32 KiB */) + val digest = MessageDigest.getInstance("SHA-256") + emitProgress(Progress.Downloading(downloadedByteLen, totalByteLen), true) + + val installer = context.packageManager.packageInstaller + val params = PackageInstaller.SessionParams(PackageInstaller.SessionParams.MODE_FULL_INSTALL) + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.S) + params.setRequireUserAction(PackageInstaller.SessionParams.USER_ACTION_NOT_REQUIRED) + params.setAppPackageName(context.packageName) /* Enforces updates; disallows new apps. */ + val session = installer.openSession(installer.createSession(params)) + var sessionFailure = true + try { + val installDest = session.openWrite(receiver.sessionId, 0, -1) + + installDest.use { dest -> + connection.inputStream.use { src -> + while (true) { + val readLen = src.read(fileBytes) + if (readLen <= 0) + break + + digest.update(fileBytes, 0, readLen) + dest.write(fileBytes, 0, readLen) + + downloadedByteLen += readLen.toUInt() + emitProgress(Progress.Downloading(downloadedByteLen, totalByteLen), true) + + if (downloadedByteLen >= 1024UL * 1024UL * 100UL /* 100 MiB */) + throw IOException("File too large") + } + } + } + + emitProgress(Progress.Installing) + if (!digest.digest().contentEquals(update.hash.bytes)) + throw SecurityException("Update has invalid hash") + sessionFailure = false + } finally { + if (sessionFailure) { + session.abandon() + session.close() + } + } + session.commit(pendingIntent.intentSender) + session.close() + } + + private var updating = false + private suspend fun downloadAndUpdateWrapErrors() { + if (updating) + return + updating = true + try { + downloadAndUpdate() + } catch (e: Throwable) { + Log.e(TAG, "Update failure", e) + emitProgress(Progress.Failure(e)) + } + updating = false + } + + private class InstallReceiver : BroadcastReceiver() { + val sessionId = UUID.randomUUID().toString() + + override fun onReceive(context: Context, intent: Intent) { + if (sessionId != intent.action) + return + + when (val status = intent.getIntExtra(PackageInstaller.EXTRA_STATUS, PackageInstaller.STATUS_FAILURE_INVALID)) { + PackageInstaller.STATUS_PENDING_USER_ACTION -> { + val id = intent.getIntExtra(PackageInstaller.EXTRA_SESSION_ID, 0) + val userIntervention = IntentCompat.getParcelableExtra(intent, Intent.EXTRA_INTENT, Intent::class.java)!! + applicationScope.launch { + emitProgress(Progress.NeedsUserIntervention(userIntervention, id)) + } + } + + PackageInstaller.STATUS_SUCCESS -> { + applicationScope.launch { + emitProgress(Progress.Complete) + } + context.applicationContext.unregisterReceiver(this) + } + + else -> { + val id = intent.getIntExtra(PackageInstaller.EXTRA_SESSION_ID, 0) + try { + context.applicationContext.packageManager.packageInstaller.abandonSession(id) + } catch (_: SecurityException) { + } + val message = intent.getStringExtra(PackageInstaller.EXTRA_STATUS_MESSAGE) ?: "Installation error $status" + applicationScope.launch { + val e = Exception(message) + Log.e(TAG, "Update failure", e) + emitProgress(Progress.Failure(e)) + } + context.applicationContext.unregisterReceiver(this) + } + } + } + } + + fun monitorForUpdates() { + if (BuildConfig.DEBUG) + return + + val context = Application.get() + + if (installerIsGooglePlay(context)) + return + + if (if (Build.VERSION.SDK_INT < Build.VERSION_CODES.TIRAMISU) { + @Suppress("DEPRECATION") + context.packageManager.getPackageInfo(context.packageName, PackageManager.GET_PERMISSIONS) + } else { + context.packageManager.getPackageInfo(context.packageName, PackageManager.PackageInfoFlags.of(PackageManager.GET_PERMISSIONS.toLong())) + }.requestedPermissions?.contains(Manifest.permission.REQUEST_INSTALL_PACKAGES) != true + ) { + if (installer(context).isNotEmpty()) { + updaterScope.launch { + val update = try { + checkForUpdates() + } catch (_: Throwable) { + null + } + emitProgress(Progress.Corrupt(update?.fileName)) + } + } + return + } + + updaterScope.launch { + if (UserKnobs.updaterNewerVersionSeen.firstOrNull()?.let { Version(it) > CURRENT_VERSION } == true) + return@launch + + var waitTime = 15 + while (true) { + try { + val update = checkForUpdates() ?: continue + if (update.version > CURRENT_VERSION) { + Log.i(TAG, "Update available: ${update.version}") + UserKnobs.setUpdaterNewerVersionSeen(update.version.toString()) + return@launch + } + } catch (_: Throwable) { + } + delay(waitTime.minutes) + waitTime = 45 + } + } + + UserKnobs.updaterNewerVersionSeen.onEach { ver -> + if ( + ver != null && + Version(ver) > CURRENT_VERSION && + UserKnobs.updaterNewerVersionConsented.firstOrNull()?.let { Version(it) > CURRENT_VERSION } != true + ) + emitProgress(Progress.Available(ver)) + }.launchIn(applicationScope) + + UserKnobs.updaterNewerVersionConsented.onEach { ver -> + if (ver != null && Version(ver) > CURRENT_VERSION) + updaterScope.launch { + downloadAndUpdateWrapErrors() + } + }.launchIn(applicationScope) + } + + class AppUpdatedReceiver : BroadcastReceiver() { + override fun onReceive(context: Context, intent: Intent) { + if (intent.action != Intent.ACTION_MY_PACKAGE_REPLACED) + return + + if (installer(context) != context.packageName) + return + + /* TODO: does not work because of restrictions placed on broadcast receivers. */ + val start = Intent(context, MainActivity::class.java) + start.addFlags(Intent.FLAG_ACTIVITY_CLEAR_TOP) + start.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK) + context.startActivity(start) + } + } +} diff --git a/ui/src/main/java/com/wireguard/android/util/AdminKnobs.kt b/ui/src/main/java/com/wireguard/android/util/AdminKnobs.kt new file mode 100644 index 00000000..2c23910b --- /dev/null +++ b/ui/src/main/java/com/wireguard/android/util/AdminKnobs.kt @@ -0,0 +1,17 @@ +/* + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.util + +import android.content.RestrictionsManager +import androidx.core.content.getSystemService +import com.wireguard.android.Application + +object AdminKnobs { + private val restrictions: RestrictionsManager? = Application.get().getSystemService() + val disableConfigExport: Boolean + get() = restrictions?.applicationRestrictions?.getBoolean("disable_config_export", false) + ?: false +} diff --git a/ui/src/main/java/com/wireguard/android/util/AsyncWorker.kt b/ui/src/main/java/com/wireguard/android/util/AsyncWorker.kt deleted file mode 100644 index a6e5d4be..00000000 --- a/ui/src/main/java/com/wireguard/android/util/AsyncWorker.kt +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright © 2017-2020 WireGuard LLC. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package com.wireguard.android.util - -import android.os.Handler -import java9.util.concurrent.CompletableFuture -import java9.util.concurrent.CompletionStage -import java.util.concurrent.Executor - -/** - * Helper class for running asynchronous tasks and ensuring they are completed on the main thread. - */ - -class AsyncWorker(private val executor: Executor, private val handler: Handler) { - - fun runAsync(run: () -> Unit): CompletionStage<Void> { - val future = CompletableFuture<Void>() - executor.execute { - try { - run() - handler.post { future.complete(null) } - } catch (t: Throwable) { - handler.post { future.completeExceptionally(t) } - } - } - return future - } - - fun <T> supplyAsync(get: () -> T?): CompletionStage<T> { - val future = CompletableFuture<T>() - executor.execute { - try { - val result = get() - handler.post { future.complete(result) } - } catch (t: Throwable) { - handler.post { future.completeExceptionally(t) } - } - } - return future - } -} diff --git a/ui/src/main/java/com/wireguard/android/util/BiometricAuthenticator.kt b/ui/src/main/java/com/wireguard/android/util/BiometricAuthenticator.kt index aed8a4f2..064ea04d 100644 --- a/ui/src/main/java/com/wireguard/android/util/BiometricAuthenticator.kt +++ b/ui/src/main/java/com/wireguard/android/util/BiometricAuthenticator.kt @@ -1,28 +1,27 @@ /* - * Copyright © 2020 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.util -import android.annotation.SuppressLint -import android.app.KeyguardManager -import android.content.Context -import android.os.Build import android.os.Handler +import android.os.Looper import android.util.Log import androidx.annotation.StringRes -import androidx.biometric.BiometricConstants import androidx.biometric.BiometricManager +import androidx.biometric.BiometricManager.Authenticators import androidx.biometric.BiometricPrompt -import androidx.core.content.getSystemService import androidx.fragment.app.Fragment import com.wireguard.android.R object BiometricAuthenticator { private const val TAG = "WireGuard/BiometricAuthenticator" - private val handler = Handler() + + // Not all devices support strong biometric auth so we're allowing both device credentials as + // well as weak biometrics. + private const val allowedAuthenticators = Authenticators.DEVICE_CREDENTIAL or Authenticators.BIOMETRIC_WEAK sealed class Result { data class Success(val cryptoObject: BiometricPrompt.CryptoObject?) : Result() @@ -31,40 +30,30 @@ object BiometricAuthenticator { object Cancelled : Result() } - @SuppressLint("PrivateApi") - private fun isPinEnabled(context: Context): Boolean { - if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) - return context.getSystemService<KeyguardManager>()!!.isDeviceSecure - return try { - val lockUtilsClass = Class.forName("com.android.internal.widget.LockPatternUtils") - val lockUtils = lockUtilsClass.getConstructor(Context::class.java).newInstance(context) - val method = lockUtilsClass.getMethod("isLockScreenDisabled") - !(method.invoke(lockUtils) as Boolean) - } catch (e: Exception) { - false - } - } - fun authenticate( - @StringRes dialogTitleRes: Int, - fragment: Fragment, - callback: (Result) -> Unit + @StringRes dialogTitleRes: Int, + fragment: Fragment, + callback: (Result) -> Unit ) { val authCallback = object : BiometricPrompt.AuthenticationCallback() { override fun onAuthenticationError(errorCode: Int, errString: CharSequence) { super.onAuthenticationError(errorCode, errString) Log.d(TAG, "BiometricAuthentication error: errorCode=$errorCode, msg=$errString") - callback(when (errorCode) { - BiometricConstants.ERROR_CANCELED, BiometricConstants.ERROR_USER_CANCELED, - BiometricConstants.ERROR_NEGATIVE_BUTTON -> { - Result.Cancelled - } - BiometricConstants.ERROR_HW_NOT_PRESENT, BiometricConstants.ERROR_HW_UNAVAILABLE, - BiometricConstants.ERROR_NO_BIOMETRICS, BiometricConstants.ERROR_NO_DEVICE_CREDENTIAL -> { - Result.HardwareUnavailableOrDisabled + callback( + when (errorCode) { + BiometricPrompt.ERROR_CANCELED, BiometricPrompt.ERROR_USER_CANCELED, + BiometricPrompt.ERROR_NEGATIVE_BUTTON -> { + Result.Cancelled + } + + BiometricPrompt.ERROR_HW_NOT_PRESENT, BiometricPrompt.ERROR_HW_UNAVAILABLE, + BiometricPrompt.ERROR_NO_BIOMETRICS, BiometricPrompt.ERROR_NO_DEVICE_CREDENTIAL -> { + Result.HardwareUnavailableOrDisabled + } + + else -> Result.Failure(errorCode, fragment.getString(R.string.biometric_auth_error_reason, errString)) } - else -> Result.Failure(errorCode, fragment.getString(R.string.biometric_auth_error_reason, errString)) - }) + ) } override fun onAuthenticationFailed() { @@ -77,12 +66,12 @@ object BiometricAuthenticator { callback(Result.Success(result.cryptoObject)) } } - val biometricPrompt = BiometricPrompt(fragment, { handler.post(it) }, authCallback) + val biometricPrompt = BiometricPrompt(fragment, { Handler(Looper.getMainLooper()).post(it) }, authCallback) val promptInfo = BiometricPrompt.PromptInfo.Builder() - .setTitle(fragment.getString(dialogTitleRes)) - .setDeviceCredentialAllowed(true) - .build() - if (BiometricManager.from(fragment.requireContext()).canAuthenticate() == BiometricManager.BIOMETRIC_SUCCESS || isPinEnabled(fragment.requireContext())) { + .setTitle(fragment.getString(dialogTitleRes)) + .setAllowedAuthenticators(allowedAuthenticators) + .build() + if (BiometricManager.from(fragment.requireContext()).canAuthenticate(allowedAuthenticators) == BiometricManager.BIOMETRIC_SUCCESS) { biometricPrompt.authenticate(promptInfo) } else { callback(Result.HardwareUnavailableOrDisabled) diff --git a/ui/src/main/java/com/wireguard/android/util/ClipboardUtils.kt b/ui/src/main/java/com/wireguard/android/util/ClipboardUtils.kt index 51f6486f..8968979f 100644 --- a/ui/src/main/java/com/wireguard/android/util/ClipboardUtils.kt +++ b/ui/src/main/java/com/wireguard/android/util/ClipboardUtils.kt @@ -1,16 +1,18 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.util import android.content.ClipData import android.content.ClipboardManager +import android.os.Build import android.view.View import android.widget.TextView import androidx.core.content.getSystemService import com.google.android.material.snackbar.Snackbar import com.google.android.material.textfield.TextInputEditText +import com.wireguard.android.R /** * Standalone utilities for interacting with the system clipboard. @@ -28,6 +30,8 @@ object ClipboardUtils { } val service = view.context.getSystemService<ClipboardManager>() ?: return service.setPrimaryClip(ClipData.newPlainText(data.second, data.first)) - Snackbar.make(view, "${data.second} copied to clipboard", Snackbar.LENGTH_LONG).show() + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.TIRAMISU) { + Snackbar.make(view, view.context.getString(R.string.copied_to_clipboard, data.second), Snackbar.LENGTH_LONG).show() + } } } diff --git a/ui/src/main/java/com/wireguard/android/util/DownloadsFileSaver.kt b/ui/src/main/java/com/wireguard/android/util/DownloadsFileSaver.kt index a0f0e1fb..f78094b6 100644 --- a/ui/src/main/java/com/wireguard/android/util/DownloadsFileSaver.kt +++ b/ui/src/main/java/com/wireguard/android/util/DownloadsFileSaver.kt @@ -1,68 +1,100 @@ /* - * Copyright © 2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.util +import android.Manifest import android.content.ContentValues import android.content.Context +import android.content.pm.PackageManager import android.net.Uri import android.os.Build import android.os.Environment import android.provider.MediaStore import android.provider.MediaStore.MediaColumns +import androidx.activity.ComponentActivity +import androidx.activity.result.ActivityResultLauncher +import androidx.activity.result.contract.ActivityResultContracts +import androidx.core.content.ContextCompat import com.wireguard.android.R +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext import java.io.File import java.io.FileOutputStream import java.io.IOException import java.io.OutputStream -object DownloadsFileSaver { - @Throws(Exception::class) - fun save(context: Context, name: String, mimeType: String?, overwriteExisting: Boolean) = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) { - val contentResolver = context.contentResolver - if (overwriteExisting) - contentResolver.delete(MediaStore.Downloads.EXTERNAL_CONTENT_URI, String.format("%s = ?", MediaColumns.DISPLAY_NAME), arrayOf(name)) - val contentValues = ContentValues() - contentValues.put(MediaColumns.DISPLAY_NAME, name) - contentValues.put(MediaColumns.MIME_TYPE, mimeType) - val contentUri = contentResolver.insert(MediaStore.Downloads.EXTERNAL_CONTENT_URI, contentValues) +class DownloadsFileSaver(private val context: ComponentActivity) { + private lateinit var activityResult: ActivityResultLauncher<String> + private lateinit var futureGrant: CompletableDeferred<Boolean> + + init { + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.Q) { + futureGrant = CompletableDeferred() + activityResult = context.registerForActivityResult(ActivityResultContracts.RequestPermission()) { ret -> futureGrant.complete(ret) } + } + } + + suspend fun save(name: String, mimeType: String?, overwriteExisting: Boolean) = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) { + withContext(Dispatchers.IO) { + val contentResolver = context.contentResolver + if (overwriteExisting) + contentResolver.delete(MediaStore.Downloads.EXTERNAL_CONTENT_URI, String.format("%s = ?", MediaColumns.DISPLAY_NAME), arrayOf(name)) + val contentValues = ContentValues() + contentValues.put(MediaColumns.DISPLAY_NAME, name) + contentValues.put(MediaColumns.MIME_TYPE, mimeType) + val contentUri = contentResolver.insert(MediaStore.Downloads.EXTERNAL_CONTENT_URI, contentValues) ?: throw IOException(context.getString(R.string.create_downloads_file_error)) - val contentStream = contentResolver.openOutputStream(contentUri) + val contentStream = contentResolver.openOutputStream(contentUri) ?: throw IOException(context.getString(R.string.create_downloads_file_error)) - @Suppress("DEPRECATION") var cursor = contentResolver.query(contentUri, arrayOf(MediaColumns.DATA), null, null, null) - var path: String? = null - if (cursor != null) { - try { - if (cursor.moveToFirst()) - path = cursor.getString(0) - } finally { - cursor.close() - } - } - if (path == null) { - path = "Download/" - cursor = contentResolver.query(contentUri, arrayOf(MediaColumns.DISPLAY_NAME), null, null, null) + @Suppress("DEPRECATION") var cursor = contentResolver.query(contentUri, arrayOf(MediaColumns.DATA), null, null, null) + var path: String? = null if (cursor != null) { try { if (cursor.moveToFirst()) - path += cursor.getString(0) + path = cursor.getString(0) } finally { cursor.close() } } + if (path == null) { + path = "Download/" + cursor = contentResolver.query(contentUri, arrayOf(MediaColumns.DISPLAY_NAME), null, null, null) + if (cursor != null) { + try { + if (cursor.moveToFirst()) + path += cursor.getString(0) + } finally { + cursor.close() + } + } + } + DownloadsFile(context, contentStream, path, contentUri) } - DownloadsFile(context, contentStream, path, contentUri) } else { - @Suppress("DEPRECATION") val path = Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOWNLOADS) - val file = File(path, name) - if (!path.isDirectory && !path.mkdirs()) - throw IOException(context.getString(R.string.create_output_dir_error)) - DownloadsFile(context, FileOutputStream(file), file.absolutePath, null) + withContext(Dispatchers.Main.immediate) { + if (ContextCompat.checkSelfPermission(context, Manifest.permission.WRITE_EXTERNAL_STORAGE) != PackageManager.PERMISSION_GRANTED) { + activityResult.launch(Manifest.permission.WRITE_EXTERNAL_STORAGE) + val granted = futureGrant.await() + if (!granted) { + futureGrant = CompletableDeferred() + return@withContext null + } + } + @Suppress("DEPRECATION") val path = Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOWNLOADS) + withContext(Dispatchers.IO) { + val file = File(path, name) + if (!path.isDirectory && !path.mkdirs()) + throw IOException(context.getString(R.string.create_output_dir_error)) + DownloadsFile(context, FileOutputStream(file), file.absolutePath, null) + } + } } class DownloadsFile(private val context: Context, val outputStream: OutputStream, val fileName: String, private val uri: Uri?) { - fun delete() { + suspend fun delete() = withContext(Dispatchers.IO) { if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) context.contentResolver.delete(uri!!, null, null) else diff --git a/ui/src/main/java/com/wireguard/android/util/ErrorMessages.kt b/ui/src/main/java/com/wireguard/android/util/ErrorMessages.kt index d8ac94d9..4157ebf2 100644 --- a/ui/src/main/java/com/wireguard/android/util/ErrorMessages.kt +++ b/ui/src/main/java/com/wireguard/android/util/ErrorMessages.kt @@ -1,12 +1,14 @@ /* - * Copyright © 2018-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.util import android.content.res.Resources import android.os.RemoteException -import com.wireguard.android.Application.Companion.get +import com.google.zxing.ChecksumException +import com.google.zxing.NotFoundException +import com.wireguard.android.Application import com.wireguard.android.R import com.wireguard.android.backend.BackendException import com.wireguard.android.util.RootShell.RootShellException @@ -20,50 +22,51 @@ import java.net.InetAddress object ErrorMessages { private val BCE_REASON_MAP = mapOf( - BadConfigException.Reason.INVALID_KEY to R.string.bad_config_reason_invalid_key, - BadConfigException.Reason.INVALID_NUMBER to R.string.bad_config_reason_invalid_number, - BadConfigException.Reason.INVALID_VALUE to R.string.bad_config_reason_invalid_value, - BadConfigException.Reason.MISSING_ATTRIBUTE to R.string.bad_config_reason_missing_attribute, - BadConfigException.Reason.MISSING_SECTION to R.string.bad_config_reason_missing_section, - BadConfigException.Reason.SYNTAX_ERROR to R.string.bad_config_reason_syntax_error, - BadConfigException.Reason.UNKNOWN_ATTRIBUTE to R.string.bad_config_reason_unknown_attribute, - BadConfigException.Reason.UNKNOWN_SECTION to R.string.bad_config_reason_unknown_section + BadConfigException.Reason.INVALID_KEY to R.string.bad_config_reason_invalid_key, + BadConfigException.Reason.INVALID_NUMBER to R.string.bad_config_reason_invalid_number, + BadConfigException.Reason.INVALID_VALUE to R.string.bad_config_reason_invalid_value, + BadConfigException.Reason.MISSING_ATTRIBUTE to R.string.bad_config_reason_missing_attribute, + BadConfigException.Reason.MISSING_SECTION to R.string.bad_config_reason_missing_section, + BadConfigException.Reason.SYNTAX_ERROR to R.string.bad_config_reason_syntax_error, + BadConfigException.Reason.UNKNOWN_ATTRIBUTE to R.string.bad_config_reason_unknown_attribute, + BadConfigException.Reason.UNKNOWN_SECTION to R.string.bad_config_reason_unknown_section ) private val BE_REASON_MAP = mapOf( - BackendException.Reason.UNKNOWN_KERNEL_MODULE_NAME to R.string.module_version_error, - BackendException.Reason.WG_QUICK_CONFIG_ERROR_CODE to R.string.tunnel_config_error, - BackendException.Reason.TUNNEL_MISSING_CONFIG to R.string.no_config_error, - BackendException.Reason.VPN_NOT_AUTHORIZED to R.string.vpn_not_authorized_error, - BackendException.Reason.UNABLE_TO_START_VPN to R.string.vpn_start_error, - BackendException.Reason.TUN_CREATION_ERROR to R.string.tun_create_error, - BackendException.Reason.GO_ACTIVATION_ERROR_CODE to R.string.tunnel_on_error + BackendException.Reason.UNKNOWN_KERNEL_MODULE_NAME to R.string.module_version_error, + BackendException.Reason.WG_QUICK_CONFIG_ERROR_CODE to R.string.tunnel_config_error, + BackendException.Reason.TUNNEL_MISSING_CONFIG to R.string.no_config_error, + BackendException.Reason.VPN_NOT_AUTHORIZED to R.string.vpn_not_authorized_error, + BackendException.Reason.UNABLE_TO_START_VPN to R.string.vpn_start_error, + BackendException.Reason.TUN_CREATION_ERROR to R.string.tun_create_error, + BackendException.Reason.GO_ACTIVATION_ERROR_CODE to R.string.tunnel_on_error, + BackendException.Reason.DNS_RESOLUTION_FAILURE to R.string.tunnel_dns_failure ) private val KFE_FORMAT_MAP = mapOf( - Key.Format.BASE64 to R.string.key_length_explanation_base64, - Key.Format.BINARY to R.string.key_length_explanation_binary, - Key.Format.HEX to R.string.key_length_explanation_hex + Key.Format.BASE64 to R.string.key_length_explanation_base64, + Key.Format.BINARY to R.string.key_length_explanation_binary, + Key.Format.HEX to R.string.key_length_explanation_hex ) private val KFE_TYPE_MAP = mapOf( - KeyFormatException.Type.CONTENTS to R.string.key_contents_error, - KeyFormatException.Type.LENGTH to R.string.key_length_error + KeyFormatException.Type.CONTENTS to R.string.key_contents_error, + KeyFormatException.Type.LENGTH to R.string.key_length_error ) private val PE_CLASS_MAP = mapOf( - InetAddress::class.java to R.string.parse_error_inet_address, - InetEndpoint::class.java to R.string.parse_error_inet_endpoint, - InetNetwork::class.java to R.string.parse_error_inet_network, - Int::class.java to R.string.parse_error_integer + InetAddress::class.java to R.string.parse_error_inet_address, + InetEndpoint::class.java to R.string.parse_error_inet_endpoint, + InetNetwork::class.java to R.string.parse_error_inet_network, + Int::class.java to R.string.parse_error_integer ) private val RSE_REASON_MAP = mapOf( - RootShellException.Reason.NO_ROOT_ACCESS to R.string.error_root, - RootShellException.Reason.SHELL_MARKER_COUNT_ERROR to R.string.shell_marker_count_error, - RootShellException.Reason.SHELL_EXIT_STATUS_READ_ERROR to R.string.shell_exit_status_read_error, - RootShellException.Reason.SHELL_START_ERROR to R.string.shell_start_error, - RootShellException.Reason.CREATE_BIN_DIR_ERROR to R.string.create_bin_dir_error, - RootShellException.Reason.CREATE_TEMP_DIR_ERROR to R.string.create_temp_dir_error + RootShellException.Reason.NO_ROOT_ACCESS to R.string.error_root, + RootShellException.Reason.SHELL_MARKER_COUNT_ERROR to R.string.shell_marker_count_error, + RootShellException.Reason.SHELL_EXIT_STATUS_READ_ERROR to R.string.shell_exit_status_read_error, + RootShellException.Reason.SHELL_START_ERROR to R.string.shell_start_error, + RootShellException.Reason.CREATE_BIN_DIR_ERROR to R.string.create_bin_dir_error, + RootShellException.Reason.CREATE_TEMP_DIR_ERROR to R.string.create_temp_dir_error ) operator fun get(throwable: Throwable?): String { - val resources = get().resources + val resources = Application.get().resources if (throwable == null) return resources.getString(R.string.unknown_error) val rootCause = rootCause(throwable) return when { @@ -77,15 +80,27 @@ object ErrorMessages { val explanation = getBadConfigExceptionExplanation(resources, rootCause) resources.getString(R.string.bad_config_error, reason, context) + explanation } + rootCause is BackendException -> { resources.getString(BE_REASON_MAP.getValue(rootCause.reason), *rootCause.format) } + rootCause is RootShellException -> { resources.getString(RSE_REASON_MAP.getValue(rootCause.reason), *rootCause.format) } - rootCause.message != null -> { - rootCause.message!! + + rootCause is NotFoundException -> { + resources.getString(R.string.error_no_qr_found) + } + + rootCause is ChecksumException -> { + resources.getString(R.string.error_qr_checksum) + } + + rootCause.localizedMessage != null -> { + rootCause.localizedMessage!! } + else -> { val errorType = rootCause.javaClass.simpleName resources.getString(R.string.generic_error, errorType) @@ -93,14 +108,16 @@ object ErrorMessages { } } - private fun getBadConfigExceptionExplanation(resources: Resources, - bce: BadConfigException): String { + private fun getBadConfigExceptionExplanation( + resources: Resources, + bce: BadConfigException + ): String { if (bce.cause is KeyFormatException) { val kfe = bce.cause as KeyFormatException? if (kfe!!.type == KeyFormatException.Type.LENGTH) return resources.getString(KFE_FORMAT_MAP.getValue(kfe.format)) } else if (bce.cause is ParseException) { val pe = bce.cause as ParseException? - if (pe!!.message != null) return ": ${pe.message}" + if (pe!!.localizedMessage != null) return ": ${pe.localizedMessage}" } else if (bce.location == BadConfigException.Location.LISTEN_PORT) { return resources.getString(R.string.bad_config_explanation_udp_port) } else if (bce.location == BadConfigException.Location.MTU) { @@ -111,8 +128,10 @@ object ErrorMessages { return "" } - private fun getBadConfigExceptionReason(resources: Resources, - bce: BadConfigException): String { + private fun getBadConfigExceptionReason( + resources: Resources, + bce: BadConfigException + ): String { if (bce.cause is KeyFormatException) { val kfe = bce.cause as KeyFormatException? return resources.getString(KFE_TYPE_MAP.getValue(kfe!!.type)) @@ -128,7 +147,8 @@ object ErrorMessages { var cause = throwable while (cause.cause != null) { if (cause is BadConfigException || cause is BackendException || - cause is RootShellException) break + cause is RootShellException + ) break val nextCause = cause.cause!! if (nextCause is RemoteException) break cause = nextCause diff --git a/ui/src/main/java/com/wireguard/android/util/ExceptionLoggers.kt b/ui/src/main/java/com/wireguard/android/util/ExceptionLoggers.kt deleted file mode 100644 index 4470134c..00000000 --- a/ui/src/main/java/com/wireguard/android/util/ExceptionLoggers.kt +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package com.wireguard.android.util - -import android.util.Log -import java9.util.function.BiConsumer - -/** - * Helpers for logging exceptions from asynchronous tasks. These can be passed to - * `CompletionStage.whenComplete()` at the end of an asynchronous future chain. - */ -enum class ExceptionLoggers(private val priority: Int) : BiConsumer<Any?, Throwable?> { - D(Log.DEBUG), E(Log.ERROR); - - override fun accept(result: Any?, throwable: Throwable?) { - if (throwable != null) - Log.println(Log.ERROR, TAG, Log.getStackTraceString(throwable)) - else if (priority <= Log.DEBUG) - Log.println(priority, TAG, "Future completed successfully") - } - - companion object { - private const val TAG = "WireGuard/ExceptionLoggers" - } -} diff --git a/ui/src/main/java/com/wireguard/android/util/Extensions.kt b/ui/src/main/java/com/wireguard/android/util/Extensions.kt index a705401f..c4b43951 100644 --- a/ui/src/main/java/com/wireguard/android/util/Extensions.kt +++ b/ui/src/main/java/com/wireguard/android/util/Extensions.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2020 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ @@ -8,7 +8,11 @@ package com.wireguard.android.util import android.content.Context import android.util.TypedValue import androidx.annotation.AttrRes -import androidx.fragment.app.Fragment +import androidx.lifecycle.lifecycleScope +import androidx.preference.Preference +import com.wireguard.android.Application +import com.wireguard.android.activity.SettingsActivity +import kotlinx.coroutines.CoroutineScope fun Context.resolveAttribute(@AttrRes attrRes: Int): Int { val typedValue = TypedValue() @@ -16,6 +20,12 @@ fun Context.resolveAttribute(@AttrRes attrRes: Int): Int { return typedValue.data } -fun Fragment.requireTargetFragment(): Fragment { - return requireNotNull(targetFragment) { "A target fragment should always be set for $this" } -} +val Any.applicationScope: CoroutineScope + get() = Application.getCoroutineScope() + +val Preference.activity: SettingsActivity + get() = context as? SettingsActivity + ?: throw IllegalStateException("Failed to resolve SettingsActivity") + +val Preference.lifecycleScope: CoroutineScope + get() = activity.lifecycleScope diff --git a/ui/src/main/java/com/wireguard/android/util/FragmentUtils.kt b/ui/src/main/java/com/wireguard/android/util/FragmentUtils.kt deleted file mode 100644 index 90e7ab0c..00000000 --- a/ui/src/main/java/com/wireguard/android/util/FragmentUtils.kt +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package com.wireguard.android.util - -import android.view.ContextThemeWrapper -import androidx.preference.Preference -import com.wireguard.android.activity.SettingsActivity - -object FragmentUtils { - fun getPrefActivity(preference: Preference): SettingsActivity { - val context = preference.context - if (context is ContextThemeWrapper) { - if (context is SettingsActivity) { - return context - } - } - throw IllegalStateException("Failed to resolve SettingsActivity") - } -} diff --git a/ui/src/main/java/com/wireguard/android/util/QrCodeFromFileScanner.kt b/ui/src/main/java/com/wireguard/android/util/QrCodeFromFileScanner.kt new file mode 100644 index 00000000..4ea2dc7a --- /dev/null +++ b/ui/src/main/java/com/wireguard/android/util/QrCodeFromFileScanner.kt @@ -0,0 +1,84 @@ +/* + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.util + +import android.content.ContentResolver +import android.graphics.Bitmap +import android.graphics.BitmapFactory +import android.net.Uri +import android.util.Log +import com.google.zxing.BinaryBitmap +import com.google.zxing.DecodeHintType +import com.google.zxing.NotFoundException +import com.google.zxing.RGBLuminanceSource +import com.google.zxing.Reader +import com.google.zxing.Result +import com.google.zxing.common.HybridBinarizer +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext + +/** + * Encapsulates the logic of scanning a barcode from a file, + * @property contentResolver - Resolver to read the incoming data + * @property reader - An instance of zxing's [Reader] class to parse the image + */ +class QrCodeFromFileScanner( + private val contentResolver: ContentResolver, + private val reader: Reader, +) { + private fun scanBitmapForResult(source: Bitmap): Result { + val width = source.width + val height = source.height + val pixels = IntArray(width * height) + source.getPixels(pixels, 0, width, 0, 0, width, height) + + val bBitmap = BinaryBitmap(HybridBinarizer(RGBLuminanceSource(width, height, pixels))) + return reader.decode(bBitmap, mapOf(DecodeHintType.TRY_HARDER to true)) + } + + private fun doScan(data: Uri): Result { + Log.d(TAG, "Starting to scan an image: $data") + contentResolver.openInputStream(data).use { inputStream -> + var bitmap: Bitmap? = null + var firstException: Throwable? = null + for (i in arrayOf(1, 2, 4, 8, 16, 32, 64, 128)) { + try { + val options = BitmapFactory.Options() + options.inSampleSize = i + bitmap = BitmapFactory.decodeStream(inputStream, null, options) + ?: throw IllegalArgumentException("Can't decode stream for bitmap") + return scanBitmapForResult(bitmap) + } catch (e: Throwable) { + bitmap?.recycle() + System.gc() + Log.e(TAG, "Original image scan at scale factor $i finished with error: $e") + if (firstException == null) + firstException = e + } + } + throw Exception(firstException) + } + } + + /** + * Attempts to parse incoming data + * @return result of the decoding operation + * @throws NotFoundException when parser didn't find QR code in the image + */ + suspend fun scan(data: Uri) = withContext(Dispatchers.Default) { doScan(data) } + + companion object { + private const val TAG = "QrCodeFromFileScanner" + + /** + * Given a reference to a file, check if this file could be parsed by this class + * @return true if the file can be parsed, false if not + */ + fun validContentType(contentResolver: ContentResolver, data: Uri): Boolean { + return contentResolver.getType(data)?.startsWith("image/") == true + } + } +} diff --git a/ui/src/main/java/com/wireguard/android/util/QuantityFormatter.kt b/ui/src/main/java/com/wireguard/android/util/QuantityFormatter.kt new file mode 100644 index 00000000..2fbb5c29 --- /dev/null +++ b/ui/src/main/java/com/wireguard/android/util/QuantityFormatter.kt @@ -0,0 +1,66 @@ +/* + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.util + +import android.icu.text.ListFormatter +import android.icu.text.MeasureFormat +import android.icu.text.RelativeDateTimeFormatter +import android.icu.util.Measure +import android.icu.util.MeasureUnit +import android.os.Build +import com.wireguard.android.Application +import com.wireguard.android.R +import java.util.Locale +import kotlin.time.Duration.Companion.seconds + +object QuantityFormatter { + fun formatBytes(bytes: Long): String { + val context = Application.get().applicationContext + return when { + bytes < 1024 -> context.getString(R.string.transfer_bytes, bytes) + bytes < 1024 * 1024 -> context.getString(R.string.transfer_kibibytes, bytes / 1024.0) + bytes < 1024 * 1024 * 1024 -> context.getString(R.string.transfer_mibibytes, bytes / (1024.0 * 1024.0)) + bytes < 1024 * 1024 * 1024 * 1024L -> context.getString(R.string.transfer_gibibytes, bytes / (1024.0 * 1024.0 * 1024.0)) + else -> context.getString(R.string.transfer_tibibytes, bytes / (1024.0 * 1024.0 * 1024.0) / 1024.0) + } + } + + fun formatEpochAgo(epochMillis: Long): String { + var span = (System.currentTimeMillis() - epochMillis) / 1000 + + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.N) + return Application.get().applicationContext.getString(R.string.latest_handshake_ago, span.seconds.toString()) + + if (span <= 0L) + return RelativeDateTimeFormatter.getInstance().format(RelativeDateTimeFormatter.Direction.PLAIN, RelativeDateTimeFormatter.AbsoluteUnit.NOW) + val measureFormat = MeasureFormat.getInstance(Locale.getDefault(), MeasureFormat.FormatWidth.WIDE) + val parts = ArrayList<CharSequence>(4) + if (span >= 24 * 60 * 60L) { + val v = span / (24 * 60 * 60L) + parts.add(measureFormat.format(Measure(v, MeasureUnit.DAY))) + span -= v * (24 * 60 * 60L) + } + if (span >= 60 * 60L) { + val v = span / (60 * 60L) + parts.add(measureFormat.format(Measure(v, MeasureUnit.HOUR))) + span -= v * (60 * 60L) + } + if (span >= 60L) { + val v = span / 60L + parts.add(measureFormat.format(Measure(v, MeasureUnit.MINUTE))) + span -= v * 60L + } + if (span > 0L) + parts.add(measureFormat.format(Measure(span, MeasureUnit.SECOND))) + + val joined = if (Build.VERSION.SDK_INT < Build.VERSION_CODES.TIRAMISU) + parts.joinToString() + else + ListFormatter.getInstance(Locale.getDefault(), ListFormatter.Type.UNITS, ListFormatter.Width.SHORT).format(parts) + + return Application.get().applicationContext.getString(R.string.latest_handshake_ago, joined) + } +}
\ No newline at end of file diff --git a/ui/src/main/java/com/wireguard/android/util/TunnelImporter.kt b/ui/src/main/java/com/wireguard/android/util/TunnelImporter.kt new file mode 100644 index 00000000..18a37ef6 --- /dev/null +++ b/ui/src/main/java/com/wireguard/android/util/TunnelImporter.kt @@ -0,0 +1,152 @@ +/* + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.util + +import android.content.ContentResolver +import android.net.Uri +import android.provider.OpenableColumns +import android.util.Log +import androidx.fragment.app.FragmentManager +import com.wireguard.android.Application +import com.wireguard.android.R +import com.wireguard.android.fragment.ConfigNamingDialogFragment +import com.wireguard.android.model.ObservableTunnel +import com.wireguard.config.Config +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.async +import kotlinx.coroutines.withContext +import java.io.BufferedReader +import java.io.ByteArrayInputStream +import java.io.InputStreamReader +import java.nio.charset.StandardCharsets +import java.util.zip.ZipEntry +import java.util.zip.ZipInputStream + +object TunnelImporter { + suspend fun importTunnel(contentResolver: ContentResolver, uri: Uri, messageCallback: (CharSequence) -> Unit) = withContext(Dispatchers.IO) { + val context = Application.get().applicationContext + val futureTunnels = ArrayList<Deferred<ObservableTunnel>>() + val throwables = ArrayList<Throwable>() + try { + val columns = arrayOf(OpenableColumns.DISPLAY_NAME) + var name = "" + contentResolver.query(uri, columns, null, null, null)?.use { cursor -> + if (cursor.moveToFirst() && !cursor.isNull(0)) { + name = cursor.getString(0) + } + } + if (name.isEmpty()) { + name = Uri.decode(uri.lastPathSegment) + } + var idx = name.lastIndexOf('/') + if (idx >= 0) { + require(idx < name.length - 1) { context.getString(R.string.illegal_filename_error, name) } + name = name.substring(idx + 1) + } + val isZip = name.lowercase().endsWith(".zip") + if (name.lowercase().endsWith(".conf")) { + name = name.substring(0, name.length - ".conf".length) + } else { + require(isZip) { context.getString(R.string.bad_extension_error) } + } + + if (isZip) { + ZipInputStream(contentResolver.openInputStream(uri)).use { zip -> + val reader = BufferedReader(InputStreamReader(zip, StandardCharsets.UTF_8)) + var entry: ZipEntry? + while (true) { + entry = zip.nextEntry ?: break + name = entry.name + idx = name.lastIndexOf('/') + if (idx >= 0) { + if (idx >= name.length - 1) { + continue + } + name = name.substring(name.lastIndexOf('/') + 1) + } + if (name.lowercase().endsWith(".conf")) { + name = name.substring(0, name.length - ".conf".length) + } else { + continue + } + try { + Config.parse(reader) + } catch (e: Throwable) { + throwables.add(e) + null + }?.let { + val nameCopy = name + futureTunnels.add(async(SupervisorJob()) { Application.getTunnelManager().create(nameCopy, it) }) + } + } + } + } else { + futureTunnels.add(async(SupervisorJob()) { Application.getTunnelManager().create(name, Config.parse(contentResolver.openInputStream(uri)!!)) }) + } + + if (futureTunnels.isEmpty()) { + if (throwables.size == 1) { + throw throwables[0] + } else { + require(throwables.isNotEmpty()) { context.getString(R.string.no_configs_error) } + } + } + val tunnels = futureTunnels.mapNotNull { + try { + it.await() + } catch (e: Throwable) { + throwables.add(e) + null + } + } + withContext(Dispatchers.Main.immediate) { onTunnelImportFinished(tunnels, throwables, messageCallback) } + } catch (e: Throwable) { + withContext(Dispatchers.Main.immediate) { onTunnelImportFinished(emptyList(), listOf(e), messageCallback) } + } + } + + fun importTunnel(parentFragmentManager: FragmentManager, configText: String, messageCallback: (CharSequence) -> Unit) { + try { + // Ensure the config text is parseable before proceeding… + Config.parse(ByteArrayInputStream(configText.toByteArray(StandardCharsets.UTF_8))) + + // Config text is valid, now create the tunnel… + ConfigNamingDialogFragment.newInstance(configText).show(parentFragmentManager, null) + } catch (e: Throwable) { + onTunnelImportFinished(emptyList(), listOf<Throwable>(e), messageCallback) + } + } + + private fun onTunnelImportFinished(tunnels: List<ObservableTunnel>, throwables: Collection<Throwable>, messageCallback: (CharSequence) -> Unit) { + val context = Application.get().applicationContext + var message = "" + for (throwable in throwables) { + val error = ErrorMessages[throwable] + message = context.getString(R.string.import_error, error) + Log.e(TAG, message, throwable) + } + if (tunnels.size == 1 && throwables.isEmpty()) + message = context.getString(R.string.import_success, tunnels[0].name) + else if (tunnels.isEmpty() && throwables.size == 1) + else if (throwables.isEmpty()) + message = context.resources.getQuantityString( + R.plurals.import_total_success, + tunnels.size, tunnels.size + ) + else if (!throwables.isEmpty()) + message = context.resources.getQuantityString( + R.plurals.import_partial_success, + tunnels.size + throwables.size, + tunnels.size, tunnels.size + throwables.size + ) + + messageCallback(message) + } + + private const val TAG = "WireGuard/TunnelImporter" +}
\ No newline at end of file diff --git a/ui/src/main/java/com/wireguard/android/util/UserKnobs.kt b/ui/src/main/java/com/wireguard/android/util/UserKnobs.kt new file mode 100644 index 00000000..ca051739 --- /dev/null +++ b/ui/src/main/java/com/wireguard/android/util/UserKnobs.kt @@ -0,0 +1,121 @@ +/* + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.util + +import androidx.datastore.preferences.core.booleanPreferencesKey +import androidx.datastore.preferences.core.edit +import androidx.datastore.preferences.core.stringPreferencesKey +import androidx.datastore.preferences.core.stringSetPreferencesKey +import com.wireguard.android.Application +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map + +object UserKnobs { + private val ENABLE_KERNEL_MODULE = booleanPreferencesKey("enable_kernel_module") + val enableKernelModule: Flow<Boolean> + get() = Application.getPreferencesDataStore().data.map { + it[ENABLE_KERNEL_MODULE] ?: false + } + + suspend fun setEnableKernelModule(enable: Boolean?) { + Application.getPreferencesDataStore().edit { + if (enable == null) + it.remove(ENABLE_KERNEL_MODULE) + else + it[ENABLE_KERNEL_MODULE] = enable + } + } + + private val MULTIPLE_TUNNELS = booleanPreferencesKey("multiple_tunnels") + val multipleTunnels: Flow<Boolean> + get() = Application.getPreferencesDataStore().data.map { + it[MULTIPLE_TUNNELS] ?: false + } + + private val DARK_THEME = booleanPreferencesKey("dark_theme") + val darkTheme: Flow<Boolean> + get() = Application.getPreferencesDataStore().data.map { + it[DARK_THEME] ?: false + } + + suspend fun setDarkTheme(on: Boolean) { + Application.getPreferencesDataStore().edit { + it[DARK_THEME] = on + } + } + + private val ALLOW_REMOTE_CONTROL_INTENTS = booleanPreferencesKey("allow_remote_control_intents") + val allowRemoteControlIntents: Flow<Boolean> + get() = Application.getPreferencesDataStore().data.map { + it[ALLOW_REMOTE_CONTROL_INTENTS] ?: false + } + + private val RESTORE_ON_BOOT = booleanPreferencesKey("restore_on_boot") + val restoreOnBoot: Flow<Boolean> + get() = Application.getPreferencesDataStore().data.map { + it[RESTORE_ON_BOOT] ?: false + } + + private val LAST_USED_TUNNEL = stringPreferencesKey("last_used_tunnel") + val lastUsedTunnel: Flow<String?> + get() = Application.getPreferencesDataStore().data.map { + it[LAST_USED_TUNNEL] + } + + suspend fun setLastUsedTunnel(lastUsedTunnel: String?) { + Application.getPreferencesDataStore().edit { + if (lastUsedTunnel == null) + it.remove(LAST_USED_TUNNEL) + else + it[LAST_USED_TUNNEL] = lastUsedTunnel + } + } + + private val RUNNING_TUNNELS = stringSetPreferencesKey("enabled_configs") + val runningTunnels: Flow<Set<String>> + get() = Application.getPreferencesDataStore().data.map { + it[RUNNING_TUNNELS] ?: emptySet() + } + + suspend fun setRunningTunnels(runningTunnels: Set<String>) { + Application.getPreferencesDataStore().edit { + if (runningTunnels.isEmpty()) + it.remove(RUNNING_TUNNELS) + else + it[RUNNING_TUNNELS] = runningTunnels + } + } + + private val UPDATER_NEWER_VERSION_SEEN = stringPreferencesKey("updater_newer_version_seen") + val updaterNewerVersionSeen: Flow<String?> + get() = Application.getPreferencesDataStore().data.map { + it[UPDATER_NEWER_VERSION_SEEN] + } + + suspend fun setUpdaterNewerVersionSeen(newerVersionSeen: String?) { + Application.getPreferencesDataStore().edit { + if (newerVersionSeen == null) + it.remove(UPDATER_NEWER_VERSION_SEEN) + else + it[UPDATER_NEWER_VERSION_SEEN] = newerVersionSeen + } + } + + private val UPDATER_NEWER_VERSION_CONSENTED = stringPreferencesKey("updater_newer_version_consented") + val updaterNewerVersionConsented: Flow<String?> + get() = Application.getPreferencesDataStore().data.map { + it[UPDATER_NEWER_VERSION_CONSENTED] + } + + suspend fun setUpdaterNewerVersionConsented(newerVersionConsented: String?) { + Application.getPreferencesDataStore().edit { + if (newerVersionConsented == null) + it.remove(UPDATER_NEWER_VERSION_CONSENTED) + else + it[UPDATER_NEWER_VERSION_CONSENTED] = newerVersionConsented + } + } +} diff --git a/ui/src/main/java/com/wireguard/android/viewmodel/ConfigProxy.kt b/ui/src/main/java/com/wireguard/android/viewmodel/ConfigProxy.kt index 36774c3f..7f39b461 100644 --- a/ui/src/main/java/com/wireguard/android/viewmodel/ConfigProxy.kt +++ b/ui/src/main/java/com/wireguard/android/viewmodel/ConfigProxy.kt @@ -1,25 +1,30 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.viewmodel +import android.os.Build import android.os.Parcel import android.os.Parcelable +import androidx.core.os.ParcelCompat import androidx.databinding.ObservableArrayList import androidx.databinding.ObservableList import com.wireguard.config.BadConfigException import com.wireguard.config.Config import com.wireguard.config.Peer -import java.util.ArrayList class ConfigProxy : Parcelable { val `interface`: InterfaceProxy val peers: ObservableList<PeerProxy> = ObservableArrayList() private constructor(parcel: Parcel) { - `interface` = parcel.readParcelable(InterfaceProxy::class.java.classLoader)!! - parcel.readTypedList(peers, PeerProxy.CREATOR) + `interface` = ParcelCompat.readParcelable(parcel, InterfaceProxy::class.java.classLoader, InterfaceProxy::class.java) ?: InterfaceProxy() + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) { + ParcelCompat.readParcelableList(parcel, peers, PeerProxy::class.java.classLoader, PeerProxy::class.java) + } else { + parcel.readTypedList(peers, PeerProxy.CREATOR) + } peers.forEach { it.bind(this) } } @@ -50,14 +55,18 @@ class ConfigProxy : Parcelable { val resolvedPeers: MutableCollection<Peer> = ArrayList() peers.forEach { resolvedPeers.add(it.resolve()) } return Config.Builder() - .setInterface(`interface`.resolve()) - .addPeers(resolvedPeers) - .build() + .setInterface(`interface`.resolve()) + .addPeers(resolvedPeers) + .build() } override fun writeToParcel(dest: Parcel, flags: Int) { dest.writeParcelable(`interface`, flags) - dest.writeTypedList(peers) + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) { + dest.writeParcelableList(peers, flags) + } else { + dest.writeTypedList(peers) + } } private class ConfigProxyCreator : Parcelable.Creator<ConfigProxy> { diff --git a/ui/src/main/java/com/wireguard/android/viewmodel/InterfaceProxy.kt b/ui/src/main/java/com/wireguard/android/viewmodel/InterfaceProxy.kt index bd2a9831..25c2fd19 100644 --- a/ui/src/main/java/com/wireguard/android/viewmodel/InterfaceProxy.kt +++ b/ui/src/main/java/com/wireguard/android/viewmodel/InterfaceProxy.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.viewmodel @@ -81,7 +81,7 @@ class InterfaceProxy : BaseObservable, Parcelable { constructor(other: Interface) { addresses = Attribute.join(other.addresses) - val dnsServerStrings = other.dnsServers.map { it.hostAddress } + val dnsServerStrings = other.dnsServers.map { it.hostAddress }.plus(other.dnsSearchDomains) dnsServers = Attribute.join(dnsServerStrings) excludedApplications.addAll(other.excludedApplications) includedApplications.addAll(other.includedApplications) diff --git a/ui/src/main/java/com/wireguard/android/viewmodel/PeerProxy.kt b/ui/src/main/java/com/wireguard/android/viewmodel/PeerProxy.kt index 6ac04f72..15bf8a08 100644 --- a/ui/src/main/java/com/wireguard/android/viewmodel/PeerProxy.kt +++ b/ui/src/main/java/com/wireguard/android/viewmodel/PeerProxy.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.viewmodel @@ -16,8 +16,6 @@ import com.wireguard.config.Attribute import com.wireguard.config.BadConfigException import com.wireguard.config.Peer import java.lang.ref.WeakReference -import java.util.ArrayList -import java.util.LinkedHashSet class PeerProxy : BaseObservable, Parcelable { private val dnsRoutes: MutableList<String?> = ArrayList() @@ -240,24 +238,32 @@ class PeerProxy : BaseObservable, Parcelable { peerProxy.setTotalPeers(sender.size) } - override fun onItemRangeChanged(sender: ObservableList<PeerProxy?>, - positionStart: Int, itemCount: Int) { + override fun onItemRangeChanged( + sender: ObservableList<PeerProxy?>, + positionStart: Int, itemCount: Int + ) { // Do nothing. } - override fun onItemRangeInserted(sender: ObservableList<PeerProxy?>, - positionStart: Int, itemCount: Int) { + override fun onItemRangeInserted( + sender: ObservableList<PeerProxy?>, + positionStart: Int, itemCount: Int + ) { onChanged(sender) } - override fun onItemRangeMoved(sender: ObservableList<PeerProxy?>, - fromPosition: Int, toPosition: Int, - itemCount: Int) { + override fun onItemRangeMoved( + sender: ObservableList<PeerProxy?>, + fromPosition: Int, toPosition: Int, + itemCount: Int + ) { // Do nothing. } - override fun onItemRangeRemoved(sender: ObservableList<PeerProxy?>, - positionStart: Int, itemCount: Int) { + override fun onItemRangeRemoved( + sender: ObservableList<PeerProxy?>, + positionStart: Int, itemCount: Int + ) { onChanged(sender) } } @@ -276,12 +282,12 @@ class PeerProxy : BaseObservable, Parcelable { @JvmField val CREATOR: Parcelable.Creator<PeerProxy> = PeerProxyCreator() private val IPV4_PUBLIC_NETWORKS = setOf( - "0.0.0.0/5", "8.0.0.0/7", "11.0.0.0/8", "12.0.0.0/6", "16.0.0.0/4", "32.0.0.0/3", - "64.0.0.0/2", "128.0.0.0/3", "160.0.0.0/5", "168.0.0.0/6", "172.0.0.0/12", - "172.32.0.0/11", "172.64.0.0/10", "172.128.0.0/9", "173.0.0.0/8", "174.0.0.0/7", - "176.0.0.0/4", "192.0.0.0/9", "192.128.0.0/11", "192.160.0.0/13", "192.169.0.0/16", - "192.170.0.0/15", "192.172.0.0/14", "192.176.0.0/12", "192.192.0.0/10", - "193.0.0.0/8", "194.0.0.0/7", "196.0.0.0/6", "200.0.0.0/5", "208.0.0.0/4" + "0.0.0.0/5", "8.0.0.0/7", "11.0.0.0/8", "12.0.0.0/6", "16.0.0.0/4", "32.0.0.0/3", + "64.0.0.0/2", "128.0.0.0/3", "160.0.0.0/5", "168.0.0.0/6", "172.0.0.0/12", + "172.32.0.0/11", "172.64.0.0/10", "172.128.0.0/9", "173.0.0.0/8", "174.0.0.0/7", + "176.0.0.0/4", "192.0.0.0/9", "192.128.0.0/11", "192.160.0.0/13", "192.169.0.0/16", + "192.170.0.0/15", "192.172.0.0/14", "192.176.0.0/12", "192.192.0.0/10", + "193.0.0.0/8", "194.0.0.0/7", "196.0.0.0/6", "200.0.0.0/5", "208.0.0.0/4" ) private val IPV4_WILDCARD = setOf("0.0.0.0/0") } diff --git a/ui/src/main/java/com/wireguard/android/widget/EdgeToEdge.kt b/ui/src/main/java/com/wireguard/android/widget/EdgeToEdge.kt deleted file mode 100644 index 3f109b37..00000000 --- a/ui/src/main/java/com/wireguard/android/widget/EdgeToEdge.kt +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright © 2020 WireGuard LLC. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package com.wireguard.android.widget - -import android.view.View -import android.view.ViewGroup -import androidx.core.view.marginBottom -import androidx.core.view.marginLeft -import androidx.core.view.marginRight -import androidx.core.view.marginTop -import androidx.core.view.updateLayoutParams -import androidx.core.view.updatePadding -import com.google.android.material.floatingactionbutton.FloatingActionButton - -/** - * A utility for edge-to-edge display. It provides several features needed to make the app - * displayed edge-to-edge on Android Q with gestural navigation. - */ - -object EdgeToEdge { - @JvmStatic - fun setUpRoot(root: ViewGroup) { - root.systemUiVisibility = - View.SYSTEM_UI_FLAG_LAYOUT_HIDE_NAVIGATION or View.SYSTEM_UI_FLAG_LAYOUT_STABLE - } - - @JvmStatic - fun setUpScrollingContent(scrollingContent: ViewGroup, fab: FloatingActionButton?) { - val originalPaddingLeft = scrollingContent.paddingLeft - val originalPaddingRight = scrollingContent.paddingRight - val originalPaddingBottom = scrollingContent.paddingBottom - - val fabPaddingBottom = fab?.height ?: 0 - - val originalMarginTop = scrollingContent.marginTop - - scrollingContent.setOnApplyWindowInsetsListener { _, windowInsets -> - scrollingContent.updatePadding( - left = originalPaddingLeft + windowInsets.systemWindowInsetLeft, - right = originalPaddingRight + windowInsets.systemWindowInsetRight, - bottom = originalPaddingBottom + fabPaddingBottom + windowInsets.systemWindowInsetBottom - ) - scrollingContent.updateLayoutParams<ViewGroup.MarginLayoutParams> { - topMargin = originalMarginTop + windowInsets.systemWindowInsetTop - } - windowInsets - } - } - - @JvmStatic - fun setUpFAB(fab: FloatingActionButton) { - val originalMarginLeft = fab.marginLeft - val originalMarginRight = fab.marginRight - val originalMarginBottom = fab.marginBottom - fab.setOnApplyWindowInsetsListener { _, windowInsets -> - fab.updateLayoutParams<ViewGroup.MarginLayoutParams> { - leftMargin = originalMarginLeft + windowInsets.systemWindowInsetLeft - rightMargin = originalMarginRight + windowInsets.systemWindowInsetRight - bottomMargin = originalMarginBottom + windowInsets.systemWindowInsetBottom - } - windowInsets - } - } -} diff --git a/ui/src/main/java/com/wireguard/android/widget/KeyInputFilter.kt b/ui/src/main/java/com/wireguard/android/widget/KeyInputFilter.kt index 951af699..548760c5 100644 --- a/ui/src/main/java/com/wireguard/android/widget/KeyInputFilter.kt +++ b/ui/src/main/java/com/wireguard/android/widget/KeyInputFilter.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.widget @@ -13,10 +13,12 @@ import com.wireguard.crypto.Key * InputFilter for entering WireGuard private/public keys encoded with base64. */ class KeyInputFilter : InputFilter { - override fun filter(source: CharSequence, - sStart: Int, sEnd: Int, - dest: Spanned, - dStart: Int, dEnd: Int): CharSequence? { + override fun filter( + source: CharSequence, + sStart: Int, sEnd: Int, + dest: Spanned, + dStart: Int, dEnd: Int + ): CharSequence? { var replacement: SpannableStringBuilder? = null var rIndex = 0 val dLength = dest.length @@ -26,8 +28,9 @@ class KeyInputFilter : InputFilter { // Restrict characters to the base64 character set. // Ensure adding this character does not push the length over the limit. if ((dIndex + 1 < Key.Format.BASE64.length && isAllowed(c) || - dIndex + 1 == Key.Format.BASE64.length && c == '=') && - dLength + (sIndex - sStart) < Key.Format.BASE64.length) { + dIndex + 1 == Key.Format.BASE64.length && c == '=') && + dLength + (sIndex - sStart) < Key.Format.BASE64.length + ) { ++rIndex } else { if (replacement == null) replacement = SpannableStringBuilder(source, sStart, sEnd) diff --git a/ui/src/main/java/com/wireguard/android/widget/MonkeyedTextInputEditText.kt b/ui/src/main/java/com/wireguard/android/widget/MonkeyedTextInputEditText.kt deleted file mode 100644 index 97746c09..00000000 --- a/ui/src/main/java/com/wireguard/android/widget/MonkeyedTextInputEditText.kt +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright © 2020 WireGuard LLC. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package com.wireguard.android.widget - -import android.content.Context -import android.text.Editable -import android.text.SpannableStringBuilder -import android.util.AttributeSet -import com.google.android.material.R -import com.google.android.material.textfield.TextInputEditText -import com.google.android.material.textfield.TextInputLayout - -class MonkeyedTextInputEditText @JvmOverloads constructor(context: Context, attrs: AttributeSet? = null, defStyleAttr: Int = R.attr.editTextStyle) : TextInputEditText(context, attrs, defStyleAttr) { - @Override - override fun getText(): Editable? { - val text = super.getText() - if (!text.isNullOrEmpty()) - return text - /* We want this expression in TextInputLayout.java to be true if there's a hint set: - * final boolean hasText = editText != null && !TextUtils.isEmpty(editText.getText()); - * But for everyone else it should return the real value, so we check the caller. - */ - if (!hint.isNullOrEmpty() && Thread.currentThread().stackTrace[3].className == TextInputLayout::class.qualifiedName) - return SpannableStringBuilder(hint) - return text - } -} diff --git a/ui/src/main/java/com/wireguard/android/widget/MultiselectableRelativeLayout.kt b/ui/src/main/java/com/wireguard/android/widget/MultiselectableRelativeLayout.kt index 5312f5b9..9b3ec401 100644 --- a/ui/src/main/java/com/wireguard/android/widget/MultiselectableRelativeLayout.kt +++ b/ui/src/main/java/com/wireguard/android/widget/MultiselectableRelativeLayout.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.widget @@ -11,10 +11,10 @@ import android.widget.RelativeLayout import com.wireguard.android.R class MultiselectableRelativeLayout @JvmOverloads constructor( - context: Context? = null, - attrs: AttributeSet? = null, - defStyleAttr: Int = 0, - defStyleRes: Int = 0 + context: Context? = null, + attrs: AttributeSet? = null, + defStyleAttr: Int = 0, + defStyleRes: Int = 0 ) : RelativeLayout(context, attrs, defStyleAttr, defStyleRes) { private var multiselected = false diff --git a/ui/src/main/java/com/wireguard/android/widget/NameInputFilter.kt b/ui/src/main/java/com/wireguard/android/widget/NameInputFilter.kt index ab894195..93b77ba9 100644 --- a/ui/src/main/java/com/wireguard/android/widget/NameInputFilter.kt +++ b/ui/src/main/java/com/wireguard/android/widget/NameInputFilter.kt @@ -1,5 +1,5 @@ /* - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.widget @@ -13,10 +13,12 @@ import com.wireguard.android.backend.Tunnel * InputFilter for entering WireGuard configuration names (Linux interface names). */ class NameInputFilter : InputFilter { - override fun filter(source: CharSequence, - sStart: Int, sEnd: Int, - dest: Spanned, - dStart: Int, dEnd: Int): CharSequence? { + override fun filter( + source: CharSequence, + sStart: Int, sEnd: Int, + dest: Spanned, + dStart: Int, dEnd: Int + ): CharSequence? { var replacement: SpannableStringBuilder? = null var rIndex = 0 val dLength = dest.length @@ -26,7 +28,8 @@ class NameInputFilter : InputFilter { // Restrict characters to those valid in interfaces. // Ensure adding this character does not push the length over the limit. if (dIndex < Tunnel.NAME_MAX_LENGTH && isAllowed(c) && - dLength + (sIndex - sStart) < Tunnel.NAME_MAX_LENGTH) { + dLength + (sIndex - sStart) < Tunnel.NAME_MAX_LENGTH + ) { ++rIndex } else { if (replacement == null) replacement = SpannableStringBuilder(source, sStart, sEnd) diff --git a/ui/src/main/java/com/wireguard/android/widget/SlashDrawable.kt b/ui/src/main/java/com/wireguard/android/widget/SlashDrawable.kt index ee4d278f..79dc3338 100644 --- a/ui/src/main/java/com/wireguard/android/widget/SlashDrawable.kt +++ b/ui/src/main/java/com/wireguard/android/widget/SlashDrawable.kt @@ -1,6 +1,6 @@ /* * Copyright © 2018 The Android Open Source Project - * Copyright © 2018-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.widget @@ -35,10 +35,10 @@ class SlashDrawable(private val mDrawable: Drawable) : Drawable() { val radiusX = scale(CORNER_RADIUS, width) val radiusY = scale(CORNER_RADIUS, height) updateRect( - scale(LEFT, width), - scale(TOP, height), - scale(RIGHT, width), - scale(TOP + mCurrentSlashLength, height) + scale(LEFT, width), + scale(TOP, height), + scale(RIGHT, width), + scale(TOP + mCurrentSlashLength, height) ) mPath.reset() // Draw the slash vertically @@ -69,6 +69,7 @@ class SlashDrawable(private val mDrawable: Drawable) : Drawable() { override fun getIntrinsicWidth() = mDrawable.intrinsicWidth + @Deprecated("Deprecated in API level 29") override fun getOpacity() = PixelFormat.OPAQUE override fun onBoundsChange(bounds: Rect) { diff --git a/ui/src/main/java/com/wireguard/android/widget/ToggleSwitch.kt b/ui/src/main/java/com/wireguard/android/widget/ToggleSwitch.kt index c97cb934..9b79706b 100644 --- a/ui/src/main/java/com/wireguard/android/widget/ToggleSwitch.kt +++ b/ui/src/main/java/com/wireguard/android/widget/ToggleSwitch.kt @@ -1,6 +1,6 @@ /* * Copyright © 2013 The Android Open Source Project - * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.widget @@ -8,9 +8,9 @@ package com.wireguard.android.widget import android.content.Context import android.os.Parcelable import android.util.AttributeSet -import android.widget.Switch +import com.google.android.material.materialswitch.MaterialSwitch -class ToggleSwitch @JvmOverloads constructor(context: Context?, attrs: AttributeSet? = null) : Switch(context, attrs) { +class ToggleSwitch @JvmOverloads constructor(context: Context, attrs: AttributeSet? = null) : MaterialSwitch(context, attrs) { private var isRestoringState = false private var listener: OnBeforeCheckedChangeListener? = null override fun onRestoreInstanceState(state: Parcelable) { diff --git a/ui/src/main/java/com/wireguard/android/widget/TvCardView.kt b/ui/src/main/java/com/wireguard/android/widget/TvCardView.kt new file mode 100644 index 00000000..de301313 --- /dev/null +++ b/ui/src/main/java/com/wireguard/android/widget/TvCardView.kt @@ -0,0 +1,44 @@ +/* + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.widget + +import android.content.Context +import android.util.AttributeSet +import android.view.View +import com.google.android.material.card.MaterialCardView +import com.wireguard.android.R + +class TvCardView(context: Context?, attrs: AttributeSet?) : MaterialCardView(context, attrs) { + var isUp: Boolean = false + set(value) { + field = value + refreshDrawableState() + } + var isDeleting: Boolean = false + set(value) { + field = value + refreshDrawableState() + } + + override fun onCreateDrawableState(extraSpace: Int): IntArray { + if (isUp || isDeleting) { + val drawableState = super.onCreateDrawableState(extraSpace + (if (isUp) 1 else 0) + (if (isDeleting) 1 else 0)) + if (isUp) { + View.mergeDrawableStates(drawableState, STATE_IS_UP) + } + if (isDeleting) { + View.mergeDrawableStates(drawableState, STATE_IS_DELETING) + } + return drawableState + } + return super.onCreateDrawableState(extraSpace) + } + + companion object { + private val STATE_IS_UP = intArrayOf(R.attr.state_isUp) + private val STATE_IS_DELETING = intArrayOf(R.attr.state_isDeleting) + } +}
\ No newline at end of file |