From 72ce5a8715976ec5eccedab0552fc7d9233903c1 Mon Sep 17 00:00:00 2001 From: Bin Jin Date: Thu, 16 Mar 2017 00:26:40 +0800 Subject: remove STM packet queue --- src/Network/WireGuard/Core.hs | 23 ++++++------- src/Network/WireGuard/Daemon.hs | 9 +++-- src/Network/WireGuard/Internal/Constant.hs | 3 -- src/Network/WireGuard/Internal/PacketQueue.hs | 47 +++++---------------------- src/Network/WireGuard/TunListener.hs | 6 ++-- src/Network/WireGuard/UdpListener.hs | 4 +-- 6 files changed, 29 insertions(+), 63 deletions(-) diff --git a/src/Network/WireGuard/Core.hs b/src/Network/WireGuard/Core.hs index 116ea9f..6d65e37 100644 --- a/src/Network/WireGuard/Core.hs +++ b/src/Network/WireGuard/Core.hs @@ -69,24 +69,23 @@ runCore device readTunChan writeTunChan readUdpChan writeUdpChan = do handleReadTun :: Device -> PacketQueue (Time, TunPacket) -> PacketQueue UdpPacket -> IO () handleReadTun device readTunChan writeUdpChan = forever $ do earliestToProcess <- (`addTime` (-handshakeRetryTime)) <$> epochTime - (_, tunPacket) <- dropUntilM ((>=earliestToProcess).fst) $ - atomically $ popPacketQueue readTunChan + (_, tunPacket) <- dropUntilM ((>=earliestToProcess).fst) $ popPacketQueue readTunChan res <- runExceptT $ processTunPacket device writeUdpChan tunPacket case res of - Right udpPacket -> atomically $ pushPacketQueue writeUdpChan udpPacket + Right udpPacket -> pushPacketQueue writeUdpChan udpPacket Left err -> hPutStrLn stderr (show err) -- TODO: proper logging handleReadUdp :: Device -> PacketQueue UdpPacket -> PacketQueue TunPacket -> PacketQueue UdpPacket -> IO () handleReadUdp device readUdpChan writeTunChan writeUdpChan = forever $ do - udpPacket <- atomically $ popPacketQueue readUdpChan + udpPacket <- popPacketQueue readUdpChan res <- runExceptT $ processUdpPacket device udpPacket case res of Left err -> hPutStrLn stderr (show err) -- TODO: proper logging Right mpacket -> case mpacket of - Just (Right tunp) -> atomically $ pushPacketQueue writeTunChan tunp - Just (Left udpp) -> atomically $ pushPacketQueue writeUdpChan udpp + Just (Right tunp) -> pushPacketQueue writeTunChan tunp + Just (Left udpp) -> pushPacketQueue writeUdpChan udpp Nothing -> return () processTunPacket :: Device -> PacketQueue UdpPacket -> TunPacket @@ -263,7 +262,7 @@ runHeartbeat device key chan = do let (msg, authtag) = encryptMessage (sessionKey session) nonce mempty keepalivePacket = runPut $ buildPacket (error "internal error") $ PacketData (theirIndex session) nonce msg authtag - atomically $ pushPacketQueue chan (keepalivePacket, endp) + pushPacketQueue chan (keepalivePacket, endp) when (lastrecv < lastsent && lastkeep < lastsent && lastsent <= addTime now (-(sessionKeepaliveTime + handshakeRetryTime))) $ do atomically $ writeTVar (lastTransferTime peer) now atomically $ writeTVar (lastReceiveTime peer) now @@ -298,7 +297,7 @@ tryInitiateHandshakeIfEmpty device key psk chan peer@Peer{..} endp stopTime = do let state0 = newNoiseState key psk ekey (Just remotePub) InitiatorRole Right (payload, state1) = sendFirstMessage state0 timestamp timestamp = BA.convert (genTai64n now) - atomically $ do + mpacket <- atomically $ do isEmpty <- isNothing <$> readTVar initiatorWait if isEmpty then do @@ -310,9 +309,11 @@ tryInitiateHandshakeIfEmpty device key psk chan peer@Peer{..} endp stopTime = do writeTVar initiatorWait (Just iwait) let packet = runPut $ buildPacket (getMac1 remotePub psk) $ HandshakeInitiation index payload - void $ tryPushPacketQueue chan $ (packet, endp) - return True - else return False + return (Just packet) + else return Nothing + case mpacket of + Just packet -> pushPacketQueue chan (packet, endp) >> return True + Nothing -> return False genTai64n :: Time -> TAI64n genTai64n (CTime now) = runPut $ do diff --git a/src/Network/WireGuard/Daemon.hs b/src/Network/WireGuard/Daemon.hs index 12a6a2f..cc0d22c 100644 --- a/src/Network/WireGuard/Daemon.hs +++ b/src/Network/WireGuard/Daemon.hs @@ -19,7 +19,6 @@ import Network.WireGuard.RPC (runRPC) import Network.WireGuard.TunListener (runTunListener) import Network.WireGuard.UdpListener (runUdpListener) -import Network.WireGuard.Internal.Constant import Network.WireGuard.Internal.PacketQueue import Network.WireGuard.Internal.Util @@ -29,13 +28,13 @@ runDaemon intfName sockPath tunFds = do rpcThread <- async $ runRPC sockPath device - readTunChan <- atomically $ newPacketQueue maxQueuedTunPackets - writeTunChan <- atomically $ newPacketQueue maxQueuedTunPackets + readTunChan <- newPacketQueue + writeTunChan <- newPacketQueue tunListenerThread <- async $ runTunListener tunFds readTunChan writeTunChan -- TODO: Support per-host packet queue - readUdpChan <- atomically $ newPacketQueue maxQueuedUdpPackets - writeUdpChan <- atomically $ newPacketQueue maxQueuedUdpPackets + readUdpChan <- newPacketQueue + writeUdpChan <- newPacketQueue udpListenerThread <- async $ runUdpListener device readUdpChan writeUdpChan coreThread <- async $ runCore device readTunChan writeTunChan readUdpChan writeUdpChan diff --git a/src/Network/WireGuard/Internal/Constant.hs b/src/Network/WireGuard/Internal/Constant.hs index c615a0f..73dbd42 100644 --- a/src/Network/WireGuard/Internal/Constant.hs +++ b/src/Network/WireGuard/Internal/Constant.hs @@ -21,9 +21,6 @@ mac2Length = 16 maxQueuedUdpPackets :: Int maxQueuedUdpPackets = 4096 -maxQueuedTunPackets :: Int -maxQueuedTunPackets = 4096 - udpReadBufferLength :: Int udpReadBufferLength = 4096 diff --git a/src/Network/WireGuard/Internal/PacketQueue.hs b/src/Network/WireGuard/Internal/PacketQueue.hs index bc390f8..2840a73 100644 --- a/src/Network/WireGuard/Internal/PacketQueue.hs +++ b/src/Network/WireGuard/Internal/PacketQueue.hs @@ -1,49 +1,20 @@ -{-# LANGUAGE RecordWildCards #-} - module Network.WireGuard.Internal.PacketQueue ( PacketQueue , newPacketQueue , popPacketQueue , pushPacketQueue - , tryPushPacketQueue + , module Control.Concurrent.Chan ) where -import Control.Concurrent.STM - -data PacketQueue packet = PacketQueue - { tqueue :: TQueue packet - , allowance :: TVar Int - } - --- | Create a new PacketQueue with size limit of |maxQueuedPackets|. -newPacketQueue :: Int -> STM (PacketQueue packet) -newPacketQueue maxQueuedPackets = PacketQueue <$> newTQueue <*> newTVar maxQueuedPackets +import Control.Concurrent.Chan --- | Pop a packet out from the queue, blocks if no packet is available. -popPacketQueue :: PacketQueue packet -> STM packet -popPacketQueue PacketQueue{..} = do - packet <- readTQueue tqueue - modifyTVar' allowance (+1) - return packet +type PacketQueue packet = Chan packet --- | Push a packet into the queue. Blocks if it's full. -pushPacketQueue :: PacketQueue packet -> packet -> STM () -pushPacketQueue PacketQueue{..} packet = do - allowance' <- readTVar allowance - if allowance' <= 0 - then retry - else do - writeTQueue tqueue packet - writeTVar allowance (allowance' - 1) +newPacketQueue :: IO (PacketQueue packet) +newPacketQueue = newChan --- | Try to push a packet into the queue. Returns True if it's pushed. -tryPushPacketQueue :: PacketQueue packet -> packet -> STM Bool -tryPushPacketQueue PacketQueue{..} packet = do - allowance' <- readTVar allowance - if allowance' <= 0 - then return False - else do - writeTQueue tqueue packet - writeTVar allowance (allowance' - 1) - return True +popPacketQueue :: PacketQueue packet -> IO packet +popPacketQueue = readChan +pushPacketQueue :: PacketQueue packet -> packet -> IO () +pushPacketQueue = writeChan diff --git a/src/Network/WireGuard/TunListener.hs b/src/Network/WireGuard/TunListener.hs index 8e058df..f5628e5 100644 --- a/src/Network/WireGuard/TunListener.hs +++ b/src/Network/WireGuard/TunListener.hs @@ -4,7 +4,6 @@ module Network.WireGuard.TunListener import Control.Concurrent.Async (wait, withAsync) import Control.Monad (forever, void) -import Control.Monad.STM (atomically) import qualified Data.ByteArray as BA import Data.Word (Word8) import Foreign.Marshal.Alloc (allocaBytes) @@ -29,12 +28,11 @@ runTunListener fds readTunChan writeTunChan = loop fds [] handleRead :: PacketQueue (Time, TunPacket) -> Fd -> IO () handleRead readTunChan fd = allocaBytes tunReadBufferLength $ \buf -> - forever (((,) <$> epochTime <*> readTun buf fd) - >>= atomically . pushPacketQueue readTunChan) + forever (((,) <$> epochTime <*> readTun buf fd) >>= pushPacketQueue readTunChan) handleWrite :: PacketQueue TunPacket -> Fd -> IO () handleWrite writeTunChan fd = - forever (atomically (popPacketQueue writeTunChan) >>= writeTun fd) + forever (popPacketQueue writeTunChan >>= writeTun fd) readTun :: BA.ByteArray ba => Ptr Word8 -> Fd -> IO ba readTun buf fd = do diff --git a/src/Network/WireGuard/UdpListener.hs b/src/Network/WireGuard/UdpListener.hs index 77b8ae0..93369f4 100644 --- a/src/Network/WireGuard/UdpListener.hs +++ b/src/Network/WireGuard/UdpListener.hs @@ -46,11 +46,11 @@ handlePort bindPort readUdpChan writeUdpChan = retryWithBackoff $ handleRead :: Socket -> PacketQueue UdpPacket -> IO () handleRead sock readUdpChan = forever $ do packet <- recvFrom sock udpReadBufferLength - void $ atomically $ tryPushPacketQueue readUdpChan packet + pushPacketQueue readUdpChan packet handleWrite :: Socket -> PacketQueue UdpPacket -> IO () handleWrite sock writeUdpChan = forever $ do - (packet, dest) <- atomically $ popPacketQueue writeUdpChan + (packet, dest) <- popPacketQueue writeUdpChan void $ sendTo sock packet dest waitNewVar :: Eq a => a -> TVar a -> STM a -- cgit v1.2.3-59-g8ed1b