From a60988db82aed029b71531e8a4bc5698fa247c02 Mon Sep 17 00:00:00 2001 From: Bin Jin Date: Tue, 14 Mar 2017 21:18:00 +0800 Subject: Tun: add timestamp for incoming packets --- src/Network/WireGuard/Core.hs | 10 ++++++---- src/Network/WireGuard/Daemon.hs | 1 - src/Network/WireGuard/Internal/State.hs | 18 ++++++++++++------ src/Network/WireGuard/Internal/Util.hs | 10 ++++++++++ src/Network/WireGuard/TunListener.hs | 8 +++++--- 5 files changed, 33 insertions(+), 14 deletions(-) (limited to 'src/Network/WireGuard') diff --git a/src/Network/WireGuard/Core.hs b/src/Network/WireGuard/Core.hs index beb4e36..116ea9f 100644 --- a/src/Network/WireGuard/Core.hs +++ b/src/Network/WireGuard/Core.hs @@ -45,7 +45,7 @@ import Network.WireGuard.Internal.Types import Network.WireGuard.Internal.Util runCore :: Device - -> PacketQueue TunPacket -> PacketQueue TunPacket + -> PacketQueue (Time, TunPacket) -> PacketQueue TunPacket -> PacketQueue UdpPacket -> PacketQueue UdpPacket -> IO () runCore device readTunChan writeTunChan readUdpChan writeUdpChan = do @@ -66,9 +66,11 @@ runCore device readTunChan writeTunChan readUdpChan writeUdpChan = do withAsync (retryWithBackoff $ handleReadUdp device readUdpChan writeTunChan writeUdpChan) $ \ru -> loop (x-1) (rt:ru:asyncs) -handleReadTun :: Device -> PacketQueue TunPacket -> PacketQueue UdpPacket -> IO () +handleReadTun :: Device -> PacketQueue (Time, TunPacket) -> PacketQueue UdpPacket -> IO () handleReadTun device readTunChan writeUdpChan = forever $ do - tunPacket <- atomically $ popPacketQueue readTunChan + earliestToProcess <- (`addTime` (-handshakeRetryTime)) <$> epochTime + (_, tunPacket) <- dropUntilM ((>=earliestToProcess).fst) $ + atomically $ popPacketQueue readTunChan res <- runExceptT $ processTunPacket device writeUdpChan tunPacket case res of Right udpPacket -> atomically $ pushPacketQueue writeUdpChan udpPacket @@ -106,7 +108,7 @@ processTunPacket device@Device{..} writeUdpChan packet = do now0 <- liftIO epochTime endp0 <- assertJust EndPointUnknownError $ liftIO $ readTVarIO (endPoint peer) liftIO $ void $ checkAndTryInitiateHandshake device key psk writeUdpChan peer endp0 now0 - liftIO $ atomically $ waitForSession peer + assertJust OutdatedPacketError $ liftIO $ waitForSession (handshakeRetryTime * 1000000) peer nonce <- liftIO $ atomically $ nextNonce session let (msg, authtag) = encryptMessage (sessionKey session) nonce packet encrypted = runPut $ buildPacket (error "internal error") $ diff --git a/src/Network/WireGuard/Daemon.hs b/src/Network/WireGuard/Daemon.hs index 5b3f225..12a6a2f 100644 --- a/src/Network/WireGuard/Daemon.hs +++ b/src/Network/WireGuard/Daemon.hs @@ -34,7 +34,6 @@ runDaemon intfName sockPath tunFds = do tunListenerThread <- async $ runTunListener tunFds readTunChan writeTunChan -- TODO: Support per-host packet queue - -- TODO: Add timestamp and discard really ancient UDP packets readUdpChan <- atomically $ newPacketQueue maxQueuedUdpPackets writeUdpChan <- atomically $ newPacketQueue maxQueuedUdpPackets udpListenerThread <- async $ runUdpListener device readUdpChan writeUdpChan diff --git a/src/Network/WireGuard/Internal/State.hs b/src/Network/WireGuard/Internal/State.hs index 38866e8..f7b1ca0 100644 --- a/src/Network/WireGuard/Internal/State.hs +++ b/src/Network/WireGuard/Internal/State.hs @@ -193,12 +193,18 @@ getSession peer = do [] -> return Nothing (s:_) -> return (Just s) -waitForSession :: Peer -> STM Session -waitForSession peer = do - sessions' <- readTVar (sessions peer) - case sessions' of - [] -> retry - (s:_) -> return s +waitForSession :: Int -> Peer -> IO (Maybe Session) +waitForSession timelimit peer = do + getTimeout <- registerDelay timelimit + atomically $ do + sessions' <- readTVar (sessions peer) + case sessions' of + [] -> do + timeout <- readTVar getTimeout + if timeout + then return Nothing + else retry + (s:_) -> return (Just s) findSession :: Peer -> Index -> STM (Maybe (Either ResponderWait Session)) findSession peer index = do diff --git a/src/Network/WireGuard/Internal/Util.hs b/src/Network/WireGuard/Internal/Util.hs index f7ecde5..6aefee7 100644 --- a/src/Network/WireGuard/Internal/Util.hs +++ b/src/Network/WireGuard/Internal/Util.hs @@ -7,6 +7,7 @@ module Network.WireGuard.Internal.Util , catchIOExceptionAnd , catchSomeExceptionAnd , withJust + , dropUntilM , zeroMemory , copyMemory ) where @@ -52,6 +53,15 @@ withJust mma func = do Nothing -> return () Just a -> func a +dropUntilM :: Monad m => (a -> Bool) -> m a -> m a +dropUntilM cond ma = loop + where + loop = do + a <- ma + if cond a + then return a + else loop + zeroMemory :: Ptr a -> CSize -> IO () zeroMemory dest nbytes = memset dest 0 (fromIntegral nbytes) diff --git a/src/Network/WireGuard/TunListener.hs b/src/Network/WireGuard/TunListener.hs index 0cc23ec..8e058df 100644 --- a/src/Network/WireGuard/TunListener.hs +++ b/src/Network/WireGuard/TunListener.hs @@ -9,6 +9,7 @@ import qualified Data.ByteArray as BA import Data.Word (Word8) import Foreign.Marshal.Alloc (allocaBytes) import Foreign.Ptr (Ptr) +import System.Posix.Time (epochTime) import System.Posix.Types (Fd) import Network.WireGuard.Foreign.Tun @@ -17,7 +18,7 @@ import Network.WireGuard.Internal.PacketQueue import Network.WireGuard.Internal.Types import Network.WireGuard.Internal.Util -runTunListener :: [Fd] -> PacketQueue TunPacket -> PacketQueue TunPacket -> IO () +runTunListener :: [Fd] -> PacketQueue (Time, TunPacket) -> PacketQueue TunPacket -> IO () runTunListener fds readTunChan writeTunChan = loop fds [] where loop [] asyncs = mapM_ wait asyncs @@ -26,9 +27,10 @@ runTunListener fds readTunChan writeTunChan = loop fds [] withAsync (retryWithBackoff $ handleWrite writeTunChan fd) $ \wt -> loop rest (rt:wt:asyncs) -handleRead :: PacketQueue TunPacket -> Fd -> IO () +handleRead :: PacketQueue (Time, TunPacket) -> Fd -> IO () handleRead readTunChan fd = allocaBytes tunReadBufferLength $ \buf -> - forever (readTun buf fd >>= atomically . pushPacketQueue readTunChan) + forever (((,) <$> epochTime <*> readTun buf fd) + >>= atomically . pushPacketQueue readTunChan) handleWrite :: PacketQueue TunPacket -> Fd -> IO () handleWrite writeTunChan fd = -- cgit v1.2.3-59-g8ed1b