aboutsummaryrefslogtreecommitdiffstats
path: root/WireGuard/WireGuardNetworkExtension/PacketTunnelProvider.swift
blob: da4372e3cbe5a73e296939a749327e152ddee97f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
//
//  Copyright © 2018 WireGuard LLC. All Rights Reserved.
//

import NetworkExtension
import os.log

enum PacketTunnelProviderError: Error {
    case invalidOptions
    case couldNotStartWireGuard
    case coultNotSetNetworkSettings
}

/// A packet tunnel provider object.
class PacketTunnelProvider: NEPacketTunnelProvider {

    // MARK: Properties

    private var wgHandle: Int32?
    private var wgContext: WireGuardContext?

    // MARK: NEPacketTunnelProvider

    /// Begin the process of establishing the tunnel.
    override func startTunnel(options: [String: NSObject]?,
                              completionHandler startTunnelCompletionHandler: @escaping (Error?) -> Void) {
        os_log("Starting tunnel", log: OSLog.default, type: .info)

        guard let options = options else {
            os_log("Starting tunnel failed: No options passed. Possible connection request from preferences", log: OSLog.default, type: .error)
            // displayMessage is deprecated API
            displayMessage("Please use the WireGuard app to start WireGuard tunnels") { (_) in
                startTunnelCompletionHandler(PacketTunnelProviderError.invalidOptions)
            }
            return
        }

        guard let interfaceName = options[.interfaceName] as? String,
            let wireguardSettings = options[.wireguardSettings] as? String,
            let remoteAddress = options[.remoteAddress] as? String,
            let dnsServers = options[.dnsServers] as? [String],
            let mtu = options[.mtu] as? NSNumber,

            // IPv4 settings
            let ipv4Addresses = options[.ipv4Addresses] as? [String],
            let ipv4SubnetMasks = options[.ipv4SubnetMasks] as? [String],
            let ipv4IncludedRouteAddresses = options[.ipv4IncludedRouteAddresses] as? [String],
            let ipv4IncludedRouteSubnetMasks = options[.ipv4IncludedRouteSubnetMasks] as? [String],
            let ipv4ExcludedRouteAddresses = options[.ipv4ExcludedRouteAddresses] as? [String],
            let ipv4ExcludedRouteSubnetMasks = options[.ipv4ExcludedRouteSubnetMasks] as? [String],

            // IPv6 settings
            let ipv6Addresses = options[.ipv6Addresses] as? [String],
            let ipv6NetworkPrefixLengths = options[.ipv6NetworkPrefixLengths] as? [NSNumber],
            let ipv6IncludedRouteAddresses = options[.ipv6IncludedRouteAddresses] as? [String],
            let ipv6IncludedRouteNetworkPrefixLengths = options[.ipv6IncludedRouteNetworkPrefixLengths] as? [NSNumber],
            let ipv6ExcludedRouteAddresses = options[.ipv6ExcludedRouteAddresses] as? [String],
            let ipv6ExcludedRouteNetworkPrefixLengths = options[.ipv6ExcludedRouteNetworkPrefixLengths] as? [NSNumber]

            else {
                os_log("Starting tunnel failed: Invalid options passed", log: OSLog.default, type: .error)
                startTunnelCompletionHandler(PacketTunnelProviderError.invalidOptions)
                return
        }

        configureLogger()
        wgContext = WireGuardContext(packetFlow: self.packetFlow)

        let handle = connect(interfaceName: interfaceName, settings: wireguardSettings, mtu: mtu.uint16Value)

        if handle < 0 {
            os_log("Starting tunnel failed: Could not start WireGuard", log: OSLog.default, type: .error)
            startTunnelCompletionHandler(PacketTunnelProviderError.couldNotStartWireGuard)
            return
        }

        wgHandle = handle

        // Network settings
        let networkSettings = NEPacketTunnelNetworkSettings(tunnelRemoteAddress: remoteAddress)

        // IPv4 settings
        let ipv4Settings = NEIPv4Settings(addresses: ipv4Addresses, subnetMasks: ipv4SubnetMasks)
        assert(ipv4IncludedRouteAddresses.count == ipv4IncludedRouteSubnetMasks.count)
        ipv4Settings.includedRoutes = zip(ipv4IncludedRouteAddresses, ipv4IncludedRouteSubnetMasks).map {
            NEIPv4Route(destinationAddress: $0.0, subnetMask: $0.1)
        }
        assert(ipv4ExcludedRouteAddresses.count == ipv4ExcludedRouteSubnetMasks.count)
        ipv4Settings.excludedRoutes = zip(ipv4ExcludedRouteAddresses, ipv4ExcludedRouteSubnetMasks).map {
            NEIPv4Route(destinationAddress: $0.0, subnetMask: $0.1)
        }
        networkSettings.ipv4Settings = ipv4Settings

        // IPv6 settings
        let ipv6Settings = NEIPv6Settings(addresses: ipv6Addresses, networkPrefixLengths: ipv6NetworkPrefixLengths)
        assert(ipv6IncludedRouteAddresses.count == ipv6IncludedRouteNetworkPrefixLengths.count)
        ipv6Settings.includedRoutes = zip(ipv6IncludedRouteAddresses, ipv6IncludedRouteNetworkPrefixLengths).map {
            NEIPv6Route(destinationAddress: $0.0, networkPrefixLength: $0.1)
        }
        assert(ipv6ExcludedRouteAddresses.count == ipv6ExcludedRouteNetworkPrefixLengths.count)
        ipv6Settings.excludedRoutes = zip(ipv6ExcludedRouteAddresses, ipv6ExcludedRouteNetworkPrefixLengths).map {
            NEIPv6Route(destinationAddress: $0.0, networkPrefixLength: $0.1)
        }
        networkSettings.ipv6Settings = ipv6Settings

        // DNS
        networkSettings.dnsSettings = NEDNSSettings(servers: dnsServers)

        // MTU
        if (mtu == 0) {
            // 0 imples automatic MTU, where we set overhead as 95 bytes,
            // 80 for WireGuard and the 15 to make sure WireGuard's padding will work.
            networkSettings.tunnelOverheadBytes = 95
        } else {
            networkSettings.mtu = mtu
        }

        setTunnelNetworkSettings(networkSettings) { (error) in
            if let error = error {
                os_log("Starting tunnel failed: Error setting network settings: %s", log: OSLog.default, type: .error, error.localizedDescription)
                startTunnelCompletionHandler(PacketTunnelProviderError.coultNotSetNetworkSettings)
            } else {
                startTunnelCompletionHandler(nil /* No errors */)
            }
        }
    }

    /// Begin the process of stopping the tunnel.
    override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) {
        os_log("Stopping tunnel", log: OSLog.default, type: .info)
        wgContext?.closeTunnel()
        if let handle = wgHandle {
            wgTurnOff(handle)
        }
        completionHandler()
    }

    private func configureLogger() {
        wgSetLogger { (level, msgCStr) in
            let logType: OSLogType
            switch level {
            case 0:
                logType = .debug
            case 1:
                logType = .info
            case 2:
                logType = .error
            default:
                logType = .default
            }
            let msg = (msgCStr != nil) ? String(cString: msgCStr!) : ""
            os_log("%{public}s", log: OSLog.default, type: logType, msg)
        }
    }

    private func connect(interfaceName: String, settings: String, mtu: UInt16) -> Int32 { // swiftlint:disable:this cyclomatic_complexity
        return withStringsAsGoStrings(interfaceName, settings) { (nameGoStr, settingsGoStr) -> Int32 in
            return withUnsafeMutablePointer(to: &wgContext) { (wgCtxPtr) -> Int32 in
                return wgTurnOn(nameGoStr, settingsGoStr, mtu, { (wgCtxPtr, buf, len) -> Int in
                    autoreleasepool {
                        // read_fn: Read from the TUN interface and pass it on to WireGuard
                        guard let wgCtxPtr = wgCtxPtr else { return 0 }
                        guard let buf = buf else { return 0 }
                        let wgContext = wgCtxPtr.bindMemory(to: WireGuardContext.self, capacity: 1).pointee
                        var isTunnelClosed = false
                        let packet = wgContext.readPacket(isTunnelClosed: &isTunnelClosed)
                        if isTunnelClosed { return -1 }
                        guard let packetData = packet?.data else { return 0 }
                        if packetData.count <= len {
                            packetData.copyBytes(to: buf, count: packetData.count)
                            return packetData.count
                        }
                        return 0
                    }
                }, { (wgCtxPtr, buf, len) -> Int in
                    autoreleasepool {
                        // write_fn: Receive packets from WireGuard and write to the TUN interface
                        guard let wgCtxPtr = wgCtxPtr else { return 0 }
                        guard let buf = buf else { return 0 }
                        guard len > 0 else { return 0 }
                        let wgContext = wgCtxPtr.bindMemory(to: WireGuardContext.self, capacity: 1).pointee
                        let ipVersionBits = (buf[0] & 0xf0) >> 4
                        let ipVersion: sa_family_t? = {
                            if ipVersionBits == 4 { return sa_family_t(AF_INET) } // IPv4
                            if ipVersionBits == 6 { return sa_family_t(AF_INET6) } // IPv6
                            return nil
                        }()
                        guard let protocolFamily = ipVersion else { fatalError("Unknown IP version") }
                        let packet = NEPacket(data: Data(bytes: buf, count: len), protocolFamily: protocolFamily)
                        var isTunnelClosed = false
                        let isWritten = wgContext.writePacket(packet: packet, isTunnelClosed: &isTunnelClosed)
                        if isTunnelClosed { return -1 }
                        if isWritten {
                            return len
                        }
                        return 0
                    }
                },
                    wgCtxPtr)
            }
        }
    }
}

class WireGuardContext {
    private var packetFlow: NEPacketTunnelFlow
    private var outboundPackets: [NEPacket] = []
    private var isTunnelClosed: Bool = false
    private var readPacketCondition = NSCondition()

    init(packetFlow: NEPacketTunnelFlow) {
        self.packetFlow = packetFlow
    }

    func closeTunnel() {
        isTunnelClosed = true
        readPacketCondition.signal()
    }

    func packetsRead(packets: [NEPacket]) {
        readPacketCondition.lock()
        outboundPackets.append(contentsOf: packets)
        readPacketCondition.unlock()
        readPacketCondition.signal()
    }

    func readPacket(isTunnelClosed: inout Bool) -> NEPacket? {
        if outboundPackets.isEmpty {
            readPacketCondition.lock()
            packetFlow.readPacketObjects(completionHandler: packetsRead)
            while outboundPackets.isEmpty && !self.isTunnelClosed {
                readPacketCondition.wait()
            }
            readPacketCondition.unlock()
        }
        isTunnelClosed = self.isTunnelClosed
        if !outboundPackets.isEmpty {
            return outboundPackets.removeFirst()
        }
        return nil
    }

    func writePacket(packet: NEPacket, isTunnelClosed: inout Bool) -> Bool {
        isTunnelClosed = self.isTunnelClosed
        return packetFlow.writePacketObjects([packet])
    }
}

private func withStringsAsGoStrings<R>(_ str1: String, _ str2: String, closure: (gostring_t, gostring_t) -> R) -> R {
    return str1.withCString { (s1cStr) -> R in
        let gstr1 = gostring_t(p: s1cStr, n: str1.utf8.count)
        return str2.withCString { (s2cStr) -> R in
            let gstr2 = gostring_t(p: s2cStr, n: str2.utf8.count)
            return closure(gstr1, gstr2)
        }
    }
}