aboutsummaryrefslogtreecommitdiffstats
path: root/src/Network/WireGuard/RPC.hs
blob: 01751271d455ae8e08fd8d5a14301473a8accc8d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE OverloadedStrings #-}

module Network.WireGuard.RPC
  ( runRPC,
    serveConduit,
    bytesToPair,
    showDevice,
    showPeer
  ) where

import           Control.Concurrent.STM                    (STM, atomically,
                                                            readTVar, writeTVar)
import           Control.Monad                             (when, unless)
import           Control.Monad.IO.Class                    (liftIO)
import qualified Crypto.Noise.DH                     as DH (dhPubToBytes, dhSecToBytes,
                                                            dhBytesToPair, dhBytesToPair)
import qualified Data.ByteArray                      as BA (convert)
import qualified Data.ByteString                     as BS (ByteString, concat,
                                                            empty)
import qualified Data.ByteString.Char8               as BC (pack, singleton, map)
import           Data.Char                                 (toLower)
import           Data.Conduit.Attoparsec                   (sinkParserEither) 
import           Data.Conduit.Network.Unix                 (appSink, appSource,
                                                            runUnixServer,
                                                            serverSettings)
import qualified Data.HashMap.Strict                 as HM (delete, lookup, insert, empty,
                                                            elems, member)
import           Data.Hex                                  (hex)
import           Data.List                                 (foldl')
import           Data.Conduit                              (ConduitM, (.|),
                                                            yield, runConduit)
import           Data.Maybe                                (fromJust, isJust,
                                                            fromMaybe)
import           Network.WireGuard.Internal.RpcParsers     (requestParser)
import           Network.WireGuard.Internal.State          (Device(..), Peer(..),
                                                            createPeer)
import           Network.WireGuard.Internal.Data.Types     (PrivateKey, PublicKey,
                                                            KeyPair)
import           Network.WireGuard.Internal.Data.RpcTypes  (RpcRequest(..), RpcSetPayload(..),
                                                            OpType(..), RpcDevicePayload(..),
                                                            RpcPeerPayload(..))
import           Network.WireGuard.Internal.Util           (catchIOExceptionAnd)

-- | Run RPC service over a unix socket
runRPC :: FilePath -> Device -> IO ()
runRPC sockPath device = runUnixServer (serverSettings sockPath) $ \app ->
    catchIOExceptionAnd (return ()) $ 
      runConduit (appSource app .| serveConduit device .| appSink app)
    
-- TODO: ensure that all bytestring over sockets will be erased
serveConduit :: Device -> ConduitM BS.ByteString BS.ByteString IO ()
serveConduit device = do
  request <- sinkParserEither requestParser
  routeRequest request
  where
    --returnError = yield $ writeConfig (-invalidValueError)
    routeRequest (Left _) = yield mempty
    routeRequest (Right req) = 
      case opType req of
        Set -> do 
          err <- liftIO . atomically $ setDevice req device
          let errno = fromMaybe "0" err
          yield $ BS.concat [BC.pack "errno=", errno, BC.pack "\n\n"]
        Get -> do
          deviceBstr <- liftIO . atomically $ showDevice device
          yield $ BS.concat [deviceBstr, BC.pack "errno=0\n\n"]

setDevice :: RpcRequest -> Device -> STM (Maybe BS.ByteString)
setDevice req dev = do
  let devReq = devicePayload . fromJust $ payload req
  when (isJust $ pk devReq) . writeTVar (localKey dev) $ pk devReq
  writeTVar (port dev) $ listenPort devReq
  when (isJust $ fwMark devReq) . writeTVar (fwmark dev) . fromJust $ fwMark devReq
  when (replacePeers devReq) $ delDevPeers dev
  let peersList = peersPayload . fromJust $ payload req
  unless (null peersList) $ setPeers peersList dev
  return Nothing
  -- TODO: Handle errors using errno.h

setPeers :: [RpcPeerPayload] -> Device -> STM ()
setPeers peerList dev = mapM_ inFunc peerList
  where
    inFunc peer = do
      statePeers <- readTVar $ peers dev
      let peerPubK = pubToString $ pubK peer
      let peerExists = HM.member peerPubK statePeers
      if remove peer && peerExists
        then removePeer peer dev
        else do
          stmPeer <- if peerExists
                          then return . fromJust $ HM.lookup peerPubK statePeers
                          else createPeer $ pubK peer
          modifySTMPeer peer stmPeer
          let nPeers = HM.insert peerPubK stmPeer statePeers
          writeTVar (peers dev) nPeers

modifySTMPeer :: RpcPeerPayload -> Peer -> STM () 
modifySTMPeer peer stmPeer = do
  stmPIps <- if replaceIps peer
              then return []
              else readTVar $ ipmasks stmPeer
  writeTVar (endPoint stmPeer) . Just $ endpoint peer
  writeTVar (keepaliveInterval stmPeer) $ persistantKeepaliveInterval peer
  writeTVar (ipmasks stmPeer) $ stmPIps ++ allowedIp peer
  
delDevPeers :: Device -> STM ()
delDevPeers dev = writeTVar (peers dev) HM.empty

removePeer :: RpcPeerPayload -> Device -> STM ()
removePeer peer dev = do
  currentPeers <- readTVar $ peers dev
  let nPeers = HM.delete (pubToString $ pubK peer) currentPeers
  writeTVar (peers dev) nPeers

showDevice :: Device -> STM BS.ByteString
showDevice Device{..} = do
  listen_port   <- BC.pack . show <$> readTVar port
  fwm           <- BC.pack . show <$> readTVar fwmark
  private_key   <- fmap (toLowerBs . hex . privToBytes . fst) <$> readTVar localKey
  let devHm     = [("private_key", private_key),
                   ("listen_port", Just listen_port),
                   ("fwmark", Just fwm)]
  let devBs     = serializeRpcKeyValue devHm
  prs           <- readTVar peers 
  peersBstrList <-  mapM showPeer $ HM.elems prs
  return . BS.concat $ (devBs : peersBstrList)

showPeer :: Peer -> STM BS.ByteString
showPeer Peer{..} = do
  let public_key                =  pubToString remotePub
  endpoint                      <- readTVar endPoint
  persistant_keepalive_interval <- readTVar keepaliveInterval
  allowed_ip                    <- readTVar ipmasks
  rx_bytes                      <- readTVar receivedBytes
  tx_bytes                      <- readTVar transferredBytes
  last_handshake_time           <- readTVar lastHandshakeTime
  let peer = [("public_key", Just public_key),
              ("endpoint", BC.pack . show <$> endpoint),
              ("persistent_keepalive_interval", Just . BC.pack . show $ persistant_keepalive_interval),
              ("tx_bytes", Just . BC.pack . show $ tx_bytes),
              ("rx_bytes", Just . BC.pack . show $ rx_bytes),
              ("last_handshake_time", BC.pack . show <$> last_handshake_time)
              ] ++ expandAllowedIps (Just . BC.pack . show <$> allowed_ip)
  return $ serializeRpcKeyValue peer
  where
    expandAllowedIps = foldr (\val acc -> ("allowed_ip", val):acc) []

serializeRpcKeyValue :: [(String, Maybe BS.ByteString)] -> BS.ByteString
serializeRpcKeyValue = foldl' showKeyValueLine BS.empty
  where
    showKeyValueLine acc (key, Just val) 
      | val == BC.pack "0" = acc
      | otherwise          = BS.concat [acc, BC.pack key, BC.singleton '=', val, BC.singleton '\n']
    showKeyValueLine acc (_, Nothing) = acc

pubToBytes :: PublicKey -> BS.ByteString
pubToBytes = BA.convert . DH.dhPubToBytes

pubToString :: PublicKey -> BS.ByteString
pubToString = toLowerBs . hex . pubToBytes

privToBytes :: PrivateKey -> BS.ByteString
privToBytes = BA.convert . DH.dhSecToBytes

bytesToPair :: BS.ByteString -> Maybe KeyPair
bytesToPair = DH.dhBytesToPair . BA.convert

toLowerBs :: BS.ByteString -> BS.ByteString
toLowerBs = BC.map toLower