diff options
Diffstat (limited to 'Sources/Shared')
-rw-r--r-- | Sources/Shared/FileManager+Extension.swift | 50 | ||||
-rw-r--r-- | Sources/Shared/Keychain.swift | 114 | ||||
-rw-r--r-- | Sources/Shared/Logging/Logger.swift | 65 | ||||
-rw-r--r-- | Sources/Shared/Logging/ringlogger.c | 173 | ||||
-rw-r--r-- | Sources/Shared/Logging/ringlogger.h | 18 | ||||
-rw-r--r-- | Sources/Shared/Logging/test_ringlogger.c | 63 | ||||
-rw-r--r-- | Sources/Shared/Model/NETunnelProviderProtocol+Extension.swift | 106 | ||||
-rw-r--r-- | Sources/Shared/Model/String+ArrayConversion.swift | 32 | ||||
-rw-r--r-- | Sources/Shared/Model/TunnelConfiguration+WgQuickConfig.swift | 252 | ||||
-rw-r--r-- | Sources/Shared/NotificationToken.swift | 33 |
10 files changed, 906 insertions, 0 deletions
diff --git a/Sources/Shared/FileManager+Extension.swift b/Sources/Shared/FileManager+Extension.swift new file mode 100644 index 0000000..48fa33f --- /dev/null +++ b/Sources/Shared/FileManager+Extension.swift @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright © 2018-2023 WireGuard LLC. All Rights Reserved. + +import Foundation +import os.log + +extension FileManager { + static var appGroupId: String? { + #if os(iOS) + let appGroupIdInfoDictionaryKey = "com.wireguard.ios.app_group_id" + #elseif os(macOS) + let appGroupIdInfoDictionaryKey = "com.wireguard.macos.app_group_id" + #else + #error("Unimplemented") + #endif + return Bundle.main.object(forInfoDictionaryKey: appGroupIdInfoDictionaryKey) as? String + } + private static var sharedFolderURL: URL? { + guard let appGroupId = FileManager.appGroupId else { + os_log("Cannot obtain app group ID from bundle", log: OSLog.default, type: .error) + return nil + } + guard let sharedFolderURL = FileManager.default.containerURL(forSecurityApplicationGroupIdentifier: appGroupId) else { + wg_log(.error, message: "Cannot obtain shared folder URL") + return nil + } + return sharedFolderURL + } + + static var logFileURL: URL? { + return sharedFolderURL?.appendingPathComponent("tunnel-log.bin") + } + + static var networkExtensionLastErrorFileURL: URL? { + return sharedFolderURL?.appendingPathComponent("last-error.txt") + } + + static var loginHelperTimestampURL: URL? { + return sharedFolderURL?.appendingPathComponent("login-helper-timestamp.bin") + } + + static func deleteFile(at url: URL) -> Bool { + do { + try FileManager.default.removeItem(at: url) + } catch { + return false + } + return true + } +} diff --git a/Sources/Shared/Keychain.swift b/Sources/Shared/Keychain.swift new file mode 100644 index 0000000..2e0e7f0 --- /dev/null +++ b/Sources/Shared/Keychain.swift @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright © 2018-2023 WireGuard LLC. All Rights Reserved. + +import Foundation +import Security + +class Keychain { + static func openReference(called ref: Data) -> String? { + var result: CFTypeRef? + let ret = SecItemCopyMatching([kSecValuePersistentRef: ref, + kSecReturnData: true] as CFDictionary, + &result) + if ret != errSecSuccess || result == nil { + wg_log(.error, message: "Unable to open config from keychain: \(ret)") + return nil + } + guard let data = result as? Data else { return nil } + return String(data: data, encoding: String.Encoding.utf8) + } + + static func makeReference(containing value: String, called name: String, previouslyReferencedBy oldRef: Data? = nil) -> Data? { + var ret: OSStatus + guard var bundleIdentifier = Bundle.main.bundleIdentifier else { + wg_log(.error, staticMessage: "Unable to determine bundle identifier") + return nil + } + if bundleIdentifier.hasSuffix(".network-extension") { + bundleIdentifier.removeLast(".network-extension".count) + } + let itemLabel = "WireGuard Tunnel: \(name)" + var items: [CFString: Any] = [kSecClass: kSecClassGenericPassword, + kSecAttrLabel: itemLabel, + kSecAttrAccount: name + ": " + UUID().uuidString, + kSecAttrDescription: "wg-quick(8) config", + kSecAttrService: bundleIdentifier, + kSecValueData: value.data(using: .utf8) as Any, + kSecReturnPersistentRef: true] + + #if os(iOS) + items[kSecAttrAccessGroup] = FileManager.appGroupId + items[kSecAttrAccessible] = kSecAttrAccessibleAfterFirstUnlock + #elseif os(macOS) + items[kSecAttrSynchronizable] = false + items[kSecAttrAccessible] = kSecAttrAccessibleAfterFirstUnlockThisDeviceOnly + + guard let extensionPath = Bundle.main.builtInPlugInsURL?.appendingPathComponent("WireGuardNetworkExtension.appex", isDirectory: true).path else { + wg_log(.error, staticMessage: "Unable to determine app extension path") + return nil + } + var extensionApp: SecTrustedApplication? + var mainApp: SecTrustedApplication? + ret = SecTrustedApplicationCreateFromPath(extensionPath, &extensionApp) + if ret != kOSReturnSuccess || extensionApp == nil { + wg_log(.error, message: "Unable to create keychain extension trusted application object: \(ret)") + return nil + } + ret = SecTrustedApplicationCreateFromPath(nil, &mainApp) + if ret != errSecSuccess || mainApp == nil { + wg_log(.error, message: "Unable to create keychain local trusted application object: \(ret)") + return nil + } + var access: SecAccess? + ret = SecAccessCreate(itemLabel as CFString, [extensionApp!, mainApp!] as CFArray, &access) + if ret != errSecSuccess || access == nil { + wg_log(.error, message: "Unable to create keychain ACL object: \(ret)") + return nil + } + items[kSecAttrAccess] = access! + #else + #error("Unimplemented") + #endif + + var ref: CFTypeRef? + ret = SecItemAdd(items as CFDictionary, &ref) + if ret != errSecSuccess || ref == nil { + wg_log(.error, message: "Unable to add config to keychain: \(ret)") + return nil + } + if let oldRef = oldRef { + deleteReference(called: oldRef) + } + return ref as? Data + } + + static func deleteReference(called ref: Data) { + let ret = SecItemDelete([kSecValuePersistentRef: ref] as CFDictionary) + if ret != errSecSuccess { + wg_log(.error, message: "Unable to delete config from keychain: \(ret)") + } + } + + static func deleteReferences(except whitelist: Set<Data>) { + var result: CFTypeRef? + let ret = SecItemCopyMatching([kSecClass: kSecClassGenericPassword, + kSecAttrService: Bundle.main.bundleIdentifier as Any, + kSecMatchLimit: kSecMatchLimitAll, + kSecReturnPersistentRef: true] as CFDictionary, + &result) + if ret != errSecSuccess || result == nil { + return + } + guard let items = result as? [Data] else { return } + for item in items { + if !whitelist.contains(item) { + deleteReference(called: item) + } + } + } + + static func verifyReference(called ref: Data) -> Bool { + return SecItemCopyMatching([kSecValuePersistentRef: ref] as CFDictionary, + nil) != errSecItemNotFound + } +} diff --git a/Sources/Shared/Logging/Logger.swift b/Sources/Shared/Logging/Logger.swift new file mode 100644 index 0000000..f3ee2b7 --- /dev/null +++ b/Sources/Shared/Logging/Logger.swift @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright © 2018-2023 WireGuard LLC. All Rights Reserved. + +import Foundation +import os.log + +public class Logger { + enum LoggerError: Error { + case openFailure + } + + static var global: Logger? + + var log: OpaquePointer + var tag: String + + init(tagged tag: String, withFilePath filePath: String) throws { + guard let log = open_log(filePath) else { throw LoggerError.openFailure } + self.log = log + self.tag = tag + } + + deinit { + close_log(self.log) + } + + func log(message: String) { + write_msg_to_log(log, tag, message.trimmingCharacters(in: .newlines)) + } + + func writeLog(to targetFile: String) -> Bool { + return write_log_to_file(targetFile, self.log) == 0 + } + + static func configureGlobal(tagged tag: String, withFilePath filePath: String?) { + if Logger.global != nil { + return + } + guard let filePath = filePath else { + os_log("Unable to determine log destination path. Log will not be saved to file.", log: OSLog.default, type: .error) + return + } + guard let logger = try? Logger(tagged: tag, withFilePath: filePath) else { + os_log("Unable to open log file for writing. Log will not be saved to file.", log: OSLog.default, type: .error) + return + } + Logger.global = logger + var appVersion = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String ?? "Unknown version" + if let appBuild = Bundle.main.infoDictionary?["CFBundleVersion"] as? String { + appVersion += " (\(appBuild))" + } + + Logger.global?.log(message: "App version: \(appVersion)") + } +} + +func wg_log(_ type: OSLogType, staticMessage msg: StaticString) { + os_log(msg, log: OSLog.default, type: type) + Logger.global?.log(message: "\(msg)") +} + +func wg_log(_ type: OSLogType, message msg: String) { + os_log("%{public}s", log: OSLog.default, type: type, msg) + Logger.global?.log(message: msg) +} diff --git a/Sources/Shared/Logging/ringlogger.c b/Sources/Shared/Logging/ringlogger.c new file mode 100644 index 0000000..9bb0d13 --- /dev/null +++ b/Sources/Shared/Logging/ringlogger.c @@ -0,0 +1,173 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright © 2018-2023 WireGuard LLC. All Rights Reserved. + */ + +#include <string.h> +#include <stdio.h> +#include <stdint.h> +#include <stdlib.h> +#include <stdatomic.h> +#include <stdbool.h> +#include <time.h> +#include <errno.h> +#include <unistd.h> +#include <fcntl.h> +#include <sys/types.h> +#include <sys/stat.h> +#include <sys/time.h> +#include <sys/mman.h> +#include "ringlogger.h" + +enum { + MAX_LOG_LINE_LENGTH = 512, + MAX_LINES = 2048, + MAGIC = 0xabadbeefU +}; + +struct log_line { + atomic_uint_fast64_t time_ns; + char line[MAX_LOG_LINE_LENGTH]; +}; + +struct log { + atomic_uint_fast32_t next_index; + struct log_line lines[MAX_LINES]; + uint32_t magic; +}; + +void write_msg_to_log(struct log *log, const char *tag, const char *msg) +{ + uint32_t index; + struct log_line *line; + struct timespec ts; + + // Race: This isn't synchronized with the fetch_add below, so items might be slightly out of order. + clock_gettime(CLOCK_REALTIME, &ts); + + // Race: More than MAX_LINES writers and this will clash. + index = atomic_fetch_add(&log->next_index, 1); + line = &log->lines[index % MAX_LINES]; + + // Race: Before this line executes, we'll display old data after new data. + atomic_store(&line->time_ns, 0); + memset(line->line, 0, MAX_LOG_LINE_LENGTH); + + snprintf(line->line, MAX_LOG_LINE_LENGTH, "[%s] %s", tag, msg); + atomic_store(&line->time_ns, ts.tv_sec * 1000000000ULL + ts.tv_nsec); + + msync(&log->next_index, sizeof(log->next_index), MS_ASYNC); + msync(line, sizeof(*line), MS_ASYNC); +} + +int write_log_to_file(const char *file_name, const struct log *input_log) +{ + struct log *log; + uint32_t l, i; + FILE *file; + int ret; + + log = malloc(sizeof(*log)); + if (!log) + return -errno; + memcpy(log, input_log, sizeof(*log)); + + file = fopen(file_name, "w"); + if (!file) { + free(log); + return -errno; + } + + for (l = 0, i = log->next_index; l < MAX_LINES; ++l, ++i) { + const struct log_line *line = &log->lines[i % MAX_LINES]; + time_t seconds = line->time_ns / 1000000000ULL; + uint32_t useconds = (line->time_ns % 1000000000ULL) / 1000ULL; + struct tm tm; + + if (!line->time_ns) + continue; + + if (!localtime_r(&seconds, &tm)) + goto err; + + if (fprintf(file, "%04d-%02d-%02d %02d:%02d:%02d.%06d: %s\n", + tm.tm_year + 1900, tm.tm_mon + 1, tm.tm_mday, + tm.tm_hour, tm.tm_min, tm.tm_sec, useconds, + line->line) < 0) + goto err; + + + } + errno = 0; + +err: + ret = -errno; + fclose(file); + free(log); + return ret; +} + +uint32_t view_lines_from_cursor(const struct log *input_log, uint32_t cursor, void *ctx, void(*cb)(const char *, uint64_t, void *)) +{ + struct log *log; + uint32_t l, i = cursor; + + log = malloc(sizeof(*log)); + if (!log) + return cursor; + memcpy(log, input_log, sizeof(*log)); + + if (i == -1) + i = log->next_index; + + for (l = 0; l < MAX_LINES; ++l, ++i) { + const struct log_line *line = &log->lines[i % MAX_LINES]; + + if (cursor != -1 && i % MAX_LINES == log->next_index % MAX_LINES) + break; + + if (!line->time_ns) { + if (cursor == -1) + continue; + else + break; + } + cb(line->line, line->time_ns, ctx); + cursor = (i + 1) % MAX_LINES; + } + free(log); + return cursor; +} + +struct log *open_log(const char *file_name) +{ + int fd; + struct log *log; + + fd = open(file_name, O_RDWR | O_CREAT, S_IRUSR | S_IWUSR); + if (fd < 0) + return NULL; + if (ftruncate(fd, sizeof(*log))) + goto err; + log = mmap(NULL, sizeof(*log), PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (log == MAP_FAILED) + goto err; + close(fd); + + if (log->magic != MAGIC) { + memset(log, 0, sizeof(*log)); + log->magic = MAGIC; + msync(log, sizeof(*log), MS_ASYNC); + } + + return log; + +err: + close(fd); + return NULL; +} + +void close_log(struct log *log) +{ + munmap(log, sizeof(*log)); +} diff --git a/Sources/Shared/Logging/ringlogger.h b/Sources/Shared/Logging/ringlogger.h new file mode 100644 index 0000000..0e28c93 --- /dev/null +++ b/Sources/Shared/Logging/ringlogger.h @@ -0,0 +1,18 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright © 2018-2023 WireGuard LLC. All Rights Reserved. + */ + +#ifndef RINGLOGGER_H +#define RINGLOGGER_H + +#include <stdint.h> + +struct log; +void write_msg_to_log(struct log *log, const char *tag, const char *msg); +int write_log_to_file(const char *file_name, const struct log *input_log); +uint32_t view_lines_from_cursor(const struct log *input_log, uint32_t cursor, void *ctx, void(*)(const char *, uint64_t, void *)); +struct log *open_log(const char *file_name); +void close_log(struct log *log); + +#endif diff --git a/Sources/Shared/Logging/test_ringlogger.c b/Sources/Shared/Logging/test_ringlogger.c new file mode 100644 index 0000000..ae3f4a9 --- /dev/null +++ b/Sources/Shared/Logging/test_ringlogger.c @@ -0,0 +1,63 @@ +#include "ringlogger.h" +#include <stdio.h> +#include <stdbool.h> +#include <string.h> +#include <unistd.h> +#include <inttypes.h> +#include <sys/wait.h> + +static void forkwrite(void) +{ + struct log *log = open_log("/tmp/test_log"); + char c[512]; + int i, base; + bool in_fork = !fork(); + + base = 10000 * in_fork; + for (i = 0; i < 1024; ++i) { + snprintf(c, 512, "bla bla bla %d", base + i); + write_msg_to_log(log, "HMM", c); + } + + + if (in_fork) + _exit(0); + wait(NULL); + + write_log_to_file("/dev/stdout", log); + close_log(log); +} + +static void writetext(const char *text) +{ + struct log *log = open_log("/tmp/test_log"); + write_msg_to_log(log, "TXT", text); + close_log(log); +} + +static void show_line(const char *line, uint64_t time_ns) +{ + printf("%" PRIu64 ": %s\n", time_ns, line); +} + +static void follow(void) +{ + uint32_t cursor = -1; + struct log *log = open_log("/tmp/test_log"); + + for (;;) { + cursor = view_lines_from_cursor(log, cursor, show_line); + usleep(1000 * 300); + } +} + +int main(int argc, char *argv[]) +{ + if (!strcmp(argv[1], "fork")) + forkwrite(); + else if (!strcmp(argv[1], "write")) + writetext(argv[2]); + else if (!strcmp(argv[1], "follow")) + follow(); + return 0; +} diff --git a/Sources/Shared/Model/NETunnelProviderProtocol+Extension.swift b/Sources/Shared/Model/NETunnelProviderProtocol+Extension.swift new file mode 100644 index 0000000..0a303f4 --- /dev/null +++ b/Sources/Shared/Model/NETunnelProviderProtocol+Extension.swift @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: MIT +// Copyright © 2018-2023 WireGuard LLC. All Rights Reserved. + +import NetworkExtension + +enum PacketTunnelProviderError: String, Error { + case savedProtocolConfigurationIsInvalid + case dnsResolutionFailure + case couldNotStartBackend + case couldNotDetermineFileDescriptor + case couldNotSetNetworkSettings +} + +extension NETunnelProviderProtocol { + convenience init?(tunnelConfiguration: TunnelConfiguration, previouslyFrom old: NEVPNProtocol? = nil) { + self.init() + + guard let name = tunnelConfiguration.name else { return nil } + guard let appId = Bundle.main.bundleIdentifier else { return nil } + providerBundleIdentifier = "\(appId).network-extension" + passwordReference = Keychain.makeReference(containing: tunnelConfiguration.asWgQuickConfig(), called: name, previouslyReferencedBy: old?.passwordReference) + if passwordReference == nil { + return nil + } + #if os(macOS) + providerConfiguration = ["UID": getuid()] + #endif + + let endpoints = tunnelConfiguration.peers.compactMap { $0.endpoint } + if endpoints.count == 1 { + serverAddress = endpoints[0].stringRepresentation + } else if endpoints.isEmpty { + serverAddress = "Unspecified" + } else { + serverAddress = "Multiple endpoints" + } + } + + func asTunnelConfiguration(called name: String? = nil) -> TunnelConfiguration? { + if let passwordReference = passwordReference, + let config = Keychain.openReference(called: passwordReference) { + return try? TunnelConfiguration(fromWgQuickConfig: config, called: name) + } + if let oldConfig = providerConfiguration?["WgQuickConfig"] as? String { + return try? TunnelConfiguration(fromWgQuickConfig: oldConfig, called: name) + } + return nil + } + + func destroyConfigurationReference() { + guard let ref = passwordReference else { return } + Keychain.deleteReference(called: ref) + } + + func verifyConfigurationReference() -> Bool { + guard let ref = passwordReference else { return false } + return Keychain.verifyReference(called: ref) + } + + @discardableResult + func migrateConfigurationIfNeeded(called name: String) -> Bool { + /* This is how we did things before we switched to putting items + * in the keychain. But it's still useful to keep the migration + * around so that .mobileconfig files are easier. + */ + if let oldConfig = providerConfiguration?["WgQuickConfig"] as? String { + #if os(macOS) + providerConfiguration = ["UID": getuid()] + #elseif os(iOS) + providerConfiguration = nil + #else + #error("Unimplemented") + #endif + guard passwordReference == nil else { return true } + wg_log(.info, message: "Migrating tunnel configuration '\(name)'") + passwordReference = Keychain.makeReference(containing: oldConfig, called: name) + return true + } + #if os(macOS) + if passwordReference != nil && providerConfiguration?["UID"] == nil && verifyConfigurationReference() { + providerConfiguration = ["UID": getuid()] + return true + } + #elseif os(iOS) + /* Update the stored reference from the old iOS 14 one to the canonical iOS 15 one. + * The iOS 14 ones are 96 bits, while the iOS 15 ones are 160 bits. We do this so + * that we can have fast set exclusion in deleteReferences safely. */ + if passwordReference != nil && passwordReference!.count == 12 { + var result: CFTypeRef? + let ret = SecItemCopyMatching([kSecValuePersistentRef: passwordReference!, + kSecReturnPersistentRef: true] as CFDictionary, + &result) + if ret != errSecSuccess || result == nil { + return false + } + guard let newReference = result as? Data else { return false } + if !newReference.elementsEqual(passwordReference!) { + wg_log(.info, message: "Migrating iOS 14-style keychain reference to iOS 15-style keychain reference for '\(name)'") + passwordReference = newReference + return true + } + } + #endif + return false + } +} diff --git a/Sources/Shared/Model/String+ArrayConversion.swift b/Sources/Shared/Model/String+ArrayConversion.swift new file mode 100644 index 0000000..97984f8 --- /dev/null +++ b/Sources/Shared/Model/String+ArrayConversion.swift @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright © 2018-2023 WireGuard LLC. All Rights Reserved. + +import Foundation + +extension String { + + func splitToArray(separator: Character = ",", trimmingCharacters: CharacterSet? = nil) -> [String] { + return split(separator: separator) + .map { + if let charSet = trimmingCharacters { + return $0.trimmingCharacters(in: charSet) + } else { + return String($0) + } + } + } + +} + +extension Optional where Wrapped == String { + + func splitToArray(separator: Character = ",", trimmingCharacters: CharacterSet? = nil) -> [String] { + switch self { + case .none: + return [] + case .some(let wrapped): + return wrapped.splitToArray(separator: separator, trimmingCharacters: trimmingCharacters) + } + } + +} diff --git a/Sources/Shared/Model/TunnelConfiguration+WgQuickConfig.swift b/Sources/Shared/Model/TunnelConfiguration+WgQuickConfig.swift new file mode 100644 index 0000000..86af010 --- /dev/null +++ b/Sources/Shared/Model/TunnelConfiguration+WgQuickConfig.swift @@ -0,0 +1,252 @@ +// SPDX-License-Identifier: MIT +// Copyright © 2018-2023 WireGuard LLC. All Rights Reserved. + +import Foundation + +extension TunnelConfiguration { + + enum ParserState { + case inInterfaceSection + case inPeerSection + case notInASection + } + + enum ParseError: Error { + case invalidLine(String.SubSequence) + case noInterface + case multipleInterfaces + case interfaceHasNoPrivateKey + case interfaceHasInvalidPrivateKey(String) + case interfaceHasInvalidListenPort(String) + case interfaceHasInvalidAddress(String) + case interfaceHasInvalidDNS(String) + case interfaceHasInvalidMTU(String) + case interfaceHasUnrecognizedKey(String) + case peerHasNoPublicKey + case peerHasInvalidPublicKey(String) + case peerHasInvalidPreSharedKey(String) + case peerHasInvalidAllowedIP(String) + case peerHasInvalidEndpoint(String) + case peerHasInvalidPersistentKeepAlive(String) + case peerHasInvalidTransferBytes(String) + case peerHasInvalidLastHandshakeTime(String) + case peerHasUnrecognizedKey(String) + case multiplePeersWithSamePublicKey + case multipleEntriesForKey(String) + } + + convenience init(fromWgQuickConfig wgQuickConfig: String, called name: String? = nil) throws { + var interfaceConfiguration: InterfaceConfiguration? + var peerConfigurations = [PeerConfiguration]() + + let lines = wgQuickConfig.split { $0.isNewline } + + var parserState = ParserState.notInASection + var attributes = [String: String]() + + for (lineIndex, line) in lines.enumerated() { + var trimmedLine: String + if let commentRange = line.range(of: "#") { + trimmedLine = String(line[..<commentRange.lowerBound]) + } else { + trimmedLine = String(line) + } + + trimmedLine = trimmedLine.trimmingCharacters(in: .whitespacesAndNewlines) + let lowercasedLine = trimmedLine.lowercased() + + if !trimmedLine.isEmpty { + if let equalsIndex = trimmedLine.firstIndex(of: "=") { + // Line contains an attribute + let keyWithCase = trimmedLine[..<equalsIndex].trimmingCharacters(in: .whitespacesAndNewlines) + let key = keyWithCase.lowercased() + let value = trimmedLine[trimmedLine.index(equalsIndex, offsetBy: 1)...].trimmingCharacters(in: .whitespacesAndNewlines) + let keysWithMultipleEntriesAllowed: Set<String> = ["address", "allowedips", "dns"] + if let presentValue = attributes[key] { + if keysWithMultipleEntriesAllowed.contains(key) { + attributes[key] = presentValue + "," + value + } else { + throw ParseError.multipleEntriesForKey(keyWithCase) + } + } else { + attributes[key] = value + } + let interfaceSectionKeys: Set<String> = ["privatekey", "listenport", "address", "dns", "mtu"] + let peerSectionKeys: Set<String> = ["publickey", "presharedkey", "allowedips", "endpoint", "persistentkeepalive"] + if parserState == .inInterfaceSection { + guard interfaceSectionKeys.contains(key) else { + throw ParseError.interfaceHasUnrecognizedKey(keyWithCase) + } + } else if parserState == .inPeerSection { + guard peerSectionKeys.contains(key) else { + throw ParseError.peerHasUnrecognizedKey(keyWithCase) + } + } + } else if lowercasedLine != "[interface]" && lowercasedLine != "[peer]" { + throw ParseError.invalidLine(line) + } + } + + let isLastLine = lineIndex == lines.count - 1 + + if isLastLine || lowercasedLine == "[interface]" || lowercasedLine == "[peer]" { + // Previous section has ended; process the attributes collected so far + if parserState == .inInterfaceSection { + let interface = try TunnelConfiguration.collate(interfaceAttributes: attributes) + guard interfaceConfiguration == nil else { throw ParseError.multipleInterfaces } + interfaceConfiguration = interface + } else if parserState == .inPeerSection { + let peer = try TunnelConfiguration.collate(peerAttributes: attributes) + peerConfigurations.append(peer) + } + } + + if lowercasedLine == "[interface]" { + parserState = .inInterfaceSection + attributes.removeAll() + } else if lowercasedLine == "[peer]" { + parserState = .inPeerSection + attributes.removeAll() + } + } + + let peerPublicKeysArray = peerConfigurations.map { $0.publicKey } + let peerPublicKeysSet = Set<PublicKey>(peerPublicKeysArray) + if peerPublicKeysArray.count != peerPublicKeysSet.count { + throw ParseError.multiplePeersWithSamePublicKey + } + + if let interfaceConfiguration = interfaceConfiguration { + self.init(name: name, interface: interfaceConfiguration, peers: peerConfigurations) + } else { + throw ParseError.noInterface + } + } + + func asWgQuickConfig() -> String { + var output = "[Interface]\n" + output.append("PrivateKey = \(interface.privateKey.base64Key)\n") + if let listenPort = interface.listenPort { + output.append("ListenPort = \(listenPort)\n") + } + if !interface.addresses.isEmpty { + let addressString = interface.addresses.map { $0.stringRepresentation }.joined(separator: ", ") + output.append("Address = \(addressString)\n") + } + if !interface.dns.isEmpty || !interface.dnsSearch.isEmpty { + var dnsLine = interface.dns.map { $0.stringRepresentation } + dnsLine.append(contentsOf: interface.dnsSearch) + let dnsString = dnsLine.joined(separator: ", ") + output.append("DNS = \(dnsString)\n") + } + if let mtu = interface.mtu { + output.append("MTU = \(mtu)\n") + } + + for peer in peers { + output.append("\n[Peer]\n") + output.append("PublicKey = \(peer.publicKey.base64Key)\n") + if let preSharedKey = peer.preSharedKey?.base64Key { + output.append("PresharedKey = \(preSharedKey)\n") + } + if !peer.allowedIPs.isEmpty { + let allowedIPsString = peer.allowedIPs.map { $0.stringRepresentation }.joined(separator: ", ") + output.append("AllowedIPs = \(allowedIPsString)\n") + } + if let endpoint = peer.endpoint { + output.append("Endpoint = \(endpoint.stringRepresentation)\n") + } + if let persistentKeepAlive = peer.persistentKeepAlive { + output.append("PersistentKeepalive = \(persistentKeepAlive)\n") + } + } + + return output + } + + private static func collate(interfaceAttributes attributes: [String: String]) throws -> InterfaceConfiguration { + guard let privateKeyString = attributes["privatekey"] else { + throw ParseError.interfaceHasNoPrivateKey + } + guard let privateKey = PrivateKey(base64Key: privateKeyString) else { + throw ParseError.interfaceHasInvalidPrivateKey(privateKeyString) + } + var interface = InterfaceConfiguration(privateKey: privateKey) + if let listenPortString = attributes["listenport"] { + guard let listenPort = UInt16(listenPortString) else { + throw ParseError.interfaceHasInvalidListenPort(listenPortString) + } + interface.listenPort = listenPort + } + if let addressesString = attributes["address"] { + var addresses = [IPAddressRange]() + for addressString in addressesString.splitToArray(trimmingCharacters: .whitespacesAndNewlines) { + guard let address = IPAddressRange(from: addressString) else { + throw ParseError.interfaceHasInvalidAddress(addressString) + } + addresses.append(address) + } + interface.addresses = addresses + } + if let dnsString = attributes["dns"] { + var dnsServers = [DNSServer]() + var dnsSearch = [String]() + for dnsServerString in dnsString.splitToArray(trimmingCharacters: .whitespacesAndNewlines) { + if let dnsServer = DNSServer(from: dnsServerString) { + dnsServers.append(dnsServer) + } else { + dnsSearch.append(dnsServerString) + } + } + interface.dns = dnsServers + interface.dnsSearch = dnsSearch + } + if let mtuString = attributes["mtu"] { + guard let mtu = UInt16(mtuString) else { + throw ParseError.interfaceHasInvalidMTU(mtuString) + } + interface.mtu = mtu + } + return interface + } + + private static func collate(peerAttributes attributes: [String: String]) throws -> PeerConfiguration { + guard let publicKeyString = attributes["publickey"] else { + throw ParseError.peerHasNoPublicKey + } + guard let publicKey = PublicKey(base64Key: publicKeyString) else { + throw ParseError.peerHasInvalidPublicKey(publicKeyString) + } + var peer = PeerConfiguration(publicKey: publicKey) + if let preSharedKeyString = attributes["presharedkey"] { + guard let preSharedKey = PreSharedKey(base64Key: preSharedKeyString) else { + throw ParseError.peerHasInvalidPreSharedKey(preSharedKeyString) + } + peer.preSharedKey = preSharedKey + } + if let allowedIPsString = attributes["allowedips"] { + var allowedIPs = [IPAddressRange]() + for allowedIPString in allowedIPsString.splitToArray(trimmingCharacters: .whitespacesAndNewlines) { + guard let allowedIP = IPAddressRange(from: allowedIPString) else { + throw ParseError.peerHasInvalidAllowedIP(allowedIPString) + } + allowedIPs.append(allowedIP) + } + peer.allowedIPs = allowedIPs + } + if let endpointString = attributes["endpoint"] { + guard let endpoint = Endpoint(from: endpointString) else { + throw ParseError.peerHasInvalidEndpoint(endpointString) + } + peer.endpoint = endpoint + } + if let persistentKeepAliveString = attributes["persistentkeepalive"] { + guard let persistentKeepAlive = UInt16(persistentKeepAliveString) else { + throw ParseError.peerHasInvalidPersistentKeepAlive(persistentKeepAliveString) + } + peer.persistentKeepAlive = persistentKeepAlive + } + return peer + } + +} diff --git a/Sources/Shared/NotificationToken.swift b/Sources/Shared/NotificationToken.swift new file mode 100644 index 0000000..78d36ba --- /dev/null +++ b/Sources/Shared/NotificationToken.swift @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright © 2018-2023 WireGuard LLC. All Rights Reserved. + +import Foundation + +/// This source file contains bits of code from: +/// https://oleb.net/blog/2018/01/notificationcenter-removeobserver/ + +/// Wraps the observer token received from +/// `NotificationCenter.addObserver(forName:object:queue:using:)` +/// and unregisters it in deinit. +final class NotificationToken { + let notificationCenter: NotificationCenter + let token: Any + + init(notificationCenter: NotificationCenter = .default, token: Any) { + self.notificationCenter = notificationCenter + self.token = token + } + + deinit { + notificationCenter.removeObserver(token) + } +} + +extension NotificationCenter { + /// Convenience wrapper for addObserver(forName:object:queue:using:) + /// that returns our custom `NotificationToken`. + func observe(name: NSNotification.Name?, object obj: Any?, queue: OperationQueue?, using block: @escaping (Notification) -> Void) -> NotificationToken { + let token = addObserver(forName: name, object: obj, queue: queue, using: block) + return NotificationToken(notificationCenter: self, token: token) + } +} |