aboutsummaryrefslogtreecommitdiffstats
path: root/Sources/WireGuardKit/DNSResolver.swift
diff options
context:
space:
mode:
Diffstat (limited to 'Sources/WireGuardKit/DNSResolver.swift')
-rw-r--r--Sources/WireGuardKit/DNSResolver.swift153
1 files changed, 153 insertions, 0 deletions
diff --git a/Sources/WireGuardKit/DNSResolver.swift b/Sources/WireGuardKit/DNSResolver.swift
new file mode 100644
index 0000000..379c698
--- /dev/null
+++ b/Sources/WireGuardKit/DNSResolver.swift
@@ -0,0 +1,153 @@
+// SPDX-License-Identifier: MIT
+// Copyright © 2018-2019 WireGuard LLC. All Rights Reserved.
+
+import Network
+import Foundation
+
+enum DNSResolver {}
+
+extension DNSResolver {
+
+ /// Concurrent queue used for DNS resolutions
+ private static let resolverQueue = DispatchQueue(label: "DNSResolverQueue", qos: .default, attributes: .concurrent)
+
+ static func resolveSync(endpoints: [Endpoint?]) -> [Result<Endpoint, DNSResolutionError>?] {
+ let isAllEndpointsAlreadyResolved = endpoints.allSatisfy { maybeEndpoint -> Bool in
+ return maybeEndpoint?.hasHostAsIPAddress() ?? true
+ }
+
+ if isAllEndpointsAlreadyResolved {
+ return endpoints.map { endpoint in
+ return endpoint.map { .success($0) }
+ }
+ }
+
+ return endpoints.concurrentMap(queue: resolverQueue) { endpoint -> Result<Endpoint, DNSResolutionError>? in
+ guard let endpoint = endpoint else { return nil }
+
+ if endpoint.hasHostAsIPAddress() {
+ return .success(endpoint)
+ } else {
+ return Result { try DNSResolver.resolveSync(endpoint: endpoint) }
+ .mapError { error -> DNSResolutionError in
+ // swiftlint:disable:next force_cast
+ return error as! DNSResolutionError
+ }
+ }
+ }
+ }
+
+ private static func resolveSync(endpoint: Endpoint) throws -> Endpoint {
+ guard case .name(let name, _) = endpoint.host else {
+ return endpoint
+ }
+
+ var hints = addrinfo()
+ hints.ai_flags = AI_ALL // We set this to ALL so that we get v4 addresses even on DNS64 networks
+ hints.ai_family = AF_UNSPEC
+ hints.ai_socktype = SOCK_DGRAM
+ hints.ai_protocol = IPPROTO_UDP
+
+ var resultPointer: UnsafeMutablePointer<addrinfo>?
+ defer {
+ resultPointer.flatMap { freeaddrinfo($0) }
+ }
+
+ let errorCode = getaddrinfo(name, "\(endpoint.port)", &hints, &resultPointer)
+ if errorCode != 0 {
+ throw DNSResolutionError(errorCode: errorCode, address: name)
+ }
+
+ var ipv4Address: IPv4Address?
+ var ipv6Address: IPv6Address?
+
+ var next: UnsafeMutablePointer<addrinfo>? = resultPointer
+ let iterator = AnyIterator { () -> addrinfo? in
+ let result = next?.pointee
+ next = result?.ai_next
+ return result
+ }
+
+ for addrInfo in iterator {
+ if let maybeIpv4Address = IPv4Address(addrInfo: addrInfo) {
+ ipv4Address = maybeIpv4Address
+ break // If we found an IPv4 address, we can stop
+ } else if let maybeIpv6Address = IPv6Address(addrInfo: addrInfo) {
+ ipv6Address = maybeIpv6Address
+ continue // If we already have an IPv6 address, we can skip this one
+ }
+ }
+
+ // We prefer an IPv4 address over an IPv6 address
+ if let ipv4Address = ipv4Address {
+ return Endpoint(host: .ipv4(ipv4Address), port: endpoint.port)
+ } else if let ipv6Address = ipv6Address {
+ return Endpoint(host: .ipv6(ipv6Address), port: endpoint.port)
+ } else {
+ // Must never happen
+ fatalError()
+ }
+ }
+}
+
+extension Endpoint {
+ func withReresolvedIP() throws -> Endpoint {
+ #if os(iOS)
+ let hostname: String
+ switch host {
+ case .name(let name, _):
+ hostname = name
+ case .ipv4(let address):
+ hostname = "\(address)"
+ case .ipv6(let address):
+ hostname = "\(address)"
+ @unknown default:
+ fatalError()
+ }
+
+ var hints = addrinfo()
+ hints.ai_family = AF_UNSPEC
+ hints.ai_socktype = SOCK_DGRAM
+ hints.ai_protocol = IPPROTO_UDP
+ hints.ai_flags = AI_DEFAULT
+
+ var result: UnsafeMutablePointer<addrinfo>?
+ defer {
+ result.flatMap { freeaddrinfo($0) }
+ }
+
+ let errorCode = getaddrinfo(hostname, "\(self.port)", &hints, &result)
+ if errorCode != 0 {
+ throw DNSResolutionError(errorCode: errorCode, address: hostname)
+ }
+
+ let addrInfo = result!.pointee
+ if let ipv4Address = IPv4Address(addrInfo: addrInfo) {
+ return Endpoint(host: .ipv4(ipv4Address), port: port)
+ } else if let ipv6Address = IPv6Address(addrInfo: addrInfo) {
+ return Endpoint(host: .ipv6(ipv6Address), port: port)
+ } else {
+ fatalError()
+ }
+ #elseif os(macOS)
+ return self
+ #else
+ #error("Unimplemented")
+ #endif
+ }
+}
+
+/// An error type describing DNS resolution error
+public struct DNSResolutionError: LocalizedError {
+ public let errorCode: Int32
+ public let address: String
+
+ init(errorCode: Int32, address: String) {
+ self.errorCode = errorCode
+ self.address = address
+ }
+
+ public var errorDescription: String? {
+ return String(cString: gai_strerror(errorCode))
+ }
+}