1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
|
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections #-}
module Network.WireGuard.Internal.State
( PeerId
, Device(..)
, Peer(..)
, InitiatorWait(..)
, ResponderWait(..)
, Session(..)
, createDevice
, createPeer
, invalidateSessions
, buildRouteTables
, acquireEmptyIndex
, removeIndex
, nextNonce
, eraseInitiatorWait
, eraseResponderWait
, getSession
, waitForSession
, findSession
, addSession
, filterSessions
, updateTai64n
, updateEndPoint
) where
import Control.Monad (forM, when)
import Crypto.Noise (NoiseState)
import Crypto.Noise.Cipher.ChaChaPoly1305 (ChaChaPoly1305)
import Crypto.Noise.DH.Curve25519 (Curve25519)
import Crypto.Noise.Hash.BLAKE2s (BLAKE2s)
import qualified Data.HashMap.Strict as HM
import Data.IP (IPRange (..), IPv4, IPv6)
import qualified Data.IP.RouteTable as RT
import Data.Maybe (catMaybes, fromJust,
isNothing, mapMaybe)
import Data.Word
import Network.Socket.Internal (SockAddr)
import Control.Concurrent.STM
import Network.WireGuard.Internal.Constant
import Network.WireGuard.Internal.Types
data Device = Device
{ intfName :: String
, localKey :: TVar (Maybe KeyPair)
, presharedKey :: TVar (Maybe PresharedKey)
, fwmark :: TVar Word
, port :: TVar Int
, peers :: TVar (HM.HashMap PeerId Peer)
, routeTable4 :: TVar (RT.IPRTable IPv4 Peer)
, routeTable6 :: TVar (RT.IPRTable IPv6 Peer)
, indexMap :: TVar (HM.HashMap Index Peer)
}
data Peer = Peer
{ remotePub :: !PublicKey
, ipmasks :: TVar [IPRange]
, endPoint :: TVar (Maybe SockAddr)
, lastHandshakeTime :: TVar (Maybe Time)
, receivedBytes :: TVar Word64
, transferredBytes :: TVar Word64
, keepaliveInterval :: TVar Int
, initiatorWait :: TVar (Maybe InitiatorWait)
, responderWait :: TVar (Maybe ResponderWait)
, sessions :: TVar [Session] -- last two active sessions
, lastTai64n :: TVar TAI64n
, lastReceiveTime :: TVar Time
, lastTransferTime :: TVar Time
, lastKeepaliveTime :: TVar Time
}
data InitiatorWait = InitiatorWait
{ initOurIndex :: !Index
, initRetryTime :: !Time
, initStopTime :: !Time
, initNoise :: !(NoiseState ChaChaPoly1305 Curve25519 BLAKE2s)
}
data ResponderWait = ResponderWait
{ respOurIndex :: !Index
, respTheirIndex :: !Index
, respStopTime :: !Time
, respSessionKey :: !SessionKey
}
data Session = Session
{ ourIndex :: !Index
, theirIndex :: !Index
, sessionKey :: !SessionKey
, renewTime :: !Time
, expireTime :: !Time
, sessionCounter :: TVar Counter
-- TODO: avoid nonce reuse from remote peer
}
createDevice :: String -> STM Device
createDevice intf = Device intf <$> newTVar Nothing
<*> newTVar Nothing
<*> newTVar 0
<*> newTVar 0
<*> newTVar HM.empty
<*> newTVar RT.empty
<*> newTVar RT.empty
<*> newTVar HM.empty
createPeer :: PublicKey -> STM Peer
createPeer rpub = Peer rpub <$> newTVar []
<*> newTVar Nothing
<*> newTVar Nothing
<*> newTVar 0
<*> newTVar 0
<*> newTVar 0
<*> newTVar Nothing
<*> newTVar Nothing
<*> newTVar []
<*> newTVar mempty
<*> newTVar farFuture
<*> newTVar farFuture
<*> newTVar 0
invalidateSessions :: Device -> STM ()
invalidateSessions Device{..} = do
writeTVar indexMap HM.empty
readTVar peers >>= mapM_ invalidatePeerSessions
where
invalidatePeerSessions Peer{..} = do
writeTVar lastHandshakeTime Nothing
writeTVar initiatorWait Nothing
writeTVar responderWait Nothing
writeTVar sessions []
buildRouteTables :: Device -> STM ()
buildRouteTables Device{..} = do
gather pickIPv4 >>= writeTVar routeTable4 . RT.fromList . concat
gather pickIPv6 >>= writeTVar routeTable6 . RT.fromList . concat
where
gather pick = do
peers' <- readTVar peers
forM peers' $ \peer ->
map (,peer) . mapMaybe pick <$> readTVar (ipmasks peer)
pickIPv4 (IPv4Range ipv4) = Just ipv4
pickIPv4 _ = Nothing
pickIPv6 (IPv6Range ipv6) = Just ipv6
pickIPv6 _ = Nothing
acquireEmptyIndex :: Device -> Peer -> Index -> STM Index
acquireEmptyIndex device peer seed = do
imap <- readTVar (indexMap device)
let findEmpty idx
| HM.member idx imap = findEmpty (idx * 3 + 1)
| otherwise = idx
emptyIndex = findEmpty seed
writeTVar (indexMap device) $ HM.insert emptyIndex peer imap
return emptyIndex
removeIndex :: Device -> Index -> STM ()
removeIndex device index = modifyTVar' (indexMap device) (HM.delete index)
nextNonce :: Session -> STM Counter
nextNonce Session{..} = do
nonce <- readTVar sessionCounter
writeTVar sessionCounter (nonce + 1)
return nonce
eraseInitiatorWait :: Device -> Peer -> Maybe Index -> STM Bool
eraseInitiatorWait device Peer{..} index = do
miwait <- readTVar initiatorWait
case miwait of
Just iwait | isNothing index || initOurIndex iwait == fromJust index -> do
writeTVar initiatorWait Nothing
when (isNothing index) $ removeIndex device (initOurIndex iwait)
return True
_ -> return False
eraseResponderWait :: Device -> Peer -> Maybe Index -> STM Bool
eraseResponderWait device Peer{..} index = do
mrwait <- readTVar responderWait
case mrwait of
Just rwait | isNothing index || respOurIndex rwait == fromJust index -> do
writeTVar responderWait Nothing
when (isNothing index) $ removeIndex device (respOurIndex rwait)
return True
_ -> return False
getSession :: Peer -> IO (Maybe Session)
getSession peer = do
sessions' <- readTVarIO (sessions peer)
case sessions' of
[] -> return Nothing
(s:_) -> return (Just s)
waitForSession :: Peer -> STM Session
waitForSession peer = do
sessions' <- readTVar (sessions peer)
case sessions' of
[] -> retry
(s:_) -> return s
findSession :: Peer -> Index -> STM (Maybe (Either ResponderWait Session))
findSession peer index = do
sessions' <- filter ((==index).ourIndex) <$> readTVar (sessions peer)
case sessions' of
(s:_) -> return (Just (Right s))
[] -> do
mrwait <- readTVar (responderWait peer)
case mrwait of
Just rwait | respOurIndex rwait == index -> return (Just (Left rwait))
_ -> return Nothing
addSession :: Device -> Peer -> Session -> STM ()
addSession device peer session = do
(toKeep, toDrop) <- splitAt maxActiveSessions . (session:) <$> readTVar (sessions peer)
mapM_ (removeIndex device . ourIndex) toDrop
writeTVar (sessions peer) toKeep
filterSessions :: Device -> Peer -> (Session -> Bool) -> STM ()
filterSessions device peer cond = do
sessions' <- readTVar (sessions peer)
filtered <- fmap catMaybes $ forM sessions' $ \session ->
if cond session
then return (Just session)
else do
removeIndex device (ourIndex session)
return Nothing
writeTVar (sessions peer) filtered
updateTai64n :: Peer -> TAI64n -> STM Bool
updateTai64n peer tai64n = do
lastTai64n' <- readTVar (lastTai64n peer)
if tai64n <= lastTai64n'
then return False
else do
writeTVar (lastTai64n peer) tai64n
return True
updateEndPoint :: Peer -> SockAddr -> STM ()
updateEndPoint peer sock = writeTVar (endPoint peer) (Just sock)
|