aboutsummaryrefslogtreecommitdiffstats
path: root/src/Network/WireGuard/Core.hs
blob: 39beb788d4ca2dc7f9be9f0a5d9f0d9b8036b2d2 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
{-# LANGUAGE RecordWildCards #-}

module Network.WireGuard.Core
 ( runCore
 ) where

import           Control.Concurrent                     (getNumCapabilities,
                                                         threadDelay)
import           Control.Concurrent.Async               (wait, withAsync)
import           Control.Monad                          (forM_, forever, unless,
                                                         void, when)
import           Control.Monad.IO.Class                 (liftIO)
import           Control.Monad.STM                      (atomically)
import           Control.Monad.Trans.Except             (ExceptT, runExceptT,
                                                         throwE)
import           Crypto.Noise                           (HandshakeRole (..))
import           Crypto.Noise.DH                        (dhGenKey, dhPubEq,
                                                         dhPubToBytes)
import qualified Data.ByteArray                         as BA
import qualified Data.ByteString                        as BS
import qualified Data.HashMap.Strict                    as HM
import           Data.IP                                (makeAddrRange)
import qualified Data.IP.RouteTable                     as RT
import           Data.Maybe                             (fromMaybe, isJust,
                                                         isNothing)
import           Data.Serialize                         (putWord32be,
                                                         putWord64be, runGet,
                                                         runPut)
import           Foreign.C.Types                        (CTime (..))
import           Network.Socket                         (SockAddr)
import           System.IO                              (hPrint, stderr)
import           System.Posix.Time                      (epochTime)
import           System.Random                          (randomIO)

import           Control.Concurrent.STM.TVar
import           Crypto.Hash.BLAKE2.BLAKE2s

import           Network.WireGuard.Internal.Constant
import           Network.WireGuard.Internal.IPPacket
import           Network.WireGuard.Internal.Noise
import           Network.WireGuard.Internal.Packet
import           Network.WireGuard.Internal.PacketQueue
import           Network.WireGuard.Internal.State
import           Network.WireGuard.Internal.Data.Types
import           Network.WireGuard.Internal.Util

runCore :: Device
        -> PacketQueue (Time, TunPacket) -> PacketQueue TunPacket
        -> PacketQueue UdpPacket -> PacketQueue UdpPacket
        -> IO ()
runCore device readTunChan writeTunChan readUdpChan writeUdpChan = do
    threads <- getNumCapabilities
    loop threads []
  where
    heartbeatLoop = forever $ ignoreSyncExceptions $ do
        withJust (readTVarIO (localKey device)) $ \key ->
            runHeartbeat device key writeUdpChan
        -- TODO: use accurate timer
        threadDelay heartbeatWaitTime

    loop 0 asyncs =
        withAsync heartbeatLoop $ \ht ->
            mapM_ wait asyncs >> wait ht
    loop x asyncs =
        withAsync (retryWithBackoff $ handleReadTun device readTunChan writeUdpChan) $ \rt ->
        withAsync (retryWithBackoff $ handleReadUdp device readUdpChan writeTunChan writeUdpChan) $ \ru ->
            loop (x-1) (rt:ru:asyncs)

handleReadTun :: Device -> PacketQueue (Time, TunPacket) -> PacketQueue UdpPacket -> IO ()
handleReadTun device readTunChan writeUdpChan = forever $ do
    earliestToProcess <- (`addTime` (-handshakeRetryTime)) <$> epochTime
    (_, tunPacket) <- dropUntilM ((>=earliestToProcess).fst) $ popPacketQueue readTunChan
    res <- runExceptT $ processTunPacket device writeUdpChan tunPacket
    case res of
        Right udpPacket -> pushPacketQueue writeUdpChan udpPacket
        Left err        -> hPrint stderr err -- TODO: proper logging

handleReadUdp :: Device -> PacketQueue UdpPacket -> PacketQueue TunPacket
              -> PacketQueue UdpPacket
              -> IO ()
handleReadUdp device readUdpChan writeTunChan writeUdpChan = forever $ do
    udpPacket <- popPacketQueue readUdpChan
    res <- runExceptT $ processUdpPacket device udpPacket
    case res of
        Left err      -> hPrint stderr err -- TODO: proper logging
        Right mpacket -> case mpacket of
            Just (Right tunp) -> pushPacketQueue writeTunChan tunp
            Just (Left  udpp) -> pushPacketQueue writeUdpChan udpp
            Nothing           -> return ()

processTunPacket :: Device -> PacketQueue UdpPacket -> TunPacket
                 -> ExceptT WireGuardError IO UdpPacket
processTunPacket device@Device{..} writeUdpChan packet = do
    key <- assertJust DeviceNotReadyError $ liftIO (readTVarIO localKey)
    psk <- liftIO (readTVarIO presharedKey)
    parsedPacket <- liftIO $ parseIPPacket packet
    peer <- assertJust DestinationNotReachableError $ case parsedPacket of
        InvalidIPPacket    -> throwE InvalidIPPacketError
        IPv4Packet _ dest4 -> RT.lookup (makeAddrRange dest4 32)
            <$> liftIO (readTVarIO routeTable4)
        IPv6Packet _ dest6 -> RT.lookup (makeAddrRange dest6 128)
            <$> liftIO (readTVarIO routeTable6)
    msession <- liftIO (getSession peer)
    session <- case msession of
        Just session -> return session
        Nothing      -> do
            now0 <- liftIO epochTime
            endp0 <- assertJust EndPointUnknownError $ liftIO $ readTVarIO (endPoint peer)
            liftIO $ void $ checkAndTryInitiateHandshake device key psk writeUdpChan peer endp0 now0
            assertJust OutdatedPacketError $ liftIO $ waitForSession (handshakeRetryTime * 1000000) peer
    nonce <- liftIO $ atomically $ nextNonce session
    let (msg, authtag) = encryptMessage (sessionKey session) nonce packet
        encrypted = runPut $ buildPacket (error "internal error") $
            PacketData (theirIndex session) nonce msg authtag
    now <- liftIO epochTime
    endp <- assertJust EndPointUnknownError $ liftIO $ readTVarIO (endPoint peer)
    when (now >= renewTime session) $ liftIO $
        void $ checkAndTryInitiateHandshake device key psk writeUdpChan peer endp now
    liftIO $ atomically $ modifyTVar' (transferredBytes peer) (+fromIntegral (BA.length packet))
    liftIO $ atomically $ writeTVar (lastTransferTime peer) now
    return (encrypted, endp)

processUdpPacket :: Device -> UdpPacket
                 -> ExceptT WireGuardError IO (Maybe (Either UdpPacket TunPacket))
processUdpPacket device@Device{..} (packet, sock) = do
    key <- assertJust DeviceNotReadyError $ liftIO (readTVarIO localKey)
    psk <- liftIO (readTVarIO presharedKey)
    let mp = runGet (parsePacket (getMac1 (snd key) psk)) packet
    case mp of
        Left errMsg        -> throwE (InvalidWGPacketError errMsg)
        Right parsedPacket -> processPacket device key psk sock parsedPacket

processPacket :: Device -> KeyPair -> Maybe PresharedKey -> SockAddr -> Packet
              -> ExceptT WireGuardError IO (Maybe (Either UdpPacket TunPacket))
processPacket device@Device{..} key psk sock HandshakeInitiation{..} = do
    ekey <- liftIO dhGenKey
    let state0 = newNoiseState key psk ekey Nothing ResponderRole
        outcome = recvFirstMessageAndReply state0 encryptedPayload mempty
    case outcome of
        Left err                                   -> throwE (NoiseError err)
        Right (reply, decryptedPayload, rpub, sks) -> do
            when (BA.length decryptedPayload /= timestampLength) $
                throwE $ InvalidWGPacketError "timestamp expected"
            peer <- assertJust RemotePeerNotFoundError $
                HM.lookup (getPeerId rpub) <$> liftIO (readTVarIO peers)
            notReplayAttack <- liftIO $ atomically $ updateTai64n peer (BA.convert decryptedPayload)
            unless notReplayAttack $ throwE HandshakeInitiationReplayError
            now <- liftIO epochTime
            seed <- liftIO randomIO
            ourindex <- liftIO $ atomically $ do
                ourindex <- acquireEmptyIndex device peer seed
                void $ eraseResponderWait device peer Nothing
                let rwait = ResponderWait ourindex senderIndex
                        (addTime now handshakeStopTime) sks
                writeTVar (responderWait peer) (Just rwait)
                return ourindex
            let responsePacket = runPut $ buildPacket (getMac1 rpub psk) $
                    HandshakeResponse ourindex senderIndex reply
            return (Just (Left (responsePacket, sock)))

processPacket device@Device{..} _key _psk sock HandshakeResponse{..} = do
    peer <- assertJust UnknownIndexError $
        HM.lookup receiverIndex <$> liftIO (readTVarIO indexMap)
    iwait <- assertJust OutdatedPacketError $ liftIO (readTVarIO (initiatorWait peer))
    when (initOurIndex iwait /= receiverIndex) $ throwE OutdatedPacketError
    let state1 = initNoise iwait
        outcome = recvSecondMessage state1 encryptedPayload
    case outcome of
        Left err                      -> throwE (NoiseError err)
        Right (decryptedPayload, sks) -> do
            now <- liftIO epochTime
            newCounter <- liftIO $ atomically $ newTVar 0
            let newsession = Session receiverIndex senderIndex sks
                    (addTime now sessionRenewTime)
                    (addTime now sessionExpireTime)
                    newCounter
            when (BA.length decryptedPayload /= 0) $
                throwE $ InvalidWGPacketError "empty payload expected"
            succeeded <- liftIO $ atomically $ do
                erased <- eraseInitiatorWait device peer (Just receiverIndex)
                when erased $ do
                    addSession device peer newsession
                    writeTVar (lastHandshakeTime peer) (Just now)
                return erased
            unless succeeded $ throwE OutdatedPacketError
            liftIO $ atomically $ updateEndPoint peer sock
            return Nothing

processPacket device@Device{..} _key _psk sock PacketData{..} = do
    peer <- assertJust UnknownIndexError $
        HM.lookup receiverIndex <$> liftIO (readTVarIO indexMap)
    outcome <- liftIO $ atomically $ findSession peer receiverIndex
    now <- liftIO epochTime
    (isFromResponderWait, session) <- case outcome of
        Nothing                       -> throwE OutdatedPacketError
        Just (Right session)          -> return (False, session)
        Just (Left ResponderWait{..}) -> do
            newCounter <- liftIO $ atomically $ newTVar 0
            let newsession = Session respOurIndex respTheirIndex respSessionKey
                    (addTime now (sessionRenewTime + 2 * handshakeRetryTime))
                    (addTime now sessionExpireTime)
                    newCounter
            return (True, newsession)
    case decryptMessage (sessionKey session) counter (encryptedPayload, authTag) of
        Nothing               -> throwE DecryptFailureError
        Just decryptedPayload -> do
            when isFromResponderWait $ liftIO $ atomically $ do
                erased <- eraseResponderWait device peer (Just receiverIndex)
                when erased $ do
                    addSession device peer session
                    writeTVar (lastHandshakeTime peer) (Just now)
            liftIO $ atomically $ updateEndPoint peer sock
            if BA.length decryptedPayload /= 0
              then do
                parsedPacket <- liftIO $ parseIPPacket decryptedPayload
                case parsedPacket of
                    InvalidIPPacket   -> throwE InvalidIPPacketError
                    IPv4Packet src4 _ -> do
                        peer' <- assertJust SourceAddrBlockedError $
                            RT.lookup (makeAddrRange src4 32) <$> liftIO (readTVarIO routeTable4)
                        unless (remotePub peer `dhPubEq` remotePub peer') $ throwE SourceAddrBlockedError
                    IPv6Packet src6 _ -> do
                        peer' <- assertJust SourceAddrBlockedError $
                            RT.lookup (makeAddrRange src6 128) <$> liftIO (readTVarIO routeTable6)
                        unless (remotePub peer `dhPubEq` remotePub peer') $ throwE SourceAddrBlockedError
                liftIO $ atomically $ writeTVar (lastReceiveTime peer) now
                liftIO $ atomically $ modifyTVar' (receivedBytes peer) (+fromIntegral (BA.length decryptedPayload))
              else 
                liftIO $ atomically $ writeTVar (lastKeepaliveTime peer) now
            return (Just (Right decryptedPayload))

runHeartbeat :: Device -> KeyPair -> PacketQueue UdpPacket -> IO ()
runHeartbeat device key chan = do
    psk <- readTVarIO (presharedKey device)
    now <- epochTime
    peers' <- readTVarIO (peers device)
    forM_ peers' $ \peer -> do
        reinitiate <- atomically $ do
            miwait <- readTVar (initiatorWait peer)
            case miwait of
                Just iwait | now >= initStopTime iwait -> do
                    void $ eraseInitiatorWait device peer Nothing
                    return Nothing
                Just iwait | now >= initRetryTime iwait -> do
                    void $ eraseInitiatorWait device peer Nothing
                    return (Just (initStopTime iwait))
                _ -> return Nothing
        when (isJust reinitiate) $ withJust (readTVarIO (endPoint peer)) $ \endp ->
            void $ tryInitiateHandshakeIfEmpty device key psk chan peer endp reinitiate
        atomically $ withJust (readTVar (responderWait peer)) $ \rwait ->
            when (now >= respStopTime rwait) $ void $ eraseResponderWait device peer Nothing
        atomically $ filterSessions device peer ((now<).expireTime)
        lastrecv <- readTVarIO (lastReceiveTime peer)
        lastsent <- readTVarIO (lastTransferTime peer)
        lastkeep <- readTVarIO (lastKeepaliveTime peer)
        when (lastsent < lastrecv && lastrecv <= addTime now (-sessionKeepaliveTime)) $ do
            atomically $ writeTVar (lastTransferTime peer) now
            atomically $ writeTVar (lastReceiveTime peer) now
            withJust (readTVarIO (endPoint peer)) $ \endp ->
                withJust (getSession peer) $ \session -> do
                    nonce <- atomically $ nextNonce session
                    let (msg, authtag) = encryptMessage (sessionKey session) nonce mempty
                        keepalivePacket = runPut $ buildPacket (error "internal error") $
                            PacketData (theirIndex session) nonce msg authtag
                    pushPacketQueue chan (keepalivePacket, endp)
        when (lastrecv < lastsent && lastkeep < lastsent && lastsent <= addTime now (-(sessionKeepaliveTime + handshakeRetryTime))) $ do
            atomically $ writeTVar (lastTransferTime peer) now
            atomically $ writeTVar (lastReceiveTime peer) now
            withJust (readTVarIO (endPoint peer)) $ \endp ->
                void $ checkAndTryInitiateHandshake device key psk chan peer endp now

checkAndTryInitiateHandshake :: Device -> KeyPair -> Maybe PresharedKey
                             -> PacketQueue UdpPacket -> Peer -> SockAddr -> Time
                             -> IO Bool
checkAndTryInitiateHandshake device key psk chan peer@Peer{..} endp now = do
    initiated <- readAndVerifyStopTime initStopTime initiatorWait (eraseInitiatorWait device peer Nothing)
    responded <- readAndVerifyStopTime respStopTime responderWait (eraseResponderWait device peer Nothing)
    if initiated || responded
      then return False
      else tryInitiateHandshakeIfEmpty device key psk chan peer endp Nothing
  where
    readAndVerifyStopTime getStopTime tvar erase = atomically $ do
        ma <- readTVar tvar
        case ma of
            Just a  | now > getStopTime a -> erase >> return False
            Just _  -> return True
            Nothing -> return False


tryInitiateHandshakeIfEmpty :: Device -> KeyPair -> Maybe PresharedKey
                            -> PacketQueue UdpPacket -> Peer -> SockAddr -> Maybe Time
                            -> IO Bool
tryInitiateHandshakeIfEmpty device key psk chan peer@Peer{..} endp stopTime = do
    ekey <- dhGenKey
    now <- epochTime
    seed <- randomIO
    let state0 = newNoiseState key psk ekey (Just remotePub) InitiatorRole
        Right (payload, state1) = sendFirstMessage state0 timestamp
        timestamp = BA.convert (genTai64n now)
    mpacket <- atomically $ do
        isEmpty <- isNothing <$> readTVar initiatorWait
        if isEmpty
          then do
            index <- acquireEmptyIndex device peer seed
            let iwait = InitiatorWait index
                    (addTime now handshakeRetryTime)
                    (fromMaybe (addTime now handshakeStopTime) stopTime)
                    state1
            writeTVar initiatorWait (Just iwait)
            let packet = runPut $ buildPacket (getMac1 remotePub psk) $
                    HandshakeInitiation index payload
            return (Just packet)
          else return Nothing
    case mpacket of
        Just packet -> pushPacketQueue chan (packet, endp) >> return True
        Nothing     -> return False

genTai64n :: Time -> TAI64n
genTai64n (CTime now) = runPut $ do
    putWord64be (fromIntegral now + 4611686018427387914)
    putWord32be 0

addTime :: Time -> Int -> Time
addTime (CTime now) secs = CTime (now + fromIntegral secs)

getMac1 :: PublicKey -> Maybe PresharedKey -> BS.ByteString -> BS.ByteString
getMac1 pub mpsk payload =
    finalize mac1Length $ update payload $ update (BA.convert (dhPubToBytes pub)) $
        case mpsk of
            Nothing  -> initialize mac1Length
            Just psk -> initialize' mac1Length (BA.convert psk)

assertJust :: Monad m => e -> ExceptT e m (Maybe a) -> ExceptT e m a
assertJust err ma = do
    res <- ma
    case res of
        Just a  -> return a
        Nothing -> throwE err