From 8f5716c876f96be640539a3af129ed0a02cdcd85 Mon Sep 17 00:00:00 2001 From: Baylac-Jacqué Félix Date: Sat, 16 Sep 2017 15:35:34 +0200 Subject: Plumbed RPC set parser to STM state. --- src/Network/WireGuard/RPC.hs | 106 ++++++++++--------------------------------- 1 file changed, 25 insertions(+), 81 deletions(-) (limited to 'src/Network/WireGuard') diff --git a/src/Network/WireGuard/RPC.hs b/src/Network/WireGuard/RPC.hs index 6875332..0175127 100644 --- a/src/Network/WireGuard/RPC.hs +++ b/src/Network/WireGuard/RPC.hs @@ -10,52 +10,33 @@ module Network.WireGuard.RPC ) where import Control.Concurrent.STM (STM, atomically, - modifyTVar', readTVar, - writeTVar) -import Control.Monad (when) + readTVar, writeTVar) +import Control.Monad (when, unless) import Control.Monad.IO.Class (liftIO) import qualified Crypto.Noise.DH as DH (dhPubToBytes, dhSecToBytes, - dhBytesToPair, dhBytesToPair, - dhBytesToPub) -import Crypto.Noise.DH.Curve25519 (Curve25519) + dhBytesToPair, dhBytesToPair) import qualified Data.ByteArray as BA (convert) import qualified Data.ByteString as BS (ByteString, concat, - replicate, empty) -import Data.ByteString.Lazy (fromStrict) -import Data.ByteString.Conversion (toByteString') + empty) import qualified Data.ByteString.Char8 as BC (pack, singleton, map) import Data.Char (toLower) import Data.Conduit.Attoparsec (sinkParserEither) import Data.Conduit.Network.Unix (appSink, appSource, runUnixServer, serverSettings) -import qualified Data.HashMap.Strict as HM (delete, lookup, insert, - empty, elems, member) +import qualified Data.HashMap.Strict as HM (delete, lookup, insert, empty, + elems, member) import Data.Hex (hex) -import Data.Int (Int32) import Data.List (foldl') -import Data.Bits (Bits(..)) import Data.Conduit (ConduitM, (.|), yield, runConduit) -import Data.IP (IPRange(..), addrRangePair, - toHostAddress, toHostAddress6, - fromHostAddress, makeAddrRange, - fromHostAddress6) import Data.Maybe (fromJust, isJust, fromMaybe) - -import Network.WireGuard.Foreign.UAPI (WgPeer(..), WgDevice(..), - WgIpmask(..), - peerFlagRemoveMe, peerFlagReplaceIpmasks, - deviceFlagRemoveFwmark, deviceFlagReplacePeers, - deviceFlagRemovePrivateKey, deviceFlagRemovePresharedKey) -import Network.WireGuard.Internal.Constant (keyLength) import Network.WireGuard.Internal.RpcParsers (requestParser) import Network.WireGuard.Internal.State (Device(..), Peer(..), - createPeer, - invalidateSessions) + createPeer) import Network.WireGuard.Internal.Data.Types (PrivateKey, PublicKey, - PresharedKey, KeyPair) + KeyPair) import Network.WireGuard.Internal.Data.RpcTypes (RpcRequest(..), RpcSetPayload(..), OpType(..), RpcDevicePayload(..), RpcPeerPayload(..)) @@ -93,7 +74,7 @@ setDevice req dev = do when (isJust $ fwMark devReq) . writeTVar (fwmark dev) . fromJust $ fwMark devReq when (replacePeers devReq) $ delDevPeers dev let peersList = peersPayload . fromJust $ payload req - when (not $ null peersList) $ setPeers peersList dev + unless (null peersList) $ setPeers peersList dev return Nothing -- TODO: Handle errors using errno.h @@ -104,30 +85,25 @@ setPeers peerList dev = mapM_ inFunc peerList statePeers <- readTVar $ peers dev let peerPubK = pubToString $ pubK peer let peerExists = HM.member peerPubK statePeers - if remove peer + if remove peer && peerExists then removePeer peer dev - else if peerExists - then do - stmPeer <- modifyPeer peer (fromJust $ HM.lookup peerPubK statePeers) - let nPeers = HM.insert peerPubK stmPeer statePeers - writeTVar (peers dev) nPeers - else do - stmPeer <- createSTMPeer peer - let nPeers = HM.insert peerPubK stmPeer statePeers - writeTVar (peers dev) nPeers - -modifyPeer :: RpcPeerPayload -> Peer -> STM Peer -modifyPeer peer stmPeer = undefined - -createSTMPeer :: RpcPeerPayload -> STM Peer -createSTMPeer peer = do - stmPeer <- createPeer $ pubK peer + else do + stmPeer <- if peerExists + then return . fromJust $ HM.lookup peerPubK statePeers + else createPeer $ pubK peer + modifySTMPeer peer stmPeer + let nPeers = HM.insert peerPubK stmPeer statePeers + writeTVar (peers dev) nPeers + +modifySTMPeer :: RpcPeerPayload -> Peer -> STM () +modifySTMPeer peer stmPeer = do + stmPIps <- if replaceIps peer + then return [] + else readTVar $ ipmasks stmPeer writeTVar (endPoint stmPeer) . Just $ endpoint peer writeTVar (keepaliveInterval stmPeer) $ persistantKeepaliveInterval peer - writeTVar (ipmasks stmPeer) $ allowedIp peer - return stmPeer + writeTVar (ipmasks stmPeer) $ stmPIps ++ allowedIp peer - delDevPeers :: Device -> STM () delDevPeers dev = writeTVar (peers dev) HM.empty @@ -138,7 +114,7 @@ removePeer peer dev = do writeTVar (peers dev) nPeers showDevice :: Device -> STM BS.ByteString -showDevice device@Device{..} = do +showDevice Device{..} = do listen_port <- BC.pack . show <$> readTVar port fwm <- BC.pack . show <$> readTVar fwmark private_key <- fmap (toLowerBs . hex . privToBytes . fst) <$> readTVar localKey @@ -152,7 +128,6 @@ showDevice device@Device{..} = do showPeer :: Peer -> STM BS.ByteString showPeer Peer{..} = do - let hm = HM.empty let public_key = pubToString remotePub endpoint <- readTVar endPoint persistant_keepalive_interval <- readTVar keepaliveInterval @@ -179,25 +154,6 @@ serializeRpcKeyValue = foldl' showKeyValueLine BS.empty | otherwise = BS.concat [acc, BC.pack key, BC.singleton '=', val, BC.singleton '\n'] showKeyValueLine acc (_, Nothing) = acc - - -ipRangeToWgIpmask :: IPRange -> WgIpmask -ipRangeToWgIpmask (IPv4Range ipv4range) = case addrRangePair ipv4range of - (ipv4, prefix) -> WgIpmask (Left (toHostAddress ipv4)) (fromIntegral prefix) -ipRangeToWgIpmask (IPv6Range ipv6range) = case addrRangePair ipv6range of - (ipv6, prefix) -> WgIpmask (Right (toHostAddress6 ipv6)) (fromIntegral prefix) - -wgIpmaskToIpRange :: WgIpmask -> IPRange -wgIpmaskToIpRange (WgIpmask ip cidr) = case ip of - Left ipv4 -> IPv4Range $ makeAddrRange (fromHostAddress ipv4) (fromIntegral cidr) - Right ipv6 -> IPv6Range $ makeAddrRange (fromHostAddress6 ipv6) (fromIntegral cidr) - -invalidValueError :: Int32 -invalidValueError = 22 -- TODO: report back actual error - -emptyKey :: BS.ByteString -emptyKey = BS.replicate keyLength 0 - pubToBytes :: PublicKey -> BS.ByteString pubToBytes = BA.convert . DH.dhPubToBytes @@ -207,20 +163,8 @@ pubToString = toLowerBs . hex . pubToBytes privToBytes :: PrivateKey -> BS.ByteString privToBytes = BA.convert . DH.dhSecToBytes -pskToBytes :: PresharedKey -> BS.ByteString -pskToBytes = BA.convert - bytesToPair :: BS.ByteString -> Maybe KeyPair bytesToPair = DH.dhBytesToPair . BA.convert -bytesToPub :: BS.ByteString -> Maybe PublicKey -bytesToPub = DH.dhBytesToPub . BA.convert - -bytesToPSK :: BS.ByteString -> PresharedKey -bytesToPSK = BA.convert - toLowerBs :: BS.ByteString -> BS.ByteString toLowerBs = BC.map toLower - -testFlag :: Bits a => a -> a -> Bool -testFlag a flag = (a .&. flag) /= zeroBits -- cgit v1.2.3-59-g8ed1b