aboutsummaryrefslogtreecommitdiffstats
path: root/WireGuard
diff options
context:
space:
mode:
Diffstat (limited to 'WireGuard')
-rw-r--r--WireGuard/WireGuard/Tunnel/TunnelsManager.swift57
1 files changed, 39 insertions, 18 deletions
diff --git a/WireGuard/WireGuard/Tunnel/TunnelsManager.swift b/WireGuard/WireGuard/Tunnel/TunnelsManager.swift
index 5640e6c8..3efadb5a 100644
--- a/WireGuard/WireGuard/Tunnel/TunnelsManager.swift
+++ b/WireGuard/WireGuard/Tunnel/TunnelsManager.swift
@@ -25,10 +25,12 @@ class TunnelsManager {
weak var activationDelegate: TunnelsManagerActivationDelegate?
private var statusObservationToken: AnyObject?
private var waiteeObservationToken: AnyObject?
+ private var configurationsObservationToken: AnyObject?
init(tunnelProviders: [NETunnelProviderManager]) {
tunnels = tunnelProviders.map { TunnelContainer(tunnel: $0) }.sorted { $0.name < $1.name }
startObservingTunnelStatuses()
+ startObservingTunnelConfigurations()
}
static func create(completionHandler: @escaping (WireGuardResult<TunnelsManager>) -> Void) {
@@ -53,26 +55,33 @@ class TunnelsManager {
#endif
}
- func reload(completionHandler: @escaping (Bool) -> Void) {
- #if targetEnvironment(simulator)
- completionHandler(false)
- #else
- NETunnelProviderManager.loadAllFromPreferences { managers, _ in
- guard let managers = managers else {
- completionHandler(false)
- return
- }
+ func reload() {
+ NETunnelProviderManager.loadAllFromPreferences { [weak self] managers, _ in
+ guard let self = self else { return }
- let newTunnels = managers.map { TunnelContainer(tunnel: $0) }.sorted { $0.name < $1.name }
- let hasChanges = self.tunnels.map { $0.tunnelConfiguration } != newTunnels.map { $0.tunnelConfiguration }
- if hasChanges {
- self.tunnels = newTunnels
- completionHandler(true)
- } else {
- completionHandler(false)
+ let loadedTunnelProviders = managers ?? []
+
+ var numberOfRemovedTunnels = 0
+ for (index, currentTunnel) in self.tunnels.enumerated() {
+ if !loadedTunnelProviders.contains(where: { $0.tunnelConfiguration == currentTunnel.tunnelConfiguration }) {
+ // Tunnel was deleted outside the app
+ self.tunnels.remove(at: index - numberOfRemovedTunnels)
+ self.tunnelsListDelegate?.tunnelRemoved(at: index - numberOfRemovedTunnels)
+ numberOfRemovedTunnels += 1
+ }
+ }
+ for loadedTunnelProvider in loadedTunnelProviders {
+ if let matchingTunnel = self.tunnels.first(where: { $0.tunnelConfiguration == loadedTunnelProvider.tunnelConfiguration }) {
+ matchingTunnel.tunnelProvider = loadedTunnelProvider
+ } else {
+ // Tunnel was added outside the app
+ let tunnel = TunnelContainer(tunnel: loadedTunnelProvider)
+ self.tunnels.append(tunnel)
+ self.tunnels.sort { $0.name < $1.name }
+ self.tunnelsListDelegate?.tunnelAdded(at: self.tunnels.firstIndex(of: tunnel)!)
+ }
}
}
- #endif
}
func add(tunnelConfiguration: TunnelConfiguration, activateOnDemandSetting: ActivateOnDemandSetting = ActivateOnDemandSetting.defaultSetting, completionHandler: @escaping (WireGuardResult<TunnelContainer>) -> Void) {
@@ -319,6 +328,12 @@ class TunnelsManager {
}
}
+ func startObservingTunnelConfigurations() {
+ configurationsObservationToken = NotificationCenter.default.addObserver(forName: .NEVPNConfigurationChange, object: nil, queue: OperationQueue.main) { [weak self] _ in
+ self?.reload()
+ }
+ }
+
}
private func lastErrorTextFromNetworkExtension(for tunnel: TunnelContainer) -> (title: String, message: String)? {
@@ -367,7 +382,7 @@ class TunnelContainer: NSObject {
fileprivate var tunnelProvider: NETunnelProviderManager
var tunnelConfiguration: TunnelConfiguration? {
- return (tunnelProvider.protocolConfiguration as? NETunnelProviderProtocol)?.asTunnelConfiguration(called: tunnelProvider.localizedDescription)
+ return tunnelProvider.tunnelConfiguration
}
var activateOnDemandSetting: ActivateOnDemandSetting {
@@ -461,3 +476,9 @@ class TunnelContainer: NSObject {
(tunnelProvider.connection as? NETunnelProviderSession)?.stopTunnel()
}
}
+
+extension NETunnelProviderManager {
+ var tunnelConfiguration: TunnelConfiguration? {
+ return (protocolConfiguration as? NETunnelProviderProtocol)?.asTunnelConfiguration(called: localizedDescription)
+ }
+}