aboutsummaryrefslogtreecommitdiffstats
path: root/Sources/WireGuardKit/IPAddressRange.swift
diff options
context:
space:
mode:
Diffstat (limited to 'Sources/WireGuardKit/IPAddressRange.swift')
-rw-r--r--Sources/WireGuardKit/IPAddressRange.swift115
1 files changed, 115 insertions, 0 deletions
diff --git a/Sources/WireGuardKit/IPAddressRange.swift b/Sources/WireGuardKit/IPAddressRange.swift
new file mode 100644
index 0000000..60430af
--- /dev/null
+++ b/Sources/WireGuardKit/IPAddressRange.swift
@@ -0,0 +1,115 @@
+// SPDX-License-Identifier: MIT
+// Copyright © 2018-2023 WireGuard LLC. All Rights Reserved.
+
+import Foundation
+import Network
+
+public struct IPAddressRange {
+ public let address: IPAddress
+ public let networkPrefixLength: UInt8
+
+ init(address: IPAddress, networkPrefixLength: UInt8) {
+ self.address = address
+ self.networkPrefixLength = networkPrefixLength
+ }
+}
+
+extension IPAddressRange: Equatable {
+ public static func == (lhs: IPAddressRange, rhs: IPAddressRange) -> Bool {
+ return lhs.address.rawValue == rhs.address.rawValue && lhs.networkPrefixLength == rhs.networkPrefixLength
+ }
+}
+
+extension IPAddressRange: Hashable {
+ public func hash(into hasher: inout Hasher) {
+ hasher.combine(address.rawValue)
+ hasher.combine(networkPrefixLength)
+ }
+}
+
+extension IPAddressRange {
+ public var stringRepresentation: String {
+ return "\(address)/\(networkPrefixLength)"
+ }
+
+ public init?(from string: String) {
+ guard let parsed = IPAddressRange.parseAddressString(string) else { return nil }
+ address = parsed.0
+ networkPrefixLength = parsed.1
+ }
+
+ private static func parseAddressString(_ string: String) -> (IPAddress, UInt8)? {
+ let endOfIPAddress = string.lastIndex(of: "/") ?? string.endIndex
+ let addressString = String(string[string.startIndex ..< endOfIPAddress])
+ let address: IPAddress
+ if let addr = IPv4Address(addressString) {
+ address = addr
+ } else if let addr = IPv6Address(addressString) {
+ address = addr
+ } else {
+ return nil
+ }
+
+ let maxNetworkPrefixLength: UInt8 = address is IPv4Address ? 32 : 128
+ var networkPrefixLength: UInt8
+ if endOfIPAddress < string.endIndex { // "/" was located
+ let indexOfNetworkPrefixLength = string.index(after: endOfIPAddress)
+ guard indexOfNetworkPrefixLength < string.endIndex else { return nil }
+ let networkPrefixLengthSubstring = string[indexOfNetworkPrefixLength ..< string.endIndex]
+ guard let npl = UInt8(networkPrefixLengthSubstring) else { return nil }
+ networkPrefixLength = min(npl, maxNetworkPrefixLength)
+ } else {
+ networkPrefixLength = maxNetworkPrefixLength
+ }
+
+ return (address, networkPrefixLength)
+ }
+
+ public func subnetMask() -> IPAddress {
+ if address is IPv4Address {
+ let mask = networkPrefixLength > 0 ? ~UInt32(0) << (32 - networkPrefixLength) : UInt32(0)
+ let bytes = Data([
+ UInt8(truncatingIfNeeded: mask >> 24),
+ UInt8(truncatingIfNeeded: mask >> 16),
+ UInt8(truncatingIfNeeded: mask >> 8),
+ UInt8(truncatingIfNeeded: mask >> 0)
+ ])
+ return IPv4Address(bytes)!
+ }
+ if address is IPv6Address {
+ var bytes = Data(repeating: 0, count: 16)
+ for i in 0..<Int(networkPrefixLength/8) {
+ bytes[i] = 0xff
+ }
+ let nibble = networkPrefixLength % 32
+ if nibble != 0 {
+ let mask = ~UInt32(0) << (32 - nibble)
+ let i = Int(networkPrefixLength / 32 * 4)
+ bytes[i + 0] = UInt8(truncatingIfNeeded: mask >> 24)
+ bytes[i + 1] = UInt8(truncatingIfNeeded: mask >> 16)
+ bytes[i + 2] = UInt8(truncatingIfNeeded: mask >> 8)
+ bytes[i + 3] = UInt8(truncatingIfNeeded: mask >> 0)
+ }
+ return IPv6Address(bytes)!
+ }
+ fatalError()
+ }
+
+ public func maskedAddress() -> IPAddress {
+ let subnet = subnetMask().rawValue
+ var masked = Data(address.rawValue)
+ if subnet.count != masked.count {
+ fatalError()
+ }
+ for i in 0..<subnet.count {
+ masked[i] &= subnet[i]
+ }
+ if subnet.count == 4 {
+ return IPv4Address(masked)!
+ }
+ if subnet.count == 16 {
+ return IPv6Address(masked)!
+ }
+ fatalError()
+ }
+}