diff options
Diffstat (limited to 'tunnel/src/main/java/com/wireguard/android/backend')
6 files changed, 922 insertions, 0 deletions
diff --git a/tunnel/src/main/java/com/wireguard/android/backend/Backend.java b/tunnel/src/main/java/com/wireguard/android/backend/Backend.java new file mode 100644 index 00000000..224d5849 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/backend/Backend.java @@ -0,0 +1,67 @@ +/* + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.backend; + +import com.wireguard.config.Config; +import com.wireguard.util.NonNullForAll; + +import java.util.Set; + +import androidx.annotation.Nullable; + +/** + * Interface for implementations of the WireGuard secure network tunnel. + */ + +@NonNullForAll +public interface Backend { + /** + * Enumerate names of currently-running tunnels. + * + * @return The set of running tunnel names. + */ + Set<String> getRunningTunnelNames(); + + /** + * Get the state of a tunnel. + * + * @param tunnel The tunnel to examine the state of. + * @return The state of the tunnel. + * @throws Exception Exception raised when retrieving tunnel's state. + */ + Tunnel.State getState(Tunnel tunnel) throws Exception; + + /** + * Get statistics about traffic and errors on this tunnel. If the tunnel is not running, the + * statistics object will be filled with zero values. + * + * @param tunnel The tunnel to retrieve statistics for. + * @return The statistics for the tunnel. + * @throws Exception Exception raised when retrieving statistics. + */ + Statistics getStatistics(Tunnel tunnel) throws Exception; + + /** + * Determine version of underlying backend. + * + * @return The version of the backend. + * @throws Exception Exception raised while retrieving version. + */ + String getVersion() throws Exception; + + /** + * Set the state of a tunnel, updating it's configuration. If the tunnel is already up, config + * may update the running configuration; config may be null when setting the tunnel down. + * + * @param tunnel The tunnel to control the state of. + * @param state The new state for this tunnel. Must be {@code UP}, {@code DOWN}, or + * {@code TOGGLE}. + * @param config The configuration for this tunnel, may be null if state is {@code DOWN}. + * @return The updated state of the tunnel. + * @throws Exception Exception raised while changing state. + */ + Tunnel.State setState(Tunnel tunnel, Tunnel.State state, @Nullable Config config) throws Exception; +} diff --git a/tunnel/src/main/java/com/wireguard/android/backend/BackendException.java b/tunnel/src/main/java/com/wireguard/android/backend/BackendException.java new file mode 100644 index 00000000..94f7b098 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/backend/BackendException.java @@ -0,0 +1,61 @@ +/* + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.backend; + +import com.wireguard.util.NonNullForAll; + +/** + * A subclass of {@link Exception} that encapsulates the reasons for a failure originating in + * implementations of {@link Backend}. + */ +@NonNullForAll +public final class BackendException extends Exception { + private final Object[] format; + private final Reason reason; + + /** + * Public constructor for BackendException. + * + * @param reason The {@link Reason} which caused this exception to be thrown + * @param format Format string values used when converting exceptions to user-facing strings. + */ + public BackendException(final Reason reason, final Object... format) { + this.reason = reason; + this.format = format; + } + + /** + * Get the format string values associated with the instance. + * + * @return Array of {@link Object} for string formatting purposes + */ + public Object[] getFormat() { + return format; + } + + /** + * Get the reason for this exception. + * + * @return Associated {@link Reason} for this exception. + */ + public Reason getReason() { + return reason; + } + + /** + * Enum class containing all known reasons for why a {@link BackendException} might be thrown. + */ + public enum Reason { + UNKNOWN_KERNEL_MODULE_NAME, + WG_QUICK_CONFIG_ERROR_CODE, + TUNNEL_MISSING_CONFIG, + VPN_NOT_AUTHORIZED, + UNABLE_TO_START_VPN, + TUN_CREATION_ERROR, + GO_ACTIVATION_ERROR_CODE, + DNS_RESOLUTION_FAILURE, + } +} diff --git a/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java b/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java new file mode 100644 index 00000000..6b66f2c5 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java @@ -0,0 +1,439 @@ +/* + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.backend; + +import android.content.Context; +import android.content.Intent; +import android.os.Build; +import android.os.ParcelFileDescriptor; +import android.system.OsConstants; +import android.util.Log; + +import com.wireguard.android.backend.BackendException.Reason; +import com.wireguard.android.backend.Tunnel.State; +import com.wireguard.android.util.SharedLibraryLoader; +import com.wireguard.config.Config; +import com.wireguard.config.InetEndpoint; +import com.wireguard.config.InetNetwork; +import com.wireguard.config.Peer; +import com.wireguard.crypto.Key; +import com.wireguard.crypto.KeyFormatException; +import com.wireguard.util.NonNullForAll; + +import java.net.InetAddress; +import java.time.Instant; +import java.util.Collections; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.FutureTask; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import androidx.annotation.Nullable; +import androidx.collection.ArraySet; + +/** + * Implementation of {@link Backend} that uses the wireguard-go userspace implementation to provide + * WireGuard tunnels. + */ +@NonNullForAll +public final class GoBackend implements Backend { + private static final int DNS_RESOLUTION_RETRIES = 10; + private static final String TAG = "WireGuard/GoBackend"; + @Nullable private static AlwaysOnCallback alwaysOnCallback; + private static GhettoCompletableFuture<VpnService> vpnService = new GhettoCompletableFuture<>(); + private final Context context; + @Nullable private Config currentConfig; + @Nullable private Tunnel currentTunnel; + private int currentTunnelHandle = -1; + + /** + * Public constructor for GoBackend. + * + * @param context An Android {@link Context} + */ + public GoBackend(final Context context) { + SharedLibraryLoader.loadSharedLibrary(context, "wg-go"); + this.context = context; + } + + /** + * Set a {@link AlwaysOnCallback} to be invoked when {@link VpnService} is started by the + * system's Always-On VPN mode. + * + * @param cb Callback to be invoked + */ + public static void setAlwaysOnCallback(final AlwaysOnCallback cb) { + alwaysOnCallback = cb; + } + + @Nullable private static native String wgGetConfig(int handle); + + private static native int wgGetSocketV4(int handle); + + private static native int wgGetSocketV6(int handle); + + private static native void wgTurnOff(int handle); + + private static native int wgTurnOn(String ifName, int tunFd, String settings); + + private static native String wgVersion(); + + /** + * Method to get the names of running tunnels. + * + * @return A set of string values denoting names of running tunnels. + */ + @Override + public Set<String> getRunningTunnelNames() { + if (currentTunnel != null) { + final Set<String> runningTunnels = new ArraySet<>(); + runningTunnels.add(currentTunnel.getName()); + return runningTunnels; + } + return Collections.emptySet(); + } + + /** + * Get the associated {@link State} for a given {@link Tunnel}. + * + * @param tunnel The tunnel to examine the state of. + * @return {@link State} associated with the given tunnel. + */ + @Override + public State getState(final Tunnel tunnel) { + return currentTunnel == tunnel ? State.UP : State.DOWN; + } + + /** + * Get the associated {@link Statistics} for a given {@link Tunnel}. + * + * @param tunnel The tunnel to retrieve statistics for. + * @return {@link Statistics} associated with the given tunnel. + */ + @Override + public Statistics getStatistics(final Tunnel tunnel) { + final Statistics stats = new Statistics(); + if (tunnel != currentTunnel || currentTunnelHandle == -1) + return stats; + final String config = wgGetConfig(currentTunnelHandle); + if (config == null) + return stats; + Key key = null; + long rx = 0; + long tx = 0; + long latestHandshakeMSec = 0; + for (final String line : config.split("\\n")) { + if (line.startsWith("public_key=")) { + if (key != null) + stats.add(key, rx, tx, latestHandshakeMSec); + rx = 0; + tx = 0; + latestHandshakeMSec = 0; + try { + key = Key.fromHex(line.substring(11)); + } catch (final KeyFormatException ignored) { + key = null; + } + } else if (line.startsWith("rx_bytes=")) { + if (key == null) + continue; + try { + rx = Long.parseLong(line.substring(9)); + } catch (final NumberFormatException ignored) { + rx = 0; + } + } else if (line.startsWith("tx_bytes=")) { + if (key == null) + continue; + try { + tx = Long.parseLong(line.substring(9)); + } catch (final NumberFormatException ignored) { + tx = 0; + } + } else if (line.startsWith("last_handshake_time_sec=")) { + if (key == null) + continue; + try { + latestHandshakeMSec += Long.parseLong(line.substring(24)) * 1000; + } catch (final NumberFormatException ignored) { + latestHandshakeMSec = 0; + } + } else if (line.startsWith("last_handshake_time_nsec=")) { + if (key == null) + continue; + try { + latestHandshakeMSec += Long.parseLong(line.substring(25)) / 1000000; + } catch (final NumberFormatException ignored) { + latestHandshakeMSec = 0; + } + } + } + if (key != null) + stats.add(key, rx, tx, latestHandshakeMSec); + return stats; + } + + /** + * Get the version of the underlying wireguard-go library. + * + * @return {@link String} value of the version of the wireguard-go library. + */ + @Override + public String getVersion() { + return wgVersion(); + } + + /** + * Change the state of a given {@link Tunnel}, optionally applying a given {@link Config}. + * + * @param tunnel The tunnel to control the state of. + * @param state The new state for this tunnel. Must be {@code UP}, {@code DOWN}, or + * {@code TOGGLE}. + * @param config The configuration for this tunnel, may be null if state is {@code DOWN}. + * @return {@link State} of the tunnel after state changes are applied. + * @throws Exception Exception raised while changing tunnel state. + */ + @Override + public State setState(final Tunnel tunnel, State state, @Nullable final Config config) throws Exception { + final State originalState = getState(tunnel); + + if (state == State.TOGGLE) + state = originalState == State.UP ? State.DOWN : State.UP; + if (state == originalState && tunnel == currentTunnel && config == currentConfig) + return originalState; + if (state == State.UP) { + final Config originalConfig = currentConfig; + final Tunnel originalTunnel = currentTunnel; + if (currentTunnel != null) + setStateInternal(currentTunnel, null, State.DOWN); + try { + setStateInternal(tunnel, config, state); + } catch (final Exception e) { + if (originalTunnel != null) + setStateInternal(originalTunnel, originalConfig, State.UP); + throw e; + } + } else if (state == State.DOWN && tunnel == currentTunnel) { + setStateInternal(tunnel, null, State.DOWN); + } + return getState(tunnel); + } + + private void setStateInternal(final Tunnel tunnel, @Nullable final Config config, final State state) + throws Exception { + Log.i(TAG, "Bringing tunnel " + tunnel.getName() + ' ' + state); + + if (state == State.UP) { + if (config == null) + throw new BackendException(Reason.TUNNEL_MISSING_CONFIG); + + if (VpnService.prepare(context) != null) + throw new BackendException(Reason.VPN_NOT_AUTHORIZED); + + final VpnService service; + if (!vpnService.isDone()) { + Log.d(TAG, "Requesting to start VpnService"); + context.startService(new Intent(context, VpnService.class)); + } + + try { + service = vpnService.get(2, TimeUnit.SECONDS); + } catch (final TimeoutException e) { + final Exception be = new BackendException(Reason.UNABLE_TO_START_VPN); + be.initCause(e); + throw be; + } + service.setOwner(this); + + if (currentTunnelHandle != -1) { + Log.w(TAG, "Tunnel already up"); + return; + } + + + dnsRetry: for (int i = 0; i < DNS_RESOLUTION_RETRIES; ++i) { + // Pre-resolve IPs so they're cached when building the userspace string + for (final Peer peer : config.getPeers()) { + final InetEndpoint ep = peer.getEndpoint().orElse(null); + if (ep == null) + continue; + if (ep.getResolved().orElse(null) == null) { + if (i < DNS_RESOLUTION_RETRIES - 1) { + Log.w(TAG, "DNS host \"" + ep.getHost() + "\" failed to resolve; trying again"); + Thread.sleep(1000); + continue dnsRetry; + } else + throw new BackendException(Reason.DNS_RESOLUTION_FAILURE, ep.getHost()); + } + } + break; + } + + // Build config + final String goConfig = config.toWgUserspaceString(); + + // Create the vpn tunnel with android API + final VpnService.Builder builder = service.getBuilder(); + builder.setSession(tunnel.getName()); + + for (final String excludedApplication : config.getInterface().getExcludedApplications()) + builder.addDisallowedApplication(excludedApplication); + + for (final String includedApplication : config.getInterface().getIncludedApplications()) + builder.addAllowedApplication(includedApplication); + + for (final InetNetwork addr : config.getInterface().getAddresses()) + builder.addAddress(addr.getAddress(), addr.getMask()); + + for (final InetAddress addr : config.getInterface().getDnsServers()) + builder.addDnsServer(addr.getHostAddress()); + + for (final String dnsSearchDomain : config.getInterface().getDnsSearchDomains()) + builder.addSearchDomain(dnsSearchDomain); + + boolean sawDefaultRoute = false; + for (final Peer peer : config.getPeers()) { + for (final InetNetwork addr : peer.getAllowedIps()) { + if (addr.getMask() == 0) + sawDefaultRoute = true; + builder.addRoute(addr.getAddress(), addr.getMask()); + } + } + + // "Kill-switch" semantics + if (!(sawDefaultRoute && config.getPeers().size() == 1)) { + builder.allowFamily(OsConstants.AF_INET); + builder.allowFamily(OsConstants.AF_INET6); + } + + builder.setMtu(config.getInterface().getMtu().orElse(1280)); + + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) + builder.setMetered(false); + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) + service.setUnderlyingNetworks(null); + + builder.setBlocking(true); + try (final ParcelFileDescriptor tun = builder.establish()) { + if (tun == null) + throw new BackendException(Reason.TUN_CREATION_ERROR); + Log.d(TAG, "Go backend " + wgVersion()); + currentTunnelHandle = wgTurnOn(tunnel.getName(), tun.detachFd(), goConfig); + } + if (currentTunnelHandle < 0) + throw new BackendException(Reason.GO_ACTIVATION_ERROR_CODE, currentTunnelHandle); + + currentTunnel = tunnel; + currentConfig = config; + + service.protect(wgGetSocketV4(currentTunnelHandle)); + service.protect(wgGetSocketV6(currentTunnelHandle)); + } else { + if (currentTunnelHandle == -1) { + Log.w(TAG, "Tunnel already down"); + return; + } + int handleToClose = currentTunnelHandle; + currentTunnel = null; + currentTunnelHandle = -1; + currentConfig = null; + wgTurnOff(handleToClose); + try { + vpnService.get(0, TimeUnit.NANOSECONDS).stopSelf(); + } catch (final TimeoutException ignored) { } + } + + tunnel.onStateChange(state); + } + + /** + * Callback for {@link GoBackend} that is invoked when {@link VpnService} is started by the + * system's Always-On VPN mode. + */ + public interface AlwaysOnCallback { + void alwaysOnTriggered(); + } + + // TODO: When we finally drop API 21 and move to API 24, delete this and replace with the ordinary CompletableFuture. + private static final class GhettoCompletableFuture<V> { + private final LinkedBlockingQueue<V> completion = new LinkedBlockingQueue<>(1); + private final FutureTask<V> result = new FutureTask<>(completion::peek); + + public boolean complete(final V value) { + final boolean offered = completion.offer(value); + if (offered) + result.run(); + return offered; + } + + public V get() throws ExecutionException, InterruptedException { + return result.get(); + } + + public V get(final long timeout, final TimeUnit unit) throws ExecutionException, InterruptedException, TimeoutException { + return result.get(timeout, unit); + } + + public boolean isDone() { + return !completion.isEmpty(); + } + + public GhettoCompletableFuture<V> newIncompleteFuture() { + return new GhettoCompletableFuture<>(); + } + } + + /** + * {@link android.net.VpnService} implementation for {@link GoBackend} + */ + public static class VpnService extends android.net.VpnService { + @Nullable private GoBackend owner; + + public Builder getBuilder() { + return new Builder(); + } + + @Override + public void onCreate() { + vpnService.complete(this); + super.onCreate(); + } + + @Override + public void onDestroy() { + if (owner != null) { + final Tunnel tunnel = owner.currentTunnel; + if (tunnel != null) { + if (owner.currentTunnelHandle != -1) + wgTurnOff(owner.currentTunnelHandle); + owner.currentTunnel = null; + owner.currentTunnelHandle = -1; + owner.currentConfig = null; + tunnel.onStateChange(State.DOWN); + } + } + vpnService = vpnService.newIncompleteFuture(); + super.onDestroy(); + } + + @Override + public int onStartCommand(@Nullable final Intent intent, final int flags, final int startId) { + vpnService.complete(this); + if (intent == null || intent.getComponent() == null || !intent.getComponent().getPackageName().equals(getPackageName())) { + Log.d(TAG, "Service started by Always-on VPN feature"); + if (alwaysOnCallback != null) + alwaysOnCallback.alwaysOnTriggered(); + } + return super.onStartCommand(intent, flags, startId); + } + + public void setOwner(final GoBackend owner) { + this.owner = owner; + } + } +} diff --git a/tunnel/src/main/java/com/wireguard/android/backend/Statistics.java b/tunnel/src/main/java/com/wireguard/android/backend/Statistics.java new file mode 100644 index 00000000..d5d41c5f --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/backend/Statistics.java @@ -0,0 +1,102 @@ +/* + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.backend; + +import android.os.SystemClock; + +import com.wireguard.crypto.Key; +import com.wireguard.util.NonNullForAll; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import androidx.annotation.Nullable; + +/** + * Class representing transfer statistics for a {@link Tunnel} instance. + */ +@NonNullForAll +public class Statistics { + public record PeerStats(long rxBytes, long txBytes, long latestHandshakeEpochMillis) { } + private final Map<Key, PeerStats> stats = new HashMap<>(); + private long lastTouched = SystemClock.elapsedRealtime(); + + Statistics() { + } + + /** + * Add a peer and its current stats to the internal map. + * + * @param key A WireGuard public key bound to a particular peer + * @param rxBytes The received traffic for the {@link com.wireguard.config.Peer} referenced by + * the provided {@link Key}. This value is in bytes + * @param txBytes The transmitted traffic for the {@link com.wireguard.config.Peer} referenced by + * the provided {@link Key}. This value is in bytes. + * @param latestHandshake The timestamp of the latest handshake for the {@link com.wireguard.config.Peer} + * referenced by the provided {@link Key}. The value is in epoch milliseconds. + */ + void add(final Key key, final long rxBytes, final long txBytes, final long latestHandshake) { + stats.put(key, new PeerStats(rxBytes, txBytes, latestHandshake)); + lastTouched = SystemClock.elapsedRealtime(); + } + + /** + * Check if the statistics are stale, indicating the need for the {@link Backend} to update them. + * + * @return boolean indicating if the current statistics instance has stale values. + */ + public boolean isStale() { + return SystemClock.elapsedRealtime() - lastTouched > 900; + } + + /** + * Get the statistics for the {@link com.wireguard.config.Peer} referenced by the provided {@link Key} + * + * @param peer A {@link Key} representing a {@link com.wireguard.config.Peer}. + * @return a {@link PeerStats} representing various statistics about this peer. + */ + @Nullable + public PeerStats peer(final Key peer) { + return stats.get(peer); + } + + /** + * Get the list of peers being tracked by this instance. + * + * @return An array of {@link Key} instances representing WireGuard + * {@link com.wireguard.config.Peer}s + */ + public Key[] peers() { + return stats.keySet().toArray(new Key[0]); + } + + /** + * Get the total received traffic by all the peers being tracked by this instance + * + * @return a long representing the number of bytes received by the peers being tracked. + */ + public long totalRx() { + long rx = 0; + for (final PeerStats val : stats.values()) { + rx += val.rxBytes; + } + return rx; + } + + /** + * Get the total transmitted traffic by all the peers being tracked by this instance + * + * @return a long representing the number of bytes transmitted by the peers being tracked. + */ + public long totalTx() { + long tx = 0; + for (final PeerStats val : stats.values()) { + tx += val.txBytes; + } + return tx; + } +} diff --git a/tunnel/src/main/java/com/wireguard/android/backend/Tunnel.java b/tunnel/src/main/java/com/wireguard/android/backend/Tunnel.java new file mode 100644 index 00000000..dbc91c27 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/backend/Tunnel.java @@ -0,0 +1,57 @@ +/* + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.backend; + +import com.wireguard.util.NonNullForAll; + +import java.util.regex.Pattern; + +/** + * Represents a WireGuard tunnel. + */ + +@NonNullForAll +public interface Tunnel { + int NAME_MAX_LENGTH = 15; + Pattern NAME_PATTERN = Pattern.compile("[a-zA-Z0-9_=+.-]{1,15}"); + + static boolean isNameInvalid(final CharSequence name) { + return !NAME_PATTERN.matcher(name).matches(); + } + + /** + * Get the name of the tunnel, which should always pass the !isNameInvalid test. + * + * @return The name of the tunnel. + */ + String getName(); + + /** + * React to a change in state of the tunnel. Should only be directly called by Backend. + * + * @param newState The new state of the tunnel. + */ + void onStateChange(State newState); + + /** + * Enum class to represent all possible states of a {@link Tunnel}. + */ + enum State { + DOWN, + TOGGLE, + UP; + + /** + * Get the state of a {@link Tunnel} + * + * @param running boolean indicating if the tunnel is running. + * @return State of the tunnel based on whether or not it is running. + */ + public static State of(final boolean running) { + return running ? UP : DOWN; + } + } +} diff --git a/tunnel/src/main/java/com/wireguard/android/backend/WgQuickBackend.java b/tunnel/src/main/java/com/wireguard/android/backend/WgQuickBackend.java new file mode 100644 index 00000000..87fdf6e5 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/backend/WgQuickBackend.java @@ -0,0 +1,196 @@ +/* + * Copyright © 2017-2025 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.backend; + +import android.content.Context; +import android.util.Log; +import android.util.Pair; + +import com.wireguard.android.backend.BackendException.Reason; +import com.wireguard.android.backend.Tunnel.State; +import com.wireguard.android.util.RootShell; +import com.wireguard.android.util.ToolsInstaller; +import com.wireguard.config.Config; +import com.wireguard.crypto.Key; +import com.wireguard.util.NonNullForAll; + +import java.io.File; +import java.io.FileOutputStream; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +import androidx.annotation.Nullable; + +/** + * Implementation of {@link Backend} that uses the kernel module and {@code wg-quick} to provide + * WireGuard tunnels. + */ + +@NonNullForAll +public final class WgQuickBackend implements Backend { + private static final String TAG = "WireGuard/WgQuickBackend"; + private final File localTemporaryDir; + private final RootShell rootShell; + private final Map<Tunnel, Config> runningConfigs = new HashMap<>(); + private final ToolsInstaller toolsInstaller; + private boolean multipleTunnels; + + public WgQuickBackend(final Context context, final RootShell rootShell, final ToolsInstaller toolsInstaller) { + localTemporaryDir = new File(context.getCacheDir(), "tmp"); + this.rootShell = rootShell; + this.toolsInstaller = toolsInstaller; + } + + public static boolean hasKernelSupport() { + return new File("/sys/module/wireguard").exists(); + } + + @Override + public Set<String> getRunningTunnelNames() { + final List<String> output = new ArrayList<>(); + // Don't throw an exception here or nothing will show up in the UI. + try { + toolsInstaller.ensureToolsAvailable(); + if (rootShell.run(output, "wg show interfaces") != 0 || output.isEmpty()) + return Collections.emptySet(); + } catch (final Exception e) { + Log.w(TAG, "Unable to enumerate running tunnels", e); + return Collections.emptySet(); + } + // wg puts all interface names on the same line. Split them into separate elements. + return Set.of(output.get(0).split(" ")); + } + + @Override + public State getState(final Tunnel tunnel) { + return getRunningTunnelNames().contains(tunnel.getName()) ? State.UP : State.DOWN; + } + + @Override + public Statistics getStatistics(final Tunnel tunnel) { + final Statistics stats = new Statistics(); + final Collection<String> output = new ArrayList<>(); + try { + if (rootShell.run(output, String.format("wg show '%s' dump", tunnel.getName())) != 0) + return stats; + } catch (final Exception ignored) { + return stats; + } + for (final String line : output) { + final String[] parts = line.split("\\t"); + if (parts.length != 8) + continue; + try { + stats.add(Key.fromBase64(parts[0]), Long.parseLong(parts[5]), Long.parseLong(parts[6]), Long.parseLong(parts[4]) * 1000); + } catch (final Exception ignored) { + } + } + return stats; + } + + @Override + public String getVersion() throws Exception { + final List<String> output = new ArrayList<>(); + if (rootShell.run(output, "cat /sys/module/wireguard/version") != 0 || output.isEmpty()) + throw new BackendException(Reason.UNKNOWN_KERNEL_MODULE_NAME); + return output.get(0); + } + + public void setMultipleTunnels(final boolean on) { + multipleTunnels = on; + } + + @Override + public State setState(final Tunnel tunnel, State state, @Nullable final Config config) throws Exception { + final State originalState = getState(tunnel); + final Config originalConfig = runningConfigs.get(tunnel); + final Map<Tunnel, Config> runningConfigsSnapshot = new HashMap<>(runningConfigs); + + if (state == State.TOGGLE) + state = originalState == State.UP ? State.DOWN : State.UP; + if ((state == State.UP && originalState == State.UP && originalConfig != null && originalConfig == config) || + (state == State.DOWN && originalState == State.DOWN)) + return originalState; + if (state == State.UP) { + toolsInstaller.ensureToolsAvailable(); + if (!multipleTunnels && originalState == State.DOWN) { + final List<Pair<Tunnel, Config>> rewind = new LinkedList<>(); + try { + for (final Map.Entry<Tunnel, Config> entry : runningConfigsSnapshot.entrySet()) { + setStateInternal(entry.getKey(), entry.getValue(), State.DOWN); + rewind.add(Pair.create(entry.getKey(), entry.getValue())); + } + } catch (final Exception e) { + try { + for (final Pair<Tunnel, Config> entry : rewind) { + setStateInternal(entry.first, entry.second, State.UP); + } + } catch (final Exception ignored) { + } + throw e; + } + } + if (originalState == State.UP) + setStateInternal(tunnel, originalConfig == null ? config : originalConfig, State.DOWN); + try { + setStateInternal(tunnel, config, State.UP); + } catch (final Exception e) { + try { + if (originalState == State.UP && originalConfig != null) { + setStateInternal(tunnel, originalConfig, State.UP); + } + if (!multipleTunnels && originalState == State.DOWN) { + for (final Map.Entry<Tunnel, Config> entry : runningConfigsSnapshot.entrySet()) { + setStateInternal(entry.getKey(), entry.getValue(), State.UP); + } + } + } catch (final Exception ignored) { + } + throw e; + } + } else if (state == State.DOWN) { + setStateInternal(tunnel, originalConfig == null ? config : originalConfig, State.DOWN); + } + return state; + } + + private void setStateInternal(final Tunnel tunnel, @Nullable final Config config, final State state) throws Exception { + Log.i(TAG, "Bringing tunnel " + tunnel.getName() + ' ' + state); + + Objects.requireNonNull(config, "Trying to set state up with a null config"); + + final File tempFile = new File(localTemporaryDir, tunnel.getName() + ".conf"); + try (final FileOutputStream stream = new FileOutputStream(tempFile, false)) { + stream.write(config.toWgQuickString().getBytes(StandardCharsets.UTF_8)); + } + String command = String.format("wg-quick %s '%s'", + state.toString().toLowerCase(Locale.ENGLISH), tempFile.getAbsolutePath()); + if (state == State.UP) + command = "cat /sys/module/wireguard/version && " + command; + final int result = rootShell.run(null, command); + // noinspection ResultOfMethodCallIgnored + tempFile.delete(); + if (result != 0) + throw new BackendException(Reason.WG_QUICK_CONFIG_ERROR_CODE, result); + + if (state == State.UP) + runningConfigs.put(tunnel, config); + else + runningConfigs.remove(tunnel); + + tunnel.onStateChange(state); + } +} |