aboutsummaryrefslogtreecommitdiffstats
path: root/src/Network/WireGuard/TunListener.hs
blob: d57cf9866937d7a319ff7e015700d985d495484a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
{-# LANGUAGE CPP             #-}
{-# LANGUAGE RecordWildCards #-}

module Network.WireGuard.TunListener
  ( runTunListener
  ) where

import           Control.Concurrent.Async               (wait, withAsync)
import           Control.Monad                          (forever, void)
import           Control.Monad.STM                      (atomically)
import qualified Data.ByteArray                         as BA
import           Data.Word                              (Word8)
import           Foreign.Marshal.Alloc                  (allocaBytes)
import           Foreign.Ptr                            (Ptr)
import           System.Posix.Types                     (Fd)

import           Network.WireGuard.Foreign.Tun          (fdReadBuf, fdWriteBuf)

import           Network.WireGuard.Internal.Constant
import           Network.WireGuard.Internal.PacketQueue
import           Network.WireGuard.Internal.Types
import           Network.WireGuard.Internal.Util

#ifdef OS_LINUX
import           Control.Concurrent                     (threadWaitRead,
                                                         threadWaitWrite)
#endif

runTunListener :: [Fd] -> PacketQueue TunPacket -> PacketQueue TunPacket -> IO ()
runTunListener fds readTunChan writeTunChan = loop fds []
  where
    loop [] asyncs = mapM_ wait asyncs
    loop (fd:rest) asyncs =
        withAsync (retryWithBackoff $ handleRead readTunChan fd) $ \rt ->
        withAsync (retryWithBackoff $ handleWrite writeTunChan fd) $ \wt ->
            loop rest (rt:wt:asyncs)

handleRead :: PacketQueue TunPacket -> Fd -> IO ()
handleRead readTunChan fd = allocaBytes tunReadBufferLength $ \buf ->
    forever (readFd buf fd >>= atomically . pushPacketQueue readTunChan)

handleWrite :: PacketQueue TunPacket -> Fd -> IO ()
handleWrite writeTunChan fd =
    forever (atomically (popPacketQueue writeTunChan) >>= writeFd fd)

readFd :: BA.ByteArray ba => Ptr Word8 -> Fd -> IO ba
readFd buf fd = do
#ifdef OS_LINUX
    threadWaitRead fd
#endif
    nbytes <- fdReadBuf fd buf (fromIntegral tunReadBufferLength)
    snd <$> BA.allocRet (fromIntegral nbytes)
        (\ptr -> copyMemory ptr buf nbytes >> zeroMemory buf nbytes)

writeFd :: BA.ByteArrayAccess ba => Fd -> ba -> IO ()
writeFd fd ba = BA.withByteArray ba $ \ptr -> do
#ifdef OS_LINUX
    threadWaitWrite fd
#endif
    void $ fdWriteBuf fd ptr (fromIntegral (BA.length ba))