aboutsummaryrefslogtreecommitdiffstats
path: root/src/Network/WireGuard/Internal/Packet.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Network/WireGuard/Internal/Packet.hs')
-rw-r--r--src/Network/WireGuard/Internal/Packet.hs112
1 files changed, 112 insertions, 0 deletions
diff --git a/src/Network/WireGuard/Internal/Packet.hs b/src/Network/WireGuard/Internal/Packet.hs
new file mode 100644
index 0000000..ebc24fc
--- /dev/null
+++ b/src/Network/WireGuard/Internal/Packet.hs
@@ -0,0 +1,112 @@
+{-# LANGUAGE RecordWildCards #-}
+
+module Network.WireGuard.Internal.Packet
+ ( Packet(..)
+ , parsePacket
+ , buildPacket
+ ) where
+
+import Control.Monad (replicateM_, unless, when)
+import qualified Data.ByteString as BS
+import Foreign.Storable (sizeOf)
+
+import Data.Serialize
+
+import Network.WireGuard.Internal.Constant
+import Network.WireGuard.Internal.Types
+
+data Packet = HandshakeInitiation
+ { senderIndex :: !Index
+ , encryptedPayload :: !EncryptedPayload
+ }
+ | HandshakeResponse
+ { senderIndex :: !Index
+ , receiverIndex :: !Index
+ , encryptedPayload :: !EncryptedPayload
+ }
+ | PacketData
+ { receiverIndex :: !Index
+ , counter :: !Counter
+ , encryptedPayload :: !EncryptedPayload
+ , authTag :: !AuthTag
+ }
+ deriving (Show)
+
+parsePacket :: (BS.ByteString -> BS.ByteString) -> Get Packet
+parsePacket getMac1 = do
+ packetType <- lookAhead getWord8
+ case packetType of
+ 1 -> verifyLength (==handshakeInitiationPacketLength) $ verifyMac getMac1 parseHandshakeInitiation
+ 2 -> verifyLength (==handshakeResponsePacketLength) $ verifyMac getMac1 parseHandshakeResponse
+ 4 -> verifyLength (>=packetDataMinimumPacketLength) parsePacketData
+ _ -> fail "unknown packet"
+ where
+ handshakeInitiationPacketLength = 4 + indexSize + keyLength + aeadLength keyLength + aeadLength timestampLength + mac1Length + mac2Length
+ handshakeResponsePacketLength = 4 + indexSize + indexSize + keyLength + aeadLength 0 + mac1Length + mac2Length
+ packetDataMinimumPacketLength = 4 + indexSize + counterSize + aeadLength 0
+
+ indexSize = sizeOf (undefined :: Index)
+ counterSize = sizeOf (undefined :: Counter)
+
+parseHandshakeInitiation :: Get Packet
+parseHandshakeInitiation = do
+ skip 4
+ HandshakeInitiation <$> getWord32le <*> (remaining >>= getBytes)
+
+parseHandshakeResponse :: Get Packet
+parseHandshakeResponse = do
+ skip 4
+ HandshakeResponse <$> getWord32le <*> getWord32le <*> (remaining >>= getBytes)
+
+parsePacketData :: Get Packet
+parsePacketData = do
+ skip 4
+ PacketData <$> getWord32le <*> getWord64le <*>
+ (remaining >>= getBytes . subtract authLength) <*> getBytes authLength
+
+buildPacket :: (BS.ByteString -> BS.ByteString) -> Putter Packet
+buildPacket getMac1 HandshakeInitiation{..} = appendMac getMac1 $ do
+ putWord8 1
+ replicateM_ 3 (putWord8 0)
+ putWord32le senderIndex
+ putByteString encryptedPayload
+
+buildPacket getMac1 HandshakeResponse{..} = appendMac getMac1 $ do
+ putWord8 2
+ replicateM_ 3 (putWord8 0)
+ putWord32le senderIndex
+ putWord32le receiverIndex
+ putByteString encryptedPayload
+
+buildPacket _getMac1 PacketData{..} = do
+ putWord8 4
+ replicateM_ 3 (putWord8 0)
+ putWord32le receiverIndex
+ putWord64le counter
+ putByteString encryptedPayload
+ putByteString authTag
+
+verifyLength :: (Int -> Bool) -> Get a -> Get a
+verifyLength check ga = do
+ outcome <- check <$> remaining
+ unless outcome $ fail "wrong packet length"
+ ga
+
+verifyMac :: (BS.ByteString -> BS.ByteString) -> Get Packet -> Get Packet
+verifyMac getMac1 ga = do
+ bodyLength <- subtract (mac1Length + mac2Length) <$> remaining
+ when (bodyLength < 0) $ fail "packet too small"
+ expectedMac1 <- getMac1 <$> lookAhead (getBytes bodyLength)
+ parsed <- isolate bodyLength ga
+ receivedMac1 <- getBytes mac1Length
+ when (expectedMac1 /= receivedMac1) $ fail "wrong mac1"
+ skip mac2Length
+ return parsed
+
+appendMac :: (BS.ByteString -> BS.ByteString) -> Put -> Put
+appendMac getMac1 p = do
+ -- TODO: find a smart approach to avoid extra ByteString allocation
+ let bs = runPut p
+ putByteString bs
+ putByteString (getMac1 bs)
+ replicateM_ mac2Length (putWord8 0)