diff options
Diffstat (limited to '')
-rw-r--r-- | WireGuard/WireGuardNetworkExtension/PacketTunnelProvider.swift | 106 |
1 files changed, 11 insertions, 95 deletions
diff --git a/WireGuard/WireGuardNetworkExtension/PacketTunnelProvider.swift b/WireGuard/WireGuardNetworkExtension/PacketTunnelProvider.swift index ddd2677..f6f2bb4 100644 --- a/WireGuard/WireGuardNetworkExtension/PacketTunnelProvider.swift +++ b/WireGuard/WireGuardNetworkExtension/PacketTunnelProvider.swift @@ -17,7 +17,6 @@ class PacketTunnelProvider: NEPacketTunnelProvider { // MARK: Properties private var wgHandle: Int32? - private var wgContext: WireGuardContext? // MARK: NEPacketTunnelProvider @@ -64,9 +63,14 @@ class PacketTunnelProvider: NEPacketTunnelProvider { } configureLogger() - wgContext = WireGuardContext(packetFlow: self.packetFlow) - let handle = connect(interfaceName: interfaceName, settings: wireguardSettings, mtu: mtu.uint16Value) + let fd = packetFlow.value(forKeyPath: "socket.fileDescriptor") as! Int32 + if fd < 0 { + os_log("Starting tunnel failed: Could not determine file descriptor", log: OSLog.default, type: .error) + startTunnelCompletionHandler(PacketTunnelProviderError.couldNotStartWireGuard) + return + } + let handle = connect(interfaceName: interfaceName, settings: wireguardSettings, fd: fd) if handle < 0 { os_log("Starting tunnel failed: Could not start WireGuard", log: OSLog.default, type: .error) @@ -114,9 +118,8 @@ class PacketTunnelProvider: NEPacketTunnelProvider { // MTU if (mtu == 0) { - // 0 imples automatic MTU, where we set overhead as 95 bytes, - // 80 for WireGuard and the 15 to make sure WireGuard's padding will work. - networkSettings.tunnelOverheadBytes = 95 + // 0 imples automatic MTU, where we set overhead as 80 bytes, which is the worst case for WireGuard + networkSettings.tunnelOverheadBytes = 80 } else { networkSettings.mtu = mtu } @@ -134,7 +137,6 @@ class PacketTunnelProvider: NEPacketTunnelProvider { /// Begin the process of stopping the tunnel. override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) { os_log("Stopping tunnel", log: OSLog.default, type: .info) - wgContext?.closeTunnel() if let handle = wgHandle { wgTurnOff(handle) } @@ -159,99 +161,13 @@ class PacketTunnelProvider: NEPacketTunnelProvider { } } - private func connect(interfaceName: String, settings: String, mtu: UInt16) -> Int32 { // swiftlint:disable:this cyclomatic_complexity + private func connect(interfaceName: String, settings: String, fd: Int32) -> Int32 { // swiftlint:disable:this cyclomatic_complexity return withStringsAsGoStrings(interfaceName, settings) { (nameGoStr, settingsGoStr) -> Int32 in - return withUnsafeMutablePointer(to: &wgContext) { (wgCtxPtr) -> Int32 in - return wgTurnOn(nameGoStr, settingsGoStr, mtu, { (wgCtxPtr, buf, len) -> Int in - autoreleasepool { - // read_fn: Read from the TUN interface and pass it on to WireGuard - guard let wgCtxPtr = wgCtxPtr else { return 0 } - guard let buf = buf else { return 0 } - let wgContext = wgCtxPtr.bindMemory(to: WireGuardContext.self, capacity: 1).pointee - var isTunnelClosed = false - let packet = wgContext.readPacket(isTunnelClosed: &isTunnelClosed) - if isTunnelClosed { return -1 } - guard let packetData = packet?.data else { return 0 } - if packetData.count <= len { - packetData.copyBytes(to: buf, count: packetData.count) - return packetData.count - } - return 0 - } - }, { (wgCtxPtr, buf, len) -> Int in - autoreleasepool { - // write_fn: Receive packets from WireGuard and write to the TUN interface - guard let wgCtxPtr = wgCtxPtr else { return 0 } - guard let buf = buf else { return 0 } - guard len > 0 else { return 0 } - let wgContext = wgCtxPtr.bindMemory(to: WireGuardContext.self, capacity: 1).pointee - let ipVersionBits = (buf[0] & 0xf0) >> 4 - let ipVersion: sa_family_t? = { - if ipVersionBits == 4 { return sa_family_t(AF_INET) } // IPv4 - if ipVersionBits == 6 { return sa_family_t(AF_INET6) } // IPv6 - return nil - }() - guard let protocolFamily = ipVersion else { fatalError("Unknown IP version") } - let packet = NEPacket(data: Data(bytes: buf, count: len), protocolFamily: protocolFamily) - var isTunnelClosed = false - let isWritten = wgContext.writePacket(packet: packet, isTunnelClosed: &isTunnelClosed) - if isTunnelClosed { return -1 } - if isWritten { - return len - } - return 0 - } - }, - wgCtxPtr) - } + return wgTurnOn(nameGoStr, settingsGoStr, fd) } } } -class WireGuardContext { - private var packetFlow: NEPacketTunnelFlow - private var outboundPackets: [NEPacket] = [] - private var isTunnelClosed: Bool = false - private var readPacketCondition = NSCondition() - - init(packetFlow: NEPacketTunnelFlow) { - self.packetFlow = packetFlow - } - - func closeTunnel() { - isTunnelClosed = true - readPacketCondition.signal() - } - - func packetsRead(packets: [NEPacket]) { - readPacketCondition.lock() - outboundPackets.append(contentsOf: packets) - readPacketCondition.unlock() - readPacketCondition.signal() - } - - func readPacket(isTunnelClosed: inout Bool) -> NEPacket? { - if outboundPackets.isEmpty { - readPacketCondition.lock() - packetFlow.readPacketObjects(completionHandler: packetsRead) - while outboundPackets.isEmpty && !self.isTunnelClosed { - readPacketCondition.wait() - } - readPacketCondition.unlock() - } - isTunnelClosed = self.isTunnelClosed - if !outboundPackets.isEmpty { - return outboundPackets.removeFirst() - } - return nil - } - - func writePacket(packet: NEPacket, isTunnelClosed: inout Bool) -> Bool { - isTunnelClosed = self.isTunnelClosed - return packetFlow.writePacketObjects([packet]) - } -} - private func withStringsAsGoStrings<R>(_ str1: String, _ str2: String, closure: (gostring_t, gostring_t) -> R) -> R { return str1.withCString { (s1cStr) -> R in let gstr1 = gostring_t(p: s1cStr, n: str1.utf8.count) |