aboutsummaryrefslogtreecommitdiffstats
path: root/src/Network/WireGuard/UdpListener.hs
blob: 77b8ae0a6ccedce744c155dfbc257be02bb1cf76 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
{-# LANGUAGE OverloadedStrings #-}

module Network.WireGuard.UdpListener
  ( runUdpListener
  ) where

import           Control.Concurrent.Async               (cancel, wait,
                                                         withAsync)
import           Control.Concurrent.STM.TVar            (TVar, readTVar)
import           Control.Exception                      (bracket)
import           Control.Monad                          (forever, void)
import           Control.Monad.STM                      (STM, atomically, retry)
import           Data.Streaming.Network                 (bindPortUDP,
                                                         bindRandomPortUDP)
import           Network.Socket                         (Socket, close)
import           Network.Socket.ByteString              (recvFrom, sendTo)

import           Network.WireGuard.Internal.State       (Device (..))

import           Network.WireGuard.Internal.Constant
import           Network.WireGuard.Internal.PacketQueue
import           Network.WireGuard.Internal.Types
import           Network.WireGuard.Internal.Util

runUdpListener :: Device -> PacketQueue UdpPacket -> PacketQueue UdpPacket -> IO ()
runUdpListener device readUdpChan writeUdpChan = loop 0
  where
    loop oport =
        withAsync (handlePort oport readUdpChan writeUdpChan) $ \t -> do
            nport <- atomically $ waitNewVar oport (port device)
            cancel t
            loop nport

handlePort :: Int -> PacketQueue UdpPacket -> PacketQueue UdpPacket -> IO ()
handlePort bindPort readUdpChan writeUdpChan = retryWithBackoff $
    bracket (bind bindPort) close $ \sock ->
        withAsync (handleRead sock readUdpChan) $ \rt ->
        withAsync (handleWrite sock writeUdpChan) $ \wt -> do
            wait rt
            wait wt
  where
    -- TODO: prefer ipv6 binding here
    bind 0 = snd <$> bindRandomPortUDP "!4"
    bind p = bindPortUDP p "!4"

handleRead :: Socket -> PacketQueue UdpPacket -> IO ()
handleRead sock readUdpChan = forever $ do
    packet <- recvFrom sock udpReadBufferLength
    void $ atomically $ tryPushPacketQueue readUdpChan packet

handleWrite :: Socket -> PacketQueue UdpPacket -> IO ()
handleWrite sock writeUdpChan = forever $ do
    (packet, dest) <- atomically $ popPacketQueue writeUdpChan
    void $ sendTo sock packet dest

waitNewVar :: Eq a => a -> TVar a -> STM a
waitNewVar old tvar = do
    now <- readTVar tvar
    if now == old
      then retry
      else return now