From 9fe8c0f80dafd297fa7828095241b5be005432de Mon Sep 17 00:00:00 2001 From: Baylac-Jacqué Félix Date: Thu, 14 Sep 2017 15:38:29 +0200 Subject: Wire up RPC set parsers to STM state. --- src/Network/WireGuard/RPC.hs | 118 +++++++++++++++++++++++++------------------ 1 file changed, 70 insertions(+), 48 deletions(-) (limited to 'src/Network/WireGuard') diff --git a/src/Network/WireGuard/RPC.hs b/src/Network/WireGuard/RPC.hs index 162b5b4..6875332 100644 --- a/src/Network/WireGuard/RPC.hs +++ b/src/Network/WireGuard/RPC.hs @@ -1,4 +1,5 @@ {-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE OverloadedStrings #-} module Network.WireGuard.RPC ( runRPC, @@ -16,17 +17,20 @@ import Control.Monad.IO.Class (liftIO) import qualified Crypto.Noise.DH as DH (dhPubToBytes, dhSecToBytes, dhBytesToPair, dhBytesToPair, dhBytesToPub) +import Crypto.Noise.DH.Curve25519 (Curve25519) import qualified Data.ByteArray as BA (convert) import qualified Data.ByteString as BS (ByteString, concat, replicate, empty) +import Data.ByteString.Lazy (fromStrict) +import Data.ByteString.Conversion (toByteString') import qualified Data.ByteString.Char8 as BC (pack, singleton, map) import Data.Char (toLower) import Data.Conduit.Attoparsec (sinkParserEither) import Data.Conduit.Network.Unix (appSink, appSource, runUnixServer, serverSettings) -import qualified Data.HashMap.Strict as HM ( delete, lookup, insert, - empty, elems) +import qualified Data.HashMap.Strict as HM (delete, lookup, insert, + empty, elems, member) import Data.Hex (hex) import Data.Int (Int32) import Data.List (foldl') @@ -37,7 +41,8 @@ import Data.IP (IPRange(..), addrRan toHostAddress, toHostAddress6, fromHostAddress, makeAddrRange, fromHostAddress6) -import Data.Maybe (fromJust, isJust) +import Data.Maybe (fromJust, isJust, + fromMaybe) import Network.WireGuard.Foreign.UAPI (WgPeer(..), WgDevice(..), WgIpmask(..), @@ -52,7 +57,8 @@ import Network.WireGuard.Internal.State (Device(..), Peer(..) import Network.WireGuard.Internal.Data.Types (PrivateKey, PublicKey, PresharedKey, KeyPair) import Network.WireGuard.Internal.Data.RpcTypes (RpcRequest(..), RpcSetPayload(..), - OpType(..)) + OpType(..), RpcDevicePayload(..), + RpcPeerPayload(..)) import Network.WireGuard.Internal.Util (catchIOExceptionAnd) -- | Run RPC service over a unix socket @@ -71,11 +77,66 @@ serveConduit device = do routeRequest (Left _) = yield mempty routeRequest (Right req) = case opType req of - Set -> undefined + Set -> do + err <- liftIO . atomically $ setDevice req device + let errno = fromMaybe "0" err + yield $ BS.concat [BC.pack "errno=", errno, BC.pack "\n\n"] Get -> do deviceBstr <- liftIO . atomically $ showDevice device yield $ BS.concat [deviceBstr, BC.pack "errno=0\n\n"] +setDevice :: RpcRequest -> Device -> STM (Maybe BS.ByteString) +setDevice req dev = do + let devReq = devicePayload . fromJust $ payload req + when (isJust $ pk devReq) . writeTVar (localKey dev) $ pk devReq + writeTVar (port dev) $ listenPort devReq + when (isJust $ fwMark devReq) . writeTVar (fwmark dev) . fromJust $ fwMark devReq + when (replacePeers devReq) $ delDevPeers dev + let peersList = peersPayload . fromJust $ payload req + when (not $ null peersList) $ setPeers peersList dev + return Nothing + -- TODO: Handle errors using errno.h + +setPeers :: [RpcPeerPayload] -> Device -> STM () +setPeers peerList dev = mapM_ inFunc peerList + where + inFunc peer = do + statePeers <- readTVar $ peers dev + let peerPubK = pubToString $ pubK peer + let peerExists = HM.member peerPubK statePeers + if remove peer + then removePeer peer dev + else if peerExists + then do + stmPeer <- modifyPeer peer (fromJust $ HM.lookup peerPubK statePeers) + let nPeers = HM.insert peerPubK stmPeer statePeers + writeTVar (peers dev) nPeers + else do + stmPeer <- createSTMPeer peer + let nPeers = HM.insert peerPubK stmPeer statePeers + writeTVar (peers dev) nPeers + +modifyPeer :: RpcPeerPayload -> Peer -> STM Peer +modifyPeer peer stmPeer = undefined + +createSTMPeer :: RpcPeerPayload -> STM Peer +createSTMPeer peer = do + stmPeer <- createPeer $ pubK peer + writeTVar (endPoint stmPeer) . Just $ endpoint peer + writeTVar (keepaliveInterval stmPeer) $ persistantKeepaliveInterval peer + writeTVar (ipmasks stmPeer) $ allowedIp peer + return stmPeer + + +delDevPeers :: Device -> STM () +delDevPeers dev = writeTVar (peers dev) HM.empty + +removePeer :: RpcPeerPayload -> Device -> STM () +removePeer peer dev = do + currentPeers <- readTVar $ peers dev + let nPeers = HM.delete (pubToString $ pubK peer) currentPeers + writeTVar (peers dev) nPeers + showDevice :: Device -> STM BS.ByteString showDevice device@Device{..} = do listen_port <- BC.pack . show <$> readTVar port @@ -92,7 +153,7 @@ showDevice device@Device{..} = do showPeer :: Peer -> STM BS.ByteString showPeer Peer{..} = do let hm = HM.empty - let public_key = toLowerBs . hex $ pubToBytes remotePub + let public_key = pubToString remotePub endpoint <- readTVar endPoint persistant_keepalive_interval <- readTVar keepaliveInterval allowed_ip <- readTVar ipmasks @@ -119,48 +180,6 @@ serializeRpcKeyValue = foldl' showKeyValueLine BS.empty showKeyValueLine acc (_, Nothing) = acc --- | implementation of config.c::set_peer() -setPeer :: Device -> WgPeer -> [IPRange] -> STM Bool -setPeer Device{..} WgPeer{..} ipranges - | peerPubKey == emptyKey = return False - | testFlag peerFlags peerFlagRemoveMe = modifyTVar' peers (HM.delete peerPubKey) >> return False - | otherwise = do - peers' <- readTVar peers - Peer{..} <- case HM.lookup peerPubKey peers' of - Nothing -> do - newPeer <- createPeer (fromJust $ bytesToPub peerPubKey) -- TODO: replace fromJust - modifyTVar' peers (HM.insert peerPubKey newPeer) - return newPeer - Just p -> return p - when (isJust peerAddr) $ writeTVar endPoint peerAddr - let replaceIpmasks = testFlag peerFlags peerFlagReplaceIpmasks - changeIpmasks = replaceIpmasks || not (null ipranges) - when changeIpmasks $ - if replaceIpmasks - then writeTVar ipmasks ipranges - else modifyTVar' ipmasks (++ipranges) - when (peerKeepaliveInterval /= complement 0) $ - writeTVar keepaliveInterval (fromIntegral peerKeepaliveInterval) - return changeIpmasks - --- | implementation of config.c::config_set_device() -setDevice :: Device -> WgDevice -> STM () -setDevice device@Device{..} WgDevice{..} = do - when (deviceFwmark /= 0 || deviceFwmark == 0 && testFlag deviceFlags deviceFlagRemoveFwmark) $ - writeTVar fwmark (fromIntegral deviceFwmark) - when (devicePort /= 0) $ writeTVar port (fromIntegral devicePort) - when (testFlag deviceFlags deviceFlagReplacePeers) $ writeTVar peers HM.empty - - let removeLocalKey = testFlag deviceFlags deviceFlagRemovePrivateKey - changeLocalKey = removeLocalKey || devicePrivkey /= emptyKey - changeLocalKeyTo = if removeLocalKey then Nothing else bytesToPair devicePrivkey - when changeLocalKey $ writeTVar localKey changeLocalKeyTo - - let removePSK = testFlag deviceFlags deviceFlagRemovePresharedKey - changePSK = removePSK || devicePSK /= emptyKey - changePSKTo = if removePSK then Nothing else Just (bytesToPSK devicePSK) - when changePSK $ writeTVar presharedKey changePSKTo - when (changeLocalKey || changePSK) $ invalidateSessions device ipRangeToWgIpmask :: IPRange -> WgIpmask ipRangeToWgIpmask (IPv4Range ipv4range) = case addrRangePair ipv4range of @@ -182,6 +201,9 @@ emptyKey = BS.replicate keyLength 0 pubToBytes :: PublicKey -> BS.ByteString pubToBytes = BA.convert . DH.dhPubToBytes +pubToString :: PublicKey -> BS.ByteString +pubToString = toLowerBs . hex . pubToBytes + privToBytes :: PrivateKey -> BS.ByteString privToBytes = BA.convert . DH.dhSecToBytes -- cgit v1.2.3-59-g8ed1b