diff options
Diffstat (limited to 'src/Network/WireGuard/Core.hs')
-rw-r--r-- | src/Network/WireGuard/Core.hs | 336 |
1 files changed, 336 insertions, 0 deletions
diff --git a/src/Network/WireGuard/Core.hs b/src/Network/WireGuard/Core.hs new file mode 100644 index 0000000..f36b3c9 --- /dev/null +++ b/src/Network/WireGuard/Core.hs @@ -0,0 +1,336 @@ +{-# LANGUAGE RecordWildCards #-} + +module Network.WireGuard.Core + ( runCore + ) where + +import Control.Concurrent (getNumCapabilities, + threadDelay) +import Control.Concurrent.Async (wait, withAsync) +import Control.Monad (forM_, forever, unless, + void, when) +import Control.Monad.IO.Class (liftIO) +import Control.Monad.STM (atomically) +import Control.Monad.Trans.Except (ExceptT, runExceptT, + throwE) +import Crypto.Noise (HandshakeRole (..)) +import Crypto.Noise.DH (dhGenKey, dhPubEq, + dhPubToBytes) +import qualified Data.ByteArray as BA +import qualified Data.ByteString as BS +import qualified Data.HashMap.Strict as HM +import Data.IP (makeAddrRange) +import qualified Data.IP.RouteTable as RT +import Data.Maybe (fromMaybe, isJust, + isNothing) +import Data.Serialize (putWord32be, + putWord64be, runGet, + runPut) +import Foreign.C.Types (CTime (..)) +import Network.Socket (SockAddr) +import System.IO (hPutStrLn, stderr) +import System.Posix.Time (epochTime) +import System.Random (randomIO) + +import Control.Concurrent.STM.TVar +import Crypto.Hash.BLAKE2.BLAKE2s + +import Network.WireGuard.Internal.Constant +import Network.WireGuard.Internal.IPPacket +import Network.WireGuard.Internal.Noise +import Network.WireGuard.Internal.Packet +import Network.WireGuard.Internal.PacketQueue +import Network.WireGuard.Internal.State +import Network.WireGuard.Internal.Types +import Network.WireGuard.Internal.Util + +runCore :: Device + -> PacketQueue TunPacket -> PacketQueue TunPacket + -> PacketQueue UdpPacket -> PacketQueue UdpPacket + -> IO () +runCore device readTunChan writeTunChan readUdpChan writeUdpChan = do + threads <- getNumCapabilities + loop threads [] + where + heartbeatLoop = forever $ ignoreSyncExceptions $ do + withJust (readTVarIO (localKey device)) $ \key -> + runHeartbeat device key writeUdpChan + -- TODO: use accurate timer + threadDelay heartbeatWaitTime + + loop 0 asyncs = + withAsync heartbeatLoop $ \ht -> + mapM_ wait asyncs >> wait ht + loop x asyncs = + withAsync (retryWithBackoff $ handleReadTun device readTunChan writeUdpChan) $ \rt -> + withAsync (retryWithBackoff $ handleReadUdp device readUdpChan writeTunChan writeUdpChan) $ \ru -> + loop (x-1) (rt:ru:asyncs) + +handleReadTun :: Device -> PacketQueue TunPacket -> PacketQueue UdpPacket -> IO () +handleReadTun device readTunChan writeUdpChan = forever $ do + tunPacket <- atomically $ popPacketQueue readTunChan + res <- runExceptT $ processTunPacket device writeUdpChan tunPacket + case res of + Right udpPacket -> atomically $ pushPacketQueue writeUdpChan udpPacket + Left err -> hPutStrLn stderr (show err) -- TODO: proper logging + +handleReadUdp :: Device -> PacketQueue UdpPacket -> PacketQueue TunPacket + -> PacketQueue UdpPacket + -> IO () +handleReadUdp device readUdpChan writeTunChan writeUdpChan = forever $ do + udpPacket <- atomically $ popPacketQueue readUdpChan + res <- runExceptT $ processUdpPacket device udpPacket + case res of + Left err -> hPutStrLn stderr (show err) -- TODO: proper logging + Right mpacket -> case mpacket of + Just (Right tunp) -> atomically $ pushPacketQueue writeTunChan tunp + Just (Left udpp) -> atomically $ pushPacketQueue writeUdpChan udpp + Nothing -> return () + +processTunPacket :: Device -> PacketQueue UdpPacket -> TunPacket + -> ExceptT WireGuardError IO UdpPacket +processTunPacket device@Device{..} writeUdpChan packet = do + key <- assertJust DeviceNotReadyError $ liftIO (readTVarIO localKey) + psk <- liftIO (readTVarIO presharedKey) + parsedPacket <- liftIO $ parseIPPacket packet + peer <- assertJust DestinationNotReachableError $ case parsedPacket of + InvalidIPPacket -> throwE InvalidIPPacketError + IPv4Packet _ dest4 -> RT.lookup (makeAddrRange dest4 32) + <$> liftIO (readTVarIO routeTable4) + IPv6Packet _ dest6 -> RT.lookup (makeAddrRange dest6 128) + <$> liftIO (readTVarIO routeTable6) + msession <- liftIO (getSession peer) + session <- case msession of + Just session -> return session + Nothing -> do + now0 <- liftIO epochTime + endp0 <- assertJust EndPointUnknownError $ liftIO $ readTVarIO (endPoint peer) + liftIO $ void $ checkAndTryInitiateHandshake device key psk writeUdpChan peer endp0 now0 + liftIO $ atomically $ waitForSession peer + nonce <- liftIO $ atomically $ nextNonce session + let (msg, authtag) = encryptMessage (sessionKey session) nonce packet + encrypted = runPut $ buildPacket (error "internal error") $ + PacketData (theirIndex session) nonce msg authtag + now <- liftIO epochTime + endp <- assertJust EndPointUnknownError $ liftIO $ readTVarIO (endPoint peer) + when (now >= renewTime session) $ liftIO $ + void $ checkAndTryInitiateHandshake device key psk writeUdpChan peer endp now + liftIO $ atomically $ modifyTVar' (transferredBytes peer) (+fromIntegral (BA.length packet)) + liftIO $ atomically $ writeTVar (lastTransferTime peer) now + return (encrypted, endp) + +processUdpPacket :: Device -> UdpPacket + -> ExceptT WireGuardError IO (Maybe (Either UdpPacket TunPacket)) +processUdpPacket device@Device{..} (packet, sock) = do + key <- assertJust DeviceNotReadyError $ liftIO (readTVarIO localKey) + psk <- liftIO (readTVarIO presharedKey) + let mp = runGet (parsePacket (getMac1 (snd key) psk)) packet + case mp of + Left errMsg -> throwE (InvalidWGPacketError errMsg) + Right parsedPacket -> processPacket device key psk sock parsedPacket + +processPacket :: Device -> KeyPair -> Maybe PresharedKey -> SockAddr -> Packet + -> ExceptT WireGuardError IO (Maybe (Either UdpPacket TunPacket)) +processPacket device@Device{..} key psk sock HandshakeInitiation{..} = do + ekey <- liftIO dhGenKey + let state0 = newNoiseState key psk ekey Nothing ResponderRole + outcome = recvFirstMessageAndReply state0 encryptedPayload mempty + case outcome of + Left err -> throwE (NoiseError err) + Right (reply, decryptedPayload, rpub, sks) -> do + when (BA.length decryptedPayload /= timestampLength) $ + throwE $ InvalidWGPacketError "timestamp expected" + peer <- assertJust RemotePeerNotFoundError $ + HM.lookup (getPeerId rpub) <$> liftIO (readTVarIO peers) + notReplayAttack <- liftIO $ atomically $ updateTai64n peer (BA.convert decryptedPayload) + unless notReplayAttack $ throwE HandshakeInitiationReplayError + now <- liftIO epochTime + seed <- liftIO randomIO + ourindex <- liftIO $ atomically $ do + ourindex <- acquireEmptyIndex device peer seed + void $ eraseResponderWait device peer Nothing + let rwait = ResponderWait ourindex senderIndex + (addTime now handshakeStopTime) sks + writeTVar (responderWait peer) (Just rwait) + return ourindex + let responsePacket = runPut $ buildPacket (getMac1 rpub psk) $ + HandshakeResponse ourindex senderIndex reply + return (Just (Left (responsePacket, sock))) + +processPacket device@Device{..} _key _psk sock HandshakeResponse{..} = do + peer <- assertJust UnknownIndexError $ + HM.lookup receiverIndex <$> liftIO (readTVarIO indexMap) + iwait <- assertJust OutdatedPacketError $ liftIO (readTVarIO (initiatorWait peer)) + when (initOurIndex iwait /= receiverIndex) $ throwE OutdatedPacketError + let state1 = initNoise iwait + outcome = recvSecondMessage state1 encryptedPayload + case outcome of + Left err -> throwE (NoiseError err) + Right (decryptedPayload, rpub, sks) -> do + now <- liftIO epochTime + newCounter <- liftIO $ atomically $ newTVar 0 + let newsession = Session receiverIndex senderIndex sks + (addTime now sessionRenewTime) + (addTime now sessionExpireTime) + newCounter + when (BA.length decryptedPayload /= 0) $ + throwE $ InvalidWGPacketError "empty payload expected" + unless (rpub `dhPubEq` remotePub peer) $ throwE RemotePeerNotFoundError + succeeded <- liftIO $ atomically $ do + erased <- eraseInitiatorWait device peer (Just receiverIndex) + when erased $ do + addSession device peer newsession + writeTVar (lastHandshakeTime peer) (Just now) + return erased + unless succeeded $ throwE OutdatedPacketError + liftIO $ atomically $ updateEndPoint peer sock + return Nothing + +processPacket device@Device{..} _key _psk sock PacketData{..} = do + peer <- assertJust UnknownIndexError $ + HM.lookup receiverIndex <$> liftIO (readTVarIO indexMap) + outcome <- liftIO $ atomically $ findSession peer receiverIndex + now <- liftIO epochTime + (isFromResponderWait, session) <- case outcome of + Nothing -> throwE OutdatedPacketError + Just (Right session) -> return (False, session) + Just (Left ResponderWait{..}) -> do + newCounter <- liftIO $ atomically $ newTVar 0 + let newsession = Session respOurIndex respTheirIndex respSessionKey + (addTime now (sessionRenewTime + 2 * handshakeRetryTime)) + (addTime now sessionExpireTime) + newCounter + return (True, newsession) + case decryptMessage (sessionKey session) counter (encryptedPayload, authTag) of + Nothing -> throwE DecryptFailureError + Just decryptedPayload -> do + when isFromResponderWait $ liftIO $ atomically $ do + erased <- eraseResponderWait device peer (Just receiverIndex) + when erased $ do + addSession device peer session + writeTVar (lastHandshakeTime peer) (Just now) + liftIO $ atomically $ updateEndPoint peer sock + if BA.length decryptedPayload /= 0 + then do + parsedPacket <- liftIO $ parseIPPacket decryptedPayload + case parsedPacket of + InvalidIPPacket -> throwE InvalidIPPacketError + IPv4Packet src4 _ -> do + peer' <- assertJust SourceAddrBlockedError $ + RT.lookup (makeAddrRange src4 32) <$> liftIO (readTVarIO routeTable4) + when (remotePub peer /= remotePub peer') $ throwE SourceAddrBlockedError + IPv6Packet src6 _ -> do + peer' <- assertJust SourceAddrBlockedError $ + RT.lookup (makeAddrRange src6 128) <$> liftIO (readTVarIO routeTable6) + when (remotePub peer /= remotePub peer') $ throwE SourceAddrBlockedError + liftIO $ atomically $ writeTVar (lastReceiveTime peer) now + liftIO $ atomically $ modifyTVar' (receivedBytes peer) (+fromIntegral (BA.length decryptedPayload)) + else do + liftIO $ atomically $ writeTVar (lastKeepaliveTime peer) now + return (Just (Right decryptedPayload)) + +runHeartbeat :: Device -> KeyPair -> PacketQueue UdpPacket -> IO () +runHeartbeat device key chan = do + psk <- readTVarIO (presharedKey device) + now <- epochTime + peers' <- readTVarIO (peers device) + forM_ peers' $ \peer -> do + reinitiate <- atomically $ do + miwait <- readTVar (initiatorWait peer) + case miwait of + Just iwait | now >= initStopTime iwait -> do + void $ eraseInitiatorWait device peer Nothing + return Nothing + Just iwait | now >= initRetryTime iwait -> do + void $ eraseInitiatorWait device peer Nothing + return (Just (initStopTime iwait)) + _ -> return Nothing + when (isJust reinitiate) $ withJust (readTVarIO (endPoint peer)) $ \endp -> + void $ tryInitiateHandshakeIfEmpty device key psk chan peer endp reinitiate + atomically $ withJust (readTVar (responderWait peer)) $ \rwait -> + when (now >= respStopTime rwait) $ void $ eraseResponderWait device peer Nothing + atomically $ filterSessions device peer ((now<).expireTime) + lastrecv <- readTVarIO (lastReceiveTime peer) + lastsent <- readTVarIO (lastTransferTime peer) + lastkeep <- readTVarIO (lastKeepaliveTime peer) + when (lastsent < lastrecv && lastrecv <= addTime now (-sessionKeepaliveTime)) $ do + atomically $ writeTVar (lastTransferTime peer) now + atomically $ writeTVar (lastReceiveTime peer) now + withJust (readTVarIO (endPoint peer)) $ \endp -> + withJust (getSession peer) $ \session -> do + nonce <- atomically $ nextNonce session + let (msg, authtag) = encryptMessage (sessionKey session) nonce mempty + keepalivePacket = runPut $ buildPacket (error "internal error") $ + PacketData (theirIndex session) nonce msg authtag + atomically $ pushPacketQueue chan (keepalivePacket, endp) + when (lastrecv < lastsent && lastkeep < lastsent && lastsent <= addTime now (-(sessionKeepaliveTime + handshakeRetryTime))) $ do + atomically $ writeTVar (lastTransferTime peer) now + atomically $ writeTVar (lastReceiveTime peer) now + withJust (readTVarIO (endPoint peer)) $ \endp -> + void $ checkAndTryInitiateHandshake device key psk chan peer endp now + +checkAndTryInitiateHandshake :: Device -> KeyPair -> Maybe PresharedKey + -> PacketQueue UdpPacket -> Peer -> SockAddr -> Time + -> IO Bool +checkAndTryInitiateHandshake device key psk chan peer@Peer{..} endp now = do + initiated <- readAndVerifyStopTime initStopTime initiatorWait (eraseInitiatorWait device peer Nothing) + responded <- readAndVerifyStopTime respStopTime responderWait (eraseResponderWait device peer Nothing) + if initiated || responded + then return False + else tryInitiateHandshakeIfEmpty device key psk chan peer endp Nothing + where + readAndVerifyStopTime getStopTime tvar erase = atomically $ do + ma <- readTVar tvar + case ma of + Just a | now > getStopTime a -> erase >> return False + Just _ -> return True + Nothing -> return False + + +tryInitiateHandshakeIfEmpty :: Device -> KeyPair -> Maybe PresharedKey + -> PacketQueue UdpPacket -> Peer -> SockAddr -> Maybe Time + -> IO Bool +tryInitiateHandshakeIfEmpty device key psk chan peer@Peer{..} endp stopTime = do + ekey <- dhGenKey + now <- epochTime + seed <- randomIO + let state0 = newNoiseState key psk ekey (Just remotePub) InitiatorRole + Right (payload, state1) = sendFirstMessage state0 timestamp + timestamp = BA.convert (genTai64n now) + atomically $ do + isEmpty <- isNothing <$> readTVar initiatorWait + if isEmpty + then do + index <- acquireEmptyIndex device peer seed + let iwait = InitiatorWait index + (addTime now handshakeRetryTime) + (fromMaybe (addTime now handshakeStopTime) stopTime) + state1 + writeTVar initiatorWait (Just iwait) + let packet = runPut $ buildPacket (getMac1 remotePub psk) $ + HandshakeInitiation index payload + void $ tryPushPacketQueue chan $ (packet, endp) + return True + else return False + +genTai64n :: Time -> TAI64n +genTai64n (CTime now) = runPut $ do + putWord64be (fromIntegral now + 4611686018427387914) + putWord32be 0 + +addTime :: Time -> Int -> Time +addTime (CTime now) secs = CTime (now + fromIntegral secs) + +getMac1 :: PublicKey -> Maybe PresharedKey -> BS.ByteString -> BS.ByteString +getMac1 pub mpsk payload = + finalize mac1Length $ update payload $ update (BA.convert (dhPubToBytes pub)) $ + case mpsk of + Nothing -> initialize mac1Length + Just psk -> initialize' mac1Length (BA.convert psk) + +assertJust :: Monad m => e -> ExceptT e m (Maybe a) -> ExceptT e m a +assertJust err ma = do + res <- ma + case res of + Just a -> return a + Nothing -> throwE err |