diff options
Diffstat (limited to 'src/Network/WireGuard/Internal/Packet.hs')
-rw-r--r-- | src/Network/WireGuard/Internal/Packet.hs | 112 |
1 files changed, 112 insertions, 0 deletions
diff --git a/src/Network/WireGuard/Internal/Packet.hs b/src/Network/WireGuard/Internal/Packet.hs new file mode 100644 index 0000000..ebc24fc --- /dev/null +++ b/src/Network/WireGuard/Internal/Packet.hs @@ -0,0 +1,112 @@ +{-# LANGUAGE RecordWildCards #-} + +module Network.WireGuard.Internal.Packet + ( Packet(..) + , parsePacket + , buildPacket + ) where + +import Control.Monad (replicateM_, unless, when) +import qualified Data.ByteString as BS +import Foreign.Storable (sizeOf) + +import Data.Serialize + +import Network.WireGuard.Internal.Constant +import Network.WireGuard.Internal.Types + +data Packet = HandshakeInitiation + { senderIndex :: !Index + , encryptedPayload :: !EncryptedPayload + } + | HandshakeResponse + { senderIndex :: !Index + , receiverIndex :: !Index + , encryptedPayload :: !EncryptedPayload + } + | PacketData + { receiverIndex :: !Index + , counter :: !Counter + , encryptedPayload :: !EncryptedPayload + , authTag :: !AuthTag + } + deriving (Show) + +parsePacket :: (BS.ByteString -> BS.ByteString) -> Get Packet +parsePacket getMac1 = do + packetType <- lookAhead getWord8 + case packetType of + 1 -> verifyLength (==handshakeInitiationPacketLength) $ verifyMac getMac1 parseHandshakeInitiation + 2 -> verifyLength (==handshakeResponsePacketLength) $ verifyMac getMac1 parseHandshakeResponse + 4 -> verifyLength (>=packetDataMinimumPacketLength) parsePacketData + _ -> fail "unknown packet" + where + handshakeInitiationPacketLength = 4 + indexSize + keyLength + aeadLength keyLength + aeadLength timestampLength + mac1Length + mac2Length + handshakeResponsePacketLength = 4 + indexSize + indexSize + keyLength + aeadLength 0 + mac1Length + mac2Length + packetDataMinimumPacketLength = 4 + indexSize + counterSize + aeadLength 0 + + indexSize = sizeOf (undefined :: Index) + counterSize = sizeOf (undefined :: Counter) + +parseHandshakeInitiation :: Get Packet +parseHandshakeInitiation = do + skip 4 + HandshakeInitiation <$> getWord32le <*> (remaining >>= getBytes) + +parseHandshakeResponse :: Get Packet +parseHandshakeResponse = do + skip 4 + HandshakeResponse <$> getWord32le <*> getWord32le <*> (remaining >>= getBytes) + +parsePacketData :: Get Packet +parsePacketData = do + skip 4 + PacketData <$> getWord32le <*> getWord64le <*> + (remaining >>= getBytes . subtract authLength) <*> getBytes authLength + +buildPacket :: (BS.ByteString -> BS.ByteString) -> Putter Packet +buildPacket getMac1 HandshakeInitiation{..} = appendMac getMac1 $ do + putWord8 1 + replicateM_ 3 (putWord8 0) + putWord32le senderIndex + putByteString encryptedPayload + +buildPacket getMac1 HandshakeResponse{..} = appendMac getMac1 $ do + putWord8 2 + replicateM_ 3 (putWord8 0) + putWord32le senderIndex + putWord32le receiverIndex + putByteString encryptedPayload + +buildPacket _getMac1 PacketData{..} = do + putWord8 4 + replicateM_ 3 (putWord8 0) + putWord32le receiverIndex + putWord64le counter + putByteString encryptedPayload + putByteString authTag + +verifyLength :: (Int -> Bool) -> Get a -> Get a +verifyLength check ga = do + outcome <- check <$> remaining + unless outcome $ fail "wrong packet length" + ga + +verifyMac :: (BS.ByteString -> BS.ByteString) -> Get Packet -> Get Packet +verifyMac getMac1 ga = do + bodyLength <- subtract (mac1Length + mac2Length) <$> remaining + when (bodyLength < 0) $ fail "packet too small" + expectedMac1 <- getMac1 <$> lookAhead (getBytes bodyLength) + parsed <- isolate bodyLength ga + receivedMac1 <- getBytes mac1Length + when (expectedMac1 /= receivedMac1) $ fail "wrong mac1" + skip mac2Length + return parsed + +appendMac :: (BS.ByteString -> BS.ByteString) -> Put -> Put +appendMac getMac1 p = do + -- TODO: find a smart approach to avoid extra ByteString allocation + let bs = runPut p + putByteString bs + putByteString (getMac1 bs) + replicateM_ mac2Length (putWord8 0) |