diff options
author | Roopesh Chander <roop@roopc.net> | 2018-11-08 15:09:45 +0530 |
---|---|---|
committer | Roopesh Chander <roop@roopc.net> | 2018-11-08 15:09:45 +0530 |
commit | 2f7e43720299816b4991f1af6d42491d2af0dd1e (patch) | |
tree | f2366c762a88818414dfce0ee264d67883ef0633 /WireGuard/WireGuardNetworkExtension | |
parent | Move logic to extension: .resolvingEndpointDomains is not longer a valid status (diff) | |
download | wireguard-apple-2f7e43720299816b4991f1af6d42491d2af0dd1e.tar.xz wireguard-apple-2f7e43720299816b4991f1af6d42491d2af0dd1e.zip |
Move logic to extension: Move DNSResolver to extension
Signed-off-by: Roopesh Chander <roop@roopc.net>
Diffstat (limited to 'WireGuard/WireGuardNetworkExtension')
-rw-r--r-- | WireGuard/WireGuardNetworkExtension/DNSResolver.swift | 158 |
1 files changed, 158 insertions, 0 deletions
diff --git a/WireGuard/WireGuardNetworkExtension/DNSResolver.swift b/WireGuard/WireGuardNetworkExtension/DNSResolver.swift new file mode 100644 index 0000000..6b7b73b --- /dev/null +++ b/WireGuard/WireGuardNetworkExtension/DNSResolver.swift @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: MIT +// Copyright © 2018 WireGuard LLC. All Rights Reserved. + +import Network +import Foundation + +class DNSResolver { + let endpoints: [Endpoint?] + let dispatchGroup: DispatchGroup + var dispatchWorkItems: [DispatchWorkItem] + static var cache = NSCache<NSString, NSString>() + + init(endpoints: [Endpoint?]) { + self.endpoints = endpoints + self.dispatchWorkItems = [] + self.dispatchGroup = DispatchGroup() + } + + func resolveWithoutNetworkRequests() -> [Endpoint?]? { + var resolvedEndpoints: [Endpoint?] = Array<Endpoint?>(repeating: nil, count: endpoints.count) + for (i, endpoint) in self.endpoints.enumerated() { + guard let endpoint = endpoint else { continue } + if (endpoint.hasHostAsIPAddress()) { + resolvedEndpoints[i] = endpoint + } else if let resolvedEndpointStringInCache = DNSResolver.cache.object(forKey: endpoint.stringRepresentation() as NSString), + let resolvedEndpointInCache = Endpoint(from: resolvedEndpointStringInCache as String) { + resolvedEndpoints[i] = resolvedEndpointInCache + } else { + return nil + } + } + return resolvedEndpoints + } + + func resolve(completionHandler: @escaping ([Endpoint?]?) -> Void) { + let endpoints = self.endpoints + let dispatchGroup = self.dispatchGroup + dispatchWorkItems = [] + var resolvedEndpoints: [Endpoint?] = Array<Endpoint?>(repeating: nil, count: endpoints.count) + var isResolvedByDNSRequest: [Bool] = Array<Bool>(repeating: false, count: endpoints.count) + for (i, endpoint) in self.endpoints.enumerated() { + guard let endpoint = endpoint else { continue } + if (endpoint.hasHostAsIPAddress()) { + resolvedEndpoints[i] = endpoint + } else if let resolvedEndpointStringInCache = DNSResolver.cache.object(forKey: endpoint.stringRepresentation() as NSString), + let resolvedEndpointInCache = Endpoint(from: resolvedEndpointStringInCache as String) { + resolvedEndpoints[i] = resolvedEndpointInCache + } else { + let workItem = DispatchWorkItem { + resolvedEndpoints[i] = DNSResolver.resolveSync(endpoint: endpoint) + isResolvedByDNSRequest[i] = true + } + dispatchWorkItems.append(workItem) + DispatchQueue.global(qos: .userInitiated).async(group: dispatchGroup, execute: workItem) + } + } + dispatchGroup.notify(queue: .main) { + assert(endpoints.count == resolvedEndpoints.count) + for (i, endpoint) in endpoints.enumerated() { + guard let endpoint = endpoint, let resolvedEndpoint = resolvedEndpoints[i] else { + completionHandler(nil) + return + } + if (isResolvedByDNSRequest[i]) { + DNSResolver.cache.setObject(resolvedEndpoint.stringRepresentation() as NSString, + forKey: endpoint.stringRepresentation() as NSString) + } + } + let numberOfEndpointsToResolve = endpoints.compactMap { $0 }.count + let numberOfResolvedEndpoints = resolvedEndpoints.compactMap { $0 }.count + if (numberOfResolvedEndpoints < numberOfEndpointsToResolve) { + completionHandler(nil) + } else { + completionHandler(resolvedEndpoints) + } + } + } + + func cancel() { + for workItem in dispatchWorkItems { + workItem.cancel() + } + } + + deinit { + cancel() + } +} + +extension DNSResolver { + // Based on DNS resolution code by Jason Donenfeld <jason@zx2c4.com> + // in parse_endpoint() in src/tools/config.c in the WireGuard codebase + private static func resolveSync(endpoint: Endpoint) -> Endpoint? { + var hints = addrinfo( + ai_flags: 0, + ai_family: AF_UNSPEC, + ai_socktype: SOCK_DGRAM, // WireGuard is UDP-only + ai_protocol: IPPROTO_UDP, // WireGuard is UDP-only + ai_addrlen: 0, + ai_canonname: nil, + ai_addr: nil, + ai_next: nil) + var resultPointer = UnsafeMutablePointer<addrinfo>(OpaquePointer(bitPattern: 0)) + switch (endpoint.host) { + case .name(let name, _): + // The endpoint is a hostname and needs DNS resolution + let returnValue = getaddrinfo( + name.cString(using: .utf8), // Hostname + "\(endpoint.port)".cString(using: .utf8), // Port + &hints, + &resultPointer) + if (returnValue == 0) { + // getaddrinfo succeeded + let ipv4Buffer = UnsafeMutablePointer<Int8>.allocate(capacity: Int(INET_ADDRSTRLEN)) + let ipv6Buffer = UnsafeMutablePointer<Int8>.allocate(capacity: Int(INET6_ADDRSTRLEN)) + var ipv4AddressString: String? + var ipv6AddressString: String? + while (resultPointer != nil) { + let result = resultPointer!.pointee + resultPointer = result.ai_next + if (result.ai_family == AF_INET && result.ai_addrlen == MemoryLayout<sockaddr_in>.size) { + var sa4 = UnsafeRawPointer(result.ai_addr)!.assumingMemoryBound(to: sockaddr_in.self).pointee + if (inet_ntop(result.ai_family, &sa4.sin_addr, ipv4Buffer, socklen_t(INET_ADDRSTRLEN)) != nil) { + ipv4AddressString = String(cString: ipv4Buffer) + // If we found an IPv4 address, we can stop + break + } + } else if (result.ai_family == AF_INET6 && result.ai_addrlen == MemoryLayout<sockaddr_in6>.size) { + if (ipv6AddressString != nil) { + // If we already have an IPv6 address, we can skip this one + continue + } + var sa6 = UnsafeRawPointer(result.ai_addr)!.assumingMemoryBound(to: sockaddr_in6.self).pointee + if (inet_ntop(result.ai_family, &sa6.sin6_addr, ipv6Buffer, socklen_t(INET6_ADDRSTRLEN)) != nil) { + ipv6AddressString = String(cString: ipv6Buffer) + } + } + } + ipv4Buffer.deallocate() + ipv6Buffer.deallocate() + // We prefer an IPv4 address over an IPv6 address + if let ipv4AddressString = ipv4AddressString, let ipv4Address = IPv4Address(ipv4AddressString) { + return Endpoint(host: NWEndpoint.Host.ipv4(ipv4Address), port: endpoint.port) + } else if let ipv6AddressString = ipv6AddressString, let ipv6Address = IPv6Address(ipv6AddressString) { + return Endpoint(host: NWEndpoint.Host.ipv6(ipv6Address), port: endpoint.port) + } else { + return nil + } + } else { + // getaddrinfo failed + return nil + } + default: + // The endpoint is already resolved + return endpoint + } + } +} |