aboutsummaryrefslogtreecommitdiffstats
path: root/WireGuardNetworkExtension/PacketTunnelProvider.swift
blob: ce37c8a5479cf36a3a7639a008e1551f1a214aba (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
//
//  PacketTunnelProvider.swift
//  WireGuardNetworkExtension
//
//  Created by Jeroen Leenarts on 19-06-18.
//  Copyright © 2018 Jason A. Donenfeld <Jason@zx2c4.com>. All rights reserved.
//

import NetworkExtension
import os.log

enum PacketTunnelProviderError: Error {
    case tunnelSetupFailed
}

/// A packet tunnel provider object.
class PacketTunnelProvider: NEPacketTunnelProvider {

    // MARK: Properties

    var wgHandle: Int32?
    var wgContext: WireGuardContext?

    // MARK: NEPacketTunnelProvider

    /// Begin the process of establishing the tunnel.
    override func startTunnel(options: [String: NSObject]?, completionHandler startTunnelCompletionHandler: @escaping (Error?) -> Void) {
        os_log("Starting tunnel", log: Log.general, type: .info)

        let config = self.protocolConfiguration as! NETunnelProviderProtocol // swiftlint:disable:this force_cast
        let interfaceName = config.providerConfiguration![PCKeys.title.rawValue]! as! String // swiftlint:disable:this force_cast
        let settings = config.providerConfiguration![PCKeys.settings.rawValue]! as! String // swiftlint:disable:this force_cast

        let validatedEndpoints = (config.providerConfiguration?[PCKeys.endpoints.rawValue] as? String ?? "").commaSeparatedToArray().compactMap { try? Endpoint(endpointString: String($0)) }.compactMap {$0}
        let validatedAddresses = (config.providerConfiguration?[PCKeys.addresses.rawValue] as? String ?? "").commaSeparatedToArray().compactMap { try? CIDRAddress(stringRepresentation: String($0)) }.compactMap { $0 }

        guard let firstEndpoint = validatedEndpoints.first else {
            startTunnelCompletionHandler(PacketTunnelProviderError.tunnelSetupFailed)
            return
        }

        configureLogger()
        wgContext = WireGuardContext(packetFlow: self.packetFlow)

        let handle = connect(interfaceName: interfaceName, settings: settings)

        if handle < 0 {
            startTunnelCompletionHandler(PacketTunnelProviderError.tunnelSetupFailed)
            return
        }

        wgHandle = handle

        // We use the first endpoint for the ipAddress
        let newSettings = NEPacketTunnelNetworkSettings(tunnelRemoteAddress: firstEndpoint.ipAddress)
        newSettings.tunnelOverheadBytes = 80

        // IPv4 settings
        let validatedIPv4Addresses = validatedAddresses.filter { $0.addressType == .IPv4}
        if validatedIPv4Addresses.count > 0 {
            let ipv4Settings = NEIPv4Settings(addresses: validatedIPv4Addresses.map { $0.ipAddress }, subnetMasks: validatedIPv4Addresses.map { $0.subnetString })
            ipv4Settings.includedRoutes = [NEIPv4Route.default()]
            ipv4Settings.excludedRoutes = validatedEndpoints.filter { $0.addressType == .IPv4}.map {
                NEIPv4Route(destinationAddress: $0.ipAddress, subnetMask: "255.255.255.255")}

            newSettings.ipv4Settings = ipv4Settings
        }

        // IPv6 settings
        let validatedIPv6Addresses = validatedAddresses.filter { $0.addressType == .IPv6}
        if validatedIPv6Addresses.count > 0 {
            let ipv6Settings = NEIPv6Settings(addresses: validatedIPv6Addresses.map { $0.ipAddress }, networkPrefixLengths: validatedIPv6Addresses.map { NSNumber(value: $0.subnet) })
            ipv6Settings.includedRoutes = [NEIPv6Route.default()]
            ipv6Settings.excludedRoutes = validatedEndpoints.filter { $0.addressType == .IPv6}.map { NEIPv6Route(destinationAddress: $0.ipAddress, networkPrefixLength: 0)}

            newSettings.ipv6Settings = ipv6Settings
        }

        if let dns = config.providerConfiguration?[PCKeys.dns.rawValue] as? String {
            newSettings.dnsSettings = NEDNSSettings(servers: dns.commaSeparatedToArray())
        }

        if let mtu = config.providerConfiguration![PCKeys.mtu.rawValue] as? NSNumber, mtu.intValue > 0 {
            newSettings.mtu = mtu
        }

        setTunnelNetworkSettings(newSettings) { (error) in
            if let error = error {
                os_log("Error setting network settings: %s", log: Log.general, type: .error, error.localizedDescription)
                startTunnelCompletionHandler(PacketTunnelProviderError.tunnelSetupFailed)
            } else {
                startTunnelCompletionHandler(nil /* No errors */)
            }
        }
    }

    /// Begin the process of stopping the tunnel.
    override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) {
        os_log("Stopping tunnel", log: Log.general, type: .info)
        if let handle = wgHandle {
            wgTurnOff(handle)
        }
        wgContext?.closeTunnel()
        completionHandler()
    }

    /// Handle IPC messages from the app.
    override func handleAppMessage(_ messageData: Data, completionHandler: ((Data?) -> Void)?) {
        guard let messageString = NSString(data: messageData, encoding: String.Encoding.utf8.rawValue) else {
            completionHandler?(nil)
            return
        }

        os_log("Got a message from the app: %s", log: Log.general, type: .info, messageString)

        let responseData = "Hello app".data(using: String.Encoding.utf8)
        completionHandler?(responseData)
    }

    private func configureLogger() {
        wgSetLogger { (level, tagCStr, msgCStr) in
            let logType: OSLogType
            switch level {
            case 0:
                logType = .debug
            case 1:
                logType = .info
            case 2:
                logType = .error
            default:
                logType = .default
            }
            let tag = (tagCStr != nil) ? String(cString: tagCStr!) : ""
            let msg = (msgCStr != nil) ? String(cString: msgCStr!) : ""
            os_log("wg log: %{public}s: %{public}s", log: Log.general, type: logType, tag, msg)
        }
    }

    private func connect(interfaceName: String, settings: String) -> Int32 { // swiftlint:disable:this cyclomatic_complexity
        return withStringsAsGoStrings(interfaceName, settings) { (nameGoStr, settingsGoStr) -> Int32 in
            return withUnsafeMutablePointer(to: &wgContext) { (wgCtxPtr) -> Int32 in
                return wgTurnOn(nameGoStr, settingsGoStr,
                                // read_fn: Read from the TUN interface and pass it on to WireGuard
                    { (wgCtxPtr, buf, len) -> Int in
                        guard let wgCtxPtr = wgCtxPtr else { return 0 }
                        guard let buf = buf else { return 0 }
                        let wgContext = wgCtxPtr.bindMemory(to: WireGuardContext.self, capacity: 1).pointee
                        var isTunnelClosed = false
                        guard let packet = wgContext.readPacket(isTunnelClosed: &isTunnelClosed) else { return 0 }
                        if isTunnelClosed { return -1 }
                        let packetData = packet.data
                        if packetData.count <= len {
                            packetData.copyBytes(to: buf, count: packetData.count)
                            return packetData.count
                        }
                        return 0
                },
                    // write_fn: Receive packets from WireGuard and write to the TUN interface
                    { (wgCtxPtr, buf, len) -> Int in
                        guard let wgCtxPtr = wgCtxPtr else { return 0 }
                        guard let buf = buf else { return 0 }
                        guard len > 0 else { return 0 }
                        let wgContext = wgCtxPtr.bindMemory(to: WireGuardContext.self, capacity: 1).pointee
                        let ipVersionBits = (buf[0] & 0xf0) >> 4
                        let ipVersion: sa_family_t? = {
                            if ipVersionBits == 4 { return sa_family_t(AF_INET) } // IPv4
                            if ipVersionBits == 6 { return sa_family_t(AF_INET6) } // IPv6
                            return nil
                        }()
                        guard let protocolFamily = ipVersion else { fatalError("Unknown IP version") }
                        let packet = NEPacket(data: Data(bytes: buf, count: len), protocolFamily: protocolFamily)
                        var isTunnelClosed = false
                        let isWritten = wgContext.writePacket(packet: packet, isTunnelClosed: &isTunnelClosed)
                        if isTunnelClosed { return -1 }
                        if isWritten {
                            return len
                        }
                        return 0
                },
                    wgCtxPtr)
            }
        }
    }
}

class WireGuardContext {
    private var packetFlow: NEPacketTunnelFlow
    private var outboundPackets: [NEPacket] = []
    private var isTunnelClosed: Bool = false
    private let readPacketCondition = NSCondition()

    init(packetFlow: NEPacketTunnelFlow) {
        self.packetFlow = packetFlow
    }

    func closeTunnel() {
        isTunnelClosed = true
        readPacketCondition.signal()
    }

    func readPacket(isTunnelClosed: inout Bool) -> NEPacket? {
        if outboundPackets.isEmpty {
            let readPacketCondition = NSCondition()
            readPacketCondition.lock()
            var packetsObtained: [NEPacket]?
            packetFlow.readPacketObjects { (packets: [NEPacket]) in
                packetsObtained = packets
                readPacketCondition.signal()
            }
            // Wait till the completion handler of packetFlow.readPacketObjects() finishes
            while packetsObtained == nil && !self.isTunnelClosed {
                readPacketCondition.wait()
            }
            if let packetsObtained = packetsObtained {
                outboundPackets = packetsObtained
            }
            readPacketCondition.unlock()
        }
        isTunnelClosed = self.isTunnelClosed
        if outboundPackets.isEmpty {
            return nil
        } else {
            return outboundPackets.removeFirst()
        }
    }

    func writePacket(packet: NEPacket, isTunnelClosed: inout Bool) -> Bool {
        isTunnelClosed = self.isTunnelClosed
        return packetFlow.writePacketObjects([packet])
    }
}

private func withStringsAsGoStrings<R>(_ str1: String, _ str2: String, closure: (gostring_t, gostring_t) -> R) -> R {
    return str1.withCString { (s1cStr) -> R in
        let gstr1 = gostring_t(p: s1cStr, n: str1.utf8.count)
        return str2.withCString { (s2cStr) -> R in
            let gstr2 = gostring_t(p: s2cStr, n: str2.utf8.count)
            return closure(gstr1, gstr2)
        }
    }
}