blob: 81743f727b64feb8484b001816bef3b1583e9eb2 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
|
{-# LANGUAGE OverloadedStrings #-}
module Network.WireGuard.UdpListener
( runUdpListener
) where
import Control.Concurrent.Async (cancel, wait,
withAsync)
import Control.Concurrent.STM.TVar (TVar, readTVar)
import Control.Exception (bracket)
import Control.Monad (forever, void)
import Control.Monad.STM (STM, atomically, retry)
import Data.Streaming.Network (bindPortUDP,
bindRandomPortUDP)
import Network.Socket (Socket, close)
import Network.Socket.ByteString (recvFrom, sendTo)
import Network.WireGuard.Internal.State (Device (..))
import Network.WireGuard.Internal.Constant
import Network.WireGuard.Internal.PacketQueue
import Network.WireGuard.Internal.Data.Types
import Network.WireGuard.Internal.Util
runUdpListener :: Device -> PacketQueue UdpPacket -> PacketQueue UdpPacket -> IO ()
runUdpListener device readUdpChan writeUdpChan = loop 0
where
loop oport =
withAsync (handlePort oport readUdpChan writeUdpChan) $ \t -> do
nport <- atomically $ waitNewVar oport (port device)
cancel t
loop nport
handlePort :: Int -> PacketQueue UdpPacket -> PacketQueue UdpPacket -> IO ()
handlePort bindPort readUdpChan writeUdpChan = retryWithBackoff $
bracket (bind bindPort) close $ \sock ->
withAsync (handleRead sock readUdpChan) $ \rt ->
withAsync (handleWrite sock writeUdpChan) $ \wt -> do
wait rt
wait wt
where
-- TODO: prefer ipv6 binding here
bind 0 = snd <$> bindRandomPortUDP "!4"
bind p = bindPortUDP p "!4"
handleRead :: Socket -> PacketQueue UdpPacket -> IO ()
handleRead sock readUdpChan = forever $ do
packet <- recvFrom sock udpReadBufferLength
pushPacketQueue readUdpChan packet
handleWrite :: Socket -> PacketQueue UdpPacket -> IO ()
handleWrite sock writeUdpChan = forever $ do
(packet, dest) <- popPacketQueue writeUdpChan
void $ sendTo sock packet dest
waitNewVar :: Eq a => a -> TVar a -> STM a
waitNewVar old tvar = do
now <- readTVar tvar
if now == old
then retry
else return now
|