aboutsummaryrefslogtreecommitdiffstats
path: root/src/Network
diff options
context:
space:
mode:
Diffstat (limited to 'src/Network')
-rw-r--r--src/Network/WireGuard/RPC.hs118
1 files changed, 70 insertions, 48 deletions
diff --git a/src/Network/WireGuard/RPC.hs b/src/Network/WireGuard/RPC.hs
index 162b5b4..6875332 100644
--- a/src/Network/WireGuard/RPC.hs
+++ b/src/Network/WireGuard/RPC.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE RecordWildCards #-}
+{-# LANGUAGE OverloadedStrings #-}
module Network.WireGuard.RPC
( runRPC,
@@ -16,17 +17,20 @@ import Control.Monad.IO.Class (liftIO)
import qualified Crypto.Noise.DH as DH (dhPubToBytes, dhSecToBytes,
dhBytesToPair, dhBytesToPair,
dhBytesToPub)
+import Crypto.Noise.DH.Curve25519 (Curve25519)
import qualified Data.ByteArray as BA (convert)
import qualified Data.ByteString as BS (ByteString, concat,
replicate, empty)
+import Data.ByteString.Lazy (fromStrict)
+import Data.ByteString.Conversion (toByteString')
import qualified Data.ByteString.Char8 as BC (pack, singleton, map)
import Data.Char (toLower)
import Data.Conduit.Attoparsec (sinkParserEither)
import Data.Conduit.Network.Unix (appSink, appSource,
runUnixServer,
serverSettings)
-import qualified Data.HashMap.Strict as HM ( delete, lookup, insert,
- empty, elems)
+import qualified Data.HashMap.Strict as HM (delete, lookup, insert,
+ empty, elems, member)
import Data.Hex (hex)
import Data.Int (Int32)
import Data.List (foldl')
@@ -37,7 +41,8 @@ import Data.IP (IPRange(..), addrRan
toHostAddress, toHostAddress6,
fromHostAddress, makeAddrRange,
fromHostAddress6)
-import Data.Maybe (fromJust, isJust)
+import Data.Maybe (fromJust, isJust,
+ fromMaybe)
import Network.WireGuard.Foreign.UAPI (WgPeer(..), WgDevice(..),
WgIpmask(..),
@@ -52,7 +57,8 @@ import Network.WireGuard.Internal.State (Device(..), Peer(..)
import Network.WireGuard.Internal.Data.Types (PrivateKey, PublicKey,
PresharedKey, KeyPair)
import Network.WireGuard.Internal.Data.RpcTypes (RpcRequest(..), RpcSetPayload(..),
- OpType(..))
+ OpType(..), RpcDevicePayload(..),
+ RpcPeerPayload(..))
import Network.WireGuard.Internal.Util (catchIOExceptionAnd)
-- | Run RPC service over a unix socket
@@ -71,11 +77,66 @@ serveConduit device = do
routeRequest (Left _) = yield mempty
routeRequest (Right req) =
case opType req of
- Set -> undefined
+ Set -> do
+ err <- liftIO . atomically $ setDevice req device
+ let errno = fromMaybe "0" err
+ yield $ BS.concat [BC.pack "errno=", errno, BC.pack "\n\n"]
Get -> do
deviceBstr <- liftIO . atomically $ showDevice device
yield $ BS.concat [deviceBstr, BC.pack "errno=0\n\n"]
+setDevice :: RpcRequest -> Device -> STM (Maybe BS.ByteString)
+setDevice req dev = do
+ let devReq = devicePayload . fromJust $ payload req
+ when (isJust $ pk devReq) . writeTVar (localKey dev) $ pk devReq
+ writeTVar (port dev) $ listenPort devReq
+ when (isJust $ fwMark devReq) . writeTVar (fwmark dev) . fromJust $ fwMark devReq
+ when (replacePeers devReq) $ delDevPeers dev
+ let peersList = peersPayload . fromJust $ payload req
+ when (not $ null peersList) $ setPeers peersList dev
+ return Nothing
+ -- TODO: Handle errors using errno.h
+
+setPeers :: [RpcPeerPayload] -> Device -> STM ()
+setPeers peerList dev = mapM_ inFunc peerList
+ where
+ inFunc peer = do
+ statePeers <- readTVar $ peers dev
+ let peerPubK = pubToString $ pubK peer
+ let peerExists = HM.member peerPubK statePeers
+ if remove peer
+ then removePeer peer dev
+ else if peerExists
+ then do
+ stmPeer <- modifyPeer peer (fromJust $ HM.lookup peerPubK statePeers)
+ let nPeers = HM.insert peerPubK stmPeer statePeers
+ writeTVar (peers dev) nPeers
+ else do
+ stmPeer <- createSTMPeer peer
+ let nPeers = HM.insert peerPubK stmPeer statePeers
+ writeTVar (peers dev) nPeers
+
+modifyPeer :: RpcPeerPayload -> Peer -> STM Peer
+modifyPeer peer stmPeer = undefined
+
+createSTMPeer :: RpcPeerPayload -> STM Peer
+createSTMPeer peer = do
+ stmPeer <- createPeer $ pubK peer
+ writeTVar (endPoint stmPeer) . Just $ endpoint peer
+ writeTVar (keepaliveInterval stmPeer) $ persistantKeepaliveInterval peer
+ writeTVar (ipmasks stmPeer) $ allowedIp peer
+ return stmPeer
+
+
+delDevPeers :: Device -> STM ()
+delDevPeers dev = writeTVar (peers dev) HM.empty
+
+removePeer :: RpcPeerPayload -> Device -> STM ()
+removePeer peer dev = do
+ currentPeers <- readTVar $ peers dev
+ let nPeers = HM.delete (pubToString $ pubK peer) currentPeers
+ writeTVar (peers dev) nPeers
+
showDevice :: Device -> STM BS.ByteString
showDevice device@Device{..} = do
listen_port <- BC.pack . show <$> readTVar port
@@ -92,7 +153,7 @@ showDevice device@Device{..} = do
showPeer :: Peer -> STM BS.ByteString
showPeer Peer{..} = do
let hm = HM.empty
- let public_key = toLowerBs . hex $ pubToBytes remotePub
+ let public_key = pubToString remotePub
endpoint <- readTVar endPoint
persistant_keepalive_interval <- readTVar keepaliveInterval
allowed_ip <- readTVar ipmasks
@@ -119,48 +180,6 @@ serializeRpcKeyValue = foldl' showKeyValueLine BS.empty
showKeyValueLine acc (_, Nothing) = acc
--- | implementation of config.c::set_peer()
-setPeer :: Device -> WgPeer -> [IPRange] -> STM Bool
-setPeer Device{..} WgPeer{..} ipranges
- | peerPubKey == emptyKey = return False
- | testFlag peerFlags peerFlagRemoveMe = modifyTVar' peers (HM.delete peerPubKey) >> return False
- | otherwise = do
- peers' <- readTVar peers
- Peer{..} <- case HM.lookup peerPubKey peers' of
- Nothing -> do
- newPeer <- createPeer (fromJust $ bytesToPub peerPubKey) -- TODO: replace fromJust
- modifyTVar' peers (HM.insert peerPubKey newPeer)
- return newPeer
- Just p -> return p
- when (isJust peerAddr) $ writeTVar endPoint peerAddr
- let replaceIpmasks = testFlag peerFlags peerFlagReplaceIpmasks
- changeIpmasks = replaceIpmasks || not (null ipranges)
- when changeIpmasks $
- if replaceIpmasks
- then writeTVar ipmasks ipranges
- else modifyTVar' ipmasks (++ipranges)
- when (peerKeepaliveInterval /= complement 0) $
- writeTVar keepaliveInterval (fromIntegral peerKeepaliveInterval)
- return changeIpmasks
-
--- | implementation of config.c::config_set_device()
-setDevice :: Device -> WgDevice -> STM ()
-setDevice device@Device{..} WgDevice{..} = do
- when (deviceFwmark /= 0 || deviceFwmark == 0 && testFlag deviceFlags deviceFlagRemoveFwmark) $
- writeTVar fwmark (fromIntegral deviceFwmark)
- when (devicePort /= 0) $ writeTVar port (fromIntegral devicePort)
- when (testFlag deviceFlags deviceFlagReplacePeers) $ writeTVar peers HM.empty
-
- let removeLocalKey = testFlag deviceFlags deviceFlagRemovePrivateKey
- changeLocalKey = removeLocalKey || devicePrivkey /= emptyKey
- changeLocalKeyTo = if removeLocalKey then Nothing else bytesToPair devicePrivkey
- when changeLocalKey $ writeTVar localKey changeLocalKeyTo
-
- let removePSK = testFlag deviceFlags deviceFlagRemovePresharedKey
- changePSK = removePSK || devicePSK /= emptyKey
- changePSKTo = if removePSK then Nothing else Just (bytesToPSK devicePSK)
- when changePSK $ writeTVar presharedKey changePSKTo
- when (changeLocalKey || changePSK) $ invalidateSessions device
ipRangeToWgIpmask :: IPRange -> WgIpmask
ipRangeToWgIpmask (IPv4Range ipv4range) = case addrRangePair ipv4range of
@@ -182,6 +201,9 @@ emptyKey = BS.replicate keyLength 0
pubToBytes :: PublicKey -> BS.ByteString
pubToBytes = BA.convert . DH.dhPubToBytes
+pubToString :: PublicKey -> BS.ByteString
+pubToString = toLowerBs . hex . pubToBytes
+
privToBytes :: PrivateKey -> BS.ByteString
privToBytes = BA.convert . DH.dhSecToBytes