diff options
Diffstat (limited to 'WireGuard/WireGuardNetworkExtension/DNSResolver.swift')
-rw-r--r-- | WireGuard/WireGuardNetworkExtension/DNSResolver.swift | 41 |
1 files changed, 23 insertions, 18 deletions
diff --git a/WireGuard/WireGuardNetworkExtension/DNSResolver.swift b/WireGuard/WireGuardNetworkExtension/DNSResolver.swift index 4ce89b2..57093c8 100644 --- a/WireGuard/WireGuardNetworkExtension/DNSResolver.swift +++ b/WireGuard/WireGuardNetworkExtension/DNSResolver.swift @@ -66,25 +66,12 @@ extension DNSResolver { // Based on DNS resolution code by Jason Donenfeld <jason@zx2c4.com> // in parse_endpoint() in src/tools/config.c in the WireGuard codebase private static func resolveSync(endpoint: Endpoint) -> Endpoint? { - var hints = addrinfo( - ai_flags: 0, - ai_family: AF_UNSPEC, - ai_socktype: SOCK_DGRAM, // WireGuard is UDP-only - ai_protocol: IPPROTO_UDP, // WireGuard is UDP-only - ai_addrlen: 0, - ai_canonname: nil, - ai_addr: nil, - ai_next: nil) - var resultPointer = UnsafeMutablePointer<addrinfo>(OpaquePointer(bitPattern: 0)) switch endpoint.host { case .name(let name, _): + var resultPointer = UnsafeMutablePointer<addrinfo>(OpaquePointer(bitPattern: 0)) + // The endpoint is a hostname and needs DNS resolution - let returnValue = getaddrinfo( - name.cString(using: .utf8), // Hostname - "\(endpoint.port)".cString(using: .utf8), // Port - &hints, - &resultPointer) - if returnValue == 0 { + if addressInfo(for: name, port: endpoint.port, resultPointer: &resultPointer) == 0 { // getaddrinfo succeeded let ipv4Buffer = UnsafeMutablePointer<Int8>.allocate(capacity: Int(INET_ADDRSTRLEN)) let ipv6Buffer = UnsafeMutablePointer<Int8>.allocate(capacity: Int(INET6_ADDRSTRLEN)) @@ -115,9 +102,9 @@ extension DNSResolver { ipv6Buffer.deallocate() // We prefer an IPv4 address over an IPv6 address if let ipv4AddressString = ipv4AddressString, let ipv4Address = IPv4Address(ipv4AddressString) { - return Endpoint(host: NWEndpoint.Host.ipv4(ipv4Address), port: endpoint.port) + return Endpoint(host: .ipv4(ipv4Address), port: endpoint.port) } else if let ipv6AddressString = ipv6AddressString, let ipv6Address = IPv6Address(ipv6AddressString) { - return Endpoint(host: NWEndpoint.Host.ipv6(ipv6Address), port: endpoint.port) + return Endpoint(host: .ipv6(ipv6Address), port: endpoint.port) } else { return nil } @@ -130,4 +117,22 @@ extension DNSResolver { return endpoint } } + + private static func addressInfo(for name: String, port: NWEndpoint.Port, resultPointer: inout UnsafeMutablePointer<addrinfo>?) -> Int32 { + var hints = addrinfo( + ai_flags: 0, + ai_family: AF_UNSPEC, + ai_socktype: SOCK_DGRAM, // WireGuard is UDP-only + ai_protocol: IPPROTO_UDP, // WireGuard is UDP-only + ai_addrlen: 0, + ai_canonname: nil, + ai_addr: nil, + ai_next: nil) + + return getaddrinfo( + name.cString(using: .utf8), // Hostname + "\(port)".cString(using: .utf8), // Port + &hints, + &resultPointer) + } } |