1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
|
{-# LANGUAGE RecordWildCards #-}
module Network.WireGuard.Core
( runCore
) where
import Control.Concurrent (getNumCapabilities,
threadDelay)
import Control.Concurrent.Async (wait, withAsync)
import Control.Monad (forM_, forever, unless,
void, when)
import Control.Monad.IO.Class (liftIO)
import Control.Monad.STM (atomically)
import Control.Monad.Trans.Except (ExceptT, runExceptT,
throwE)
import Crypto.Noise (HandshakeRole (..))
import Crypto.Noise.DH (dhGenKey, dhPubEq,
dhPubToBytes)
import qualified Data.ByteArray as BA
import qualified Data.ByteString as BS
import qualified Data.HashMap.Strict as HM
import Data.IP (makeAddrRange)
import qualified Data.IP.RouteTable as RT
import Data.Maybe (fromMaybe, isJust,
isNothing)
import Data.Serialize (putWord32be,
putWord64be, runGet,
runPut)
import Foreign.C.Types (CTime (..))
import Network.Socket (SockAddr)
import System.IO (hPutStrLn, stderr)
import System.Posix.Time (epochTime)
import System.Random (randomIO)
import Control.Concurrent.STM.TVar
import Crypto.Hash.BLAKE2.BLAKE2s
import Network.WireGuard.Internal.Constant
import Network.WireGuard.Internal.IPPacket
import Network.WireGuard.Internal.Noise
import Network.WireGuard.Internal.Packet
import Network.WireGuard.Internal.PacketQueue
import Network.WireGuard.Internal.State
import Network.WireGuard.Internal.Types
import Network.WireGuard.Internal.Util
runCore :: Device
-> PacketQueue TunPacket -> PacketQueue TunPacket
-> PacketQueue UdpPacket -> PacketQueue UdpPacket
-> IO ()
runCore device readTunChan writeTunChan readUdpChan writeUdpChan = do
threads <- getNumCapabilities
loop threads []
where
heartbeatLoop = forever $ ignoreSyncExceptions $ do
withJust (readTVarIO (localKey device)) $ \key ->
runHeartbeat device key writeUdpChan
-- TODO: use accurate timer
threadDelay heartbeatWaitTime
loop 0 asyncs =
withAsync heartbeatLoop $ \ht ->
mapM_ wait asyncs >> wait ht
loop x asyncs =
withAsync (retryWithBackoff $ handleReadTun device readTunChan writeUdpChan) $ \rt ->
withAsync (retryWithBackoff $ handleReadUdp device readUdpChan writeTunChan writeUdpChan) $ \ru ->
loop (x-1) (rt:ru:asyncs)
handleReadTun :: Device -> PacketQueue TunPacket -> PacketQueue UdpPacket -> IO ()
handleReadTun device readTunChan writeUdpChan = forever $ do
tunPacket <- atomically $ popPacketQueue readTunChan
res <- runExceptT $ processTunPacket device writeUdpChan tunPacket
case res of
Right udpPacket -> atomically $ pushPacketQueue writeUdpChan udpPacket
Left err -> hPutStrLn stderr (show err) -- TODO: proper logging
handleReadUdp :: Device -> PacketQueue UdpPacket -> PacketQueue TunPacket
-> PacketQueue UdpPacket
-> IO ()
handleReadUdp device readUdpChan writeTunChan writeUdpChan = forever $ do
udpPacket <- atomically $ popPacketQueue readUdpChan
res <- runExceptT $ processUdpPacket device udpPacket
case res of
Left err -> hPutStrLn stderr (show err) -- TODO: proper logging
Right mpacket -> case mpacket of
Just (Right tunp) -> atomically $ pushPacketQueue writeTunChan tunp
Just (Left udpp) -> atomically $ pushPacketQueue writeUdpChan udpp
Nothing -> return ()
processTunPacket :: Device -> PacketQueue UdpPacket -> TunPacket
-> ExceptT WireGuardError IO UdpPacket
processTunPacket device@Device{..} writeUdpChan packet = do
key <- assertJust DeviceNotReadyError $ liftIO (readTVarIO localKey)
psk <- liftIO (readTVarIO presharedKey)
parsedPacket <- liftIO $ parseIPPacket packet
peer <- assertJust DestinationNotReachableError $ case parsedPacket of
InvalidIPPacket -> throwE InvalidIPPacketError
IPv4Packet _ dest4 -> RT.lookup (makeAddrRange dest4 32)
<$> liftIO (readTVarIO routeTable4)
IPv6Packet _ dest6 -> RT.lookup (makeAddrRange dest6 128)
<$> liftIO (readTVarIO routeTable6)
msession <- liftIO (getSession peer)
session <- case msession of
Just session -> return session
Nothing -> do
now0 <- liftIO epochTime
endp0 <- assertJust EndPointUnknownError $ liftIO $ readTVarIO (endPoint peer)
liftIO $ void $ checkAndTryInitiateHandshake device key psk writeUdpChan peer endp0 now0
liftIO $ atomically $ waitForSession peer
nonce <- liftIO $ atomically $ nextNonce session
let (msg, authtag) = encryptMessage (sessionKey session) nonce packet
encrypted = runPut $ buildPacket (error "internal error") $
PacketData (theirIndex session) nonce msg authtag
now <- liftIO epochTime
endp <- assertJust EndPointUnknownError $ liftIO $ readTVarIO (endPoint peer)
when (now >= renewTime session) $ liftIO $
void $ checkAndTryInitiateHandshake device key psk writeUdpChan peer endp now
liftIO $ atomically $ modifyTVar' (transferredBytes peer) (+fromIntegral (BA.length packet))
liftIO $ atomically $ writeTVar (lastTransferTime peer) now
return (encrypted, endp)
processUdpPacket :: Device -> UdpPacket
-> ExceptT WireGuardError IO (Maybe (Either UdpPacket TunPacket))
processUdpPacket device@Device{..} (packet, sock) = do
key <- assertJust DeviceNotReadyError $ liftIO (readTVarIO localKey)
psk <- liftIO (readTVarIO presharedKey)
let mp = runGet (parsePacket (getMac1 (snd key) psk)) packet
case mp of
Left errMsg -> throwE (InvalidWGPacketError errMsg)
Right parsedPacket -> processPacket device key psk sock parsedPacket
processPacket :: Device -> KeyPair -> Maybe PresharedKey -> SockAddr -> Packet
-> ExceptT WireGuardError IO (Maybe (Either UdpPacket TunPacket))
processPacket device@Device{..} key psk sock HandshakeInitiation{..} = do
ekey <- liftIO dhGenKey
let state0 = newNoiseState key psk ekey Nothing ResponderRole
outcome = recvFirstMessageAndReply state0 encryptedPayload mempty
case outcome of
Left err -> throwE (NoiseError err)
Right (reply, decryptedPayload, rpub, sks) -> do
when (BA.length decryptedPayload /= timestampLength) $
throwE $ InvalidWGPacketError "timestamp expected"
peer <- assertJust RemotePeerNotFoundError $
HM.lookup (getPeerId rpub) <$> liftIO (readTVarIO peers)
notReplayAttack <- liftIO $ atomically $ updateTai64n peer (BA.convert decryptedPayload)
unless notReplayAttack $ throwE HandshakeInitiationReplayError
now <- liftIO epochTime
seed <- liftIO randomIO
ourindex <- liftIO $ atomically $ do
ourindex <- acquireEmptyIndex device peer seed
void $ eraseResponderWait device peer Nothing
let rwait = ResponderWait ourindex senderIndex
(addTime now handshakeStopTime) sks
writeTVar (responderWait peer) (Just rwait)
return ourindex
let responsePacket = runPut $ buildPacket (getMac1 rpub psk) $
HandshakeResponse ourindex senderIndex reply
return (Just (Left (responsePacket, sock)))
processPacket device@Device{..} _key _psk sock HandshakeResponse{..} = do
peer <- assertJust UnknownIndexError $
HM.lookup receiverIndex <$> liftIO (readTVarIO indexMap)
iwait <- assertJust OutdatedPacketError $ liftIO (readTVarIO (initiatorWait peer))
when (initOurIndex iwait /= receiverIndex) $ throwE OutdatedPacketError
let state1 = initNoise iwait
outcome = recvSecondMessage state1 encryptedPayload
case outcome of
Left err -> throwE (NoiseError err)
Right (decryptedPayload, rpub, sks) -> do
now <- liftIO epochTime
newCounter <- liftIO $ atomically $ newTVar 0
let newsession = Session receiverIndex senderIndex sks
(addTime now sessionRenewTime)
(addTime now sessionExpireTime)
newCounter
when (BA.length decryptedPayload /= 0) $
throwE $ InvalidWGPacketError "empty payload expected"
unless (rpub `dhPubEq` remotePub peer) $ throwE RemotePeerNotFoundError
succeeded <- liftIO $ atomically $ do
erased <- eraseInitiatorWait device peer (Just receiverIndex)
when erased $ do
addSession device peer newsession
writeTVar (lastHandshakeTime peer) (Just now)
return erased
unless succeeded $ throwE OutdatedPacketError
liftIO $ atomically $ updateEndPoint peer sock
return Nothing
processPacket device@Device{..} _key _psk sock PacketData{..} = do
peer <- assertJust UnknownIndexError $
HM.lookup receiverIndex <$> liftIO (readTVarIO indexMap)
outcome <- liftIO $ atomically $ findSession peer receiverIndex
now <- liftIO epochTime
(isFromResponderWait, session) <- case outcome of
Nothing -> throwE OutdatedPacketError
Just (Right session) -> return (False, session)
Just (Left ResponderWait{..}) -> do
newCounter <- liftIO $ atomically $ newTVar 0
let newsession = Session respOurIndex respTheirIndex respSessionKey
(addTime now (sessionRenewTime + 2 * handshakeRetryTime))
(addTime now sessionExpireTime)
newCounter
return (True, newsession)
case decryptMessage (sessionKey session) counter (encryptedPayload, authTag) of
Nothing -> throwE DecryptFailureError
Just decryptedPayload -> do
when isFromResponderWait $ liftIO $ atomically $ do
erased <- eraseResponderWait device peer (Just receiverIndex)
when erased $ do
addSession device peer session
writeTVar (lastHandshakeTime peer) (Just now)
liftIO $ atomically $ updateEndPoint peer sock
if BA.length decryptedPayload /= 0
then do
parsedPacket <- liftIO $ parseIPPacket decryptedPayload
case parsedPacket of
InvalidIPPacket -> throwE InvalidIPPacketError
IPv4Packet src4 _ -> do
peer' <- assertJust SourceAddrBlockedError $
RT.lookup (makeAddrRange src4 32) <$> liftIO (readTVarIO routeTable4)
when (remotePub peer /= remotePub peer') $ throwE SourceAddrBlockedError
IPv6Packet src6 _ -> do
peer' <- assertJust SourceAddrBlockedError $
RT.lookup (makeAddrRange src6 128) <$> liftIO (readTVarIO routeTable6)
when (remotePub peer /= remotePub peer') $ throwE SourceAddrBlockedError
liftIO $ atomically $ writeTVar (lastReceiveTime peer) now
liftIO $ atomically $ modifyTVar' (receivedBytes peer) (+fromIntegral (BA.length decryptedPayload))
else do
liftIO $ atomically $ writeTVar (lastKeepaliveTime peer) now
return (Just (Right decryptedPayload))
runHeartbeat :: Device -> KeyPair -> PacketQueue UdpPacket -> IO ()
runHeartbeat device key chan = do
psk <- readTVarIO (presharedKey device)
now <- epochTime
peers' <- readTVarIO (peers device)
forM_ peers' $ \peer -> do
reinitiate <- atomically $ do
miwait <- readTVar (initiatorWait peer)
case miwait of
Just iwait | now >= initStopTime iwait -> do
void $ eraseInitiatorWait device peer Nothing
return Nothing
Just iwait | now >= initRetryTime iwait -> do
void $ eraseInitiatorWait device peer Nothing
return (Just (initStopTime iwait))
_ -> return Nothing
when (isJust reinitiate) $ withJust (readTVarIO (endPoint peer)) $ \endp ->
void $ tryInitiateHandshakeIfEmpty device key psk chan peer endp reinitiate
atomically $ withJust (readTVar (responderWait peer)) $ \rwait ->
when (now >= respStopTime rwait) $ void $ eraseResponderWait device peer Nothing
atomically $ filterSessions device peer ((now<).expireTime)
lastrecv <- readTVarIO (lastReceiveTime peer)
lastsent <- readTVarIO (lastTransferTime peer)
lastkeep <- readTVarIO (lastKeepaliveTime peer)
when (lastsent < lastrecv && lastrecv <= addTime now (-sessionKeepaliveTime)) $ do
atomically $ writeTVar (lastTransferTime peer) now
atomically $ writeTVar (lastReceiveTime peer) now
withJust (readTVarIO (endPoint peer)) $ \endp ->
withJust (getSession peer) $ \session -> do
nonce <- atomically $ nextNonce session
let (msg, authtag) = encryptMessage (sessionKey session) nonce mempty
keepalivePacket = runPut $ buildPacket (error "internal error") $
PacketData (theirIndex session) nonce msg authtag
atomically $ pushPacketQueue chan (keepalivePacket, endp)
when (lastrecv < lastsent && lastkeep < lastsent && lastsent <= addTime now (-(sessionKeepaliveTime + handshakeRetryTime))) $ do
atomically $ writeTVar (lastTransferTime peer) now
atomically $ writeTVar (lastReceiveTime peer) now
withJust (readTVarIO (endPoint peer)) $ \endp ->
void $ checkAndTryInitiateHandshake device key psk chan peer endp now
checkAndTryInitiateHandshake :: Device -> KeyPair -> Maybe PresharedKey
-> PacketQueue UdpPacket -> Peer -> SockAddr -> Time
-> IO Bool
checkAndTryInitiateHandshake device key psk chan peer@Peer{..} endp now = do
initiated <- readAndVerifyStopTime initStopTime initiatorWait (eraseInitiatorWait device peer Nothing)
responded <- readAndVerifyStopTime respStopTime responderWait (eraseResponderWait device peer Nothing)
if initiated || responded
then return False
else tryInitiateHandshakeIfEmpty device key psk chan peer endp Nothing
where
readAndVerifyStopTime getStopTime tvar erase = atomically $ do
ma <- readTVar tvar
case ma of
Just a | now > getStopTime a -> erase >> return False
Just _ -> return True
Nothing -> return False
tryInitiateHandshakeIfEmpty :: Device -> KeyPair -> Maybe PresharedKey
-> PacketQueue UdpPacket -> Peer -> SockAddr -> Maybe Time
-> IO Bool
tryInitiateHandshakeIfEmpty device key psk chan peer@Peer{..} endp stopTime = do
ekey <- dhGenKey
now <- epochTime
seed <- randomIO
let state0 = newNoiseState key psk ekey (Just remotePub) InitiatorRole
Right (payload, state1) = sendFirstMessage state0 timestamp
timestamp = BA.convert (genTai64n now)
atomically $ do
isEmpty <- isNothing <$> readTVar initiatorWait
if isEmpty
then do
index <- acquireEmptyIndex device peer seed
let iwait = InitiatorWait index
(addTime now handshakeRetryTime)
(fromMaybe (addTime now handshakeStopTime) stopTime)
state1
writeTVar initiatorWait (Just iwait)
let packet = runPut $ buildPacket (getMac1 remotePub psk) $
HandshakeInitiation index payload
void $ tryPushPacketQueue chan $ (packet, endp)
return True
else return False
genTai64n :: Time -> TAI64n
genTai64n (CTime now) = runPut $ do
putWord64be (fromIntegral now + 4611686018427387914)
putWord32be 0
addTime :: Time -> Int -> Time
addTime (CTime now) secs = CTime (now + fromIntegral secs)
getMac1 :: PublicKey -> Maybe PresharedKey -> BS.ByteString -> BS.ByteString
getMac1 pub mpsk payload =
finalize mac1Length $ update payload $ update (BA.convert (dhPubToBytes pub)) $
case mpsk of
Nothing -> initialize mac1Length
Just psk -> initialize' mac1Length (BA.convert psk)
assertJust :: Monad m => e -> ExceptT e m (Maybe a) -> ExceptT e m a
assertJust err ma = do
res <- ma
case res of
Just a -> return a
Nothing -> throwE err
|