aboutsummaryrefslogblamecommitdiffstats
path: root/src/Network/WireGuard/Internal/State.hs
blob: f7b1ca03d341f1e6e7f5ba3dab06200a73931710 (plain) (tree)


































































































































































































                                                                                     











                                                   








































                                                                                           
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections   #-}

module Network.WireGuard.Internal.State
  ( PeerId
  , Device(..)
  , Peer(..)
  , InitiatorWait(..)
  , ResponderWait(..)
  , Session(..)
  , createDevice
  , createPeer
  , invalidateSessions
  , buildRouteTables
  , acquireEmptyIndex
  , removeIndex
  , nextNonce
  , eraseInitiatorWait
  , eraseResponderWait
  , getSession
  , waitForSession
  , findSession
  , addSession
  , filterSessions
  , updateTai64n
  , updateEndPoint
  ) where

import           Control.Monad                       (forM, when)
import           Crypto.Noise                        (NoiseState)
import           Crypto.Noise.Cipher.ChaChaPoly1305  (ChaChaPoly1305)
import           Crypto.Noise.DH.Curve25519          (Curve25519)
import           Crypto.Noise.Hash.BLAKE2s           (BLAKE2s)
import qualified Data.HashMap.Strict                 as HM
import           Data.IP                             (IPRange (..), IPv4, IPv6)
import qualified Data.IP.RouteTable                  as RT
import           Data.Maybe                          (catMaybes, fromJust,
                                                      isNothing, mapMaybe)
import           Data.Word
import           Network.Socket.Internal             (SockAddr)

import           Control.Concurrent.STM

import           Network.WireGuard.Internal.Constant
import           Network.WireGuard.Internal.Types

data Device = Device
            { intfName     :: String
            , localKey     :: TVar (Maybe KeyPair)
            , presharedKey :: TVar (Maybe PresharedKey)
            , fwmark       :: TVar Word
            , port         :: TVar Int
            , peers        :: TVar (HM.HashMap PeerId Peer)
            , routeTable4  :: TVar (RT.IPRTable IPv4 Peer)
            , routeTable6  :: TVar (RT.IPRTable IPv6 Peer)
            , indexMap     :: TVar (HM.HashMap Index Peer)
            }

data Peer = Peer
          { remotePub         :: !PublicKey
          , ipmasks           :: TVar [IPRange]
          , endPoint          :: TVar (Maybe SockAddr)
          , lastHandshakeTime :: TVar (Maybe Time)
          , receivedBytes     :: TVar Word64
          , transferredBytes  :: TVar Word64
          , keepaliveInterval :: TVar Int
          , initiatorWait     :: TVar (Maybe InitiatorWait)
          , responderWait     :: TVar (Maybe ResponderWait)
          , sessions          :: TVar [Session]  -- last two active sessions
          , lastTai64n        :: TVar TAI64n
          , lastReceiveTime   :: TVar Time
          , lastTransferTime  :: TVar Time
          , lastKeepaliveTime :: TVar Time
          }

data InitiatorWait = InitiatorWait
                   { initOurIndex  :: !Index
                   , initRetryTime :: !Time
                   , initStopTime  :: !Time
                   , initNoise     :: !(NoiseState ChaChaPoly1305 Curve25519 BLAKE2s)
                   }

data ResponderWait = ResponderWait
                   { respOurIndex   :: !Index
                   , respTheirIndex :: !Index
                   , respStopTime   :: !Time
                   , respSessionKey :: !SessionKey
                   }

data Session = Session
             { ourIndex       :: !Index
             , theirIndex     :: !Index
             , sessionKey     :: !SessionKey
             , renewTime      :: !Time
             , expireTime     :: !Time
             , sessionCounter :: TVar Counter
             -- TODO: avoid nonce reuse from remote peer
             }

createDevice :: String -> STM Device
createDevice intf = Device intf <$> newTVar Nothing
                                <*> newTVar Nothing
                                <*> newTVar 0
                                <*> newTVar 0
                                <*> newTVar HM.empty
                                <*> newTVar RT.empty
                                <*> newTVar RT.empty
                                <*> newTVar HM.empty

createPeer :: PublicKey -> STM Peer
createPeer rpub = Peer rpub <$> newTVar []
                            <*> newTVar Nothing
                            <*> newTVar Nothing
                            <*> newTVar 0
                            <*> newTVar 0
                            <*> newTVar 0
                            <*> newTVar Nothing
                            <*> newTVar Nothing
                            <*> newTVar []
                            <*> newTVar mempty
                            <*> newTVar farFuture
                            <*> newTVar farFuture
                            <*> newTVar 0

invalidateSessions :: Device -> STM ()
invalidateSessions Device{..} = do
    writeTVar indexMap HM.empty
    readTVar peers >>= mapM_ invalidatePeerSessions
  where
    invalidatePeerSessions Peer{..} = do
        writeTVar lastHandshakeTime Nothing
        writeTVar initiatorWait Nothing
        writeTVar responderWait Nothing
        writeTVar sessions []

buildRouteTables :: Device -> STM ()
buildRouteTables Device{..} = do
    gather pickIPv4 >>= writeTVar routeTable4 . RT.fromList . concat
    gather pickIPv6 >>= writeTVar routeTable6 . RT.fromList . concat
  where
    gather pick = do
        peers' <- readTVar peers
        forM peers' $ \peer ->
            map (,peer) . mapMaybe pick <$> readTVar (ipmasks peer)
    pickIPv4 (IPv4Range ipv4) = Just ipv4
    pickIPv4 _                = Nothing
    pickIPv6 (IPv6Range ipv6) = Just ipv6
    pickIPv6 _                = Nothing

acquireEmptyIndex :: Device -> Peer -> Index -> STM Index
acquireEmptyIndex device peer seed = do
    imap <- readTVar (indexMap device)
    let findEmpty idx
            | HM.member idx imap = findEmpty (idx * 3 + 1)
            | otherwise          = idx
        emptyIndex = findEmpty seed
    writeTVar (indexMap device) $ HM.insert emptyIndex peer imap
    return emptyIndex

removeIndex :: Device -> Index -> STM ()
removeIndex device index = modifyTVar' (indexMap device) (HM.delete index)

nextNonce :: Session -> STM Counter
nextNonce Session{..} = do
    nonce <- readTVar sessionCounter
    writeTVar sessionCounter (nonce + 1)
    return nonce

eraseInitiatorWait :: Device -> Peer -> Maybe Index -> STM Bool
eraseInitiatorWait device Peer{..} index = do
    miwait <- readTVar initiatorWait
    case miwait of
        Just iwait | isNothing index || initOurIndex iwait == fromJust index -> do
            writeTVar initiatorWait Nothing
            when (isNothing index) $ removeIndex device (initOurIndex iwait)
            return True
        _ -> return False

eraseResponderWait :: Device -> Peer -> Maybe Index -> STM Bool
eraseResponderWait device Peer{..} index = do
    mrwait <- readTVar responderWait
    case mrwait of
        Just rwait | isNothing index || respOurIndex rwait == fromJust index -> do
            writeTVar responderWait Nothing
            when (isNothing index) $ removeIndex device (respOurIndex rwait)
            return True
        _ -> return False

getSession :: Peer -> IO (Maybe Session)
getSession peer = do
    sessions' <- readTVarIO (sessions peer)
    case sessions' of
        []    -> return Nothing
        (s:_) -> return (Just s)

waitForSession :: Int -> Peer -> IO (Maybe Session)
waitForSession timelimit peer = do
    getTimeout <- registerDelay timelimit
    atomically $ do
        sessions' <- readTVar (sessions peer)
        case sessions' of
            []    -> do
                timeout <- readTVar getTimeout
                if timeout
                  then return Nothing
                  else retry
            (s:_) -> return (Just s)

findSession :: Peer -> Index -> STM (Maybe (Either ResponderWait Session))
findSession peer index = do
    sessions' <- filter ((==index).ourIndex) <$> readTVar (sessions peer)
    case sessions' of
        (s:_) -> return (Just (Right s))
        []    -> do
            mrwait <- readTVar (responderWait peer)
            case mrwait of
                Just rwait | respOurIndex rwait == index -> return (Just (Left rwait))
                _                                        -> return Nothing


addSession :: Device -> Peer -> Session -> STM ()
addSession device peer session = do
    (toKeep, toDrop) <- splitAt maxActiveSessions . (session:) <$> readTVar (sessions peer)
    mapM_ (removeIndex device . ourIndex) toDrop
    writeTVar (sessions peer) toKeep

filterSessions :: Device -> Peer -> (Session -> Bool) -> STM ()
filterSessions device peer cond = do
    sessions' <- readTVar (sessions peer)
    filtered <- fmap catMaybes $ forM sessions' $ \session ->
        if cond session
          then return (Just session)
          else do
            removeIndex device (ourIndex session)
            return Nothing
    writeTVar (sessions peer) filtered

updateTai64n :: Peer -> TAI64n -> STM Bool
updateTai64n peer tai64n = do
    lastTai64n' <- readTVar (lastTai64n peer)
    if tai64n <= lastTai64n'
      then return False
      else do
        writeTVar (lastTai64n peer) tai64n
        return True

updateEndPoint :: Peer -> SockAddr -> STM ()
updateEndPoint peer sock = writeTVar (endPoint peer) (Just sock)