aboutsummaryrefslogtreecommitdiffstats
path: root/src/Network/WireGuard/Internal/Packet.hs
blob: f89c160edff6e993dafd486a644de114822f0d93 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
{-# LANGUAGE RecordWildCards #-}

module Network.WireGuard.Internal.Packet
  ( Packet(..)
  , parsePacket
  , buildPacket
  ) where

import           Control.Monad                       (replicateM_, unless, when)
import qualified Data.ByteString                     as BS
import           Foreign.Storable                    (sizeOf)

import           Data.Serialize

import           Network.WireGuard.Internal.Constant
import           Network.WireGuard.Internal.Data.Types

data Packet = HandshakeInitiation
              { senderIndex      :: !Index
              , encryptedPayload :: !EncryptedPayload
              }
            | HandshakeResponse
              { senderIndex      :: !Index
              , receiverIndex    :: !Index
              , encryptedPayload :: !EncryptedPayload
              }
            | PacketData
              { receiverIndex    :: !Index
              , counter          :: !Counter
              , encryptedPayload :: !EncryptedPayload
              , authTag          :: !AuthTag
              }
    deriving (Show)

parsePacket :: (BS.ByteString -> BS.ByteString) -> Get Packet
parsePacket getMac1 = do
    packetType <- lookAhead getWord8
    case packetType of
        1 -> verifyLength (==handshakeInitiationPacketLength) $ verifyMac getMac1 parseHandshakeInitiation
        2 -> verifyLength (==handshakeResponsePacketLength) $ verifyMac getMac1 parseHandshakeResponse
        4 -> verifyLength (>=packetDataMinimumPacketLength) parsePacketData
        _ -> fail "unknown packet"
  where
    handshakeInitiationPacketLength = 4 + indexSize + keyLength + aeadLength keyLength + aeadLength timestampLength + mac1Length + mac2Length
    handshakeResponsePacketLength   = 4 + indexSize + indexSize + keyLength + aeadLength 0 + mac1Length + mac2Length
    packetDataMinimumPacketLength   = 4 + indexSize + counterSize + aeadLength 0

    indexSize = sizeOf (undefined :: Index)
    counterSize = sizeOf (undefined :: Counter)

parseHandshakeInitiation :: Get Packet
parseHandshakeInitiation = do
    skip 4
    HandshakeInitiation <$> getWord32le <*> (remaining >>= getBytes)

parseHandshakeResponse :: Get Packet
parseHandshakeResponse = do
    skip 4
    HandshakeResponse <$> getWord32le <*> getWord32le <*> (remaining >>= getBytes)

parsePacketData :: Get Packet
parsePacketData = do
    skip 4
    PacketData <$> getWord32le <*> getWord64le <*>
        (remaining >>= getBytes . subtract authLength) <*> getBytes authLength

buildPacket :: (BS.ByteString -> BS.ByteString) -> Putter Packet
buildPacket getMac1 HandshakeInitiation{..} = appendMac getMac1 $ do
    putWord8 1
    replicateM_ 3 (putWord8 0)
    putWord32le senderIndex
    putByteString encryptedPayload

buildPacket getMac1 HandshakeResponse{..} = appendMac getMac1 $ do
    putWord8 2
    replicateM_ 3 (putWord8 0)
    putWord32le senderIndex
    putWord32le receiverIndex
    putByteString encryptedPayload

buildPacket _getMac1 PacketData{..} = do
    putWord8 4
    replicateM_ 3 (putWord8 0)
    putWord32le receiverIndex
    putWord64le counter
    putByteString encryptedPayload
    putByteString authTag

verifyLength :: (Int -> Bool) -> Get a -> Get a
verifyLength check ga = do
    outcome <- check <$> remaining
    unless outcome $ fail "wrong packet length"
    ga

verifyMac :: (BS.ByteString -> BS.ByteString) -> Get Packet -> Get Packet
verifyMac getMac1 ga = do
    bodyLength <- subtract (mac1Length + mac2Length) <$> remaining
    when (bodyLength < 0) $ fail "packet too small"
    expectedMac1 <- getMac1 <$> lookAhead (getBytes bodyLength)
    parsed <- isolate bodyLength ga
    receivedMac1 <- getBytes mac1Length
    when (expectedMac1 /= receivedMac1) $ fail "wrong mac1"
    skip mac2Length
    return parsed

appendMac :: (BS.ByteString -> BS.ByteString) -> Put -> Put
appendMac getMac1 p = do
    -- TODO: find a smart approach to avoid extra ByteString allocation
    let bs = runPut p
    putByteString bs
    putByteString (getMac1 bs)
    replicateM_ mac2Length (putWord8 0)