aboutsummaryrefslogtreecommitdiffstats
path: root/src/Network/WireGuard/TunListener.hs
blob: 8e058dffc55ee4214ee90a16f7ceb93f298f0f90 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
module Network.WireGuard.TunListener
  ( runTunListener
  ) where

import           Control.Concurrent.Async               (wait, withAsync)
import           Control.Monad                          (forever, void)
import           Control.Monad.STM                      (atomically)
import qualified Data.ByteArray                         as BA
import           Data.Word                              (Word8)
import           Foreign.Marshal.Alloc                  (allocaBytes)
import           Foreign.Ptr                            (Ptr)
import           System.Posix.Time                      (epochTime)
import           System.Posix.Types                     (Fd)

import           Network.WireGuard.Foreign.Tun
import           Network.WireGuard.Internal.Constant
import           Network.WireGuard.Internal.PacketQueue
import           Network.WireGuard.Internal.Types
import           Network.WireGuard.Internal.Util

runTunListener :: [Fd] -> PacketQueue (Time, TunPacket) -> PacketQueue TunPacket -> IO ()
runTunListener fds readTunChan writeTunChan = loop fds []
  where
    loop [] asyncs = mapM_ wait asyncs
    loop (fd:rest) asyncs =
        withAsync (retryWithBackoff $ handleRead readTunChan fd) $ \rt ->
        withAsync (retryWithBackoff $ handleWrite writeTunChan fd) $ \wt ->
            loop rest (rt:wt:asyncs)

handleRead :: PacketQueue (Time, TunPacket) -> Fd -> IO ()
handleRead readTunChan fd = allocaBytes tunReadBufferLength $ \buf ->
    forever (((,) <$> epochTime <*> readTun buf fd)
        >>= atomically . pushPacketQueue readTunChan)

handleWrite :: PacketQueue TunPacket -> Fd -> IO ()
handleWrite writeTunChan fd =
    forever (atomically (popPacketQueue writeTunChan) >>= writeTun fd)

readTun :: BA.ByteArray ba => Ptr Word8 -> Fd -> IO ba
readTun buf fd = do
    nbytes <- tunReadBuf fd buf (fromIntegral tunReadBufferLength)
    snd <$> BA.allocRet (fromIntegral nbytes)
        (\ptr -> copyMemory ptr buf nbytes >> zeroMemory buf nbytes)

writeTun :: BA.ByteArrayAccess ba => Fd -> ba -> IO ()
writeTun fd ba = BA.withByteArray ba $ \ptr -> do
    void $ tunWriteBuf fd ptr (fromIntegral (BA.length ba))