aboutsummaryrefslogtreecommitdiffstats
path: root/src/Network/WireGuard
diff options
context:
space:
mode:
Diffstat (limited to 'src/Network/WireGuard')
-rw-r--r--src/Network/WireGuard/RPC.hs106
1 files changed, 25 insertions, 81 deletions
diff --git a/src/Network/WireGuard/RPC.hs b/src/Network/WireGuard/RPC.hs
index 6875332..0175127 100644
--- a/src/Network/WireGuard/RPC.hs
+++ b/src/Network/WireGuard/RPC.hs
@@ -10,52 +10,33 @@ module Network.WireGuard.RPC
) where
import Control.Concurrent.STM (STM, atomically,
- modifyTVar', readTVar,
- writeTVar)
-import Control.Monad (when)
+ readTVar, writeTVar)
+import Control.Monad (when, unless)
import Control.Monad.IO.Class (liftIO)
import qualified Crypto.Noise.DH as DH (dhPubToBytes, dhSecToBytes,
- dhBytesToPair, dhBytesToPair,
- dhBytesToPub)
-import Crypto.Noise.DH.Curve25519 (Curve25519)
+ dhBytesToPair, dhBytesToPair)
import qualified Data.ByteArray as BA (convert)
import qualified Data.ByteString as BS (ByteString, concat,
- replicate, empty)
-import Data.ByteString.Lazy (fromStrict)
-import Data.ByteString.Conversion (toByteString')
+ empty)
import qualified Data.ByteString.Char8 as BC (pack, singleton, map)
import Data.Char (toLower)
import Data.Conduit.Attoparsec (sinkParserEither)
import Data.Conduit.Network.Unix (appSink, appSource,
runUnixServer,
serverSettings)
-import qualified Data.HashMap.Strict as HM (delete, lookup, insert,
- empty, elems, member)
+import qualified Data.HashMap.Strict as HM (delete, lookup, insert, empty,
+ elems, member)
import Data.Hex (hex)
-import Data.Int (Int32)
import Data.List (foldl')
-import Data.Bits (Bits(..))
import Data.Conduit (ConduitM, (.|),
yield, runConduit)
-import Data.IP (IPRange(..), addrRangePair,
- toHostAddress, toHostAddress6,
- fromHostAddress, makeAddrRange,
- fromHostAddress6)
import Data.Maybe (fromJust, isJust,
fromMaybe)
-
-import Network.WireGuard.Foreign.UAPI (WgPeer(..), WgDevice(..),
- WgIpmask(..),
- peerFlagRemoveMe, peerFlagReplaceIpmasks,
- deviceFlagRemoveFwmark, deviceFlagReplacePeers,
- deviceFlagRemovePrivateKey, deviceFlagRemovePresharedKey)
-import Network.WireGuard.Internal.Constant (keyLength)
import Network.WireGuard.Internal.RpcParsers (requestParser)
import Network.WireGuard.Internal.State (Device(..), Peer(..),
- createPeer,
- invalidateSessions)
+ createPeer)
import Network.WireGuard.Internal.Data.Types (PrivateKey, PublicKey,
- PresharedKey, KeyPair)
+ KeyPair)
import Network.WireGuard.Internal.Data.RpcTypes (RpcRequest(..), RpcSetPayload(..),
OpType(..), RpcDevicePayload(..),
RpcPeerPayload(..))
@@ -93,7 +74,7 @@ setDevice req dev = do
when (isJust $ fwMark devReq) . writeTVar (fwmark dev) . fromJust $ fwMark devReq
when (replacePeers devReq) $ delDevPeers dev
let peersList = peersPayload . fromJust $ payload req
- when (not $ null peersList) $ setPeers peersList dev
+ unless (null peersList) $ setPeers peersList dev
return Nothing
-- TODO: Handle errors using errno.h
@@ -104,30 +85,25 @@ setPeers peerList dev = mapM_ inFunc peerList
statePeers <- readTVar $ peers dev
let peerPubK = pubToString $ pubK peer
let peerExists = HM.member peerPubK statePeers
- if remove peer
+ if remove peer && peerExists
then removePeer peer dev
- else if peerExists
- then do
- stmPeer <- modifyPeer peer (fromJust $ HM.lookup peerPubK statePeers)
- let nPeers = HM.insert peerPubK stmPeer statePeers
- writeTVar (peers dev) nPeers
- else do
- stmPeer <- createSTMPeer peer
- let nPeers = HM.insert peerPubK stmPeer statePeers
- writeTVar (peers dev) nPeers
-
-modifyPeer :: RpcPeerPayload -> Peer -> STM Peer
-modifyPeer peer stmPeer = undefined
-
-createSTMPeer :: RpcPeerPayload -> STM Peer
-createSTMPeer peer = do
- stmPeer <- createPeer $ pubK peer
+ else do
+ stmPeer <- if peerExists
+ then return . fromJust $ HM.lookup peerPubK statePeers
+ else createPeer $ pubK peer
+ modifySTMPeer peer stmPeer
+ let nPeers = HM.insert peerPubK stmPeer statePeers
+ writeTVar (peers dev) nPeers
+
+modifySTMPeer :: RpcPeerPayload -> Peer -> STM ()
+modifySTMPeer peer stmPeer = do
+ stmPIps <- if replaceIps peer
+ then return []
+ else readTVar $ ipmasks stmPeer
writeTVar (endPoint stmPeer) . Just $ endpoint peer
writeTVar (keepaliveInterval stmPeer) $ persistantKeepaliveInterval peer
- writeTVar (ipmasks stmPeer) $ allowedIp peer
- return stmPeer
+ writeTVar (ipmasks stmPeer) $ stmPIps ++ allowedIp peer
-
delDevPeers :: Device -> STM ()
delDevPeers dev = writeTVar (peers dev) HM.empty
@@ -138,7 +114,7 @@ removePeer peer dev = do
writeTVar (peers dev) nPeers
showDevice :: Device -> STM BS.ByteString
-showDevice device@Device{..} = do
+showDevice Device{..} = do
listen_port <- BC.pack . show <$> readTVar port
fwm <- BC.pack . show <$> readTVar fwmark
private_key <- fmap (toLowerBs . hex . privToBytes . fst) <$> readTVar localKey
@@ -152,7 +128,6 @@ showDevice device@Device{..} = do
showPeer :: Peer -> STM BS.ByteString
showPeer Peer{..} = do
- let hm = HM.empty
let public_key = pubToString remotePub
endpoint <- readTVar endPoint
persistant_keepalive_interval <- readTVar keepaliveInterval
@@ -179,25 +154,6 @@ serializeRpcKeyValue = foldl' showKeyValueLine BS.empty
| otherwise = BS.concat [acc, BC.pack key, BC.singleton '=', val, BC.singleton '\n']
showKeyValueLine acc (_, Nothing) = acc
-
-
-ipRangeToWgIpmask :: IPRange -> WgIpmask
-ipRangeToWgIpmask (IPv4Range ipv4range) = case addrRangePair ipv4range of
- (ipv4, prefix) -> WgIpmask (Left (toHostAddress ipv4)) (fromIntegral prefix)
-ipRangeToWgIpmask (IPv6Range ipv6range) = case addrRangePair ipv6range of
- (ipv6, prefix) -> WgIpmask (Right (toHostAddress6 ipv6)) (fromIntegral prefix)
-
-wgIpmaskToIpRange :: WgIpmask -> IPRange
-wgIpmaskToIpRange (WgIpmask ip cidr) = case ip of
- Left ipv4 -> IPv4Range $ makeAddrRange (fromHostAddress ipv4) (fromIntegral cidr)
- Right ipv6 -> IPv6Range $ makeAddrRange (fromHostAddress6 ipv6) (fromIntegral cidr)
-
-invalidValueError :: Int32
-invalidValueError = 22 -- TODO: report back actual error
-
-emptyKey :: BS.ByteString
-emptyKey = BS.replicate keyLength 0
-
pubToBytes :: PublicKey -> BS.ByteString
pubToBytes = BA.convert . DH.dhPubToBytes
@@ -207,20 +163,8 @@ pubToString = toLowerBs . hex . pubToBytes
privToBytes :: PrivateKey -> BS.ByteString
privToBytes = BA.convert . DH.dhSecToBytes
-pskToBytes :: PresharedKey -> BS.ByteString
-pskToBytes = BA.convert
-
bytesToPair :: BS.ByteString -> Maybe KeyPair
bytesToPair = DH.dhBytesToPair . BA.convert
-bytesToPub :: BS.ByteString -> Maybe PublicKey
-bytesToPub = DH.dhBytesToPub . BA.convert
-
-bytesToPSK :: BS.ByteString -> PresharedKey
-bytesToPSK = BA.convert
-
toLowerBs :: BS.ByteString -> BS.ByteString
toLowerBs = BC.map toLower
-
-testFlag :: Bits a => a -> a -> Bool
-testFlag a flag = (a .&. flag) /= zeroBits