aboutsummaryrefslogtreecommitdiffstats
path: root/WireGuard/WireGuard/VPN/DNSResolver.swift
blob: e0278520157b3ca84ffce9e4381129445e50b221 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
// SPDX-License-Identifier: MIT
// Copyright © 2018 WireGuard LLC. All rights reserved.

import Network
import Foundation

class DNSResolver {
    let endpoints: [Endpoint?]

    init(endpoints: [Endpoint?]) {
        self.endpoints = endpoints
    }

    func resolve(completionHandler: @escaping ([Endpoint?]?) -> Void) {
        let endpoints = self.endpoints
        DispatchQueue.global(qos: .userInitiated).async {
            var resolvedEndpoints: [Endpoint?] = []
            var isError = false
            for endpoint in endpoints {
                if let endpoint = endpoint {
                    if let resolvedEndpoint = DNSResolver.resolveSync(endpoint: endpoint) {
                        resolvedEndpoints.append(resolvedEndpoint)
                    } else {
                        isError = true
                        break
                    }
                } else {
                    resolvedEndpoints.append(nil)
                }
            }
            if (isError) {
                DispatchQueue.main.async {
                    completionHandler(nil)
                }
                return
            }
            DispatchQueue.main.async {
                completionHandler(resolvedEndpoints)
            }
        }
    }

    // Based on DNS resolution code by Jason Donenfeld <jason@zx2c4.com>
    // in parse_endpoint() in src/tools/config.c in the WireGuard codebase
    private static func resolveSync(endpoint: Endpoint) -> Endpoint? {
        var hints = addrinfo(
            ai_flags: 0,
            ai_family: AF_UNSPEC,
            ai_socktype: SOCK_DGRAM, // WireGuard is UDP-only
            ai_protocol: IPPROTO_UDP, // WireGuard is UDP-only
            ai_addrlen: 0,
            ai_canonname: nil,
            ai_addr: nil,
            ai_next: nil)
        var resultPointer = UnsafeMutablePointer<addrinfo>(OpaquePointer(bitPattern: 0))
        switch (endpoint.host) {
        case .name(let name, _):
            // The endpoint is a hostname and needs DNS resolution
            let returnValue = getaddrinfo(
                name.cString(using: .utf8), // Hostname
                "\(endpoint.port)".cString(using: .utf8), // Port
                &hints,
                &resultPointer)
            if (returnValue == 0) {
                // getaddrinfo succeeded
                let ipv4Buffer = UnsafeMutablePointer<Int8>.allocate(capacity: Int(INET_ADDRSTRLEN))
                let ipv6Buffer = UnsafeMutablePointer<Int8>.allocate(capacity: Int(INET6_ADDRSTRLEN))
                var ipv4AddressString: String? = nil
                var ipv6AddressString: String? = nil
                while (resultPointer != nil) {
                    let result = resultPointer!.pointee
                    resultPointer = result.ai_next
                    if (result.ai_family == AF_INET && result.ai_addrlen == INET_ADDRSTRLEN) {
                        if (inet_ntop(result.ai_family, result.ai_addr, ipv4Buffer, result.ai_addrlen) != nil) {
                            ipv4AddressString = String(cString: ipv4Buffer)
                            // If we found an IPv4 address, we can stop
                            break
                        }
                    } else if (result.ai_family == AF_INET6 && result.ai_addrlen == INET6_ADDRSTRLEN) {
                        if (ipv6AddressString != nil) {
                            // If we already have an IPv6 address, we can skip this one
                            continue
                        }
                        if (inet_ntop(result.ai_family, result.ai_addr, ipv6Buffer, result.ai_addrlen) != nil) {
                            ipv6AddressString = String(cString: ipv6Buffer)
                        }
                    }
                }
                ipv4Buffer.deallocate()
                ipv6Buffer.deallocate()
                // We prefer an IPv4 address over an IPv6 address
                if let ipv4AddressString = ipv4AddressString, let ipv4Address = IPv4Address(ipv4AddressString) {
                    return Endpoint(host: NWEndpoint.Host.ipv4(ipv4Address), port: endpoint.port)
                } else if let ipv6AddressString = ipv6AddressString, let ipv6Address = IPv6Address(ipv6AddressString) {
                    return Endpoint(host: NWEndpoint.Host.ipv6(ipv6Address), port: endpoint.port)
                } else {
                    return nil
                }
            } else {
                // getaddrinfo failed
                return nil
            }
        default:
            // The endpoint is already resolved
            return endpoint
        }
    }
}