aboutsummaryrefslogtreecommitdiffstats
path: root/WireGuard/WireGuardNetworkExtension/PacketTunnelProvider.swift
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--WireGuard/WireGuardNetworkExtension/PacketTunnelProvider.swift106
1 files changed, 11 insertions, 95 deletions
diff --git a/WireGuard/WireGuardNetworkExtension/PacketTunnelProvider.swift b/WireGuard/WireGuardNetworkExtension/PacketTunnelProvider.swift
index ddd2677..f6f2bb4 100644
--- a/WireGuard/WireGuardNetworkExtension/PacketTunnelProvider.swift
+++ b/WireGuard/WireGuardNetworkExtension/PacketTunnelProvider.swift
@@ -17,7 +17,6 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
// MARK: Properties
private var wgHandle: Int32?
- private var wgContext: WireGuardContext?
// MARK: NEPacketTunnelProvider
@@ -64,9 +63,14 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
}
configureLogger()
- wgContext = WireGuardContext(packetFlow: self.packetFlow)
- let handle = connect(interfaceName: interfaceName, settings: wireguardSettings, mtu: mtu.uint16Value)
+ let fd = packetFlow.value(forKeyPath: "socket.fileDescriptor") as! Int32
+ if fd < 0 {
+ os_log("Starting tunnel failed: Could not determine file descriptor", log: OSLog.default, type: .error)
+ startTunnelCompletionHandler(PacketTunnelProviderError.couldNotStartWireGuard)
+ return
+ }
+ let handle = connect(interfaceName: interfaceName, settings: wireguardSettings, fd: fd)
if handle < 0 {
os_log("Starting tunnel failed: Could not start WireGuard", log: OSLog.default, type: .error)
@@ -114,9 +118,8 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
// MTU
if (mtu == 0) {
- // 0 imples automatic MTU, where we set overhead as 95 bytes,
- // 80 for WireGuard and the 15 to make sure WireGuard's padding will work.
- networkSettings.tunnelOverheadBytes = 95
+ // 0 imples automatic MTU, where we set overhead as 80 bytes, which is the worst case for WireGuard
+ networkSettings.tunnelOverheadBytes = 80
} else {
networkSettings.mtu = mtu
}
@@ -134,7 +137,6 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
/// Begin the process of stopping the tunnel.
override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) {
os_log("Stopping tunnel", log: OSLog.default, type: .info)
- wgContext?.closeTunnel()
if let handle = wgHandle {
wgTurnOff(handle)
}
@@ -159,99 +161,13 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
}
}
- private func connect(interfaceName: String, settings: String, mtu: UInt16) -> Int32 { // swiftlint:disable:this cyclomatic_complexity
+ private func connect(interfaceName: String, settings: String, fd: Int32) -> Int32 { // swiftlint:disable:this cyclomatic_complexity
return withStringsAsGoStrings(interfaceName, settings) { (nameGoStr, settingsGoStr) -> Int32 in
- return withUnsafeMutablePointer(to: &wgContext) { (wgCtxPtr) -> Int32 in
- return wgTurnOn(nameGoStr, settingsGoStr, mtu, { (wgCtxPtr, buf, len) -> Int in
- autoreleasepool {
- // read_fn: Read from the TUN interface and pass it on to WireGuard
- guard let wgCtxPtr = wgCtxPtr else { return 0 }
- guard let buf = buf else { return 0 }
- let wgContext = wgCtxPtr.bindMemory(to: WireGuardContext.self, capacity: 1).pointee
- var isTunnelClosed = false
- let packet = wgContext.readPacket(isTunnelClosed: &isTunnelClosed)
- if isTunnelClosed { return -1 }
- guard let packetData = packet?.data else { return 0 }
- if packetData.count <= len {
- packetData.copyBytes(to: buf, count: packetData.count)
- return packetData.count
- }
- return 0
- }
- }, { (wgCtxPtr, buf, len) -> Int in
- autoreleasepool {
- // write_fn: Receive packets from WireGuard and write to the TUN interface
- guard let wgCtxPtr = wgCtxPtr else { return 0 }
- guard let buf = buf else { return 0 }
- guard len > 0 else { return 0 }
- let wgContext = wgCtxPtr.bindMemory(to: WireGuardContext.self, capacity: 1).pointee
- let ipVersionBits = (buf[0] & 0xf0) >> 4
- let ipVersion: sa_family_t? = {
- if ipVersionBits == 4 { return sa_family_t(AF_INET) } // IPv4
- if ipVersionBits == 6 { return sa_family_t(AF_INET6) } // IPv6
- return nil
- }()
- guard let protocolFamily = ipVersion else { fatalError("Unknown IP version") }
- let packet = NEPacket(data: Data(bytes: buf, count: len), protocolFamily: protocolFamily)
- var isTunnelClosed = false
- let isWritten = wgContext.writePacket(packet: packet, isTunnelClosed: &isTunnelClosed)
- if isTunnelClosed { return -1 }
- if isWritten {
- return len
- }
- return 0
- }
- },
- wgCtxPtr)
- }
+ return wgTurnOn(nameGoStr, settingsGoStr, fd)
}
}
}
-class WireGuardContext {
- private var packetFlow: NEPacketTunnelFlow
- private var outboundPackets: [NEPacket] = []
- private var isTunnelClosed: Bool = false
- private var readPacketCondition = NSCondition()
-
- init(packetFlow: NEPacketTunnelFlow) {
- self.packetFlow = packetFlow
- }
-
- func closeTunnel() {
- isTunnelClosed = true
- readPacketCondition.signal()
- }
-
- func packetsRead(packets: [NEPacket]) {
- readPacketCondition.lock()
- outboundPackets.append(contentsOf: packets)
- readPacketCondition.unlock()
- readPacketCondition.signal()
- }
-
- func readPacket(isTunnelClosed: inout Bool) -> NEPacket? {
- if outboundPackets.isEmpty {
- readPacketCondition.lock()
- packetFlow.readPacketObjects(completionHandler: packetsRead)
- while outboundPackets.isEmpty && !self.isTunnelClosed {
- readPacketCondition.wait()
- }
- readPacketCondition.unlock()
- }
- isTunnelClosed = self.isTunnelClosed
- if !outboundPackets.isEmpty {
- return outboundPackets.removeFirst()
- }
- return nil
- }
-
- func writePacket(packet: NEPacket, isTunnelClosed: inout Bool) -> Bool {
- isTunnelClosed = self.isTunnelClosed
- return packetFlow.writePacketObjects([packet])
- }
-}
-
private func withStringsAsGoStrings<R>(_ str1: String, _ str2: String, closure: (gostring_t, gostring_t) -> R) -> R {
return str1.withCString { (s1cStr) -> R in
let gstr1 = gostring_t(p: s1cStr, n: str1.utf8.count)