blob: f89c160edff6e993dafd486a644de114822f0d93 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
|
{-# LANGUAGE RecordWildCards #-}
module Network.WireGuard.Internal.Packet
( Packet(..)
, parsePacket
, buildPacket
) where
import Control.Monad (replicateM_, unless, when)
import qualified Data.ByteString as BS
import Foreign.Storable (sizeOf)
import Data.Serialize
import Network.WireGuard.Internal.Constant
import Network.WireGuard.Internal.Data.Types
data Packet = HandshakeInitiation
{ senderIndex :: !Index
, encryptedPayload :: !EncryptedPayload
}
| HandshakeResponse
{ senderIndex :: !Index
, receiverIndex :: !Index
, encryptedPayload :: !EncryptedPayload
}
| PacketData
{ receiverIndex :: !Index
, counter :: !Counter
, encryptedPayload :: !EncryptedPayload
, authTag :: !AuthTag
}
deriving (Show)
parsePacket :: (BS.ByteString -> BS.ByteString) -> Get Packet
parsePacket getMac1 = do
packetType <- lookAhead getWord8
case packetType of
1 -> verifyLength (==handshakeInitiationPacketLength) $ verifyMac getMac1 parseHandshakeInitiation
2 -> verifyLength (==handshakeResponsePacketLength) $ verifyMac getMac1 parseHandshakeResponse
4 -> verifyLength (>=packetDataMinimumPacketLength) parsePacketData
_ -> fail "unknown packet"
where
handshakeInitiationPacketLength = 4 + indexSize + keyLength + aeadLength keyLength + aeadLength timestampLength + mac1Length + mac2Length
handshakeResponsePacketLength = 4 + indexSize + indexSize + keyLength + aeadLength 0 + mac1Length + mac2Length
packetDataMinimumPacketLength = 4 + indexSize + counterSize + aeadLength 0
indexSize = sizeOf (undefined :: Index)
counterSize = sizeOf (undefined :: Counter)
parseHandshakeInitiation :: Get Packet
parseHandshakeInitiation = do
skip 4
HandshakeInitiation <$> getWord32le <*> (remaining >>= getBytes)
parseHandshakeResponse :: Get Packet
parseHandshakeResponse = do
skip 4
HandshakeResponse <$> getWord32le <*> getWord32le <*> (remaining >>= getBytes)
parsePacketData :: Get Packet
parsePacketData = do
skip 4
PacketData <$> getWord32le <*> getWord64le <*>
(remaining >>= getBytes . subtract authLength) <*> getBytes authLength
buildPacket :: (BS.ByteString -> BS.ByteString) -> Putter Packet
buildPacket getMac1 HandshakeInitiation{..} = appendMac getMac1 $ do
putWord8 1
replicateM_ 3 (putWord8 0)
putWord32le senderIndex
putByteString encryptedPayload
buildPacket getMac1 HandshakeResponse{..} = appendMac getMac1 $ do
putWord8 2
replicateM_ 3 (putWord8 0)
putWord32le senderIndex
putWord32le receiverIndex
putByteString encryptedPayload
buildPacket _getMac1 PacketData{..} = do
putWord8 4
replicateM_ 3 (putWord8 0)
putWord32le receiverIndex
putWord64le counter
putByteString encryptedPayload
putByteString authTag
verifyLength :: (Int -> Bool) -> Get a -> Get a
verifyLength check ga = do
outcome <- check <$> remaining
unless outcome $ fail "wrong packet length"
ga
verifyMac :: (BS.ByteString -> BS.ByteString) -> Get Packet -> Get Packet
verifyMac getMac1 ga = do
bodyLength <- subtract (mac1Length + mac2Length) <$> remaining
when (bodyLength < 0) $ fail "packet too small"
expectedMac1 <- getMac1 <$> lookAhead (getBytes bodyLength)
parsed <- isolate bodyLength ga
receivedMac1 <- getBytes mac1Length
when (expectedMac1 /= receivedMac1) $ fail "wrong mac1"
skip mac2Length
return parsed
appendMac :: (BS.ByteString -> BS.ByteString) -> Put -> Put
appendMac getMac1 p = do
-- TODO: find a smart approach to avoid extra ByteString allocation
let bs = runPut p
putByteString bs
putByteString (getMac1 bs)
replicateM_ mac2Length (putWord8 0)
|