aboutsummaryrefslogtreecommitdiffstats
path: root/Sources/Shared
diff options
context:
space:
mode:
Diffstat (limited to 'Sources/Shared')
-rw-r--r--Sources/Shared/FileManager+Extension.swift50
-rw-r--r--Sources/Shared/Keychain.swift114
-rw-r--r--Sources/Shared/Logging/Logger.swift65
-rw-r--r--Sources/Shared/Logging/ringlogger.c173
-rw-r--r--Sources/Shared/Logging/ringlogger.h18
-rw-r--r--Sources/Shared/Logging/test_ringlogger.c63
-rw-r--r--Sources/Shared/Model/NETunnelProviderProtocol+Extension.swift106
-rw-r--r--Sources/Shared/Model/String+ArrayConversion.swift32
-rw-r--r--Sources/Shared/Model/TunnelConfiguration+WgQuickConfig.swift252
-rw-r--r--Sources/Shared/NotificationToken.swift33
10 files changed, 906 insertions, 0 deletions
diff --git a/Sources/Shared/FileManager+Extension.swift b/Sources/Shared/FileManager+Extension.swift
new file mode 100644
index 0000000..48fa33f
--- /dev/null
+++ b/Sources/Shared/FileManager+Extension.swift
@@ -0,0 +1,50 @@
+// SPDX-License-Identifier: MIT
+// Copyright © 2018-2023 WireGuard LLC. All Rights Reserved.
+
+import Foundation
+import os.log
+
+extension FileManager {
+ static var appGroupId: String? {
+ #if os(iOS)
+ let appGroupIdInfoDictionaryKey = "com.wireguard.ios.app_group_id"
+ #elseif os(macOS)
+ let appGroupIdInfoDictionaryKey = "com.wireguard.macos.app_group_id"
+ #else
+ #error("Unimplemented")
+ #endif
+ return Bundle.main.object(forInfoDictionaryKey: appGroupIdInfoDictionaryKey) as? String
+ }
+ private static var sharedFolderURL: URL? {
+ guard let appGroupId = FileManager.appGroupId else {
+ os_log("Cannot obtain app group ID from bundle", log: OSLog.default, type: .error)
+ return nil
+ }
+ guard let sharedFolderURL = FileManager.default.containerURL(forSecurityApplicationGroupIdentifier: appGroupId) else {
+ wg_log(.error, message: "Cannot obtain shared folder URL")
+ return nil
+ }
+ return sharedFolderURL
+ }
+
+ static var logFileURL: URL? {
+ return sharedFolderURL?.appendingPathComponent("tunnel-log.bin")
+ }
+
+ static var networkExtensionLastErrorFileURL: URL? {
+ return sharedFolderURL?.appendingPathComponent("last-error.txt")
+ }
+
+ static var loginHelperTimestampURL: URL? {
+ return sharedFolderURL?.appendingPathComponent("login-helper-timestamp.bin")
+ }
+
+ static func deleteFile(at url: URL) -> Bool {
+ do {
+ try FileManager.default.removeItem(at: url)
+ } catch {
+ return false
+ }
+ return true
+ }
+}
diff --git a/Sources/Shared/Keychain.swift b/Sources/Shared/Keychain.swift
new file mode 100644
index 0000000..2e0e7f0
--- /dev/null
+++ b/Sources/Shared/Keychain.swift
@@ -0,0 +1,114 @@
+// SPDX-License-Identifier: MIT
+// Copyright © 2018-2023 WireGuard LLC. All Rights Reserved.
+
+import Foundation
+import Security
+
+class Keychain {
+ static func openReference(called ref: Data) -> String? {
+ var result: CFTypeRef?
+ let ret = SecItemCopyMatching([kSecValuePersistentRef: ref,
+ kSecReturnData: true] as CFDictionary,
+ &result)
+ if ret != errSecSuccess || result == nil {
+ wg_log(.error, message: "Unable to open config from keychain: \(ret)")
+ return nil
+ }
+ guard let data = result as? Data else { return nil }
+ return String(data: data, encoding: String.Encoding.utf8)
+ }
+
+ static func makeReference(containing value: String, called name: String, previouslyReferencedBy oldRef: Data? = nil) -> Data? {
+ var ret: OSStatus
+ guard var bundleIdentifier = Bundle.main.bundleIdentifier else {
+ wg_log(.error, staticMessage: "Unable to determine bundle identifier")
+ return nil
+ }
+ if bundleIdentifier.hasSuffix(".network-extension") {
+ bundleIdentifier.removeLast(".network-extension".count)
+ }
+ let itemLabel = "WireGuard Tunnel: \(name)"
+ var items: [CFString: Any] = [kSecClass: kSecClassGenericPassword,
+ kSecAttrLabel: itemLabel,
+ kSecAttrAccount: name + ": " + UUID().uuidString,
+ kSecAttrDescription: "wg-quick(8) config",
+ kSecAttrService: bundleIdentifier,
+ kSecValueData: value.data(using: .utf8) as Any,
+ kSecReturnPersistentRef: true]
+
+ #if os(iOS)
+ items[kSecAttrAccessGroup] = FileManager.appGroupId
+ items[kSecAttrAccessible] = kSecAttrAccessibleAfterFirstUnlock
+ #elseif os(macOS)
+ items[kSecAttrSynchronizable] = false
+ items[kSecAttrAccessible] = kSecAttrAccessibleAfterFirstUnlockThisDeviceOnly
+
+ guard let extensionPath = Bundle.main.builtInPlugInsURL?.appendingPathComponent("WireGuardNetworkExtension.appex", isDirectory: true).path else {
+ wg_log(.error, staticMessage: "Unable to determine app extension path")
+ return nil
+ }
+ var extensionApp: SecTrustedApplication?
+ var mainApp: SecTrustedApplication?
+ ret = SecTrustedApplicationCreateFromPath(extensionPath, &extensionApp)
+ if ret != kOSReturnSuccess || extensionApp == nil {
+ wg_log(.error, message: "Unable to create keychain extension trusted application object: \(ret)")
+ return nil
+ }
+ ret = SecTrustedApplicationCreateFromPath(nil, &mainApp)
+ if ret != errSecSuccess || mainApp == nil {
+ wg_log(.error, message: "Unable to create keychain local trusted application object: \(ret)")
+ return nil
+ }
+ var access: SecAccess?
+ ret = SecAccessCreate(itemLabel as CFString, [extensionApp!, mainApp!] as CFArray, &access)
+ if ret != errSecSuccess || access == nil {
+ wg_log(.error, message: "Unable to create keychain ACL object: \(ret)")
+ return nil
+ }
+ items[kSecAttrAccess] = access!
+ #else
+ #error("Unimplemented")
+ #endif
+
+ var ref: CFTypeRef?
+ ret = SecItemAdd(items as CFDictionary, &ref)
+ if ret != errSecSuccess || ref == nil {
+ wg_log(.error, message: "Unable to add config to keychain: \(ret)")
+ return nil
+ }
+ if let oldRef = oldRef {
+ deleteReference(called: oldRef)
+ }
+ return ref as? Data
+ }
+
+ static func deleteReference(called ref: Data) {
+ let ret = SecItemDelete([kSecValuePersistentRef: ref] as CFDictionary)
+ if ret != errSecSuccess {
+ wg_log(.error, message: "Unable to delete config from keychain: \(ret)")
+ }
+ }
+
+ static func deleteReferences(except whitelist: Set<Data>) {
+ var result: CFTypeRef?
+ let ret = SecItemCopyMatching([kSecClass: kSecClassGenericPassword,
+ kSecAttrService: Bundle.main.bundleIdentifier as Any,
+ kSecMatchLimit: kSecMatchLimitAll,
+ kSecReturnPersistentRef: true] as CFDictionary,
+ &result)
+ if ret != errSecSuccess || result == nil {
+ return
+ }
+ guard let items = result as? [Data] else { return }
+ for item in items {
+ if !whitelist.contains(item) {
+ deleteReference(called: item)
+ }
+ }
+ }
+
+ static func verifyReference(called ref: Data) -> Bool {
+ return SecItemCopyMatching([kSecValuePersistentRef: ref] as CFDictionary,
+ nil) != errSecItemNotFound
+ }
+}
diff --git a/Sources/Shared/Logging/Logger.swift b/Sources/Shared/Logging/Logger.swift
new file mode 100644
index 0000000..f3ee2b7
--- /dev/null
+++ b/Sources/Shared/Logging/Logger.swift
@@ -0,0 +1,65 @@
+// SPDX-License-Identifier: MIT
+// Copyright © 2018-2023 WireGuard LLC. All Rights Reserved.
+
+import Foundation
+import os.log
+
+public class Logger {
+ enum LoggerError: Error {
+ case openFailure
+ }
+
+ static var global: Logger?
+
+ var log: OpaquePointer
+ var tag: String
+
+ init(tagged tag: String, withFilePath filePath: String) throws {
+ guard let log = open_log(filePath) else { throw LoggerError.openFailure }
+ self.log = log
+ self.tag = tag
+ }
+
+ deinit {
+ close_log(self.log)
+ }
+
+ func log(message: String) {
+ write_msg_to_log(log, tag, message.trimmingCharacters(in: .newlines))
+ }
+
+ func writeLog(to targetFile: String) -> Bool {
+ return write_log_to_file(targetFile, self.log) == 0
+ }
+
+ static func configureGlobal(tagged tag: String, withFilePath filePath: String?) {
+ if Logger.global != nil {
+ return
+ }
+ guard let filePath = filePath else {
+ os_log("Unable to determine log destination path. Log will not be saved to file.", log: OSLog.default, type: .error)
+ return
+ }
+ guard let logger = try? Logger(tagged: tag, withFilePath: filePath) else {
+ os_log("Unable to open log file for writing. Log will not be saved to file.", log: OSLog.default, type: .error)
+ return
+ }
+ Logger.global = logger
+ var appVersion = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String ?? "Unknown version"
+ if let appBuild = Bundle.main.infoDictionary?["CFBundleVersion"] as? String {
+ appVersion += " (\(appBuild))"
+ }
+
+ Logger.global?.log(message: "App version: \(appVersion)")
+ }
+}
+
+func wg_log(_ type: OSLogType, staticMessage msg: StaticString) {
+ os_log(msg, log: OSLog.default, type: type)
+ Logger.global?.log(message: "\(msg)")
+}
+
+func wg_log(_ type: OSLogType, message msg: String) {
+ os_log("%{public}s", log: OSLog.default, type: type, msg)
+ Logger.global?.log(message: msg)
+}
diff --git a/Sources/Shared/Logging/ringlogger.c b/Sources/Shared/Logging/ringlogger.c
new file mode 100644
index 0000000..9bb0d13
--- /dev/null
+++ b/Sources/Shared/Logging/ringlogger.c
@@ -0,0 +1,173 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright © 2018-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+#include <string.h>
+#include <stdio.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <stdatomic.h>
+#include <stdbool.h>
+#include <time.h>
+#include <errno.h>
+#include <unistd.h>
+#include <fcntl.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <sys/time.h>
+#include <sys/mman.h>
+#include "ringlogger.h"
+
+enum {
+ MAX_LOG_LINE_LENGTH = 512,
+ MAX_LINES = 2048,
+ MAGIC = 0xabadbeefU
+};
+
+struct log_line {
+ atomic_uint_fast64_t time_ns;
+ char line[MAX_LOG_LINE_LENGTH];
+};
+
+struct log {
+ atomic_uint_fast32_t next_index;
+ struct log_line lines[MAX_LINES];
+ uint32_t magic;
+};
+
+void write_msg_to_log(struct log *log, const char *tag, const char *msg)
+{
+ uint32_t index;
+ struct log_line *line;
+ struct timespec ts;
+
+ // Race: This isn't synchronized with the fetch_add below, so items might be slightly out of order.
+ clock_gettime(CLOCK_REALTIME, &ts);
+
+ // Race: More than MAX_LINES writers and this will clash.
+ index = atomic_fetch_add(&log->next_index, 1);
+ line = &log->lines[index % MAX_LINES];
+
+ // Race: Before this line executes, we'll display old data after new data.
+ atomic_store(&line->time_ns, 0);
+ memset(line->line, 0, MAX_LOG_LINE_LENGTH);
+
+ snprintf(line->line, MAX_LOG_LINE_LENGTH, "[%s] %s", tag, msg);
+ atomic_store(&line->time_ns, ts.tv_sec * 1000000000ULL + ts.tv_nsec);
+
+ msync(&log->next_index, sizeof(log->next_index), MS_ASYNC);
+ msync(line, sizeof(*line), MS_ASYNC);
+}
+
+int write_log_to_file(const char *file_name, const struct log *input_log)
+{
+ struct log *log;
+ uint32_t l, i;
+ FILE *file;
+ int ret;
+
+ log = malloc(sizeof(*log));
+ if (!log)
+ return -errno;
+ memcpy(log, input_log, sizeof(*log));
+
+ file = fopen(file_name, "w");
+ if (!file) {
+ free(log);
+ return -errno;
+ }
+
+ for (l = 0, i = log->next_index; l < MAX_LINES; ++l, ++i) {
+ const struct log_line *line = &log->lines[i % MAX_LINES];
+ time_t seconds = line->time_ns / 1000000000ULL;
+ uint32_t useconds = (line->time_ns % 1000000000ULL) / 1000ULL;
+ struct tm tm;
+
+ if (!line->time_ns)
+ continue;
+
+ if (!localtime_r(&seconds, &tm))
+ goto err;
+
+ if (fprintf(file, "%04d-%02d-%02d %02d:%02d:%02d.%06d: %s\n",
+ tm.tm_year + 1900, tm.tm_mon + 1, tm.tm_mday,
+ tm.tm_hour, tm.tm_min, tm.tm_sec, useconds,
+ line->line) < 0)
+ goto err;
+
+
+ }
+ errno = 0;
+
+err:
+ ret = -errno;
+ fclose(file);
+ free(log);
+ return ret;
+}
+
+uint32_t view_lines_from_cursor(const struct log *input_log, uint32_t cursor, void *ctx, void(*cb)(const char *, uint64_t, void *))
+{
+ struct log *log;
+ uint32_t l, i = cursor;
+
+ log = malloc(sizeof(*log));
+ if (!log)
+ return cursor;
+ memcpy(log, input_log, sizeof(*log));
+
+ if (i == -1)
+ i = log->next_index;
+
+ for (l = 0; l < MAX_LINES; ++l, ++i) {
+ const struct log_line *line = &log->lines[i % MAX_LINES];
+
+ if (cursor != -1 && i % MAX_LINES == log->next_index % MAX_LINES)
+ break;
+
+ if (!line->time_ns) {
+ if (cursor == -1)
+ continue;
+ else
+ break;
+ }
+ cb(line->line, line->time_ns, ctx);
+ cursor = (i + 1) % MAX_LINES;
+ }
+ free(log);
+ return cursor;
+}
+
+struct log *open_log(const char *file_name)
+{
+ int fd;
+ struct log *log;
+
+ fd = open(file_name, O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
+ if (fd < 0)
+ return NULL;
+ if (ftruncate(fd, sizeof(*log)))
+ goto err;
+ log = mmap(NULL, sizeof(*log), PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
+ if (log == MAP_FAILED)
+ goto err;
+ close(fd);
+
+ if (log->magic != MAGIC) {
+ memset(log, 0, sizeof(*log));
+ log->magic = MAGIC;
+ msync(log, sizeof(*log), MS_ASYNC);
+ }
+
+ return log;
+
+err:
+ close(fd);
+ return NULL;
+}
+
+void close_log(struct log *log)
+{
+ munmap(log, sizeof(*log));
+}
diff --git a/Sources/Shared/Logging/ringlogger.h b/Sources/Shared/Logging/ringlogger.h
new file mode 100644
index 0000000..0e28c93
--- /dev/null
+++ b/Sources/Shared/Logging/ringlogger.h
@@ -0,0 +1,18 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright © 2018-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+#ifndef RINGLOGGER_H
+#define RINGLOGGER_H
+
+#include <stdint.h>
+
+struct log;
+void write_msg_to_log(struct log *log, const char *tag, const char *msg);
+int write_log_to_file(const char *file_name, const struct log *input_log);
+uint32_t view_lines_from_cursor(const struct log *input_log, uint32_t cursor, void *ctx, void(*)(const char *, uint64_t, void *));
+struct log *open_log(const char *file_name);
+void close_log(struct log *log);
+
+#endif
diff --git a/Sources/Shared/Logging/test_ringlogger.c b/Sources/Shared/Logging/test_ringlogger.c
new file mode 100644
index 0000000..ae3f4a9
--- /dev/null
+++ b/Sources/Shared/Logging/test_ringlogger.c
@@ -0,0 +1,63 @@
+#include "ringlogger.h"
+#include <stdio.h>
+#include <stdbool.h>
+#include <string.h>
+#include <unistd.h>
+#include <inttypes.h>
+#include <sys/wait.h>
+
+static void forkwrite(void)
+{
+ struct log *log = open_log("/tmp/test_log");
+ char c[512];
+ int i, base;
+ bool in_fork = !fork();
+
+ base = 10000 * in_fork;
+ for (i = 0; i < 1024; ++i) {
+ snprintf(c, 512, "bla bla bla %d", base + i);
+ write_msg_to_log(log, "HMM", c);
+ }
+
+
+ if (in_fork)
+ _exit(0);
+ wait(NULL);
+
+ write_log_to_file("/dev/stdout", log);
+ close_log(log);
+}
+
+static void writetext(const char *text)
+{
+ struct log *log = open_log("/tmp/test_log");
+ write_msg_to_log(log, "TXT", text);
+ close_log(log);
+}
+
+static void show_line(const char *line, uint64_t time_ns)
+{
+ printf("%" PRIu64 ": %s\n", time_ns, line);
+}
+
+static void follow(void)
+{
+ uint32_t cursor = -1;
+ struct log *log = open_log("/tmp/test_log");
+
+ for (;;) {
+ cursor = view_lines_from_cursor(log, cursor, show_line);
+ usleep(1000 * 300);
+ }
+}
+
+int main(int argc, char *argv[])
+{
+ if (!strcmp(argv[1], "fork"))
+ forkwrite();
+ else if (!strcmp(argv[1], "write"))
+ writetext(argv[2]);
+ else if (!strcmp(argv[1], "follow"))
+ follow();
+ return 0;
+}
diff --git a/Sources/Shared/Model/NETunnelProviderProtocol+Extension.swift b/Sources/Shared/Model/NETunnelProviderProtocol+Extension.swift
new file mode 100644
index 0000000..0a303f4
--- /dev/null
+++ b/Sources/Shared/Model/NETunnelProviderProtocol+Extension.swift
@@ -0,0 +1,106 @@
+// SPDX-License-Identifier: MIT
+// Copyright © 2018-2023 WireGuard LLC. All Rights Reserved.
+
+import NetworkExtension
+
+enum PacketTunnelProviderError: String, Error {
+ case savedProtocolConfigurationIsInvalid
+ case dnsResolutionFailure
+ case couldNotStartBackend
+ case couldNotDetermineFileDescriptor
+ case couldNotSetNetworkSettings
+}
+
+extension NETunnelProviderProtocol {
+ convenience init?(tunnelConfiguration: TunnelConfiguration, previouslyFrom old: NEVPNProtocol? = nil) {
+ self.init()
+
+ guard let name = tunnelConfiguration.name else { return nil }
+ guard let appId = Bundle.main.bundleIdentifier else { return nil }
+ providerBundleIdentifier = "\(appId).network-extension"
+ passwordReference = Keychain.makeReference(containing: tunnelConfiguration.asWgQuickConfig(), called: name, previouslyReferencedBy: old?.passwordReference)
+ if passwordReference == nil {
+ return nil
+ }
+ #if os(macOS)
+ providerConfiguration = ["UID": getuid()]
+ #endif
+
+ let endpoints = tunnelConfiguration.peers.compactMap { $0.endpoint }
+ if endpoints.count == 1 {
+ serverAddress = endpoints[0].stringRepresentation
+ } else if endpoints.isEmpty {
+ serverAddress = "Unspecified"
+ } else {
+ serverAddress = "Multiple endpoints"
+ }
+ }
+
+ func asTunnelConfiguration(called name: String? = nil) -> TunnelConfiguration? {
+ if let passwordReference = passwordReference,
+ let config = Keychain.openReference(called: passwordReference) {
+ return try? TunnelConfiguration(fromWgQuickConfig: config, called: name)
+ }
+ if let oldConfig = providerConfiguration?["WgQuickConfig"] as? String {
+ return try? TunnelConfiguration(fromWgQuickConfig: oldConfig, called: name)
+ }
+ return nil
+ }
+
+ func destroyConfigurationReference() {
+ guard let ref = passwordReference else { return }
+ Keychain.deleteReference(called: ref)
+ }
+
+ func verifyConfigurationReference() -> Bool {
+ guard let ref = passwordReference else { return false }
+ return Keychain.verifyReference(called: ref)
+ }
+
+ @discardableResult
+ func migrateConfigurationIfNeeded(called name: String) -> Bool {
+ /* This is how we did things before we switched to putting items
+ * in the keychain. But it's still useful to keep the migration
+ * around so that .mobileconfig files are easier.
+ */
+ if let oldConfig = providerConfiguration?["WgQuickConfig"] as? String {
+ #if os(macOS)
+ providerConfiguration = ["UID": getuid()]
+ #elseif os(iOS)
+ providerConfiguration = nil
+ #else
+ #error("Unimplemented")
+ #endif
+ guard passwordReference == nil else { return true }
+ wg_log(.info, message: "Migrating tunnel configuration '\(name)'")
+ passwordReference = Keychain.makeReference(containing: oldConfig, called: name)
+ return true
+ }
+ #if os(macOS)
+ if passwordReference != nil && providerConfiguration?["UID"] == nil && verifyConfigurationReference() {
+ providerConfiguration = ["UID": getuid()]
+ return true
+ }
+ #elseif os(iOS)
+ /* Update the stored reference from the old iOS 14 one to the canonical iOS 15 one.
+ * The iOS 14 ones are 96 bits, while the iOS 15 ones are 160 bits. We do this so
+ * that we can have fast set exclusion in deleteReferences safely. */
+ if passwordReference != nil && passwordReference!.count == 12 {
+ var result: CFTypeRef?
+ let ret = SecItemCopyMatching([kSecValuePersistentRef: passwordReference!,
+ kSecReturnPersistentRef: true] as CFDictionary,
+ &result)
+ if ret != errSecSuccess || result == nil {
+ return false
+ }
+ guard let newReference = result as? Data else { return false }
+ if !newReference.elementsEqual(passwordReference!) {
+ wg_log(.info, message: "Migrating iOS 14-style keychain reference to iOS 15-style keychain reference for '\(name)'")
+ passwordReference = newReference
+ return true
+ }
+ }
+ #endif
+ return false
+ }
+}
diff --git a/Sources/Shared/Model/String+ArrayConversion.swift b/Sources/Shared/Model/String+ArrayConversion.swift
new file mode 100644
index 0000000..97984f8
--- /dev/null
+++ b/Sources/Shared/Model/String+ArrayConversion.swift
@@ -0,0 +1,32 @@
+// SPDX-License-Identifier: MIT
+// Copyright © 2018-2023 WireGuard LLC. All Rights Reserved.
+
+import Foundation
+
+extension String {
+
+ func splitToArray(separator: Character = ",", trimmingCharacters: CharacterSet? = nil) -> [String] {
+ return split(separator: separator)
+ .map {
+ if let charSet = trimmingCharacters {
+ return $0.trimmingCharacters(in: charSet)
+ } else {
+ return String($0)
+ }
+ }
+ }
+
+}
+
+extension Optional where Wrapped == String {
+
+ func splitToArray(separator: Character = ",", trimmingCharacters: CharacterSet? = nil) -> [String] {
+ switch self {
+ case .none:
+ return []
+ case .some(let wrapped):
+ return wrapped.splitToArray(separator: separator, trimmingCharacters: trimmingCharacters)
+ }
+ }
+
+}
diff --git a/Sources/Shared/Model/TunnelConfiguration+WgQuickConfig.swift b/Sources/Shared/Model/TunnelConfiguration+WgQuickConfig.swift
new file mode 100644
index 0000000..86af010
--- /dev/null
+++ b/Sources/Shared/Model/TunnelConfiguration+WgQuickConfig.swift
@@ -0,0 +1,252 @@
+// SPDX-License-Identifier: MIT
+// Copyright © 2018-2023 WireGuard LLC. All Rights Reserved.
+
+import Foundation
+
+extension TunnelConfiguration {
+
+ enum ParserState {
+ case inInterfaceSection
+ case inPeerSection
+ case notInASection
+ }
+
+ enum ParseError: Error {
+ case invalidLine(String.SubSequence)
+ case noInterface
+ case multipleInterfaces
+ case interfaceHasNoPrivateKey
+ case interfaceHasInvalidPrivateKey(String)
+ case interfaceHasInvalidListenPort(String)
+ case interfaceHasInvalidAddress(String)
+ case interfaceHasInvalidDNS(String)
+ case interfaceHasInvalidMTU(String)
+ case interfaceHasUnrecognizedKey(String)
+ case peerHasNoPublicKey
+ case peerHasInvalidPublicKey(String)
+ case peerHasInvalidPreSharedKey(String)
+ case peerHasInvalidAllowedIP(String)
+ case peerHasInvalidEndpoint(String)
+ case peerHasInvalidPersistentKeepAlive(String)
+ case peerHasInvalidTransferBytes(String)
+ case peerHasInvalidLastHandshakeTime(String)
+ case peerHasUnrecognizedKey(String)
+ case multiplePeersWithSamePublicKey
+ case multipleEntriesForKey(String)
+ }
+
+ convenience init(fromWgQuickConfig wgQuickConfig: String, called name: String? = nil) throws {
+ var interfaceConfiguration: InterfaceConfiguration?
+ var peerConfigurations = [PeerConfiguration]()
+
+ let lines = wgQuickConfig.split { $0.isNewline }
+
+ var parserState = ParserState.notInASection
+ var attributes = [String: String]()
+
+ for (lineIndex, line) in lines.enumerated() {
+ var trimmedLine: String
+ if let commentRange = line.range(of: "#") {
+ trimmedLine = String(line[..<commentRange.lowerBound])
+ } else {
+ trimmedLine = String(line)
+ }
+
+ trimmedLine = trimmedLine.trimmingCharacters(in: .whitespacesAndNewlines)
+ let lowercasedLine = trimmedLine.lowercased()
+
+ if !trimmedLine.isEmpty {
+ if let equalsIndex = trimmedLine.firstIndex(of: "=") {
+ // Line contains an attribute
+ let keyWithCase = trimmedLine[..<equalsIndex].trimmingCharacters(in: .whitespacesAndNewlines)
+ let key = keyWithCase.lowercased()
+ let value = trimmedLine[trimmedLine.index(equalsIndex, offsetBy: 1)...].trimmingCharacters(in: .whitespacesAndNewlines)
+ let keysWithMultipleEntriesAllowed: Set<String> = ["address", "allowedips", "dns"]
+ if let presentValue = attributes[key] {
+ if keysWithMultipleEntriesAllowed.contains(key) {
+ attributes[key] = presentValue + "," + value
+ } else {
+ throw ParseError.multipleEntriesForKey(keyWithCase)
+ }
+ } else {
+ attributes[key] = value
+ }
+ let interfaceSectionKeys: Set<String> = ["privatekey", "listenport", "address", "dns", "mtu"]
+ let peerSectionKeys: Set<String> = ["publickey", "presharedkey", "allowedips", "endpoint", "persistentkeepalive"]
+ if parserState == .inInterfaceSection {
+ guard interfaceSectionKeys.contains(key) else {
+ throw ParseError.interfaceHasUnrecognizedKey(keyWithCase)
+ }
+ } else if parserState == .inPeerSection {
+ guard peerSectionKeys.contains(key) else {
+ throw ParseError.peerHasUnrecognizedKey(keyWithCase)
+ }
+ }
+ } else if lowercasedLine != "[interface]" && lowercasedLine != "[peer]" {
+ throw ParseError.invalidLine(line)
+ }
+ }
+
+ let isLastLine = lineIndex == lines.count - 1
+
+ if isLastLine || lowercasedLine == "[interface]" || lowercasedLine == "[peer]" {
+ // Previous section has ended; process the attributes collected so far
+ if parserState == .inInterfaceSection {
+ let interface = try TunnelConfiguration.collate(interfaceAttributes: attributes)
+ guard interfaceConfiguration == nil else { throw ParseError.multipleInterfaces }
+ interfaceConfiguration = interface
+ } else if parserState == .inPeerSection {
+ let peer = try TunnelConfiguration.collate(peerAttributes: attributes)
+ peerConfigurations.append(peer)
+ }
+ }
+
+ if lowercasedLine == "[interface]" {
+ parserState = .inInterfaceSection
+ attributes.removeAll()
+ } else if lowercasedLine == "[peer]" {
+ parserState = .inPeerSection
+ attributes.removeAll()
+ }
+ }
+
+ let peerPublicKeysArray = peerConfigurations.map { $0.publicKey }
+ let peerPublicKeysSet = Set<PublicKey>(peerPublicKeysArray)
+ if peerPublicKeysArray.count != peerPublicKeysSet.count {
+ throw ParseError.multiplePeersWithSamePublicKey
+ }
+
+ if let interfaceConfiguration = interfaceConfiguration {
+ self.init(name: name, interface: interfaceConfiguration, peers: peerConfigurations)
+ } else {
+ throw ParseError.noInterface
+ }
+ }
+
+ func asWgQuickConfig() -> String {
+ var output = "[Interface]\n"
+ output.append("PrivateKey = \(interface.privateKey.base64Key)\n")
+ if let listenPort = interface.listenPort {
+ output.append("ListenPort = \(listenPort)\n")
+ }
+ if !interface.addresses.isEmpty {
+ let addressString = interface.addresses.map { $0.stringRepresentation }.joined(separator: ", ")
+ output.append("Address = \(addressString)\n")
+ }
+ if !interface.dns.isEmpty || !interface.dnsSearch.isEmpty {
+ var dnsLine = interface.dns.map { $0.stringRepresentation }
+ dnsLine.append(contentsOf: interface.dnsSearch)
+ let dnsString = dnsLine.joined(separator: ", ")
+ output.append("DNS = \(dnsString)\n")
+ }
+ if let mtu = interface.mtu {
+ output.append("MTU = \(mtu)\n")
+ }
+
+ for peer in peers {
+ output.append("\n[Peer]\n")
+ output.append("PublicKey = \(peer.publicKey.base64Key)\n")
+ if let preSharedKey = peer.preSharedKey?.base64Key {
+ output.append("PresharedKey = \(preSharedKey)\n")
+ }
+ if !peer.allowedIPs.isEmpty {
+ let allowedIPsString = peer.allowedIPs.map { $0.stringRepresentation }.joined(separator: ", ")
+ output.append("AllowedIPs = \(allowedIPsString)\n")
+ }
+ if let endpoint = peer.endpoint {
+ output.append("Endpoint = \(endpoint.stringRepresentation)\n")
+ }
+ if let persistentKeepAlive = peer.persistentKeepAlive {
+ output.append("PersistentKeepalive = \(persistentKeepAlive)\n")
+ }
+ }
+
+ return output
+ }
+
+ private static func collate(interfaceAttributes attributes: [String: String]) throws -> InterfaceConfiguration {
+ guard let privateKeyString = attributes["privatekey"] else {
+ throw ParseError.interfaceHasNoPrivateKey
+ }
+ guard let privateKey = PrivateKey(base64Key: privateKeyString) else {
+ throw ParseError.interfaceHasInvalidPrivateKey(privateKeyString)
+ }
+ var interface = InterfaceConfiguration(privateKey: privateKey)
+ if let listenPortString = attributes["listenport"] {
+ guard let listenPort = UInt16(listenPortString) else {
+ throw ParseError.interfaceHasInvalidListenPort(listenPortString)
+ }
+ interface.listenPort = listenPort
+ }
+ if let addressesString = attributes["address"] {
+ var addresses = [IPAddressRange]()
+ for addressString in addressesString.splitToArray(trimmingCharacters: .whitespacesAndNewlines) {
+ guard let address = IPAddressRange(from: addressString) else {
+ throw ParseError.interfaceHasInvalidAddress(addressString)
+ }
+ addresses.append(address)
+ }
+ interface.addresses = addresses
+ }
+ if let dnsString = attributes["dns"] {
+ var dnsServers = [DNSServer]()
+ var dnsSearch = [String]()
+ for dnsServerString in dnsString.splitToArray(trimmingCharacters: .whitespacesAndNewlines) {
+ if let dnsServer = DNSServer(from: dnsServerString) {
+ dnsServers.append(dnsServer)
+ } else {
+ dnsSearch.append(dnsServerString)
+ }
+ }
+ interface.dns = dnsServers
+ interface.dnsSearch = dnsSearch
+ }
+ if let mtuString = attributes["mtu"] {
+ guard let mtu = UInt16(mtuString) else {
+ throw ParseError.interfaceHasInvalidMTU(mtuString)
+ }
+ interface.mtu = mtu
+ }
+ return interface
+ }
+
+ private static func collate(peerAttributes attributes: [String: String]) throws -> PeerConfiguration {
+ guard let publicKeyString = attributes["publickey"] else {
+ throw ParseError.peerHasNoPublicKey
+ }
+ guard let publicKey = PublicKey(base64Key: publicKeyString) else {
+ throw ParseError.peerHasInvalidPublicKey(publicKeyString)
+ }
+ var peer = PeerConfiguration(publicKey: publicKey)
+ if let preSharedKeyString = attributes["presharedkey"] {
+ guard let preSharedKey = PreSharedKey(base64Key: preSharedKeyString) else {
+ throw ParseError.peerHasInvalidPreSharedKey(preSharedKeyString)
+ }
+ peer.preSharedKey = preSharedKey
+ }
+ if let allowedIPsString = attributes["allowedips"] {
+ var allowedIPs = [IPAddressRange]()
+ for allowedIPString in allowedIPsString.splitToArray(trimmingCharacters: .whitespacesAndNewlines) {
+ guard let allowedIP = IPAddressRange(from: allowedIPString) else {
+ throw ParseError.peerHasInvalidAllowedIP(allowedIPString)
+ }
+ allowedIPs.append(allowedIP)
+ }
+ peer.allowedIPs = allowedIPs
+ }
+ if let endpointString = attributes["endpoint"] {
+ guard let endpoint = Endpoint(from: endpointString) else {
+ throw ParseError.peerHasInvalidEndpoint(endpointString)
+ }
+ peer.endpoint = endpoint
+ }
+ if let persistentKeepAliveString = attributes["persistentkeepalive"] {
+ guard let persistentKeepAlive = UInt16(persistentKeepAliveString) else {
+ throw ParseError.peerHasInvalidPersistentKeepAlive(persistentKeepAliveString)
+ }
+ peer.persistentKeepAlive = persistentKeepAlive
+ }
+ return peer
+ }
+
+}
diff --git a/Sources/Shared/NotificationToken.swift b/Sources/Shared/NotificationToken.swift
new file mode 100644
index 0000000..78d36ba
--- /dev/null
+++ b/Sources/Shared/NotificationToken.swift
@@ -0,0 +1,33 @@
+// SPDX-License-Identifier: MIT
+// Copyright © 2018-2023 WireGuard LLC. All Rights Reserved.
+
+import Foundation
+
+/// This source file contains bits of code from:
+/// https://oleb.net/blog/2018/01/notificationcenter-removeobserver/
+
+/// Wraps the observer token received from
+/// `NotificationCenter.addObserver(forName:object:queue:using:)`
+/// and unregisters it in deinit.
+final class NotificationToken {
+ let notificationCenter: NotificationCenter
+ let token: Any
+
+ init(notificationCenter: NotificationCenter = .default, token: Any) {
+ self.notificationCenter = notificationCenter
+ self.token = token
+ }
+
+ deinit {
+ notificationCenter.removeObserver(token)
+ }
+}
+
+extension NotificationCenter {
+ /// Convenience wrapper for addObserver(forName:object:queue:using:)
+ /// that returns our custom `NotificationToken`.
+ func observe(name: NSNotification.Name?, object obj: Any?, queue: OperationQueue?, using block: @escaping (Notification) -> Void) -> NotificationToken {
+ let token = addObserver(forName: name, object: obj, queue: queue, using: block)
+ return NotificationToken(notificationCenter: self, token: token)
+ }
+}