1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
|
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE OverloadedStrings #-}
{-|
Module : Network.WireGuard.RPC
Description : Wireguard's RPC protocol implementation.
Copyright : Félix Baylac-Jacqué, 2017
License : GPL-3
Maintainer : felix@alternativebit.fr
Stability : experimental
Portability : POSIX
Wireguard's RPC protocol implementation. This module contains
the various operation needed to communicate with the wg CLI utility.
|-}
module Network.WireGuard.RPC
( runRPC,
serveConduit,
showDevice,
showPeer
) where
import Control.Concurrent.STM (STM, atomically,
readTVar, writeTVar)
import Control.Monad (when, unless)
import Control.Monad.IO.Class (liftIO)
import qualified Crypto.Noise.DH as DH (dhPubToBytes, dhSecToBytes)
import qualified Data.ByteArray as BA (convert)
import qualified Data.ByteString as BS (ByteString, concat,
empty)
import qualified Data.ByteString.Char8 as BC (pack, singleton, map)
import Data.Char (toLower)
import Data.Conduit.Attoparsec (sinkParserEither)
import Data.Conduit.Network.Unix (appSink, appSource,
runUnixServer,
serverSettings)
import qualified Data.HashMap.Strict as HM (delete, lookup, insert, empty,
elems, member)
import Data.Hex (hex)
import Data.List (foldl')
import Data.Conduit (ConduitM, (.|),
yield, runConduit)
import Data.Maybe (fromJust, isJust,
fromMaybe)
import Network.WireGuard.Internal.RpcParsers (requestParser)
import Network.WireGuard.Internal.State (Device(..), Peer(..),
createPeer)
import Network.WireGuard.Internal.Data.Types (PrivateKey, PublicKey)
import Network.WireGuard.Internal.Data.RpcTypes (RpcRequest(..), RpcSetPayload(..),
OpType(..), RpcDevicePayload(..),
RpcPeerPayload(..))
import Network.WireGuard.Internal.Util (catchIOExceptionAnd)
--TODO: return appropriate errno during set operations.
-- | Run RPC service over a unix socket
runRPC :: FilePath -> Device -> IO ()
runRPC sockPath device = runUnixServer (serverSettings sockPath) $ \app ->
catchIOExceptionAnd (return ()) $
runConduit (appSource app .| serveConduit device .| appSink app)
-- | Process a stream coming from a unix socket and writes back the
-- appropriate response.
serveConduit :: Device -> ConduitM BS.ByteString BS.ByteString IO ()
serveConduit device = do
request <- sinkParserEither requestParser
routeRequest request
where
routeRequest (Left _) = yield mempty
routeRequest (Right req) =
case opType req of
Set -> do
err <- liftIO . atomically $ setDevice req device
let errno = fromMaybe "0" err
yield $ BS.concat [BC.pack "errno=", errno, BC.pack "\n\n"]
Get -> do
deviceBstr <- liftIO . atomically $ showDevice device
yield $ BS.concat [deviceBstr, BC.pack "errno=0\n\n"]
-- | Print a device in a bytestring according to wireguard's RPC format.
--
-- More infos about this format on this page <https://www.wireguard.com/xplatform/>
showDevice :: Device -> STM BS.ByteString
showDevice Device{..} = do
listen_port <- BC.pack . show <$> readTVar port
fwm <- BC.pack . show <$> readTVar fwmark
private_key <- fmap (toLowerBs . hex . privToBytes . fst) <$> readTVar localKey
let devHm = [("private_key", private_key),
("listen_port", Just listen_port),
("fwmark", Just fwm)]
let devBs = serializeRpcKeyValue devHm
prs <- readTVar peers
peersBstrList <- mapM showPeer $ HM.elems prs
return . BS.concat $ (devBs : peersBstrList)
-- | Print a peer in a bytestring according to wireguard's RPC format.
--
-- More infos about this format on this page <https://www.wireguard.com/xplatform/>
showPeer :: Peer -> STM BS.ByteString
showPeer Peer{..} = do
let public_key = pubToString remotePub
endpoint <- readTVar endPoint
persistant_keepalive_interval <- readTVar keepaliveInterval
allowed_ip <- readTVar ipmasks
rx_bytes <- readTVar receivedBytes
tx_bytes <- readTVar transferredBytes
last_handshake_time <- readTVar lastHandshakeTime
let peer = [("public_key", Just public_key),
("endpoint", BC.pack . show <$> endpoint),
("persistent_keepalive_interval", Just . BC.pack . show $ persistant_keepalive_interval),
("tx_bytes", Just . BC.pack . show $ tx_bytes),
("rx_bytes", Just . BC.pack . show $ rx_bytes),
("last_handshake_time", BC.pack . show <$> last_handshake_time)
] ++ expandAllowedIps (Just . BC.pack . show <$> allowed_ip)
return $ serializeRpcKeyValue peer
where
expandAllowedIps = foldr (\val acc -> ("allowed_ip", val):acc) []
setDevice :: RpcRequest -> Device -> STM (Maybe BS.ByteString)
setDevice req dev = do
let devReq = devicePayload . fromJust $ payload req
when (isJust $ pk devReq) . writeTVar (localKey dev) $ pk devReq
writeTVar (port dev) $ listenPort devReq
when (isJust $ fwMark devReq) . writeTVar (fwmark dev) . fromJust $ fwMark devReq
when (replacePeers devReq) $ delDevPeers dev
let peersList = peersPayload . fromJust $ payload req
unless (null peersList) $ setPeers peersList dev
return Nothing
-- TODO: Handle errors using errno.h
setPeers :: [RpcPeerPayload] -> Device -> STM ()
setPeers peerList dev = mapM_ inFunc peerList
where
inFunc peer = do
statePeers <- readTVar $ peers dev
let peerPubK = pubToString $ pubK peer
let peerExists = HM.member peerPubK statePeers
if remove peer && peerExists
then removePeer peer dev
else do
stmPeer <- if peerExists
then return . fromJust $ HM.lookup peerPubK statePeers
else createPeer $ pubK peer
modifySTMPeer peer stmPeer
let nPeers = HM.insert peerPubK stmPeer statePeers
writeTVar (peers dev) nPeers
modifySTMPeer :: RpcPeerPayload -> Peer -> STM ()
modifySTMPeer peer stmPeer = do
stmPIps <- if replaceIps peer
then return []
else readTVar $ ipmasks stmPeer
writeTVar (endPoint stmPeer) . Just $ endpoint peer
writeTVar (keepaliveInterval stmPeer) $ persistantKeepaliveInterval peer
writeTVar (ipmasks stmPeer) $ stmPIps ++ allowedIp peer
delDevPeers :: Device -> STM ()
delDevPeers dev = writeTVar (peers dev) HM.empty
removePeer :: RpcPeerPayload -> Device -> STM ()
removePeer peer dev = do
currentPeers <- readTVar $ peers dev
let nPeers = HM.delete (pubToString $ pubK peer) currentPeers
writeTVar (peers dev) nPeers
serializeRpcKeyValue :: [(String, Maybe BS.ByteString)] -> BS.ByteString
serializeRpcKeyValue = foldl' showKeyValueLine BS.empty
where
showKeyValueLine acc (key, Just val)
| val == BC.pack "0" = acc
| otherwise = BS.concat [acc, BC.pack key, BC.singleton '=', val, BC.singleton '\n']
showKeyValueLine acc (_, Nothing) = acc
pubToBytes :: PublicKey -> BS.ByteString
pubToBytes = BA.convert . DH.dhPubToBytes
pubToString :: PublicKey -> BS.ByteString
pubToString = toLowerBs . hex . pubToBytes
privToBytes :: PrivateKey -> BS.ByteString
privToBytes = BA.convert . DH.dhSecToBytes
toLowerBs :: BS.ByteString -> BS.ByteString
toLowerBs = BC.map toLower
|