diff options
Diffstat (limited to 'ui/src/main/java/com/wireguard/android/updater/Updater.kt')
-rw-r--r-- | ui/src/main/java/com/wireguard/android/updater/Updater.kt | 451 |
1 files changed, 451 insertions, 0 deletions
diff --git a/ui/src/main/java/com/wireguard/android/updater/Updater.kt b/ui/src/main/java/com/wireguard/android/updater/Updater.kt new file mode 100644 index 00000000..651e3cd7 --- /dev/null +++ b/ui/src/main/java/com/wireguard/android/updater/Updater.kt @@ -0,0 +1,451 @@ +/* + * Copyright © 2017-2023 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package com.wireguard.android.updater + +import android.Manifest +import android.app.PendingIntent +import android.content.BroadcastReceiver +import android.content.Context +import android.content.Intent +import android.content.IntentFilter +import android.content.pm.PackageInstaller +import android.content.pm.PackageManager +import android.os.Build +import android.util.Base64 +import android.util.Log +import androidx.core.content.ContextCompat +import androidx.core.content.IntentCompat +import com.wireguard.android.Application +import com.wireguard.android.BuildConfig +import com.wireguard.android.activity.MainActivity +import com.wireguard.android.util.UserKnobs +import com.wireguard.android.util.applicationScope +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.firstOrNull +import kotlinx.coroutines.flow.launchIn +import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import java.io.IOException +import java.net.HttpURLConnection +import java.net.URL +import java.nio.charset.StandardCharsets +import java.security.InvalidKeyException +import java.security.InvalidParameterException +import java.security.MessageDigest +import java.util.UUID +import kotlin.math.max +import kotlin.time.Duration.Companion.minutes +import kotlin.time.Duration.Companion.seconds + +object Updater { + private const val TAG = "WireGuard/Updater" + private const val UPDATE_URL_FMT = "https://download.wireguard.com/android-client/%s" + private const val APK_NAME_PREFIX = BuildConfig.APPLICATION_ID + "-" + private const val APK_NAME_SUFFIX = ".apk" + private const val LATEST_FILE = "latest.sig" + private const val RELEASE_PUBLIC_KEY_BASE64 = "RWTAzwGRYr3EC9px0Ia3fbttz8WcVN6wrOwWp2delz4el6SI8XmkKSMp" + private val CURRENT_VERSION by lazy { Version(BuildConfig.VERSION_NAME) } + + private val updaterScope = CoroutineScope(Job() + Dispatchers.IO) + + private fun installer(context: Context): String = try { + val packageName = context.packageName + val pm = context.packageManager + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.R) { + pm.getInstallSourceInfo(packageName).installingPackageName ?: "" + } else { + @Suppress("DEPRECATION") + pm.getInstallerPackageName(packageName) ?: "" + } + } catch (_: Throwable) { + "" + } + + fun installerIsGooglePlay(context: Context): Boolean = installer(context) == "com.android.vending" + + sealed class Progress { + object Complete : Progress() + class Available(val version: String) : Progress() { + fun update() { + applicationScope.launch { + UserKnobs.setUpdaterNewerVersionConsented(version) + } + } + } + + object Rechecking : Progress() + class Downloading(val bytesDownloaded: ULong, val bytesTotal: ULong) : Progress() + object Installing : Progress() + class NeedsUserIntervention(val intent: Intent, private val id: Int) : Progress() { + + private suspend fun installerActive(): Boolean { + if (mutableState.firstOrNull() != this@NeedsUserIntervention) + return true + try { + if (Application.get().packageManager.packageInstaller.getSessionInfo(id)?.isActive == true) + return true + } catch (_: SecurityException) { + return true + } + return false + } + + fun markAsDone() { + applicationScope.launch { + if (installerActive()) + return@launch + delay(7.seconds) + if (installerActive()) + return@launch + emitProgress(Failure(Exception("Ignored by user"))) + } + } + } + + class Failure(val error: Throwable) : Progress() { + fun retry() { + updaterScope.launch { + downloadAndUpdateWrapErrors() + } + } + } + + class Corrupt(private val betterFile: String?) : Progress() { + val downloadUrl: String + get() = UPDATE_URL_FMT.format(betterFile ?: "") + } + } + + private val mutableState = MutableStateFlow<Progress>(Progress.Complete) + val state = mutableState.asStateFlow() + + private suspend fun emitProgress(progress: Progress, force: Boolean = false) { + if (force || mutableState.firstOrNull()?.javaClass != progress.javaClass) + mutableState.emit(progress) + } + + private class Sha256Digest(hex: String) { + val bytes: ByteArray + + init { + if (hex.length != 64) + throw InvalidParameterException("SHA256 hashes must be 32 bytes long") + bytes = hex.chunked(2).map { it.toInt(16).toByte() }.toByteArray() + } + } + + @OptIn(ExperimentalUnsignedTypes::class) + private class Version(version: String) : Comparable<Version> { + val parts: ULongArray + + init { + val strParts = version.split(".") + if (strParts.isEmpty()) + throw InvalidParameterException("Version has no parts") + parts = ULongArray(strParts.size) + for (i in parts.indices) { + parts[i] = strParts[i].toULong() + } + } + + override fun toString(): String { + return parts.joinToString(".") + } + + override fun compareTo(other: Version): Int { + for (i in 0 until max(parts.size, other.parts.size)) { + val lhsPart = if (i < parts.size) parts[i] else 0UL + val rhsPart = if (i < other.parts.size) other.parts[i] else 0UL + if (lhsPart > rhsPart) + return 1 + else if (lhsPart < rhsPart) + return -1 + } + return 0 + } + } + + private class Update(val fileName: String, val version: Version, val hash: Sha256Digest) + + private fun versionOfFile(name: String): Version? { + if (!name.startsWith(APK_NAME_PREFIX) || !name.endsWith(APK_NAME_SUFFIX)) + return null + return try { + Version(name.substring(APK_NAME_PREFIX.length, name.length - APK_NAME_SUFFIX.length)) + } catch (_: Throwable) { + null + } + } + + private fun verifySignedFileList(signifyDigest: String): List<Update> { + val updates = ArrayList<Update>(1) + val publicKeyBytes = Base64.decode(RELEASE_PUBLIC_KEY_BASE64, Base64.DEFAULT) + if (publicKeyBytes == null || publicKeyBytes.size != 32 + 10 || publicKeyBytes[0] != 'E'.code.toByte() || publicKeyBytes[1] != 'd'.code.toByte()) + throw InvalidKeyException("Invalid public key") + val lines = signifyDigest.split("\n", limit = 3) + if (lines.size != 3) + throw InvalidParameterException("Invalid signature format: too few lines") + if (!lines[0].startsWith("untrusted comment: ")) + throw InvalidParameterException("Invalid signature format: missing comment") + val signatureBytes = Base64.decode(lines[1], Base64.DEFAULT) + if (signatureBytes == null || signatureBytes.size != 64 + 10) + throw InvalidParameterException("Invalid signature format: wrong sized or missing signature") + for (i in 0..9) { + if (signatureBytes[i] != publicKeyBytes[i]) + throw InvalidParameterException("Invalid signature format: wrong signer") + } + if (!Ed25519.verify( + lines[2].toByteArray(StandardCharsets.UTF_8), + signatureBytes.sliceArray(10 until 10 + 64), + publicKeyBytes.sliceArray(10 until 10 + 32) + ) + ) + throw SecurityException("Invalid signature") + for (line in lines[2].split("\n").dropLastWhile { it.isEmpty() }) { + val components = line.split(" ", limit = 2) + if (components.size != 2) + throw InvalidParameterException("Invalid file list format: too few components") + /* If version is null, it's not a file we understand, but still a legitimate entry, so don't throw. */ + val version = versionOfFile(components[1]) ?: continue + updates.add(Update(components[1], version, Sha256Digest(components[0]))) + } + return updates + } + + private fun checkForUpdates(): Update? { + val connection = URL(UPDATE_URL_FMT.format(LATEST_FILE)).openConnection() as HttpURLConnection + connection.setRequestProperty("User-Agent", Application.USER_AGENT) + connection.connect() + if (connection.responseCode != HttpURLConnection.HTTP_OK) + throw IOException(connection.responseMessage) + var fileListBytes = ByteArray(1024 * 512 /* 512 KiB */) + connection.inputStream.use { + val len = it.read(fileListBytes) + if (len <= 0) + throw IOException("File list is empty") + fileListBytes = fileListBytes.sliceArray(0 until len) + } + return verifySignedFileList(fileListBytes.decodeToString()).maxByOrNull { it.version } + } + + private suspend fun downloadAndUpdate() = withContext(Dispatchers.IO) { + val receiver = InstallReceiver() + val context = Application.get().applicationContext + val pendingIntent = withContext(Dispatchers.Main) { + ContextCompat.registerReceiver(context, receiver, IntentFilter(receiver.sessionId), ContextCompat.RECEIVER_NOT_EXPORTED) + PendingIntent.getBroadcast( + context, + 0, + Intent(receiver.sessionId).setPackage(context.packageName), + PendingIntent.FLAG_UPDATE_CURRENT or PendingIntent.FLAG_MUTABLE + ) + } + + emitProgress(Progress.Rechecking) + val update = checkForUpdates() + if (update == null || update.version <= CURRENT_VERSION) { + emitProgress(Progress.Complete) + return@withContext + } + + emitProgress(Progress.Downloading(0UL, 0UL), true) + val connection = URL(UPDATE_URL_FMT.format(update.fileName)).openConnection() as HttpURLConnection + connection.setRequestProperty("User-Agent", Application.USER_AGENT) + connection.connect() + if (connection.responseCode != HttpURLConnection.HTTP_OK) + throw IOException("Update could not be fetched: ${connection.responseCode}") + + var downloadedByteLen: ULong = 0UL + val totalByteLen = (if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) connection.contentLengthLong else connection.contentLength).toLong().toULong() + val fileBytes = ByteArray(1024 * 32 /* 32 KiB */) + val digest = MessageDigest.getInstance("SHA-256") + emitProgress(Progress.Downloading(downloadedByteLen, totalByteLen), true) + + val installer = context.packageManager.packageInstaller + val params = PackageInstaller.SessionParams(PackageInstaller.SessionParams.MODE_FULL_INSTALL) + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.S) + params.setRequireUserAction(PackageInstaller.SessionParams.USER_ACTION_NOT_REQUIRED) + params.setAppPackageName(context.packageName) /* Enforces updates; disallows new apps. */ + val session = installer.openSession(installer.createSession(params)) + var sessionFailure = true + try { + val installDest = session.openWrite(receiver.sessionId, 0, -1) + + installDest.use { dest -> + connection.inputStream.use { src -> + while (true) { + val readLen = src.read(fileBytes) + if (readLen <= 0) + break + + digest.update(fileBytes, 0, readLen) + dest.write(fileBytes, 0, readLen) + + downloadedByteLen += readLen.toUInt() + emitProgress(Progress.Downloading(downloadedByteLen, totalByteLen), true) + + if (downloadedByteLen >= 1024UL * 1024UL * 100UL /* 100 MiB */) + throw IOException("File too large") + } + } + } + + emitProgress(Progress.Installing) + if (!digest.digest().contentEquals(update.hash.bytes)) + throw SecurityException("Update has invalid hash") + sessionFailure = false + } finally { + if (sessionFailure) { + session.abandon() + session.close() + } + } + session.commit(pendingIntent.intentSender) + session.close() + } + + private var updating = false + private suspend fun downloadAndUpdateWrapErrors() { + if (updating) + return + updating = true + try { + downloadAndUpdate() + } catch (e: Throwable) { + Log.e(TAG, "Update failure", e) + emitProgress(Progress.Failure(e)) + } + updating = false + } + + private class InstallReceiver : BroadcastReceiver() { + val sessionId = UUID.randomUUID().toString() + + override fun onReceive(context: Context, intent: Intent) { + if (sessionId != intent.action) + return + + when (val status = intent.getIntExtra(PackageInstaller.EXTRA_STATUS, PackageInstaller.STATUS_FAILURE_INVALID)) { + PackageInstaller.STATUS_PENDING_USER_ACTION -> { + val id = intent.getIntExtra(PackageInstaller.EXTRA_SESSION_ID, 0) + val userIntervention = IntentCompat.getParcelableExtra(intent, Intent.EXTRA_INTENT, Intent::class.java)!! + applicationScope.launch { + emitProgress(Progress.NeedsUserIntervention(userIntervention, id)) + } + } + + PackageInstaller.STATUS_SUCCESS -> { + applicationScope.launch { + emitProgress(Progress.Complete) + } + context.applicationContext.unregisterReceiver(this) + } + + else -> { + val id = intent.getIntExtra(PackageInstaller.EXTRA_SESSION_ID, 0) + try { + context.applicationContext.packageManager.packageInstaller.abandonSession(id) + } catch (_: SecurityException) { + } + val message = intent.getStringExtra(PackageInstaller.EXTRA_STATUS_MESSAGE) ?: "Installation error $status" + applicationScope.launch { + val e = Exception(message) + Log.e(TAG, "Update failure", e) + emitProgress(Progress.Failure(e)) + } + context.applicationContext.unregisterReceiver(this) + } + } + } + } + + fun monitorForUpdates() { + if (BuildConfig.DEBUG) + return + + val context = Application.get() + + if (installerIsGooglePlay(context)) + return + + if (!if (Build.VERSION.SDK_INT < Build.VERSION_CODES.TIRAMISU) { + @Suppress("DEPRECATION") + context.packageManager.getPackageInfo(context.packageName, PackageManager.GET_PERMISSIONS) + } else { + context.packageManager.getPackageInfo(context.packageName, PackageManager.PackageInfoFlags.of(PackageManager.GET_PERMISSIONS.toLong())) + }.requestedPermissions.contains(Manifest.permission.REQUEST_INSTALL_PACKAGES) + ) { + if (installer(context).isNotEmpty()) { + updaterScope.launch { + val update = try { + checkForUpdates() + } catch (_: Throwable) { + null + } + emitProgress(Progress.Corrupt(update?.fileName)) + } + } + return + } + + updaterScope.launch { + if (UserKnobs.updaterNewerVersionSeen.firstOrNull()?.let { Version(it) > CURRENT_VERSION } == true) + return@launch + + var waitTime = 15 + while (true) { + try { + val update = checkForUpdates() ?: continue + if (update.version > CURRENT_VERSION) { + Log.i(TAG, "Update available: ${update.version}") + UserKnobs.setUpdaterNewerVersionSeen(update.version.toString()) + return@launch + } + } catch (_: Throwable) { + } + delay(waitTime.minutes) + waitTime = 45 + } + } + + UserKnobs.updaterNewerVersionSeen.onEach { ver -> + if ( + ver != null && + Version(ver) > CURRENT_VERSION && + UserKnobs.updaterNewerVersionConsented.firstOrNull()?.let { Version(it) > CURRENT_VERSION } != true + ) + emitProgress(Progress.Available(ver)) + }.launchIn(applicationScope) + + UserKnobs.updaterNewerVersionConsented.onEach { ver -> + if (ver != null && Version(ver) > CURRENT_VERSION) + updaterScope.launch { + downloadAndUpdateWrapErrors() + } + }.launchIn(applicationScope) + } + + class AppUpdatedReceiver : BroadcastReceiver() { + override fun onReceive(context: Context, intent: Intent) { + if (intent.action != Intent.ACTION_MY_PACKAGE_REPLACED) + return + + if (installer(context) != context.packageName) + return + + /* TODO: does not work because of restrictions placed on broadcast receivers. */ + val start = Intent(context, MainActivity::class.java) + start.addFlags(Intent.FLAG_ACTIVITY_CLEAR_TOP) + start.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK) + context.startActivity(start) + } + } +} |