blob: d57cf9866937d7a319ff7e015700d985d495484a (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
|
{-# LANGUAGE CPP #-}
{-# LANGUAGE RecordWildCards #-}
module Network.WireGuard.TunListener
( runTunListener
) where
import Control.Concurrent.Async (wait, withAsync)
import Control.Monad (forever, void)
import Control.Monad.STM (atomically)
import qualified Data.ByteArray as BA
import Data.Word (Word8)
import Foreign.Marshal.Alloc (allocaBytes)
import Foreign.Ptr (Ptr)
import System.Posix.Types (Fd)
import Network.WireGuard.Foreign.Tun (fdReadBuf, fdWriteBuf)
import Network.WireGuard.Internal.Constant
import Network.WireGuard.Internal.PacketQueue
import Network.WireGuard.Internal.Types
import Network.WireGuard.Internal.Util
#ifdef OS_LINUX
import Control.Concurrent (threadWaitRead,
threadWaitWrite)
#endif
runTunListener :: [Fd] -> PacketQueue TunPacket -> PacketQueue TunPacket -> IO ()
runTunListener fds readTunChan writeTunChan = loop fds []
where
loop [] asyncs = mapM_ wait asyncs
loop (fd:rest) asyncs =
withAsync (retryWithBackoff $ handleRead readTunChan fd) $ \rt ->
withAsync (retryWithBackoff $ handleWrite writeTunChan fd) $ \wt ->
loop rest (rt:wt:asyncs)
handleRead :: PacketQueue TunPacket -> Fd -> IO ()
handleRead readTunChan fd = allocaBytes tunReadBufferLength $ \buf ->
forever (readFd buf fd >>= atomically . pushPacketQueue readTunChan)
handleWrite :: PacketQueue TunPacket -> Fd -> IO ()
handleWrite writeTunChan fd =
forever (atomically (popPacketQueue writeTunChan) >>= writeFd fd)
readFd :: BA.ByteArray ba => Ptr Word8 -> Fd -> IO ba
readFd buf fd = do
#ifdef OS_LINUX
threadWaitRead fd
#endif
nbytes <- fdReadBuf fd buf (fromIntegral tunReadBufferLength)
snd <$> BA.allocRet (fromIntegral nbytes)
(\ptr -> copyMemory ptr buf nbytes >> zeroMemory buf nbytes)
writeFd :: BA.ByteArrayAccess ba => Fd -> ba -> IO ()
writeFd fd ba = BA.withByteArray ba $ \ptr -> do
#ifdef OS_LINUX
threadWaitWrite fd
#endif
void $ fdWriteBuf fd ptr (fromIntegral (BA.length ba))
|