aboutsummaryrefslogtreecommitdiffstats
path: root/src/Network/WireGuard
diff options
context:
space:
mode:
Diffstat (limited to 'src/Network/WireGuard')
-rw-r--r--src/Network/WireGuard/Core.hs10
-rw-r--r--src/Network/WireGuard/Daemon.hs1
-rw-r--r--src/Network/WireGuard/Internal/State.hs18
-rw-r--r--src/Network/WireGuard/Internal/Util.hs10
-rw-r--r--src/Network/WireGuard/TunListener.hs8
5 files changed, 33 insertions, 14 deletions
diff --git a/src/Network/WireGuard/Core.hs b/src/Network/WireGuard/Core.hs
index beb4e36..116ea9f 100644
--- a/src/Network/WireGuard/Core.hs
+++ b/src/Network/WireGuard/Core.hs
@@ -45,7 +45,7 @@ import Network.WireGuard.Internal.Types
import Network.WireGuard.Internal.Util
runCore :: Device
- -> PacketQueue TunPacket -> PacketQueue TunPacket
+ -> PacketQueue (Time, TunPacket) -> PacketQueue TunPacket
-> PacketQueue UdpPacket -> PacketQueue UdpPacket
-> IO ()
runCore device readTunChan writeTunChan readUdpChan writeUdpChan = do
@@ -66,9 +66,11 @@ runCore device readTunChan writeTunChan readUdpChan writeUdpChan = do
withAsync (retryWithBackoff $ handleReadUdp device readUdpChan writeTunChan writeUdpChan) $ \ru ->
loop (x-1) (rt:ru:asyncs)
-handleReadTun :: Device -> PacketQueue TunPacket -> PacketQueue UdpPacket -> IO ()
+handleReadTun :: Device -> PacketQueue (Time, TunPacket) -> PacketQueue UdpPacket -> IO ()
handleReadTun device readTunChan writeUdpChan = forever $ do
- tunPacket <- atomically $ popPacketQueue readTunChan
+ earliestToProcess <- (`addTime` (-handshakeRetryTime)) <$> epochTime
+ (_, tunPacket) <- dropUntilM ((>=earliestToProcess).fst) $
+ atomically $ popPacketQueue readTunChan
res <- runExceptT $ processTunPacket device writeUdpChan tunPacket
case res of
Right udpPacket -> atomically $ pushPacketQueue writeUdpChan udpPacket
@@ -106,7 +108,7 @@ processTunPacket device@Device{..} writeUdpChan packet = do
now0 <- liftIO epochTime
endp0 <- assertJust EndPointUnknownError $ liftIO $ readTVarIO (endPoint peer)
liftIO $ void $ checkAndTryInitiateHandshake device key psk writeUdpChan peer endp0 now0
- liftIO $ atomically $ waitForSession peer
+ assertJust OutdatedPacketError $ liftIO $ waitForSession (handshakeRetryTime * 1000000) peer
nonce <- liftIO $ atomically $ nextNonce session
let (msg, authtag) = encryptMessage (sessionKey session) nonce packet
encrypted = runPut $ buildPacket (error "internal error") $
diff --git a/src/Network/WireGuard/Daemon.hs b/src/Network/WireGuard/Daemon.hs
index 5b3f225..12a6a2f 100644
--- a/src/Network/WireGuard/Daemon.hs
+++ b/src/Network/WireGuard/Daemon.hs
@@ -34,7 +34,6 @@ runDaemon intfName sockPath tunFds = do
tunListenerThread <- async $ runTunListener tunFds readTunChan writeTunChan
-- TODO: Support per-host packet queue
- -- TODO: Add timestamp and discard really ancient UDP packets
readUdpChan <- atomically $ newPacketQueue maxQueuedUdpPackets
writeUdpChan <- atomically $ newPacketQueue maxQueuedUdpPackets
udpListenerThread <- async $ runUdpListener device readUdpChan writeUdpChan
diff --git a/src/Network/WireGuard/Internal/State.hs b/src/Network/WireGuard/Internal/State.hs
index 38866e8..f7b1ca0 100644
--- a/src/Network/WireGuard/Internal/State.hs
+++ b/src/Network/WireGuard/Internal/State.hs
@@ -193,12 +193,18 @@ getSession peer = do
[] -> return Nothing
(s:_) -> return (Just s)
-waitForSession :: Peer -> STM Session
-waitForSession peer = do
- sessions' <- readTVar (sessions peer)
- case sessions' of
- [] -> retry
- (s:_) -> return s
+waitForSession :: Int -> Peer -> IO (Maybe Session)
+waitForSession timelimit peer = do
+ getTimeout <- registerDelay timelimit
+ atomically $ do
+ sessions' <- readTVar (sessions peer)
+ case sessions' of
+ [] -> do
+ timeout <- readTVar getTimeout
+ if timeout
+ then return Nothing
+ else retry
+ (s:_) -> return (Just s)
findSession :: Peer -> Index -> STM (Maybe (Either ResponderWait Session))
findSession peer index = do
diff --git a/src/Network/WireGuard/Internal/Util.hs b/src/Network/WireGuard/Internal/Util.hs
index f7ecde5..6aefee7 100644
--- a/src/Network/WireGuard/Internal/Util.hs
+++ b/src/Network/WireGuard/Internal/Util.hs
@@ -7,6 +7,7 @@ module Network.WireGuard.Internal.Util
, catchIOExceptionAnd
, catchSomeExceptionAnd
, withJust
+ , dropUntilM
, zeroMemory
, copyMemory
) where
@@ -52,6 +53,15 @@ withJust mma func = do
Nothing -> return ()
Just a -> func a
+dropUntilM :: Monad m => (a -> Bool) -> m a -> m a
+dropUntilM cond ma = loop
+ where
+ loop = do
+ a <- ma
+ if cond a
+ then return a
+ else loop
+
zeroMemory :: Ptr a -> CSize -> IO ()
zeroMemory dest nbytes = memset dest 0 (fromIntegral nbytes)
diff --git a/src/Network/WireGuard/TunListener.hs b/src/Network/WireGuard/TunListener.hs
index 0cc23ec..8e058df 100644
--- a/src/Network/WireGuard/TunListener.hs
+++ b/src/Network/WireGuard/TunListener.hs
@@ -9,6 +9,7 @@ import qualified Data.ByteArray as BA
import Data.Word (Word8)
import Foreign.Marshal.Alloc (allocaBytes)
import Foreign.Ptr (Ptr)
+import System.Posix.Time (epochTime)
import System.Posix.Types (Fd)
import Network.WireGuard.Foreign.Tun
@@ -17,7 +18,7 @@ import Network.WireGuard.Internal.PacketQueue
import Network.WireGuard.Internal.Types
import Network.WireGuard.Internal.Util
-runTunListener :: [Fd] -> PacketQueue TunPacket -> PacketQueue TunPacket -> IO ()
+runTunListener :: [Fd] -> PacketQueue (Time, TunPacket) -> PacketQueue TunPacket -> IO ()
runTunListener fds readTunChan writeTunChan = loop fds []
where
loop [] asyncs = mapM_ wait asyncs
@@ -26,9 +27,10 @@ runTunListener fds readTunChan writeTunChan = loop fds []
withAsync (retryWithBackoff $ handleWrite writeTunChan fd) $ \wt ->
loop rest (rt:wt:asyncs)
-handleRead :: PacketQueue TunPacket -> Fd -> IO ()
+handleRead :: PacketQueue (Time, TunPacket) -> Fd -> IO ()
handleRead readTunChan fd = allocaBytes tunReadBufferLength $ \buf ->
- forever (readTun buf fd >>= atomically . pushPacketQueue readTunChan)
+ forever (((,) <$> epochTime <*> readTun buf fd)
+ >>= atomically . pushPacketQueue readTunChan)
handleWrite :: PacketQueue TunPacket -> Fd -> IO ()
handleWrite writeTunChan fd =