From a2a3e540f2d1a507b34eccae26de09066a2a12fa Mon Sep 17 00:00:00 2001 From: Bin Jin Date: Sun, 12 Mar 2017 17:48:20 +0800 Subject: Initial commit --- src/Network/WireGuard/Internal/State.hs | 242 ++++++++++++++++++++++++++++++++ 1 file changed, 242 insertions(+) create mode 100644 src/Network/WireGuard/Internal/State.hs (limited to 'src/Network/WireGuard/Internal/State.hs') diff --git a/src/Network/WireGuard/Internal/State.hs b/src/Network/WireGuard/Internal/State.hs new file mode 100644 index 0000000..38866e8 --- /dev/null +++ b/src/Network/WireGuard/Internal/State.hs @@ -0,0 +1,242 @@ +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE TupleSections #-} + +module Network.WireGuard.Internal.State + ( PeerId + , Device(..) + , Peer(..) + , InitiatorWait(..) + , ResponderWait(..) + , Session(..) + , createDevice + , createPeer + , invalidateSessions + , buildRouteTables + , acquireEmptyIndex + , removeIndex + , nextNonce + , eraseInitiatorWait + , eraseResponderWait + , getSession + , waitForSession + , findSession + , addSession + , filterSessions + , updateTai64n + , updateEndPoint + ) where + +import Control.Monad (forM, when) +import Crypto.Noise (NoiseState) +import Crypto.Noise.Cipher.ChaChaPoly1305 (ChaChaPoly1305) +import Crypto.Noise.DH.Curve25519 (Curve25519) +import Crypto.Noise.Hash.BLAKE2s (BLAKE2s) +import qualified Data.HashMap.Strict as HM +import Data.IP (IPRange (..), IPv4, IPv6) +import qualified Data.IP.RouteTable as RT +import Data.Maybe (catMaybes, fromJust, + isNothing, mapMaybe) +import Data.Word +import Network.Socket.Internal (SockAddr) + +import Control.Concurrent.STM + +import Network.WireGuard.Internal.Constant +import Network.WireGuard.Internal.Types + +data Device = Device + { intfName :: String + , localKey :: TVar (Maybe KeyPair) + , presharedKey :: TVar (Maybe PresharedKey) + , fwmark :: TVar Word + , port :: TVar Int + , peers :: TVar (HM.HashMap PeerId Peer) + , routeTable4 :: TVar (RT.IPRTable IPv4 Peer) + , routeTable6 :: TVar (RT.IPRTable IPv6 Peer) + , indexMap :: TVar (HM.HashMap Index Peer) + } + +data Peer = Peer + { remotePub :: !PublicKey + , ipmasks :: TVar [IPRange] + , endPoint :: TVar (Maybe SockAddr) + , lastHandshakeTime :: TVar (Maybe Time) + , receivedBytes :: TVar Word64 + , transferredBytes :: TVar Word64 + , keepaliveInterval :: TVar Int + , initiatorWait :: TVar (Maybe InitiatorWait) + , responderWait :: TVar (Maybe ResponderWait) + , sessions :: TVar [Session] -- last two active sessions + , lastTai64n :: TVar TAI64n + , lastReceiveTime :: TVar Time + , lastTransferTime :: TVar Time + , lastKeepaliveTime :: TVar Time + } + +data InitiatorWait = InitiatorWait + { initOurIndex :: !Index + , initRetryTime :: !Time + , initStopTime :: !Time + , initNoise :: !(NoiseState ChaChaPoly1305 Curve25519 BLAKE2s) + } + +data ResponderWait = ResponderWait + { respOurIndex :: !Index + , respTheirIndex :: !Index + , respStopTime :: !Time + , respSessionKey :: !SessionKey + } + +data Session = Session + { ourIndex :: !Index + , theirIndex :: !Index + , sessionKey :: !SessionKey + , renewTime :: !Time + , expireTime :: !Time + , sessionCounter :: TVar Counter + -- TODO: avoid nonce reuse from remote peer + } + +createDevice :: String -> STM Device +createDevice intf = Device intf <$> newTVar Nothing + <*> newTVar Nothing + <*> newTVar 0 + <*> newTVar 0 + <*> newTVar HM.empty + <*> newTVar RT.empty + <*> newTVar RT.empty + <*> newTVar HM.empty + +createPeer :: PublicKey -> STM Peer +createPeer rpub = Peer rpub <$> newTVar [] + <*> newTVar Nothing + <*> newTVar Nothing + <*> newTVar 0 + <*> newTVar 0 + <*> newTVar 0 + <*> newTVar Nothing + <*> newTVar Nothing + <*> newTVar [] + <*> newTVar mempty + <*> newTVar farFuture + <*> newTVar farFuture + <*> newTVar 0 + +invalidateSessions :: Device -> STM () +invalidateSessions Device{..} = do + writeTVar indexMap HM.empty + readTVar peers >>= mapM_ invalidatePeerSessions + where + invalidatePeerSessions Peer{..} = do + writeTVar lastHandshakeTime Nothing + writeTVar initiatorWait Nothing + writeTVar responderWait Nothing + writeTVar sessions [] + +buildRouteTables :: Device -> STM () +buildRouteTables Device{..} = do + gather pickIPv4 >>= writeTVar routeTable4 . RT.fromList . concat + gather pickIPv6 >>= writeTVar routeTable6 . RT.fromList . concat + where + gather pick = do + peers' <- readTVar peers + forM peers' $ \peer -> + map (,peer) . mapMaybe pick <$> readTVar (ipmasks peer) + pickIPv4 (IPv4Range ipv4) = Just ipv4 + pickIPv4 _ = Nothing + pickIPv6 (IPv6Range ipv6) = Just ipv6 + pickIPv6 _ = Nothing + +acquireEmptyIndex :: Device -> Peer -> Index -> STM Index +acquireEmptyIndex device peer seed = do + imap <- readTVar (indexMap device) + let findEmpty idx + | HM.member idx imap = findEmpty (idx * 3 + 1) + | otherwise = idx + emptyIndex = findEmpty seed + writeTVar (indexMap device) $ HM.insert emptyIndex peer imap + return emptyIndex + +removeIndex :: Device -> Index -> STM () +removeIndex device index = modifyTVar' (indexMap device) (HM.delete index) + +nextNonce :: Session -> STM Counter +nextNonce Session{..} = do + nonce <- readTVar sessionCounter + writeTVar sessionCounter (nonce + 1) + return nonce + +eraseInitiatorWait :: Device -> Peer -> Maybe Index -> STM Bool +eraseInitiatorWait device Peer{..} index = do + miwait <- readTVar initiatorWait + case miwait of + Just iwait | isNothing index || initOurIndex iwait == fromJust index -> do + writeTVar initiatorWait Nothing + when (isNothing index) $ removeIndex device (initOurIndex iwait) + return True + _ -> return False + +eraseResponderWait :: Device -> Peer -> Maybe Index -> STM Bool +eraseResponderWait device Peer{..} index = do + mrwait <- readTVar responderWait + case mrwait of + Just rwait | isNothing index || respOurIndex rwait == fromJust index -> do + writeTVar responderWait Nothing + when (isNothing index) $ removeIndex device (respOurIndex rwait) + return True + _ -> return False + +getSession :: Peer -> IO (Maybe Session) +getSession peer = do + sessions' <- readTVarIO (sessions peer) + case sessions' of + [] -> return Nothing + (s:_) -> return (Just s) + +waitForSession :: Peer -> STM Session +waitForSession peer = do + sessions' <- readTVar (sessions peer) + case sessions' of + [] -> retry + (s:_) -> return s + +findSession :: Peer -> Index -> STM (Maybe (Either ResponderWait Session)) +findSession peer index = do + sessions' <- filter ((==index).ourIndex) <$> readTVar (sessions peer) + case sessions' of + (s:_) -> return (Just (Right s)) + [] -> do + mrwait <- readTVar (responderWait peer) + case mrwait of + Just rwait | respOurIndex rwait == index -> return (Just (Left rwait)) + _ -> return Nothing + + +addSession :: Device -> Peer -> Session -> STM () +addSession device peer session = do + (toKeep, toDrop) <- splitAt maxActiveSessions . (session:) <$> readTVar (sessions peer) + mapM_ (removeIndex device . ourIndex) toDrop + writeTVar (sessions peer) toKeep + +filterSessions :: Device -> Peer -> (Session -> Bool) -> STM () +filterSessions device peer cond = do + sessions' <- readTVar (sessions peer) + filtered <- fmap catMaybes $ forM sessions' $ \session -> + if cond session + then return (Just session) + else do + removeIndex device (ourIndex session) + return Nothing + writeTVar (sessions peer) filtered + +updateTai64n :: Peer -> TAI64n -> STM Bool +updateTai64n peer tai64n = do + lastTai64n' <- readTVar (lastTai64n peer) + if tai64n <= lastTai64n' + then return False + else do + writeTVar (lastTai64n peer) tai64n + return True + +updateEndPoint :: Peer -> SockAddr -> STM () +updateEndPoint peer sock = writeTVar (endPoint peer) (Just sock) -- cgit v1.2.3-59-g8ed1b