aboutsummaryrefslogtreecommitdiffstats
path: root/src/Network/WireGuard/RPC.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Network/WireGuard/RPC.hs')
-rw-r--r--src/Network/WireGuard/RPC.hs93
1 files changed, 36 insertions, 57 deletions
diff --git a/src/Network/WireGuard/RPC.hs b/src/Network/WireGuard/RPC.hs
index 73d9e7a..ae9e552 100644
--- a/src/Network/WireGuard/RPC.hs
+++ b/src/Network/WireGuard/RPC.hs
@@ -13,67 +13,49 @@ module Network.WireGuard.RPC
import Control.Concurrent.STM (STM, atomically,
modifyTVar', readTVar,
writeTVar)
-import Control.Monad (replicateM, sequence,
- when)
+import Control.Monad (when)
import Control.Monad.IO.Class (liftIO)
import qualified Crypto.Noise.DH as DH (dhPubToBytes, dhSecToBytes,
dhBytesToPair, dhBytesToPair,
dhBytesToPub)
import qualified Data.ByteArray as BA (convert)
import qualified Data.ByteString as BS (ByteString, concat,
- replicate, empty, pack)
-import qualified Data.ByteString.Lazy.Char8 as CL (unpack)
+ replicate, empty)
import qualified Data.ByteString.Char8 as BC (pack, singleton, map)
import Data.Char (toLower)
-import qualified Data.Conduit.Binary as CB (sinkStorable, sinkLbs)
+import Data.Conduit.Attoparsec (sinkParserEither)
import Data.Conduit.Network.Unix (appSink, appSource,
runUnixServer,
serverSettings)
-import qualified Data.HashMap.Strict as HM (HashMap(..), size, delete,
- lookup, insert,
- empty, fromList,
- foldrWithKey, elems)
+import qualified Data.HashMap.Strict as HM ( delete, lookup, insert,
+ empty, elems)
import Data.Hex (hex)
import Data.Int (Int32)
-import Data.List (foldl', genericLength)
-import Foreign.C.Types (CTime (..))
-
+import Data.List (foldl')
import Data.Bits (Bits(..))
import Data.Conduit (ConduitM, (.|),
- yield, runConduit,
- toConsumer)
+ yield, runConduit)
import Data.IP (IPRange(..), addrRangePair,
toHostAddress, toHostAddress6,
fromHostAddress, makeAddrRange,
fromHostAddress6)
-import Data.Maybe (fromMaybe, fromJust, isJust)
+import Data.Maybe (fromJust, isJust)
import Network.WireGuard.Foreign.UAPI (WgPeer(..), WgDevice(..),
- WgIpmask(..), writeConfig,
+ WgIpmask(..),
peerFlagRemoveMe, peerFlagReplaceIpmasks,
deviceFlagRemoveFwmark, deviceFlagReplacePeers,
deviceFlagRemovePrivateKey, deviceFlagRemovePresharedKey)
import Network.WireGuard.Internal.Constant (keyLength)
+import Network.WireGuard.Internal.RpcParsers (RpcRequest(..), RpcSetPayload(..),
+ OpType(..), requestParser)
import Network.WireGuard.Internal.State (Device(..), Peer(..),
- buildRouteTables, createPeer,
+ createPeer,
invalidateSessions)
-import Network.WireGuard.Internal.Types (PrivateKey, PublicKey,
+import Network.WireGuard.Internal.Data.Types (PrivateKey, PublicKey,
PresharedKey, KeyPair)
import Network.WireGuard.Internal.Util (catchIOExceptionAnd)
--- | Kind of client operation.
---
--- See <https://www.wireguard.com/xplatform/#configuration-protocol> for more informations.
-data OpType = Get | Set
-
--- | Request wrapper. The payload is set only for Set operations.
---
--- See <https://www.wireguard.com/xplatform/#configuration-protocol> for more informations.
-data RpcRequest = RpcRequest {
- opType :: OpType,
- payload :: BS.ByteString
-}
-
-- | Run RPC service over a unix socket
runRPC :: FilePath -> Device -> IO ()
runRPC sockPath device = runUnixServer (serverSettings sockPath) $ \app ->
@@ -83,34 +65,30 @@ runRPC sockPath device = runUnixServer (serverSettings sockPath) $ \app ->
-- TODO: ensure that all bytestring over sockets will be erased
serveConduit :: Device -> ConduitM BS.ByteString BS.ByteString IO ()
serveConduit device = do
- request <- CL.unpack <$> toConsumer CB.sinkLbs
- if request /= ""
- then routeRequest request
- else yield mempty
+ request <- sinkParserEither requestParser
+ routeRequest request
where
--returnError = yield $ writeConfig (-invalidValueError)
- isGet = (== "get=1")
- isSet = (== "set=1")
- routeRequest req = do
- let line = head $ lines req
- case () of _
- | isGet line -> do
- deviceBstr <- liftIO . atomically $ showDevice device
- yield deviceBstr
- | otherwise -> yield mempty
+ routeRequest (Left _) = yield mempty
+ routeRequest (Right req) =
+ case opType req of
+ Set -> undefined
+ Get -> do
+ deviceBstr <- liftIO . atomically $ showDevice device
+ yield $ BS.concat [deviceBstr, BC.pack "errno=0\n\n"]
showDevice :: Device -> STM BS.ByteString
showDevice device@Device{..} = do
listen_port <- BC.pack . show <$> readTVar port
- fwmark <- BC.pack . show <$> readTVar fwmark
+ fwm <- BC.pack . show <$> readTVar fwmark
private_key <- fmap (toLowerBs . hex . privToBytes . fst) <$> readTVar localKey
let devHm = [("private_key", private_key),
("listen_port", Just listen_port),
- ("fwmark", Just fwmark)]
+ ("fwmark", Just fwm)]
let devBs = serializeRpcKeyValue devHm
- peers <- readTVar peers
- peersBstrList <- mapM showPeer $ HM.elems peers
- return . BS.concat $ (devBs : peersBstrList ++ [BC.singleton '\n'])
+ prs <- readTVar peers
+ peersBstrList <- mapM showPeer $ HM.elems prs
+ return . BS.concat $ (devBs : peersBstrList)
showPeer :: Peer -> STM BS.ByteString
showPeer Peer{..} = do
@@ -124,9 +102,9 @@ showPeer Peer{..} = do
last_handshake_time <- readTVar lastHandshakeTime
let peer = [("public_key", Just public_key),
("endpoint", BC.pack . show <$> endpoint),
- ("persistant_keepalive_interval", Just . BC.pack . show $ persistant_keepalive_interval),
- ("rx_bytes", Just . BC.pack . show $ rx_bytes),
+ ("persistent_keepalive_interval", Just . BC.pack . show $ persistant_keepalive_interval),
("tx_bytes", Just . BC.pack . show $ tx_bytes),
+ ("rx_bytes", Just . BC.pack . show $ rx_bytes),
("last_handshake_time", BC.pack . show <$> last_handshake_time)
] ++ expandAllowedIps (Just . BC.pack . show <$> allowed_ip)
return $ serializeRpcKeyValue peer
@@ -136,8 +114,10 @@ showPeer Peer{..} = do
serializeRpcKeyValue :: [(String, Maybe BS.ByteString)] -> BS.ByteString
serializeRpcKeyValue = foldl' showKeyValueLine BS.empty
where
- showKeyValueLine acc (key, Just val) = BS.concat [acc, BC.pack key, BC.singleton '=', val, BC.singleton '\n']
- showKeyValueLine acc (_, Nothing) = acc
+ showKeyValueLine acc (key, Just val)
+ | val == BC.pack "0" = acc
+ | otherwise = BS.concat [acc, BC.pack key, BC.singleton '=', val, BC.singleton '\n']
+ showKeyValueLine acc (_, Nothing) = acc
-- | implementation of config.c::set_peer()
@@ -181,7 +161,6 @@ setDevice device@Device{..} WgDevice{..} = do
changePSK = removePSK || devicePSK /= emptyKey
changePSKTo = if removePSK then Nothing else Just (bytesToPSK devicePSK)
when changePSK $ writeTVar presharedKey changePSKTo
-
when (changeLocalKey || changePSK) $ invalidateSessions device
ipRangeToWgIpmask :: IPRange -> WgIpmask
@@ -201,9 +180,6 @@ invalidValueError = 22 -- TODO: report back actual error
emptyKey :: BS.ByteString
emptyKey = BS.replicate keyLength 0
-testFlag :: Bits a => a -> a -> Bool
-testFlag a flag = (a .&. flag) /= zeroBits
-
pubToBytes :: PublicKey -> BS.ByteString
pubToBytes = BA.convert . DH.dhPubToBytes
@@ -224,3 +200,6 @@ bytesToPSK = BA.convert
toLowerBs :: BS.ByteString -> BS.ByteString
toLowerBs = BC.map toLower
+
+testFlag :: Bits a => a -> a -> Bool
+testFlag a flag = (a .&. flag) /= zeroBits