aboutsummaryrefslogtreecommitdiffstats
path: root/src/Network/WireGuard/Internal
diff options
context:
space:
mode:
authorBin Jin <bjin@ctrl-d.org>2017-03-12 17:48:20 +0800
committerBin Jin <bjin@ctrl-d.org>2017-03-12 17:48:20 +0800
commita2a3e540f2d1a507b34eccae26de09066a2a12fa (patch)
tree1e336189acd095d761eaad252f18ee6e32b7216b /src/Network/WireGuard/Internal
downloadwireguard-hs-a2a3e540f2d1a507b34eccae26de09066a2a12fa.tar.xz
wireguard-hs-a2a3e540f2d1a507b34eccae26de09066a2a12fa.zip
Initial commit
Diffstat (limited to 'src/Network/WireGuard/Internal')
-rw-r--r--src/Network/WireGuard/Internal/Constant.hs55
-rw-r--r--src/Network/WireGuard/Internal/IPPacket.hs56
-rw-r--r--src/Network/WireGuard/Internal/Noise.hs98
-rw-r--r--src/Network/WireGuard/Internal/Packet.hs112
-rw-r--r--src/Network/WireGuard/Internal/PacketQueue.hs49
-rw-r--r--src/Network/WireGuard/Internal/State.hs242
-rw-r--r--src/Network/WireGuard/Internal/Types.hs78
-rw-r--r--src/Network/WireGuard/Internal/Util.hs62
8 files changed, 752 insertions, 0 deletions
diff --git a/src/Network/WireGuard/Internal/Constant.hs b/src/Network/WireGuard/Internal/Constant.hs
new file mode 100644
index 0000000..c615a0f
--- /dev/null
+++ b/src/Network/WireGuard/Internal/Constant.hs
@@ -0,0 +1,55 @@
+module Network.WireGuard.Internal.Constant where
+
+authLength :: Int
+authLength = 16
+
+aeadLength :: Int -> Int
+aeadLength payload = payload + authLength
+
+keyLength :: Int
+keyLength = 32
+
+timestampLength :: Int
+timestampLength = 12
+
+mac1Length :: Int
+mac1Length = 16
+
+mac2Length :: Int
+mac2Length = 16
+
+maxQueuedUdpPackets :: Int
+maxQueuedUdpPackets = 4096
+
+maxQueuedTunPackets :: Int
+maxQueuedTunPackets = 4096
+
+udpReadBufferLength :: Int
+udpReadBufferLength = 4096
+
+tunReadBufferLength :: Int
+tunReadBufferLength = 4096
+
+retryMaxWaitTime :: Int
+retryMaxWaitTime = 5 * 1000000 -- 5 seconds
+
+handshakeRetryTime :: Int
+handshakeRetryTime = 5
+
+handshakeStopTime :: Int
+handshakeStopTime = 90
+
+sessionRenewTime :: Int
+sessionRenewTime = 120
+
+sessionExpireTime :: Int
+sessionExpireTime = 180
+
+sessionKeepaliveTime :: Int
+sessionKeepaliveTime = 10
+
+maxActiveSessions :: Int
+maxActiveSessions = 2
+
+heartbeatWaitTime :: Int
+heartbeatWaitTime = 250 * 1000 -- 0.25 second
diff --git a/src/Network/WireGuard/Internal/IPPacket.hs b/src/Network/WireGuard/Internal/IPPacket.hs
new file mode 100644
index 0000000..56f4461
--- /dev/null
+++ b/src/Network/WireGuard/Internal/IPPacket.hs
@@ -0,0 +1,56 @@
+module Network.WireGuard.Internal.IPPacket
+ ( IPPacket(..)
+ , parseIPPacket
+ ) where
+
+import qualified Data.ByteArray as BA
+import Data.IP (IPv4, IPv6, fromHostAddress,
+ fromHostAddress6)
+import Foreign.Ptr (Ptr)
+import Foreign.Storable (peekByteOff)
+
+import Data.Bits
+import Data.Word
+
+data IPPacket = InvalidIPPacket
+ | IPv4Packet { src4 :: IPv4, dest4 :: IPv4 }
+ | IPv6Packet { src6 :: IPv6, dest6 :: IPv6 }
+
+parseIPPacket :: BA.ByteArrayAccess ba => ba -> IO IPPacket
+parseIPPacket packet | BA.length packet < 20 = return InvalidIPPacket
+parseIPPacket packet = BA.withByteArray packet $ \ptr -> do
+ firstByte <- peekByteOff ptr 0 :: IO Word8
+ let version = firstByte `shiftR` 4
+ parse4 = do
+ s4 <- peekByteOff ptr 12
+ d4 <- peekByteOff ptr 16
+ return (IPv4Packet (fromHostAddress s4) (fromHostAddress d4))
+ parse6
+ | BA.length packet < 40 = return InvalidIPPacket
+ | otherwise = do
+ s6a <- peek32be ptr 8
+ s6b <- peek32be ptr 12
+ s6c <- peek32be ptr 16
+ s6d <- peek32be ptr 20
+ d6a <- peek32be ptr 24
+ d6b <- peek32be ptr 28
+ d6c <- peek32be ptr 32
+ d6d <- peek32be ptr 36
+ let s6 = (s6a, s6b, s6c, s6d)
+ d6 = (d6a, d6b, d6c, d6d)
+ return (IPv6Packet (fromHostAddress6 s6) (fromHostAddress6 d6))
+ case version of
+ 4 -> parse4
+ 6 -> parse6
+ _ -> return InvalidIPPacket
+
+peek32be :: Ptr a -> Int -> IO Word32
+peek32be ptr offset = do
+ a <- peekByteOff ptr offset :: IO Word8
+ b <- peekByteOff ptr (offset + 1) :: IO Word8
+ c <- peekByteOff ptr (offset + 2) :: IO Word8
+ d <- peekByteOff ptr (offset + 3) :: IO Word8
+ return $! (fromIntegral a `unsafeShiftL` 24) .|.
+ (fromIntegral b `unsafeShiftL` 16) .|.
+ (fromIntegral c `unsafeShiftL` 8) .|.
+ fromIntegral d
diff --git a/src/Network/WireGuard/Internal/Noise.hs b/src/Network/WireGuard/Internal/Noise.hs
new file mode 100644
index 0000000..b529d25
--- /dev/null
+++ b/src/Network/WireGuard/Internal/Noise.hs
@@ -0,0 +1,98 @@
+{-# LANGUAGE OverloadedStrings #-}
+module Network.WireGuard.Internal.Noise
+ ( NoiseStateWG
+ , newNoiseState
+ , sendFirstMessage
+ , recvFirstMessageAndReply
+ , recvSecondMessage
+ , encryptMessage
+ , decryptMessage
+ ) where
+
+import Control.Exception (SomeException)
+import Control.Lens ((&), (.~), (^.))
+import Control.Monad (unless)
+import Control.Monad.Catch (throwM)
+import qualified Crypto.Cipher.ChaChaPoly1305 as CCP
+import Crypto.Error (throwCryptoError)
+import Crypto.Noise.Cipher (cipherSymToBytes)
+import Crypto.Noise.Cipher.ChaChaPoly1305 (ChaChaPoly1305)
+import Crypto.Noise.DH.Curve25519 (Curve25519)
+import Crypto.Noise.HandshakePatterns (noiseIK)
+import Crypto.Noise.Hash.BLAKE2s (BLAKE2s)
+import Crypto.Noise.Internal.CipherState (csk)
+import Crypto.Noise.Internal.NoiseState (nsReceivingCipherState,
+ nsSendingCipherState)
+import Data.ByteArray (ScrubbedBytes, convert)
+import Data.ByteString (ByteString)
+import qualified Data.ByteString as BS
+import Data.Maybe (fromJust)
+import Data.Serialize (putWord64le, runPut)
+
+import Crypto.Noise
+
+import Network.WireGuard.Internal.Types
+
+type NoiseStateWG = NoiseState ChaChaPoly1305 Curve25519 BLAKE2s
+
+newNoiseState :: KeyPair -> Maybe PresharedKey -> KeyPair -> Maybe PublicKey -> HandshakeRole -> NoiseStateWG
+newNoiseState staticKey presharedKey ephemeralKey remotePub role =
+ noiseState $ defaultHandshakeOpts noiseIK role
+ & hoPrologue .~ "WireGuard v0 zx2c4 Jason@zx2c4.com"
+ & hoLocalStatic .~ Just staticKey
+ & hoPreSharedKey .~ presharedKey
+ & hoRemoteStatic .~ remotePub
+ & hoLocalEphemeral .~ Just ephemeralKey
+
+sendFirstMessage :: NoiseStateWG -> ScrubbedBytes
+ -> Either SomeException (ByteString, NoiseStateWG)
+sendFirstMessage state0 plaintext1 = writeMessage state0 plaintext1
+
+recvFirstMessageAndReply :: NoiseStateWG -> ByteString -> ScrubbedBytes
+ -> Either SomeException (ByteString, ScrubbedBytes, PublicKey, SessionKey)
+recvFirstMessageAndReply state0 ciphertext1 plaintext2 = do
+ (plaintext1, state1) <- readMessage state0 ciphertext1
+ (ciphertext2, state2) <- writeMessage state1 plaintext2
+ unless (handshakeComplete state2) internalError
+ case remoteStaticKey state2 of
+ Nothing -> internalError
+ Just rpub -> return (ciphertext2, plaintext1, rpub, extractSessionKey state2)
+
+recvSecondMessage :: NoiseStateWG -> ByteString
+ -> Either SomeException (ScrubbedBytes, PublicKey, SessionKey)
+recvSecondMessage state1 ciphertext2 = do
+ (plaintext2, state2) <- readMessage state1 ciphertext2
+ unless (handshakeComplete state2) internalError
+ case remoteStaticKey state2 of
+ Nothing -> internalError
+ Just rpub -> return (plaintext2, rpub, extractSessionKey state2)
+
+encryptMessage :: SessionKey -> Counter -> ScrubbedBytes -> (EncryptedPayload, AuthTag)
+encryptMessage key counter plaintext = (ciphertext, convert authtag)
+ where
+ st0 = throwCryptoError (CCP.initialize (sendKey key) (getNonce counter))
+ (ciphertext, st) = CCP.encrypt (convert plaintext) st0
+ authtag = CCP.finalize st
+
+decryptMessage :: SessionKey -> Counter -> (EncryptedPayload, AuthTag) -> Maybe ScrubbedBytes
+decryptMessage key counter (ciphertext, authtag)
+ | authtag == authtagExpected = Just (convert plaintext)
+ | otherwise = Nothing
+ where
+ st0 = throwCryptoError (CCP.initialize (recvKey key) (getNonce counter))
+ (plaintext, st) = CCP.decrypt ciphertext st0
+ authtagExpected = convert $ CCP.finalize st
+
+getNonce :: Counter -> CCP.Nonce
+getNonce counter = throwCryptoError (CCP.nonce8 constant iv)
+ where
+ constant = BS.replicate 4 0
+ iv = runPut (putWord64le counter)
+
+extractSessionKey :: NoiseStateWG -> SessionKey
+extractSessionKey ns =
+ SessionKey (cipherSymToBytes $ fromJust (ns ^. nsSendingCipherState) ^. csk)
+ (cipherSymToBytes $ fromJust (ns ^. nsReceivingCipherState) ^. csk)
+
+internalError :: Either SomeException a
+internalError = throwM (InvalidHandshakeOptions "internal error")
diff --git a/src/Network/WireGuard/Internal/Packet.hs b/src/Network/WireGuard/Internal/Packet.hs
new file mode 100644
index 0000000..ebc24fc
--- /dev/null
+++ b/src/Network/WireGuard/Internal/Packet.hs
@@ -0,0 +1,112 @@
+{-# LANGUAGE RecordWildCards #-}
+
+module Network.WireGuard.Internal.Packet
+ ( Packet(..)
+ , parsePacket
+ , buildPacket
+ ) where
+
+import Control.Monad (replicateM_, unless, when)
+import qualified Data.ByteString as BS
+import Foreign.Storable (sizeOf)
+
+import Data.Serialize
+
+import Network.WireGuard.Internal.Constant
+import Network.WireGuard.Internal.Types
+
+data Packet = HandshakeInitiation
+ { senderIndex :: !Index
+ , encryptedPayload :: !EncryptedPayload
+ }
+ | HandshakeResponse
+ { senderIndex :: !Index
+ , receiverIndex :: !Index
+ , encryptedPayload :: !EncryptedPayload
+ }
+ | PacketData
+ { receiverIndex :: !Index
+ , counter :: !Counter
+ , encryptedPayload :: !EncryptedPayload
+ , authTag :: !AuthTag
+ }
+ deriving (Show)
+
+parsePacket :: (BS.ByteString -> BS.ByteString) -> Get Packet
+parsePacket getMac1 = do
+ packetType <- lookAhead getWord8
+ case packetType of
+ 1 -> verifyLength (==handshakeInitiationPacketLength) $ verifyMac getMac1 parseHandshakeInitiation
+ 2 -> verifyLength (==handshakeResponsePacketLength) $ verifyMac getMac1 parseHandshakeResponse
+ 4 -> verifyLength (>=packetDataMinimumPacketLength) parsePacketData
+ _ -> fail "unknown packet"
+ where
+ handshakeInitiationPacketLength = 4 + indexSize + keyLength + aeadLength keyLength + aeadLength timestampLength + mac1Length + mac2Length
+ handshakeResponsePacketLength = 4 + indexSize + indexSize + keyLength + aeadLength 0 + mac1Length + mac2Length
+ packetDataMinimumPacketLength = 4 + indexSize + counterSize + aeadLength 0
+
+ indexSize = sizeOf (undefined :: Index)
+ counterSize = sizeOf (undefined :: Counter)
+
+parseHandshakeInitiation :: Get Packet
+parseHandshakeInitiation = do
+ skip 4
+ HandshakeInitiation <$> getWord32le <*> (remaining >>= getBytes)
+
+parseHandshakeResponse :: Get Packet
+parseHandshakeResponse = do
+ skip 4
+ HandshakeResponse <$> getWord32le <*> getWord32le <*> (remaining >>= getBytes)
+
+parsePacketData :: Get Packet
+parsePacketData = do
+ skip 4
+ PacketData <$> getWord32le <*> getWord64le <*>
+ (remaining >>= getBytes . subtract authLength) <*> getBytes authLength
+
+buildPacket :: (BS.ByteString -> BS.ByteString) -> Putter Packet
+buildPacket getMac1 HandshakeInitiation{..} = appendMac getMac1 $ do
+ putWord8 1
+ replicateM_ 3 (putWord8 0)
+ putWord32le senderIndex
+ putByteString encryptedPayload
+
+buildPacket getMac1 HandshakeResponse{..} = appendMac getMac1 $ do
+ putWord8 2
+ replicateM_ 3 (putWord8 0)
+ putWord32le senderIndex
+ putWord32le receiverIndex
+ putByteString encryptedPayload
+
+buildPacket _getMac1 PacketData{..} = do
+ putWord8 4
+ replicateM_ 3 (putWord8 0)
+ putWord32le receiverIndex
+ putWord64le counter
+ putByteString encryptedPayload
+ putByteString authTag
+
+verifyLength :: (Int -> Bool) -> Get a -> Get a
+verifyLength check ga = do
+ outcome <- check <$> remaining
+ unless outcome $ fail "wrong packet length"
+ ga
+
+verifyMac :: (BS.ByteString -> BS.ByteString) -> Get Packet -> Get Packet
+verifyMac getMac1 ga = do
+ bodyLength <- subtract (mac1Length + mac2Length) <$> remaining
+ when (bodyLength < 0) $ fail "packet too small"
+ expectedMac1 <- getMac1 <$> lookAhead (getBytes bodyLength)
+ parsed <- isolate bodyLength ga
+ receivedMac1 <- getBytes mac1Length
+ when (expectedMac1 /= receivedMac1) $ fail "wrong mac1"
+ skip mac2Length
+ return parsed
+
+appendMac :: (BS.ByteString -> BS.ByteString) -> Put -> Put
+appendMac getMac1 p = do
+ -- TODO: find a smart approach to avoid extra ByteString allocation
+ let bs = runPut p
+ putByteString bs
+ putByteString (getMac1 bs)
+ replicateM_ mac2Length (putWord8 0)
diff --git a/src/Network/WireGuard/Internal/PacketQueue.hs b/src/Network/WireGuard/Internal/PacketQueue.hs
new file mode 100644
index 0000000..bc390f8
--- /dev/null
+++ b/src/Network/WireGuard/Internal/PacketQueue.hs
@@ -0,0 +1,49 @@
+{-# LANGUAGE RecordWildCards #-}
+
+module Network.WireGuard.Internal.PacketQueue
+ ( PacketQueue
+ , newPacketQueue
+ , popPacketQueue
+ , pushPacketQueue
+ , tryPushPacketQueue
+ ) where
+
+import Control.Concurrent.STM
+
+data PacketQueue packet = PacketQueue
+ { tqueue :: TQueue packet
+ , allowance :: TVar Int
+ }
+
+-- | Create a new PacketQueue with size limit of |maxQueuedPackets|.
+newPacketQueue :: Int -> STM (PacketQueue packet)
+newPacketQueue maxQueuedPackets = PacketQueue <$> newTQueue <*> newTVar maxQueuedPackets
+
+-- | Pop a packet out from the queue, blocks if no packet is available.
+popPacketQueue :: PacketQueue packet -> STM packet
+popPacketQueue PacketQueue{..} = do
+ packet <- readTQueue tqueue
+ modifyTVar' allowance (+1)
+ return packet
+
+-- | Push a packet into the queue. Blocks if it's full.
+pushPacketQueue :: PacketQueue packet -> packet -> STM ()
+pushPacketQueue PacketQueue{..} packet = do
+ allowance' <- readTVar allowance
+ if allowance' <= 0
+ then retry
+ else do
+ writeTQueue tqueue packet
+ writeTVar allowance (allowance' - 1)
+
+-- | Try to push a packet into the queue. Returns True if it's pushed.
+tryPushPacketQueue :: PacketQueue packet -> packet -> STM Bool
+tryPushPacketQueue PacketQueue{..} packet = do
+ allowance' <- readTVar allowance
+ if allowance' <= 0
+ then return False
+ else do
+ writeTQueue tqueue packet
+ writeTVar allowance (allowance' - 1)
+ return True
+
diff --git a/src/Network/WireGuard/Internal/State.hs b/src/Network/WireGuard/Internal/State.hs
new file mode 100644
index 0000000..38866e8
--- /dev/null
+++ b/src/Network/WireGuard/Internal/State.hs
@@ -0,0 +1,242 @@
+{-# LANGUAGE RecordWildCards #-}
+{-# LANGUAGE TupleSections #-}
+
+module Network.WireGuard.Internal.State
+ ( PeerId
+ , Device(..)
+ , Peer(..)
+ , InitiatorWait(..)
+ , ResponderWait(..)
+ , Session(..)
+ , createDevice
+ , createPeer
+ , invalidateSessions
+ , buildRouteTables
+ , acquireEmptyIndex
+ , removeIndex
+ , nextNonce
+ , eraseInitiatorWait
+ , eraseResponderWait
+ , getSession
+ , waitForSession
+ , findSession
+ , addSession
+ , filterSessions
+ , updateTai64n
+ , updateEndPoint
+ ) where
+
+import Control.Monad (forM, when)
+import Crypto.Noise (NoiseState)
+import Crypto.Noise.Cipher.ChaChaPoly1305 (ChaChaPoly1305)
+import Crypto.Noise.DH.Curve25519 (Curve25519)
+import Crypto.Noise.Hash.BLAKE2s (BLAKE2s)
+import qualified Data.HashMap.Strict as HM
+import Data.IP (IPRange (..), IPv4, IPv6)
+import qualified Data.IP.RouteTable as RT
+import Data.Maybe (catMaybes, fromJust,
+ isNothing, mapMaybe)
+import Data.Word
+import Network.Socket.Internal (SockAddr)
+
+import Control.Concurrent.STM
+
+import Network.WireGuard.Internal.Constant
+import Network.WireGuard.Internal.Types
+
+data Device = Device
+ { intfName :: String
+ , localKey :: TVar (Maybe KeyPair)
+ , presharedKey :: TVar (Maybe PresharedKey)
+ , fwmark :: TVar Word
+ , port :: TVar Int
+ , peers :: TVar (HM.HashMap PeerId Peer)
+ , routeTable4 :: TVar (RT.IPRTable IPv4 Peer)
+ , routeTable6 :: TVar (RT.IPRTable IPv6 Peer)
+ , indexMap :: TVar (HM.HashMap Index Peer)
+ }
+
+data Peer = Peer
+ { remotePub :: !PublicKey
+ , ipmasks :: TVar [IPRange]
+ , endPoint :: TVar (Maybe SockAddr)
+ , lastHandshakeTime :: TVar (Maybe Time)
+ , receivedBytes :: TVar Word64
+ , transferredBytes :: TVar Word64
+ , keepaliveInterval :: TVar Int
+ , initiatorWait :: TVar (Maybe InitiatorWait)
+ , responderWait :: TVar (Maybe ResponderWait)
+ , sessions :: TVar [Session] -- last two active sessions
+ , lastTai64n :: TVar TAI64n
+ , lastReceiveTime :: TVar Time
+ , lastTransferTime :: TVar Time
+ , lastKeepaliveTime :: TVar Time
+ }
+
+data InitiatorWait = InitiatorWait
+ { initOurIndex :: !Index
+ , initRetryTime :: !Time
+ , initStopTime :: !Time
+ , initNoise :: !(NoiseState ChaChaPoly1305 Curve25519 BLAKE2s)
+ }
+
+data ResponderWait = ResponderWait
+ { respOurIndex :: !Index
+ , respTheirIndex :: !Index
+ , respStopTime :: !Time
+ , respSessionKey :: !SessionKey
+ }
+
+data Session = Session
+ { ourIndex :: !Index
+ , theirIndex :: !Index
+ , sessionKey :: !SessionKey
+ , renewTime :: !Time
+ , expireTime :: !Time
+ , sessionCounter :: TVar Counter
+ -- TODO: avoid nonce reuse from remote peer
+ }
+
+createDevice :: String -> STM Device
+createDevice intf = Device intf <$> newTVar Nothing
+ <*> newTVar Nothing
+ <*> newTVar 0
+ <*> newTVar 0
+ <*> newTVar HM.empty
+ <*> newTVar RT.empty
+ <*> newTVar RT.empty
+ <*> newTVar HM.empty
+
+createPeer :: PublicKey -> STM Peer
+createPeer rpub = Peer rpub <$> newTVar []
+ <*> newTVar Nothing
+ <*> newTVar Nothing
+ <*> newTVar 0
+ <*> newTVar 0
+ <*> newTVar 0
+ <*> newTVar Nothing
+ <*> newTVar Nothing
+ <*> newTVar []
+ <*> newTVar mempty
+ <*> newTVar farFuture
+ <*> newTVar farFuture
+ <*> newTVar 0
+
+invalidateSessions :: Device -> STM ()
+invalidateSessions Device{..} = do
+ writeTVar indexMap HM.empty
+ readTVar peers >>= mapM_ invalidatePeerSessions
+ where
+ invalidatePeerSessions Peer{..} = do
+ writeTVar lastHandshakeTime Nothing
+ writeTVar initiatorWait Nothing
+ writeTVar responderWait Nothing
+ writeTVar sessions []
+
+buildRouteTables :: Device -> STM ()
+buildRouteTables Device{..} = do
+ gather pickIPv4 >>= writeTVar routeTable4 . RT.fromList . concat
+ gather pickIPv6 >>= writeTVar routeTable6 . RT.fromList . concat
+ where
+ gather pick = do
+ peers' <- readTVar peers
+ forM peers' $ \peer ->
+ map (,peer) . mapMaybe pick <$> readTVar (ipmasks peer)
+ pickIPv4 (IPv4Range ipv4) = Just ipv4
+ pickIPv4 _ = Nothing
+ pickIPv6 (IPv6Range ipv6) = Just ipv6
+ pickIPv6 _ = Nothing
+
+acquireEmptyIndex :: Device -> Peer -> Index -> STM Index
+acquireEmptyIndex device peer seed = do
+ imap <- readTVar (indexMap device)
+ let findEmpty idx
+ | HM.member idx imap = findEmpty (idx * 3 + 1)
+ | otherwise = idx
+ emptyIndex = findEmpty seed
+ writeTVar (indexMap device) $ HM.insert emptyIndex peer imap
+ return emptyIndex
+
+removeIndex :: Device -> Index -> STM ()
+removeIndex device index = modifyTVar' (indexMap device) (HM.delete index)
+
+nextNonce :: Session -> STM Counter
+nextNonce Session{..} = do
+ nonce <- readTVar sessionCounter
+ writeTVar sessionCounter (nonce + 1)
+ return nonce
+
+eraseInitiatorWait :: Device -> Peer -> Maybe Index -> STM Bool
+eraseInitiatorWait device Peer{..} index = do
+ miwait <- readTVar initiatorWait
+ case miwait of
+ Just iwait | isNothing index || initOurIndex iwait == fromJust index -> do
+ writeTVar initiatorWait Nothing
+ when (isNothing index) $ removeIndex device (initOurIndex iwait)
+ return True
+ _ -> return False
+
+eraseResponderWait :: Device -> Peer -> Maybe Index -> STM Bool
+eraseResponderWait device Peer{..} index = do
+ mrwait <- readTVar responderWait
+ case mrwait of
+ Just rwait | isNothing index || respOurIndex rwait == fromJust index -> do
+ writeTVar responderWait Nothing
+ when (isNothing index) $ removeIndex device (respOurIndex rwait)
+ return True
+ _ -> return False
+
+getSession :: Peer -> IO (Maybe Session)
+getSession peer = do
+ sessions' <- readTVarIO (sessions peer)
+ case sessions' of
+ [] -> return Nothing
+ (s:_) -> return (Just s)
+
+waitForSession :: Peer -> STM Session
+waitForSession peer = do
+ sessions' <- readTVar (sessions peer)
+ case sessions' of
+ [] -> retry
+ (s:_) -> return s
+
+findSession :: Peer -> Index -> STM (Maybe (Either ResponderWait Session))
+findSession peer index = do
+ sessions' <- filter ((==index).ourIndex) <$> readTVar (sessions peer)
+ case sessions' of
+ (s:_) -> return (Just (Right s))
+ [] -> do
+ mrwait <- readTVar (responderWait peer)
+ case mrwait of
+ Just rwait | respOurIndex rwait == index -> return (Just (Left rwait))
+ _ -> return Nothing
+
+
+addSession :: Device -> Peer -> Session -> STM ()
+addSession device peer session = do
+ (toKeep, toDrop) <- splitAt maxActiveSessions . (session:) <$> readTVar (sessions peer)
+ mapM_ (removeIndex device . ourIndex) toDrop
+ writeTVar (sessions peer) toKeep
+
+filterSessions :: Device -> Peer -> (Session -> Bool) -> STM ()
+filterSessions device peer cond = do
+ sessions' <- readTVar (sessions peer)
+ filtered <- fmap catMaybes $ forM sessions' $ \session ->
+ if cond session
+ then return (Just session)
+ else do
+ removeIndex device (ourIndex session)
+ return Nothing
+ writeTVar (sessions peer) filtered
+
+updateTai64n :: Peer -> TAI64n -> STM Bool
+updateTai64n peer tai64n = do
+ lastTai64n' <- readTVar (lastTai64n peer)
+ if tai64n <= lastTai64n'
+ then return False
+ else do
+ writeTVar (lastTai64n peer) tai64n
+ return True
+
+updateEndPoint :: Peer -> SockAddr -> STM ()
+updateEndPoint peer sock = writeTVar (endPoint peer) (Just sock)
diff --git a/src/Network/WireGuard/Internal/Types.hs b/src/Network/WireGuard/Internal/Types.hs
new file mode 100644
index 0000000..3409e2a
--- /dev/null
+++ b/src/Network/WireGuard/Internal/Types.hs
@@ -0,0 +1,78 @@
+module Network.WireGuard.Internal.Types
+ ( Index
+ , Counter
+ , PeerId
+ , PublicKey
+ , PrivateKey
+ , KeyPair
+ , PresharedKey
+ , Time
+ , UdpPacket
+ , TunPacket
+ , EncryptedPayload
+ , AuthTag
+ , TAI64n
+ , SessionKey(..)
+ , WireGuardError(..)
+ , getPeerId
+ , farFuture
+ ) where
+
+import Control.Exception (Exception, SomeException)
+import qualified Crypto.Noise.DH as DH
+import Crypto.Noise.DH.Curve25519 (Curve25519)
+import Data.ByteArray (ScrubbedBytes)
+import qualified Data.ByteArray as BA
+import qualified Data.ByteString as BS
+import Foreign.C.Types (CTime (..))
+import Network.Socket (SockAddr)
+import System.Posix.Types (EpochTime)
+
+import Data.Word
+
+type Index = Word32
+type Counter = Word64
+type PeerId = BS.ByteString
+
+type PublicKey = DH.PublicKey Curve25519
+type PrivateKey = DH.SecretKey Curve25519
+type KeyPair = DH.KeyPair Curve25519
+type PresharedKey = ScrubbedBytes
+
+type Time = EpochTime
+
+type UdpPacket = (BS.ByteString, SockAddr)
+type TunPacket = ScrubbedBytes
+
+type EncryptedPayload = BS.ByteString
+type AuthTag = BS.ByteString
+type TAI64n = BS.ByteString
+
+data SessionKey = SessionKey
+ { sendKey :: !ScrubbedBytes
+ , recvKey :: !ScrubbedBytes
+ }
+
+data WireGuardError
+ = DecryptFailureError
+ | DestinationNotReachableError
+ | DeviceNotReadyError
+ | EndPointUnknownError
+ | HandshakeInitiationReplayError
+ | InvalidIPPacketError
+ | InvalidWGPacketError String
+ | NoiseError SomeException
+ | NonceReuseError
+ | OutdatedPacketError
+ | RemotePeerNotFoundError
+ | SourceAddrBlockedError
+ | UnknownIndexError
+ deriving (Show)
+
+instance Exception WireGuardError
+
+getPeerId :: PublicKey -> PeerId
+getPeerId = BA.convert . DH.dhPubToBytes
+
+farFuture :: Time
+farFuture = CTime maxBound
diff --git a/src/Network/WireGuard/Internal/Util.hs b/src/Network/WireGuard/Internal/Util.hs
new file mode 100644
index 0000000..f7ecde5
--- /dev/null
+++ b/src/Network/WireGuard/Internal/Util.hs
@@ -0,0 +1,62 @@
+{-# LANGUAGE ScopedTypeVariables #-}
+
+module Network.WireGuard.Internal.Util
+ ( retryWithBackoff
+ , ignoreSyncExceptions
+ , foreverWithBackoff
+ , catchIOExceptionAnd
+ , catchSomeExceptionAnd
+ , withJust
+ , zeroMemory
+ , copyMemory
+ ) where
+
+import Control.Concurrent (threadDelay)
+import Control.Exception (Exception (..),
+ IOException,
+ SomeAsyncException,
+ SomeException, throwIO)
+import Control.Monad.Catch (MonadCatch (..))
+import System.IO (hPutStrLn, stderr)
+
+import Foreign
+import Foreign.C
+
+import Network.WireGuard.Internal.Constant
+
+retryWithBackoff :: IO () -> IO ()
+retryWithBackoff = foreverWithBackoff . ignoreSyncExceptions
+
+ignoreSyncExceptions :: IO () -> IO ()
+ignoreSyncExceptions m = catch m handleExcept
+ where
+ handleExcept e = case fromException e of
+ Just asyncExcept -> throwIO (asyncExcept :: SomeAsyncException)
+ Nothing -> hPutStrLn stderr (displayException e) -- TODO: proper logging
+
+foreverWithBackoff :: IO () -> IO ()
+foreverWithBackoff m = loop 1
+ where
+ loop t = m >> threadDelay t >> loop (min (t * 2) retryMaxWaitTime)
+
+catchIOExceptionAnd :: MonadCatch m => m () -> m () -> m ()
+catchIOExceptionAnd what m = catch m $ \(_ :: IOException) -> what
+
+catchSomeExceptionAnd :: MonadCatch m => m () -> m () -> m ()
+catchSomeExceptionAnd what m = catch m $ \(_ :: SomeException) -> what
+
+withJust :: Monad m => m (Maybe a) -> (a -> m ()) -> m ()
+withJust mma func = do
+ ma <- mma
+ case ma of
+ Nothing -> return ()
+ Just a -> func a
+
+zeroMemory :: Ptr a -> CSize -> IO ()
+zeroMemory dest nbytes = memset dest 0 (fromIntegral nbytes)
+
+copyMemory :: Ptr a -> Ptr b -> CSize -> IO ()
+copyMemory dest src nbytes = memcpy dest src nbytes
+
+foreign import ccall unsafe "string.h" memset :: Ptr a -> CInt -> CSize -> IO ()
+foreign import ccall unsafe "string.h" memcpy :: Ptr a -> Ptr b -> CSize -> IO ()