aboutsummaryrefslogtreecommitdiffstats
path: root/src/Network/WireGuard/RPC.hs
diff options
context:
space:
mode:
authorBin Jin <bjin@ctrl-d.org>2017-03-12 17:48:20 +0800
committerBin Jin <bjin@ctrl-d.org>2017-03-12 17:48:20 +0800
commita2a3e540f2d1a507b34eccae26de09066a2a12fa (patch)
tree1e336189acd095d761eaad252f18ee6e32b7216b /src/Network/WireGuard/RPC.hs
downloadwireguard-hs-a2a3e540f2d1a507b34eccae26de09066a2a12fa.tar.xz
wireguard-hs-a2a3e540f2d1a507b34eccae26de09066a2a12fa.zip
Initial commit
Diffstat (limited to 'src/Network/WireGuard/RPC.hs')
-rw-r--r--src/Network/WireGuard/RPC.hs187
1 files changed, 187 insertions, 0 deletions
diff --git a/src/Network/WireGuard/RPC.hs b/src/Network/WireGuard/RPC.hs
new file mode 100644
index 0000000..7ecb8de
--- /dev/null
+++ b/src/Network/WireGuard/RPC.hs
@@ -0,0 +1,187 @@
+{-# LANGUAGE RecordWildCards #-}
+
+module Network.WireGuard.RPC
+ ( runRPC
+ ) where
+
+import Control.Concurrent.STM (STM, atomically,
+ modifyTVar', readTVar,
+ writeTVar)
+import Control.Monad (replicateM, sequence,
+ when)
+import Control.Monad.IO.Class (liftIO)
+import qualified Crypto.Noise.DH as DH
+import qualified Data.ByteArray as BA
+import qualified Data.ByteString as BS
+import qualified Data.Conduit.Binary as CB
+import Data.Conduit.Network.Unix (appSink, appSource,
+ runUnixServer,
+ serverSettings)
+import qualified Data.HashMap.Strict as HM
+import Data.Int (Int32)
+import Data.List (genericLength)
+import Foreign.C.Types (CTime (..))
+
+import Data.Bits
+import Data.Conduit
+import Data.IP
+import Data.Maybe
+
+import Network.WireGuard.Foreign.UAPI
+import Network.WireGuard.Internal.Constant
+import Network.WireGuard.Internal.State
+import Network.WireGuard.Internal.Types
+import Network.WireGuard.Internal.Util
+
+-- | Run RPC service over a unix socket
+runRPC :: FilePath -> Device -> IO ()
+runRPC sockPath device = runUnixServer (serverSettings sockPath) $ \app ->
+ catchIOExceptionAnd (return ()) $ runConduit (appSource app =$= serveConduit =$= appSink app)
+ where
+ -- TODO: ensure that all bytestring over sockets will be erased
+ serveConduit = do
+ h <- CB.head
+ case h of
+ Just 0 -> showDevice device
+ Just byte -> do
+ leftover (BS.singleton byte)
+ mWgdev <- CB.sinkStorable
+ case mWgdev of
+ Just wgdev -> catchSomeExceptionAnd returnError (updateDevice wgdev)
+ Nothing -> mempty
+ Nothing -> mempty
+
+ returnError = yield $ writeConfig (-invalidValueError)
+
+ showDevice Device{..} = do
+ (wgdevice, peers') <- liftIO buildWgDevice
+ yield (writeConfig wgdevice)
+ mapM_ showPeer peers'
+ where
+ buildWgDevice = atomically $ do
+ localKey' <- readTVar localKey
+ let (pub, priv) = case localKey' of
+ Nothing -> (emptyKey, emptyKey)
+ Just (sec, pub') -> (pubToBytes pub', privToBytes sec)
+ psk' <- fmap pskToBytes <$> readTVar presharedKey
+ fwmark' <- fromIntegral <$> readTVar fwmark
+ port' <- fromIntegral <$> readTVar port
+ peers' <- readTVar peers
+ return (WgDevice intfName 0 pub priv (fromMaybe emptyKey psk')
+ fwmark' port' (fromIntegral $ HM.size peers'), peers')
+
+ showPeer Peer{..} = do
+ (wgpeer, ipmasks') <- liftIO buildWgPeer
+ yield (writeConfig wgpeer)
+ yield $ BS.concat (map (writeConfig . ipRangeToWgIpmask) ipmasks')
+ where
+ extractTime Nothing = 0
+ extractTime (Just (CTime t)) = fromIntegral t
+
+ buildWgPeer = atomically $ do
+ ipmasks' <- readTVar ipmasks
+ wgpeer <- WgPeer (pubToBytes remotePub)
+ <$> return 0
+ <*> readTVar endPoint
+ <*> (extractTime <$> readTVar lastHandshakeTime)
+ <*> (fromIntegral <$> readTVar receivedBytes)
+ <*> (fromIntegral <$> readTVar transferredBytes)
+ <*> (fromIntegral <$> readTVar keepaliveInterval)
+ <*> return (genericLength ipmasks')
+ return (wgpeer, ipmasks')
+
+ updateDevice wgdevice = do
+ setPeerMs <- replicateM (fromIntegral $ deviceNumPeers wgdevice) $ do
+ Just wgpeer <- CB.sinkStorable
+ -- TODO: replace fromJust
+ ipranges <- replicateM (fromIntegral $ peerNumIpmasks wgpeer)
+ (wgIpmaskToIpRange . fromJust <$> CB.sinkStorable)
+ return $ setPeer device wgpeer ipranges
+ liftIO $ atomically $ do
+ setDevice device wgdevice
+ anyIpMaskChanged <- or <$> sequence setPeerMs
+ -- TODO: modify routetable incrementally
+ when anyIpMaskChanged $ buildRouteTables device
+ yield $ writeConfig (0 :: Int32)
+
+-- | implementation of config.c::set_peer()
+setPeer :: Device -> WgPeer -> [IPRange] -> STM Bool
+setPeer Device{..} WgPeer{..} ipranges
+ | peerPubKey == emptyKey = return False
+ | testFlag peerFlags peerFlagRemoveMe = modifyTVar' peers (HM.delete peerPubKey) >> return False
+ | otherwise = do
+ peers' <- readTVar peers
+ Peer{..} <- case HM.lookup peerPubKey peers' of
+ Nothing -> do
+ newPeer <- createPeer (fromJust $ bytesToPub peerPubKey) -- TODO: replace fromJust
+ modifyTVar' peers (HM.insert peerPubKey newPeer)
+ return newPeer
+ Just p -> return p
+ when (isJust peerAddr) $ writeTVar endPoint peerAddr
+ let replaceIpmasks = testFlag peerFlags peerFlagReplaceIpmasks
+ changeIpmasks = replaceIpmasks || not (null ipranges)
+ when changeIpmasks $
+ if replaceIpmasks
+ then writeTVar ipmasks ipranges
+ else modifyTVar' ipmasks (++ipranges)
+ when (peerKeepaliveInterval /= complement 0) $
+ writeTVar keepaliveInterval (fromIntegral peerKeepaliveInterval)
+ return changeIpmasks
+
+-- | implementation of config.c::config_set_device()
+setDevice :: Device -> WgDevice -> STM ()
+setDevice device@Device{..} WgDevice{..} = do
+ when (deviceFwmark /= 0 || deviceFwmark == 0 && testFlag deviceFlags deviceFlagRemoveFwmark) $
+ writeTVar fwmark (fromIntegral deviceFwmark)
+ when (devicePort /= 0) $ writeTVar port (fromIntegral devicePort)
+ when (testFlag deviceFlags deviceFlagReplacePeers) $ writeTVar peers HM.empty
+
+ let removeLocalKey = testFlag deviceFlags deviceFlagRemovePrivateKey
+ changeLocalKey = removeLocalKey || devicePrivkey /= emptyKey
+ changeLocalKeyTo = if removeLocalKey then Nothing else bytesToPair devicePrivkey
+ when changeLocalKey $ writeTVar localKey changeLocalKeyTo
+
+ let removePSK = testFlag deviceFlags deviceFlagRemovePresharedKey
+ changePSK = removePSK || devicePSK /= emptyKey
+ changePSKTo = if removePSK then Nothing else Just (bytesToPSK devicePSK)
+ when changePSK $ writeTVar presharedKey changePSKTo
+
+ when (changeLocalKey || changePSK) $ invalidateSessions device
+
+ipRangeToWgIpmask :: IPRange -> WgIpmask
+ipRangeToWgIpmask (IPv4Range ipv4range) = case addrRangePair ipv4range of
+ (ipv4, prefix) -> WgIpmask (Left (toHostAddress ipv4)) (fromIntegral prefix)
+ipRangeToWgIpmask (IPv6Range ipv6range) = case addrRangePair ipv6range of
+ (ipv6, prefix) -> WgIpmask (Right (toHostAddress6 ipv6)) (fromIntegral prefix)
+
+wgIpmaskToIpRange :: WgIpmask -> IPRange
+wgIpmaskToIpRange (WgIpmask ip cidr) = case ip of
+ Left ipv4 -> IPv4Range $ makeAddrRange (fromHostAddress ipv4) (fromIntegral cidr)
+ Right ipv6 -> IPv6Range $ makeAddrRange (fromHostAddress6 ipv6) (fromIntegral cidr)
+
+invalidValueError :: Int32
+invalidValueError = 22 -- TODO: report back actual error
+
+emptyKey :: BS.ByteString
+emptyKey = BS.replicate keyLength 0
+
+testFlag :: Bits a => a -> a -> Bool
+testFlag a flag = (a .&. flag) /= zeroBits
+
+pubToBytes :: PublicKey -> BS.ByteString
+pubToBytes = BA.convert . DH.dhPubToBytes
+
+privToBytes :: PrivateKey -> BS.ByteString
+privToBytes = BA.convert . DH.dhSecToBytes
+
+pskToBytes :: PresharedKey -> BS.ByteString
+pskToBytes = BA.convert
+
+bytesToPair :: BS.ByteString -> Maybe KeyPair
+bytesToPair = DH.dhBytesToPair . BA.convert
+
+bytesToPub :: BS.ByteString -> Maybe PublicKey
+bytesToPub = DH.dhBytesToPub . BA.convert
+
+bytesToPSK :: BS.ByteString -> PresharedKey
+bytesToPSK = BA.convert