diff options
Diffstat (limited to 'ui/src/main/java/com/wireguard/android/viewmodel')
3 files changed, 522 insertions, 0 deletions
diff --git a/ui/src/main/java/com/wireguard/android/viewmodel/ConfigProxy.kt b/ui/src/main/java/com/wireguard/android/viewmodel/ConfigProxy.kt new file mode 100644 index 00000000..c73b1efc --- /dev/null +++ b/ui/src/main/java/com/wireguard/android/viewmodel/ConfigProxy.kt @@ -0,0 +1,86 @@ +/* + * Copyright © 2017-2023 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package com.wireguard.android.viewmodel + +import android.os.Build +import android.os.Parcel +import android.os.Parcelable +import androidx.core.os.ParcelCompat +import androidx.databinding.ObservableArrayList +import androidx.databinding.ObservableList +import com.wireguard.config.BadConfigException +import com.wireguard.config.Config +import com.wireguard.config.Peer + +class ConfigProxy : Parcelable { + val `interface`: InterfaceProxy + val peers: ObservableList<PeerProxy> = ObservableArrayList() + + private constructor(parcel: Parcel) { + `interface` = ParcelCompat.readParcelable(parcel, InterfaceProxy::class.java.classLoader, InterfaceProxy::class.java) ?: InterfaceProxy() + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) { + ParcelCompat.readParcelableList(parcel, peers, PeerProxy::class.java.classLoader, PeerProxy::class.java) + } else { + parcel.readTypedList(peers, PeerProxy.CREATOR) + } + peers.forEach { it.bind(this) } + } + + constructor(other: Config) { + `interface` = InterfaceProxy(other.getInterface()) + other.peers.forEach { + val proxy = PeerProxy(it) + peers.add(proxy) + proxy.bind(this) + } + } + + constructor() { + `interface` = InterfaceProxy() + } + + fun addPeer(): PeerProxy { + val proxy = PeerProxy() + peers.add(proxy) + proxy.bind(this) + return proxy + } + + override fun describeContents() = 0 + + @Throws(BadConfigException::class) + fun resolve(): Config { + val resolvedPeers: MutableCollection<Peer> = ArrayList() + peers.forEach { resolvedPeers.add(it.resolve()) } + return Config.Builder() + .setInterface(`interface`.resolve()) + .addPeers(resolvedPeers) + .build() + } + + override fun writeToParcel(dest: Parcel, flags: Int) { + dest.writeParcelable(`interface`, flags) + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) { + dest.writeParcelableList(peers, flags) + } else { + dest.writeTypedList(peers) + } + } + + private class ConfigProxyCreator : Parcelable.Creator<ConfigProxy> { + override fun createFromParcel(parcel: Parcel): ConfigProxy { + return ConfigProxy(parcel) + } + + override fun newArray(size: Int): Array<ConfigProxy?> { + return arrayOfNulls(size) + } + } + + companion object { + @JvmField + val CREATOR: Parcelable.Creator<ConfigProxy> = ConfigProxyCreator() + } +} diff --git a/ui/src/main/java/com/wireguard/android/viewmodel/InterfaceProxy.kt b/ui/src/main/java/com/wireguard/android/viewmodel/InterfaceProxy.kt new file mode 100644 index 00000000..004ebed1 --- /dev/null +++ b/ui/src/main/java/com/wireguard/android/viewmodel/InterfaceProxy.kt @@ -0,0 +1,142 @@ +/* + * Copyright © 2017-2023 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package com.wireguard.android.viewmodel + +import android.os.Parcel +import android.os.Parcelable +import androidx.databinding.BaseObservable +import androidx.databinding.Bindable +import androidx.databinding.ObservableArrayList +import androidx.databinding.ObservableList +import com.wireguard.android.BR +import com.wireguard.config.Attribute +import com.wireguard.config.BadConfigException +import com.wireguard.config.Interface +import com.wireguard.crypto.Key +import com.wireguard.crypto.KeyFormatException +import com.wireguard.crypto.KeyPair + +class InterfaceProxy : BaseObservable, Parcelable { + @get:Bindable + val excludedApplications: ObservableList<String> = ObservableArrayList() + + @get:Bindable + val includedApplications: ObservableList<String> = ObservableArrayList() + + @get:Bindable + var addresses: String = "" + set(value) { + field = value + notifyPropertyChanged(BR.addresses) + } + + @get:Bindable + var dnsServers: String = "" + set(value) { + field = value + notifyPropertyChanged(BR.dnsServers) + } + + @get:Bindable + var listenPort: String = "" + set(value) { + field = value + notifyPropertyChanged(BR.listenPort) + } + + @get:Bindable + var mtu: String = "" + set(value) { + field = value + notifyPropertyChanged(BR.mtu) + } + + @get:Bindable + var privateKey: String = "" + set(value) { + field = value + notifyPropertyChanged(BR.privateKey) + notifyPropertyChanged(BR.publicKey) + } + + @get:Bindable + val publicKey: String + get() = try { + KeyPair(Key.fromBase64(privateKey)).publicKey.toBase64() + } catch (ignored: KeyFormatException) { + "" + } + + private constructor(parcel: Parcel) { + addresses = parcel.readString() ?: "" + dnsServers = parcel.readString() ?: "" + parcel.readStringList(excludedApplications) + parcel.readStringList(includedApplications) + listenPort = parcel.readString() ?: "" + mtu = parcel.readString() ?: "" + privateKey = parcel.readString() ?: "" + } + + constructor(other: Interface) { + addresses = Attribute.join(other.addresses) + val dnsServerStrings = other.dnsServers.map { it.hostAddress }.plus(other.dnsSearchDomains) + dnsServers = Attribute.join(dnsServerStrings) + excludedApplications.addAll(other.excludedApplications) + includedApplications.addAll(other.includedApplications) + listenPort = other.listenPort.map { it.toString() }.orElse("") + mtu = other.mtu.map { it.toString() }.orElse("") + val keyPair = other.keyPair + privateKey = keyPair.privateKey.toBase64() + } + + constructor() + + override fun describeContents() = 0 + + fun generateKeyPair() { + val keyPair = KeyPair() + privateKey = keyPair.privateKey.toBase64() + notifyPropertyChanged(BR.privateKey) + notifyPropertyChanged(BR.publicKey) + } + + @Throws(BadConfigException::class) + fun resolve(): Interface { + val builder = Interface.Builder() + if (addresses.isNotEmpty()) builder.parseAddresses(addresses) + if (dnsServers.isNotEmpty()) builder.parseDnsServers(dnsServers) + if (excludedApplications.isNotEmpty()) builder.excludeApplications(excludedApplications) + if (includedApplications.isNotEmpty()) builder.includeApplications(includedApplications) + if (listenPort.isNotEmpty()) builder.parseListenPort(listenPort) + if (mtu.isNotEmpty()) builder.parseMtu(mtu) + if (privateKey.isNotEmpty()) builder.parsePrivateKey(privateKey) + return builder.build() + } + + override fun writeToParcel(dest: Parcel, flags: Int) { + dest.writeString(addresses) + dest.writeString(dnsServers) + dest.writeStringList(excludedApplications) + dest.writeStringList(includedApplications) + dest.writeString(listenPort) + dest.writeString(mtu) + dest.writeString(privateKey) + } + + private class InterfaceProxyCreator : Parcelable.Creator<InterfaceProxy> { + override fun createFromParcel(parcel: Parcel): InterfaceProxy { + return InterfaceProxy(parcel) + } + + override fun newArray(size: Int): Array<InterfaceProxy?> { + return arrayOfNulls(size) + } + } + + companion object { + @JvmField + val CREATOR: Parcelable.Creator<InterfaceProxy> = InterfaceProxyCreator() + } +} diff --git a/ui/src/main/java/com/wireguard/android/viewmodel/PeerProxy.kt b/ui/src/main/java/com/wireguard/android/viewmodel/PeerProxy.kt new file mode 100644 index 00000000..e78d0826 --- /dev/null +++ b/ui/src/main/java/com/wireguard/android/viewmodel/PeerProxy.kt @@ -0,0 +1,294 @@ +/* + * Copyright © 2017-2023 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package com.wireguard.android.viewmodel + +import android.os.Parcel +import android.os.Parcelable +import androidx.databinding.BaseObservable +import androidx.databinding.Bindable +import androidx.databinding.Observable +import androidx.databinding.Observable.OnPropertyChangedCallback +import androidx.databinding.ObservableList +import com.wireguard.android.BR +import com.wireguard.config.Attribute +import com.wireguard.config.BadConfigException +import com.wireguard.config.Peer +import java.lang.ref.WeakReference + +class PeerProxy : BaseObservable, Parcelable { + private val dnsRoutes: MutableList<String?> = ArrayList() + private var allowedIpsState = AllowedIpsState.INVALID + private var interfaceDnsListener: InterfaceDnsListener? = null + private var peerListListener: PeerListListener? = null + private var owner: ConfigProxy? = null + private var totalPeers = 0 + + @get:Bindable + var allowedIps: String = "" + set(value) { + field = value + notifyPropertyChanged(BR.allowedIps) + calculateAllowedIpsState() + } + + @get:Bindable + var endpoint: String = "" + set(value) { + field = value + notifyPropertyChanged(BR.endpoint) + } + + @get:Bindable + var persistentKeepalive: String = "" + set(value) { + field = value + notifyPropertyChanged(BR.persistentKeepalive) + } + + @get:Bindable + var preSharedKey: String = "" + set(value) { + field = value + notifyPropertyChanged(BR.preSharedKey) + } + + @get:Bindable + var publicKey: String = "" + set(value) { + field = value + notifyPropertyChanged(BR.publicKey) + } + + @get:Bindable + val isAbleToExcludePrivateIps: Boolean + get() = allowedIpsState == AllowedIpsState.CONTAINS_IPV4_PUBLIC_NETWORKS || allowedIpsState == AllowedIpsState.CONTAINS_IPV4_WILDCARD + + @get:Bindable + val isExcludingPrivateIps: Boolean + get() = allowedIpsState == AllowedIpsState.CONTAINS_IPV4_PUBLIC_NETWORKS + + private constructor(parcel: Parcel) { + allowedIps = parcel.readString() ?: "" + endpoint = parcel.readString() ?: "" + persistentKeepalive = parcel.readString() ?: "" + preSharedKey = parcel.readString() ?: "" + publicKey = parcel.readString() ?: "" + } + + constructor(other: Peer) { + allowedIps = Attribute.join(other.allowedIps) + endpoint = other.endpoint.map { it.toString() }.orElse("") + persistentKeepalive = other.persistentKeepalive.map { it.toString() }.orElse("") + preSharedKey = other.preSharedKey.map { it.toBase64() }.orElse("") + publicKey = other.publicKey.toBase64() + } + + constructor() + + fun bind(owner: ConfigProxy) { + val interfaze: InterfaceProxy = owner.`interface` + val peers = owner.peers + if (interfaceDnsListener == null) interfaceDnsListener = InterfaceDnsListener(this) + interfaze.addOnPropertyChangedCallback(interfaceDnsListener!!) + setInterfaceDns(interfaze.dnsServers) + if (peerListListener == null) peerListListener = PeerListListener(this) + peers.addOnListChangedCallback(peerListListener) + setTotalPeers(peers.size) + this.owner = owner + } + + private fun calculateAllowedIpsState() { + val newState: AllowedIpsState + newState = if (totalPeers == 1) { + // String comparison works because we only care if allowedIps is a superset of one of + // the above sets of (valid) *networks*. We are not checking for a superset based on + // the individual addresses in each set. + val networkStrings: Collection<String> = getAllowedIpsSet() + // If allowedIps contains both the wildcard and the public networks, then private + // networks aren't excluded! + if (networkStrings.containsAll(IPV4_WILDCARD)) + AllowedIpsState.CONTAINS_IPV4_WILDCARD + else if (networkStrings.containsAll(IPV4_PUBLIC_NETWORKS)) + AllowedIpsState.CONTAINS_IPV4_PUBLIC_NETWORKS + else + AllowedIpsState.OTHER + } else { + AllowedIpsState.INVALID + } + if (newState != allowedIpsState) { + allowedIpsState = newState + notifyPropertyChanged(BR.ableToExcludePrivateIps) + notifyPropertyChanged(BR.excludingPrivateIps) + } + } + + override fun describeContents() = 0 + + private fun getAllowedIpsSet() = setOf(*Attribute.split(allowedIps)) + + // Replace the first instance of the wildcard with the public network list, or vice versa. + // DNS servers only need to handled specially when we're excluding private IPs. + fun setExcludingPrivateIps(excludingPrivateIps: Boolean) { + if (!isAbleToExcludePrivateIps || isExcludingPrivateIps == excludingPrivateIps) return + val oldNetworks = if (excludingPrivateIps) IPV4_WILDCARD else IPV4_PUBLIC_NETWORKS + val newNetworks = if (excludingPrivateIps) IPV4_PUBLIC_NETWORKS else IPV4_WILDCARD + val input: Collection<String> = getAllowedIpsSet() + val outputSize = input.size - oldNetworks.size + newNetworks.size + val output: MutableCollection<String?> = LinkedHashSet(outputSize) + var replaced = false + // Replace the first instance of the wildcard with the public network list, or vice versa. + for (network in input) { + if (oldNetworks.contains(network)) { + if (!replaced) { + for (replacement in newNetworks) if (!output.contains(replacement)) output.add(replacement) + replaced = true + } + } else if (!output.contains(network)) { + output.add(network) + } + } + // DNS servers only need to handled specially when we're excluding private IPs. + if (excludingPrivateIps) output.addAll(dnsRoutes) else output.removeAll(dnsRoutes) + allowedIps = Attribute.join(output) + allowedIpsState = if (excludingPrivateIps) AllowedIpsState.CONTAINS_IPV4_PUBLIC_NETWORKS else AllowedIpsState.CONTAINS_IPV4_WILDCARD + notifyPropertyChanged(BR.allowedIps) + notifyPropertyChanged(BR.excludingPrivateIps) + } + + @Throws(BadConfigException::class) + fun resolve(): Peer { + val builder = Peer.Builder() + if (allowedIps.isNotEmpty()) builder.parseAllowedIPs(allowedIps) + if (endpoint.isNotEmpty()) builder.parseEndpoint(endpoint) + if (persistentKeepalive.isNotEmpty()) builder.parsePersistentKeepalive(persistentKeepalive) + if (preSharedKey.isNotEmpty()) builder.parsePreSharedKey(preSharedKey) + if (publicKey.isNotEmpty()) builder.parsePublicKey(publicKey) + return builder.build() + } + + private fun setInterfaceDns(dnsServers: CharSequence) { + val newDnsRoutes = Attribute.split(dnsServers).filter { !it.contains(":") }.map { "$it/32" } + if (allowedIpsState == AllowedIpsState.CONTAINS_IPV4_PUBLIC_NETWORKS) { + val input = getAllowedIpsSet() + // Yes, this is quadratic in the number of DNS servers, but most users have 1 or 2. + val output = input.filter { !dnsRoutes.contains(it) || newDnsRoutes.contains(it) }.plus(newDnsRoutes).distinct() + // None of the public networks are /32s, so this cannot change the AllowedIPs state. + allowedIps = Attribute.join(output) + notifyPropertyChanged(BR.allowedIps) + } + dnsRoutes.clear() + dnsRoutes.addAll(newDnsRoutes) + } + + private fun setTotalPeers(totalPeers: Int) { + if (this.totalPeers == totalPeers) return + this.totalPeers = totalPeers + calculateAllowedIpsState() + } + + fun unbind() { + if (owner == null) return + val interfaze: InterfaceProxy = owner!!.`interface` + val peers = owner!!.peers + if (interfaceDnsListener != null) interfaze.removeOnPropertyChangedCallback(interfaceDnsListener!!) + if (peerListListener != null) peers.removeOnListChangedCallback(peerListListener) + peers.remove(this) + setInterfaceDns("") + setTotalPeers(0) + owner = null + } + + override fun writeToParcel(dest: Parcel, flags: Int) { + dest.writeString(allowedIps) + dest.writeString(endpoint) + dest.writeString(persistentKeepalive) + dest.writeString(preSharedKey) + dest.writeString(publicKey) + } + + private enum class AllowedIpsState { + CONTAINS_IPV4_PUBLIC_NETWORKS, CONTAINS_IPV4_WILDCARD, INVALID, OTHER + } + + private class InterfaceDnsListener constructor(peerProxy: PeerProxy) : OnPropertyChangedCallback() { + private val weakPeerProxy: WeakReference<PeerProxy> = WeakReference(peerProxy) + override fun onPropertyChanged(sender: Observable, propertyId: Int) { + val peerProxy = weakPeerProxy.get() + if (peerProxy == null) { + sender.removeOnPropertyChangedCallback(this) + return + } + // This shouldn't be possible, but try to avoid a ClassCastException anyway. + if (sender !is InterfaceProxy) return + if (!(propertyId == BR._all || propertyId == BR.dnsServers)) return + peerProxy.setInterfaceDns(sender.dnsServers) + } + } + + private class PeerListListener(peerProxy: PeerProxy) : ObservableList.OnListChangedCallback<ObservableList<PeerProxy?>>() { + private val weakPeerProxy: WeakReference<PeerProxy> = WeakReference(peerProxy) + override fun onChanged(sender: ObservableList<PeerProxy?>) { + val peerProxy = weakPeerProxy.get() + if (peerProxy == null) { + sender.removeOnListChangedCallback(this) + return + } + peerProxy.setTotalPeers(sender.size) + } + + override fun onItemRangeChanged( + sender: ObservableList<PeerProxy?>, + positionStart: Int, itemCount: Int + ) { + // Do nothing. + } + + override fun onItemRangeInserted( + sender: ObservableList<PeerProxy?>, + positionStart: Int, itemCount: Int + ) { + onChanged(sender) + } + + override fun onItemRangeMoved( + sender: ObservableList<PeerProxy?>, + fromPosition: Int, toPosition: Int, + itemCount: Int + ) { + // Do nothing. + } + + override fun onItemRangeRemoved( + sender: ObservableList<PeerProxy?>, + positionStart: Int, itemCount: Int + ) { + onChanged(sender) + } + } + + private class PeerProxyCreator : Parcelable.Creator<PeerProxy> { + override fun createFromParcel(parcel: Parcel): PeerProxy { + return PeerProxy(parcel) + } + + override fun newArray(size: Int): Array<PeerProxy?> { + return arrayOfNulls(size) + } + } + + companion object { + @JvmField + val CREATOR: Parcelable.Creator<PeerProxy> = PeerProxyCreator() + private val IPV4_PUBLIC_NETWORKS = setOf( + "0.0.0.0/5", "8.0.0.0/7", "11.0.0.0/8", "12.0.0.0/6", "16.0.0.0/4", "32.0.0.0/3", + "64.0.0.0/2", "128.0.0.0/3", "160.0.0.0/5", "168.0.0.0/6", "172.0.0.0/12", + "172.32.0.0/11", "172.64.0.0/10", "172.128.0.0/9", "173.0.0.0/8", "174.0.0.0/7", + "176.0.0.0/4", "192.0.0.0/9", "192.128.0.0/11", "192.160.0.0/13", "192.169.0.0/16", + "192.170.0.0/15", "192.172.0.0/14", "192.176.0.0/12", "192.192.0.0/10", + "193.0.0.0/8", "194.0.0.0/7", "196.0.0.0/6", "200.0.0.0/5", "208.0.0.0/4" + ) + private val IPV4_WILDCARD = setOf("0.0.0.0/0") + } +} |