aboutsummaryrefslogtreecommitdiffstats
path: root/src/Network/WireGuard/Internal/State.hs
blob: f7b1ca03d341f1e6e7f5ba3dab06200a73931710 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections   #-}

module Network.WireGuard.Internal.State
  ( PeerId
  , Device(..)
  , Peer(..)
  , InitiatorWait(..)
  , ResponderWait(..)
  , Session(..)
  , createDevice
  , createPeer
  , invalidateSessions
  , buildRouteTables
  , acquireEmptyIndex
  , removeIndex
  , nextNonce
  , eraseInitiatorWait
  , eraseResponderWait
  , getSession
  , waitForSession
  , findSession
  , addSession
  , filterSessions
  , updateTai64n
  , updateEndPoint
  ) where

import           Control.Monad                       (forM, when)
import           Crypto.Noise                        (NoiseState)
import           Crypto.Noise.Cipher.ChaChaPoly1305  (ChaChaPoly1305)
import           Crypto.Noise.DH.Curve25519          (Curve25519)
import           Crypto.Noise.Hash.BLAKE2s           (BLAKE2s)
import qualified Data.HashMap.Strict                 as HM
import           Data.IP                             (IPRange (..), IPv4, IPv6)
import qualified Data.IP.RouteTable                  as RT
import           Data.Maybe                          (catMaybes, fromJust,
                                                      isNothing, mapMaybe)
import           Data.Word
import           Network.Socket.Internal             (SockAddr)

import           Control.Concurrent.STM

import           Network.WireGuard.Internal.Constant
import           Network.WireGuard.Internal.Types

data Device = Device
            { intfName     :: String
            , localKey     :: TVar (Maybe KeyPair)
            , presharedKey :: TVar (Maybe PresharedKey)
            , fwmark       :: TVar Word
            , port         :: TVar Int
            , peers        :: TVar (HM.HashMap PeerId Peer)
            , routeTable4  :: TVar (RT.IPRTable IPv4 Peer)
            , routeTable6  :: TVar (RT.IPRTable IPv6 Peer)
            , indexMap     :: TVar (HM.HashMap Index Peer)
            }

data Peer = Peer
          { remotePub         :: !PublicKey
          , ipmasks           :: TVar [IPRange]
          , endPoint          :: TVar (Maybe SockAddr)
          , lastHandshakeTime :: TVar (Maybe Time)
          , receivedBytes     :: TVar Word64
          , transferredBytes  :: TVar Word64
          , keepaliveInterval :: TVar Int
          , initiatorWait     :: TVar (Maybe InitiatorWait)
          , responderWait     :: TVar (Maybe ResponderWait)
          , sessions          :: TVar [Session]  -- last two active sessions
          , lastTai64n        :: TVar TAI64n
          , lastReceiveTime   :: TVar Time
          , lastTransferTime  :: TVar Time
          , lastKeepaliveTime :: TVar Time
          }

data InitiatorWait = InitiatorWait
                   { initOurIndex  :: !Index
                   , initRetryTime :: !Time
                   , initStopTime  :: !Time
                   , initNoise     :: !(NoiseState ChaChaPoly1305 Curve25519 BLAKE2s)
                   }

data ResponderWait = ResponderWait
                   { respOurIndex   :: !Index
                   , respTheirIndex :: !Index
                   , respStopTime   :: !Time
                   , respSessionKey :: !SessionKey
                   }

data Session = Session
             { ourIndex       :: !Index
             , theirIndex     :: !Index
             , sessionKey     :: !SessionKey
             , renewTime      :: !Time
             , expireTime     :: !Time
             , sessionCounter :: TVar Counter
             -- TODO: avoid nonce reuse from remote peer
             }

createDevice :: String -> STM Device
createDevice intf = Device intf <$> newTVar Nothing
                                <*> newTVar Nothing
                                <*> newTVar 0
                                <*> newTVar 0
                                <*> newTVar HM.empty
                                <*> newTVar RT.empty
                                <*> newTVar RT.empty
                                <*> newTVar HM.empty

createPeer :: PublicKey -> STM Peer
createPeer rpub = Peer rpub <$> newTVar []
                            <*> newTVar Nothing
                            <*> newTVar Nothing
                            <*> newTVar 0
                            <*> newTVar 0
                            <*> newTVar 0
                            <*> newTVar Nothing
                            <*> newTVar Nothing
                            <*> newTVar []
                            <*> newTVar mempty
                            <*> newTVar farFuture
                            <*> newTVar farFuture
                            <*> newTVar 0

invalidateSessions :: Device -> STM ()
invalidateSessions Device{..} = do
    writeTVar indexMap HM.empty
    readTVar peers >>= mapM_ invalidatePeerSessions
  where
    invalidatePeerSessions Peer{..} = do
        writeTVar lastHandshakeTime Nothing
        writeTVar initiatorWait Nothing
        writeTVar responderWait Nothing
        writeTVar sessions []

buildRouteTables :: Device -> STM ()
buildRouteTables Device{..} = do
    gather pickIPv4 >>= writeTVar routeTable4 . RT.fromList . concat
    gather pickIPv6 >>= writeTVar routeTable6 . RT.fromList . concat
  where
    gather pick = do
        peers' <- readTVar peers
        forM peers' $ \peer ->
            map (,peer) . mapMaybe pick <$> readTVar (ipmasks peer)
    pickIPv4 (IPv4Range ipv4) = Just ipv4
    pickIPv4 _                = Nothing
    pickIPv6 (IPv6Range ipv6) = Just ipv6
    pickIPv6 _                = Nothing

acquireEmptyIndex :: Device -> Peer -> Index -> STM Index
acquireEmptyIndex device peer seed = do
    imap <- readTVar (indexMap device)
    let findEmpty idx
            | HM.member idx imap = findEmpty (idx * 3 + 1)
            | otherwise          = idx
        emptyIndex = findEmpty seed
    writeTVar (indexMap device) $ HM.insert emptyIndex peer imap
    return emptyIndex

removeIndex :: Device -> Index -> STM ()
removeIndex device index = modifyTVar' (indexMap device) (HM.delete index)

nextNonce :: Session -> STM Counter
nextNonce Session{..} = do
    nonce <- readTVar sessionCounter
    writeTVar sessionCounter (nonce + 1)
    return nonce

eraseInitiatorWait :: Device -> Peer -> Maybe Index -> STM Bool
eraseInitiatorWait device Peer{..} index = do
    miwait <- readTVar initiatorWait
    case miwait of
        Just iwait | isNothing index || initOurIndex iwait == fromJust index -> do
            writeTVar initiatorWait Nothing
            when (isNothing index) $ removeIndex device (initOurIndex iwait)
            return True
        _ -> return False

eraseResponderWait :: Device -> Peer -> Maybe Index -> STM Bool
eraseResponderWait device Peer{..} index = do
    mrwait <- readTVar responderWait
    case mrwait of
        Just rwait | isNothing index || respOurIndex rwait == fromJust index -> do
            writeTVar responderWait Nothing
            when (isNothing index) $ removeIndex device (respOurIndex rwait)
            return True
        _ -> return False

getSession :: Peer -> IO (Maybe Session)
getSession peer = do
    sessions' <- readTVarIO (sessions peer)
    case sessions' of
        []    -> return Nothing
        (s:_) -> return (Just s)

waitForSession :: Int -> Peer -> IO (Maybe Session)
waitForSession timelimit peer = do
    getTimeout <- registerDelay timelimit
    atomically $ do
        sessions' <- readTVar (sessions peer)
        case sessions' of
            []    -> do
                timeout <- readTVar getTimeout
                if timeout
                  then return Nothing
                  else retry
            (s:_) -> return (Just s)

findSession :: Peer -> Index -> STM (Maybe (Either ResponderWait Session))
findSession peer index = do
    sessions' <- filter ((==index).ourIndex) <$> readTVar (sessions peer)
    case sessions' of
        (s:_) -> return (Just (Right s))
        []    -> do
            mrwait <- readTVar (responderWait peer)
            case mrwait of
                Just rwait | respOurIndex rwait == index -> return (Just (Left rwait))
                _                                        -> return Nothing


addSession :: Device -> Peer -> Session -> STM ()
addSession device peer session = do
    (toKeep, toDrop) <- splitAt maxActiveSessions . (session:) <$> readTVar (sessions peer)
    mapM_ (removeIndex device . ourIndex) toDrop
    writeTVar (sessions peer) toKeep

filterSessions :: Device -> Peer -> (Session -> Bool) -> STM ()
filterSessions device peer cond = do
    sessions' <- readTVar (sessions peer)
    filtered <- fmap catMaybes $ forM sessions' $ \session ->
        if cond session
          then return (Just session)
          else do
            removeIndex device (ourIndex session)
            return Nothing
    writeTVar (sessions peer) filtered

updateTai64n :: Peer -> TAI64n -> STM Bool
updateTai64n peer tai64n = do
    lastTai64n' <- readTVar (lastTai64n peer)
    if tai64n <= lastTai64n'
      then return False
      else do
        writeTVar (lastTai64n peer) tai64n
        return True

updateEndPoint :: Peer -> SockAddr -> STM ()
updateEndPoint peer sock = writeTVar (endPoint peer) (Just sock)