From a2a3e540f2d1a507b34eccae26de09066a2a12fa Mon Sep 17 00:00:00 2001 From: Bin Jin Date: Sun, 12 Mar 2017 17:48:20 +0800 Subject: Initial commit --- src/Network/WireGuard/Internal/Constant.hs | 55 ++++++ src/Network/WireGuard/Internal/IPPacket.hs | 56 ++++++ src/Network/WireGuard/Internal/Noise.hs | 98 +++++++++++ src/Network/WireGuard/Internal/Packet.hs | 112 ++++++++++++ src/Network/WireGuard/Internal/PacketQueue.hs | 49 ++++++ src/Network/WireGuard/Internal/State.hs | 242 ++++++++++++++++++++++++++ src/Network/WireGuard/Internal/Types.hs | 78 +++++++++ src/Network/WireGuard/Internal/Util.hs | 62 +++++++ 8 files changed, 752 insertions(+) create mode 100644 src/Network/WireGuard/Internal/Constant.hs create mode 100644 src/Network/WireGuard/Internal/IPPacket.hs create mode 100644 src/Network/WireGuard/Internal/Noise.hs create mode 100644 src/Network/WireGuard/Internal/Packet.hs create mode 100644 src/Network/WireGuard/Internal/PacketQueue.hs create mode 100644 src/Network/WireGuard/Internal/State.hs create mode 100644 src/Network/WireGuard/Internal/Types.hs create mode 100644 src/Network/WireGuard/Internal/Util.hs (limited to 'src/Network/WireGuard/Internal') diff --git a/src/Network/WireGuard/Internal/Constant.hs b/src/Network/WireGuard/Internal/Constant.hs new file mode 100644 index 0000000..c615a0f --- /dev/null +++ b/src/Network/WireGuard/Internal/Constant.hs @@ -0,0 +1,55 @@ +module Network.WireGuard.Internal.Constant where + +authLength :: Int +authLength = 16 + +aeadLength :: Int -> Int +aeadLength payload = payload + authLength + +keyLength :: Int +keyLength = 32 + +timestampLength :: Int +timestampLength = 12 + +mac1Length :: Int +mac1Length = 16 + +mac2Length :: Int +mac2Length = 16 + +maxQueuedUdpPackets :: Int +maxQueuedUdpPackets = 4096 + +maxQueuedTunPackets :: Int +maxQueuedTunPackets = 4096 + +udpReadBufferLength :: Int +udpReadBufferLength = 4096 + +tunReadBufferLength :: Int +tunReadBufferLength = 4096 + +retryMaxWaitTime :: Int +retryMaxWaitTime = 5 * 1000000 -- 5 seconds + +handshakeRetryTime :: Int +handshakeRetryTime = 5 + +handshakeStopTime :: Int +handshakeStopTime = 90 + +sessionRenewTime :: Int +sessionRenewTime = 120 + +sessionExpireTime :: Int +sessionExpireTime = 180 + +sessionKeepaliveTime :: Int +sessionKeepaliveTime = 10 + +maxActiveSessions :: Int +maxActiveSessions = 2 + +heartbeatWaitTime :: Int +heartbeatWaitTime = 250 * 1000 -- 0.25 second diff --git a/src/Network/WireGuard/Internal/IPPacket.hs b/src/Network/WireGuard/Internal/IPPacket.hs new file mode 100644 index 0000000..56f4461 --- /dev/null +++ b/src/Network/WireGuard/Internal/IPPacket.hs @@ -0,0 +1,56 @@ +module Network.WireGuard.Internal.IPPacket + ( IPPacket(..) + , parseIPPacket + ) where + +import qualified Data.ByteArray as BA +import Data.IP (IPv4, IPv6, fromHostAddress, + fromHostAddress6) +import Foreign.Ptr (Ptr) +import Foreign.Storable (peekByteOff) + +import Data.Bits +import Data.Word + +data IPPacket = InvalidIPPacket + | IPv4Packet { src4 :: IPv4, dest4 :: IPv4 } + | IPv6Packet { src6 :: IPv6, dest6 :: IPv6 } + +parseIPPacket :: BA.ByteArrayAccess ba => ba -> IO IPPacket +parseIPPacket packet | BA.length packet < 20 = return InvalidIPPacket +parseIPPacket packet = BA.withByteArray packet $ \ptr -> do + firstByte <- peekByteOff ptr 0 :: IO Word8 + let version = firstByte `shiftR` 4 + parse4 = do + s4 <- peekByteOff ptr 12 + d4 <- peekByteOff ptr 16 + return (IPv4Packet (fromHostAddress s4) (fromHostAddress d4)) + parse6 + | BA.length packet < 40 = return InvalidIPPacket + | otherwise = do + s6a <- peek32be ptr 8 + s6b <- peek32be ptr 12 + s6c <- peek32be ptr 16 + s6d <- peek32be ptr 20 + d6a <- peek32be ptr 24 + d6b <- peek32be ptr 28 + d6c <- peek32be ptr 32 + d6d <- peek32be ptr 36 + let s6 = (s6a, s6b, s6c, s6d) + d6 = (d6a, d6b, d6c, d6d) + return (IPv6Packet (fromHostAddress6 s6) (fromHostAddress6 d6)) + case version of + 4 -> parse4 + 6 -> parse6 + _ -> return InvalidIPPacket + +peek32be :: Ptr a -> Int -> IO Word32 +peek32be ptr offset = do + a <- peekByteOff ptr offset :: IO Word8 + b <- peekByteOff ptr (offset + 1) :: IO Word8 + c <- peekByteOff ptr (offset + 2) :: IO Word8 + d <- peekByteOff ptr (offset + 3) :: IO Word8 + return $! (fromIntegral a `unsafeShiftL` 24) .|. + (fromIntegral b `unsafeShiftL` 16) .|. + (fromIntegral c `unsafeShiftL` 8) .|. + fromIntegral d diff --git a/src/Network/WireGuard/Internal/Noise.hs b/src/Network/WireGuard/Internal/Noise.hs new file mode 100644 index 0000000..b529d25 --- /dev/null +++ b/src/Network/WireGuard/Internal/Noise.hs @@ -0,0 +1,98 @@ +{-# LANGUAGE OverloadedStrings #-} +module Network.WireGuard.Internal.Noise + ( NoiseStateWG + , newNoiseState + , sendFirstMessage + , recvFirstMessageAndReply + , recvSecondMessage + , encryptMessage + , decryptMessage + ) where + +import Control.Exception (SomeException) +import Control.Lens ((&), (.~), (^.)) +import Control.Monad (unless) +import Control.Monad.Catch (throwM) +import qualified Crypto.Cipher.ChaChaPoly1305 as CCP +import Crypto.Error (throwCryptoError) +import Crypto.Noise.Cipher (cipherSymToBytes) +import Crypto.Noise.Cipher.ChaChaPoly1305 (ChaChaPoly1305) +import Crypto.Noise.DH.Curve25519 (Curve25519) +import Crypto.Noise.HandshakePatterns (noiseIK) +import Crypto.Noise.Hash.BLAKE2s (BLAKE2s) +import Crypto.Noise.Internal.CipherState (csk) +import Crypto.Noise.Internal.NoiseState (nsReceivingCipherState, + nsSendingCipherState) +import Data.ByteArray (ScrubbedBytes, convert) +import Data.ByteString (ByteString) +import qualified Data.ByteString as BS +import Data.Maybe (fromJust) +import Data.Serialize (putWord64le, runPut) + +import Crypto.Noise + +import Network.WireGuard.Internal.Types + +type NoiseStateWG = NoiseState ChaChaPoly1305 Curve25519 BLAKE2s + +newNoiseState :: KeyPair -> Maybe PresharedKey -> KeyPair -> Maybe PublicKey -> HandshakeRole -> NoiseStateWG +newNoiseState staticKey presharedKey ephemeralKey remotePub role = + noiseState $ defaultHandshakeOpts noiseIK role + & hoPrologue .~ "WireGuard v0 zx2c4 Jason@zx2c4.com" + & hoLocalStatic .~ Just staticKey + & hoPreSharedKey .~ presharedKey + & hoRemoteStatic .~ remotePub + & hoLocalEphemeral .~ Just ephemeralKey + +sendFirstMessage :: NoiseStateWG -> ScrubbedBytes + -> Either SomeException (ByteString, NoiseStateWG) +sendFirstMessage state0 plaintext1 = writeMessage state0 plaintext1 + +recvFirstMessageAndReply :: NoiseStateWG -> ByteString -> ScrubbedBytes + -> Either SomeException (ByteString, ScrubbedBytes, PublicKey, SessionKey) +recvFirstMessageAndReply state0 ciphertext1 plaintext2 = do + (plaintext1, state1) <- readMessage state0 ciphertext1 + (ciphertext2, state2) <- writeMessage state1 plaintext2 + unless (handshakeComplete state2) internalError + case remoteStaticKey state2 of + Nothing -> internalError + Just rpub -> return (ciphertext2, plaintext1, rpub, extractSessionKey state2) + +recvSecondMessage :: NoiseStateWG -> ByteString + -> Either SomeException (ScrubbedBytes, PublicKey, SessionKey) +recvSecondMessage state1 ciphertext2 = do + (plaintext2, state2) <- readMessage state1 ciphertext2 + unless (handshakeComplete state2) internalError + case remoteStaticKey state2 of + Nothing -> internalError + Just rpub -> return (plaintext2, rpub, extractSessionKey state2) + +encryptMessage :: SessionKey -> Counter -> ScrubbedBytes -> (EncryptedPayload, AuthTag) +encryptMessage key counter plaintext = (ciphertext, convert authtag) + where + st0 = throwCryptoError (CCP.initialize (sendKey key) (getNonce counter)) + (ciphertext, st) = CCP.encrypt (convert plaintext) st0 + authtag = CCP.finalize st + +decryptMessage :: SessionKey -> Counter -> (EncryptedPayload, AuthTag) -> Maybe ScrubbedBytes +decryptMessage key counter (ciphertext, authtag) + | authtag == authtagExpected = Just (convert plaintext) + | otherwise = Nothing + where + st0 = throwCryptoError (CCP.initialize (recvKey key) (getNonce counter)) + (plaintext, st) = CCP.decrypt ciphertext st0 + authtagExpected = convert $ CCP.finalize st + +getNonce :: Counter -> CCP.Nonce +getNonce counter = throwCryptoError (CCP.nonce8 constant iv) + where + constant = BS.replicate 4 0 + iv = runPut (putWord64le counter) + +extractSessionKey :: NoiseStateWG -> SessionKey +extractSessionKey ns = + SessionKey (cipherSymToBytes $ fromJust (ns ^. nsSendingCipherState) ^. csk) + (cipherSymToBytes $ fromJust (ns ^. nsReceivingCipherState) ^. csk) + +internalError :: Either SomeException a +internalError = throwM (InvalidHandshakeOptions "internal error") diff --git a/src/Network/WireGuard/Internal/Packet.hs b/src/Network/WireGuard/Internal/Packet.hs new file mode 100644 index 0000000..ebc24fc --- /dev/null +++ b/src/Network/WireGuard/Internal/Packet.hs @@ -0,0 +1,112 @@ +{-# LANGUAGE RecordWildCards #-} + +module Network.WireGuard.Internal.Packet + ( Packet(..) + , parsePacket + , buildPacket + ) where + +import Control.Monad (replicateM_, unless, when) +import qualified Data.ByteString as BS +import Foreign.Storable (sizeOf) + +import Data.Serialize + +import Network.WireGuard.Internal.Constant +import Network.WireGuard.Internal.Types + +data Packet = HandshakeInitiation + { senderIndex :: !Index + , encryptedPayload :: !EncryptedPayload + } + | HandshakeResponse + { senderIndex :: !Index + , receiverIndex :: !Index + , encryptedPayload :: !EncryptedPayload + } + | PacketData + { receiverIndex :: !Index + , counter :: !Counter + , encryptedPayload :: !EncryptedPayload + , authTag :: !AuthTag + } + deriving (Show) + +parsePacket :: (BS.ByteString -> BS.ByteString) -> Get Packet +parsePacket getMac1 = do + packetType <- lookAhead getWord8 + case packetType of + 1 -> verifyLength (==handshakeInitiationPacketLength) $ verifyMac getMac1 parseHandshakeInitiation + 2 -> verifyLength (==handshakeResponsePacketLength) $ verifyMac getMac1 parseHandshakeResponse + 4 -> verifyLength (>=packetDataMinimumPacketLength) parsePacketData + _ -> fail "unknown packet" + where + handshakeInitiationPacketLength = 4 + indexSize + keyLength + aeadLength keyLength + aeadLength timestampLength + mac1Length + mac2Length + handshakeResponsePacketLength = 4 + indexSize + indexSize + keyLength + aeadLength 0 + mac1Length + mac2Length + packetDataMinimumPacketLength = 4 + indexSize + counterSize + aeadLength 0 + + indexSize = sizeOf (undefined :: Index) + counterSize = sizeOf (undefined :: Counter) + +parseHandshakeInitiation :: Get Packet +parseHandshakeInitiation = do + skip 4 + HandshakeInitiation <$> getWord32le <*> (remaining >>= getBytes) + +parseHandshakeResponse :: Get Packet +parseHandshakeResponse = do + skip 4 + HandshakeResponse <$> getWord32le <*> getWord32le <*> (remaining >>= getBytes) + +parsePacketData :: Get Packet +parsePacketData = do + skip 4 + PacketData <$> getWord32le <*> getWord64le <*> + (remaining >>= getBytes . subtract authLength) <*> getBytes authLength + +buildPacket :: (BS.ByteString -> BS.ByteString) -> Putter Packet +buildPacket getMac1 HandshakeInitiation{..} = appendMac getMac1 $ do + putWord8 1 + replicateM_ 3 (putWord8 0) + putWord32le senderIndex + putByteString encryptedPayload + +buildPacket getMac1 HandshakeResponse{..} = appendMac getMac1 $ do + putWord8 2 + replicateM_ 3 (putWord8 0) + putWord32le senderIndex + putWord32le receiverIndex + putByteString encryptedPayload + +buildPacket _getMac1 PacketData{..} = do + putWord8 4 + replicateM_ 3 (putWord8 0) + putWord32le receiverIndex + putWord64le counter + putByteString encryptedPayload + putByteString authTag + +verifyLength :: (Int -> Bool) -> Get a -> Get a +verifyLength check ga = do + outcome <- check <$> remaining + unless outcome $ fail "wrong packet length" + ga + +verifyMac :: (BS.ByteString -> BS.ByteString) -> Get Packet -> Get Packet +verifyMac getMac1 ga = do + bodyLength <- subtract (mac1Length + mac2Length) <$> remaining + when (bodyLength < 0) $ fail "packet too small" + expectedMac1 <- getMac1 <$> lookAhead (getBytes bodyLength) + parsed <- isolate bodyLength ga + receivedMac1 <- getBytes mac1Length + when (expectedMac1 /= receivedMac1) $ fail "wrong mac1" + skip mac2Length + return parsed + +appendMac :: (BS.ByteString -> BS.ByteString) -> Put -> Put +appendMac getMac1 p = do + -- TODO: find a smart approach to avoid extra ByteString allocation + let bs = runPut p + putByteString bs + putByteString (getMac1 bs) + replicateM_ mac2Length (putWord8 0) diff --git a/src/Network/WireGuard/Internal/PacketQueue.hs b/src/Network/WireGuard/Internal/PacketQueue.hs new file mode 100644 index 0000000..bc390f8 --- /dev/null +++ b/src/Network/WireGuard/Internal/PacketQueue.hs @@ -0,0 +1,49 @@ +{-# LANGUAGE RecordWildCards #-} + +module Network.WireGuard.Internal.PacketQueue + ( PacketQueue + , newPacketQueue + , popPacketQueue + , pushPacketQueue + , tryPushPacketQueue + ) where + +import Control.Concurrent.STM + +data PacketQueue packet = PacketQueue + { tqueue :: TQueue packet + , allowance :: TVar Int + } + +-- | Create a new PacketQueue with size limit of |maxQueuedPackets|. +newPacketQueue :: Int -> STM (PacketQueue packet) +newPacketQueue maxQueuedPackets = PacketQueue <$> newTQueue <*> newTVar maxQueuedPackets + +-- | Pop a packet out from the queue, blocks if no packet is available. +popPacketQueue :: PacketQueue packet -> STM packet +popPacketQueue PacketQueue{..} = do + packet <- readTQueue tqueue + modifyTVar' allowance (+1) + return packet + +-- | Push a packet into the queue. Blocks if it's full. +pushPacketQueue :: PacketQueue packet -> packet -> STM () +pushPacketQueue PacketQueue{..} packet = do + allowance' <- readTVar allowance + if allowance' <= 0 + then retry + else do + writeTQueue tqueue packet + writeTVar allowance (allowance' - 1) + +-- | Try to push a packet into the queue. Returns True if it's pushed. +tryPushPacketQueue :: PacketQueue packet -> packet -> STM Bool +tryPushPacketQueue PacketQueue{..} packet = do + allowance' <- readTVar allowance + if allowance' <= 0 + then return False + else do + writeTQueue tqueue packet + writeTVar allowance (allowance' - 1) + return True + diff --git a/src/Network/WireGuard/Internal/State.hs b/src/Network/WireGuard/Internal/State.hs new file mode 100644 index 0000000..38866e8 --- /dev/null +++ b/src/Network/WireGuard/Internal/State.hs @@ -0,0 +1,242 @@ +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE TupleSections #-} + +module Network.WireGuard.Internal.State + ( PeerId + , Device(..) + , Peer(..) + , InitiatorWait(..) + , ResponderWait(..) + , Session(..) + , createDevice + , createPeer + , invalidateSessions + , buildRouteTables + , acquireEmptyIndex + , removeIndex + , nextNonce + , eraseInitiatorWait + , eraseResponderWait + , getSession + , waitForSession + , findSession + , addSession + , filterSessions + , updateTai64n + , updateEndPoint + ) where + +import Control.Monad (forM, when) +import Crypto.Noise (NoiseState) +import Crypto.Noise.Cipher.ChaChaPoly1305 (ChaChaPoly1305) +import Crypto.Noise.DH.Curve25519 (Curve25519) +import Crypto.Noise.Hash.BLAKE2s (BLAKE2s) +import qualified Data.HashMap.Strict as HM +import Data.IP (IPRange (..), IPv4, IPv6) +import qualified Data.IP.RouteTable as RT +import Data.Maybe (catMaybes, fromJust, + isNothing, mapMaybe) +import Data.Word +import Network.Socket.Internal (SockAddr) + +import Control.Concurrent.STM + +import Network.WireGuard.Internal.Constant +import Network.WireGuard.Internal.Types + +data Device = Device + { intfName :: String + , localKey :: TVar (Maybe KeyPair) + , presharedKey :: TVar (Maybe PresharedKey) + , fwmark :: TVar Word + , port :: TVar Int + , peers :: TVar (HM.HashMap PeerId Peer) + , routeTable4 :: TVar (RT.IPRTable IPv4 Peer) + , routeTable6 :: TVar (RT.IPRTable IPv6 Peer) + , indexMap :: TVar (HM.HashMap Index Peer) + } + +data Peer = Peer + { remotePub :: !PublicKey + , ipmasks :: TVar [IPRange] + , endPoint :: TVar (Maybe SockAddr) + , lastHandshakeTime :: TVar (Maybe Time) + , receivedBytes :: TVar Word64 + , transferredBytes :: TVar Word64 + , keepaliveInterval :: TVar Int + , initiatorWait :: TVar (Maybe InitiatorWait) + , responderWait :: TVar (Maybe ResponderWait) + , sessions :: TVar [Session] -- last two active sessions + , lastTai64n :: TVar TAI64n + , lastReceiveTime :: TVar Time + , lastTransferTime :: TVar Time + , lastKeepaliveTime :: TVar Time + } + +data InitiatorWait = InitiatorWait + { initOurIndex :: !Index + , initRetryTime :: !Time + , initStopTime :: !Time + , initNoise :: !(NoiseState ChaChaPoly1305 Curve25519 BLAKE2s) + } + +data ResponderWait = ResponderWait + { respOurIndex :: !Index + , respTheirIndex :: !Index + , respStopTime :: !Time + , respSessionKey :: !SessionKey + } + +data Session = Session + { ourIndex :: !Index + , theirIndex :: !Index + , sessionKey :: !SessionKey + , renewTime :: !Time + , expireTime :: !Time + , sessionCounter :: TVar Counter + -- TODO: avoid nonce reuse from remote peer + } + +createDevice :: String -> STM Device +createDevice intf = Device intf <$> newTVar Nothing + <*> newTVar Nothing + <*> newTVar 0 + <*> newTVar 0 + <*> newTVar HM.empty + <*> newTVar RT.empty + <*> newTVar RT.empty + <*> newTVar HM.empty + +createPeer :: PublicKey -> STM Peer +createPeer rpub = Peer rpub <$> newTVar [] + <*> newTVar Nothing + <*> newTVar Nothing + <*> newTVar 0 + <*> newTVar 0 + <*> newTVar 0 + <*> newTVar Nothing + <*> newTVar Nothing + <*> newTVar [] + <*> newTVar mempty + <*> newTVar farFuture + <*> newTVar farFuture + <*> newTVar 0 + +invalidateSessions :: Device -> STM () +invalidateSessions Device{..} = do + writeTVar indexMap HM.empty + readTVar peers >>= mapM_ invalidatePeerSessions + where + invalidatePeerSessions Peer{..} = do + writeTVar lastHandshakeTime Nothing + writeTVar initiatorWait Nothing + writeTVar responderWait Nothing + writeTVar sessions [] + +buildRouteTables :: Device -> STM () +buildRouteTables Device{..} = do + gather pickIPv4 >>= writeTVar routeTable4 . RT.fromList . concat + gather pickIPv6 >>= writeTVar routeTable6 . RT.fromList . concat + where + gather pick = do + peers' <- readTVar peers + forM peers' $ \peer -> + map (,peer) . mapMaybe pick <$> readTVar (ipmasks peer) + pickIPv4 (IPv4Range ipv4) = Just ipv4 + pickIPv4 _ = Nothing + pickIPv6 (IPv6Range ipv6) = Just ipv6 + pickIPv6 _ = Nothing + +acquireEmptyIndex :: Device -> Peer -> Index -> STM Index +acquireEmptyIndex device peer seed = do + imap <- readTVar (indexMap device) + let findEmpty idx + | HM.member idx imap = findEmpty (idx * 3 + 1) + | otherwise = idx + emptyIndex = findEmpty seed + writeTVar (indexMap device) $ HM.insert emptyIndex peer imap + return emptyIndex + +removeIndex :: Device -> Index -> STM () +removeIndex device index = modifyTVar' (indexMap device) (HM.delete index) + +nextNonce :: Session -> STM Counter +nextNonce Session{..} = do + nonce <- readTVar sessionCounter + writeTVar sessionCounter (nonce + 1) + return nonce + +eraseInitiatorWait :: Device -> Peer -> Maybe Index -> STM Bool +eraseInitiatorWait device Peer{..} index = do + miwait <- readTVar initiatorWait + case miwait of + Just iwait | isNothing index || initOurIndex iwait == fromJust index -> do + writeTVar initiatorWait Nothing + when (isNothing index) $ removeIndex device (initOurIndex iwait) + return True + _ -> return False + +eraseResponderWait :: Device -> Peer -> Maybe Index -> STM Bool +eraseResponderWait device Peer{..} index = do + mrwait <- readTVar responderWait + case mrwait of + Just rwait | isNothing index || respOurIndex rwait == fromJust index -> do + writeTVar responderWait Nothing + when (isNothing index) $ removeIndex device (respOurIndex rwait) + return True + _ -> return False + +getSession :: Peer -> IO (Maybe Session) +getSession peer = do + sessions' <- readTVarIO (sessions peer) + case sessions' of + [] -> return Nothing + (s:_) -> return (Just s) + +waitForSession :: Peer -> STM Session +waitForSession peer = do + sessions' <- readTVar (sessions peer) + case sessions' of + [] -> retry + (s:_) -> return s + +findSession :: Peer -> Index -> STM (Maybe (Either ResponderWait Session)) +findSession peer index = do + sessions' <- filter ((==index).ourIndex) <$> readTVar (sessions peer) + case sessions' of + (s:_) -> return (Just (Right s)) + [] -> do + mrwait <- readTVar (responderWait peer) + case mrwait of + Just rwait | respOurIndex rwait == index -> return (Just (Left rwait)) + _ -> return Nothing + + +addSession :: Device -> Peer -> Session -> STM () +addSession device peer session = do + (toKeep, toDrop) <- splitAt maxActiveSessions . (session:) <$> readTVar (sessions peer) + mapM_ (removeIndex device . ourIndex) toDrop + writeTVar (sessions peer) toKeep + +filterSessions :: Device -> Peer -> (Session -> Bool) -> STM () +filterSessions device peer cond = do + sessions' <- readTVar (sessions peer) + filtered <- fmap catMaybes $ forM sessions' $ \session -> + if cond session + then return (Just session) + else do + removeIndex device (ourIndex session) + return Nothing + writeTVar (sessions peer) filtered + +updateTai64n :: Peer -> TAI64n -> STM Bool +updateTai64n peer tai64n = do + lastTai64n' <- readTVar (lastTai64n peer) + if tai64n <= lastTai64n' + then return False + else do + writeTVar (lastTai64n peer) tai64n + return True + +updateEndPoint :: Peer -> SockAddr -> STM () +updateEndPoint peer sock = writeTVar (endPoint peer) (Just sock) diff --git a/src/Network/WireGuard/Internal/Types.hs b/src/Network/WireGuard/Internal/Types.hs new file mode 100644 index 0000000..3409e2a --- /dev/null +++ b/src/Network/WireGuard/Internal/Types.hs @@ -0,0 +1,78 @@ +module Network.WireGuard.Internal.Types + ( Index + , Counter + , PeerId + , PublicKey + , PrivateKey + , KeyPair + , PresharedKey + , Time + , UdpPacket + , TunPacket + , EncryptedPayload + , AuthTag + , TAI64n + , SessionKey(..) + , WireGuardError(..) + , getPeerId + , farFuture + ) where + +import Control.Exception (Exception, SomeException) +import qualified Crypto.Noise.DH as DH +import Crypto.Noise.DH.Curve25519 (Curve25519) +import Data.ByteArray (ScrubbedBytes) +import qualified Data.ByteArray as BA +import qualified Data.ByteString as BS +import Foreign.C.Types (CTime (..)) +import Network.Socket (SockAddr) +import System.Posix.Types (EpochTime) + +import Data.Word + +type Index = Word32 +type Counter = Word64 +type PeerId = BS.ByteString + +type PublicKey = DH.PublicKey Curve25519 +type PrivateKey = DH.SecretKey Curve25519 +type KeyPair = DH.KeyPair Curve25519 +type PresharedKey = ScrubbedBytes + +type Time = EpochTime + +type UdpPacket = (BS.ByteString, SockAddr) +type TunPacket = ScrubbedBytes + +type EncryptedPayload = BS.ByteString +type AuthTag = BS.ByteString +type TAI64n = BS.ByteString + +data SessionKey = SessionKey + { sendKey :: !ScrubbedBytes + , recvKey :: !ScrubbedBytes + } + +data WireGuardError + = DecryptFailureError + | DestinationNotReachableError + | DeviceNotReadyError + | EndPointUnknownError + | HandshakeInitiationReplayError + | InvalidIPPacketError + | InvalidWGPacketError String + | NoiseError SomeException + | NonceReuseError + | OutdatedPacketError + | RemotePeerNotFoundError + | SourceAddrBlockedError + | UnknownIndexError + deriving (Show) + +instance Exception WireGuardError + +getPeerId :: PublicKey -> PeerId +getPeerId = BA.convert . DH.dhPubToBytes + +farFuture :: Time +farFuture = CTime maxBound diff --git a/src/Network/WireGuard/Internal/Util.hs b/src/Network/WireGuard/Internal/Util.hs new file mode 100644 index 0000000..f7ecde5 --- /dev/null +++ b/src/Network/WireGuard/Internal/Util.hs @@ -0,0 +1,62 @@ +{-# LANGUAGE ScopedTypeVariables #-} + +module Network.WireGuard.Internal.Util + ( retryWithBackoff + , ignoreSyncExceptions + , foreverWithBackoff + , catchIOExceptionAnd + , catchSomeExceptionAnd + , withJust + , zeroMemory + , copyMemory + ) where + +import Control.Concurrent (threadDelay) +import Control.Exception (Exception (..), + IOException, + SomeAsyncException, + SomeException, throwIO) +import Control.Monad.Catch (MonadCatch (..)) +import System.IO (hPutStrLn, stderr) + +import Foreign +import Foreign.C + +import Network.WireGuard.Internal.Constant + +retryWithBackoff :: IO () -> IO () +retryWithBackoff = foreverWithBackoff . ignoreSyncExceptions + +ignoreSyncExceptions :: IO () -> IO () +ignoreSyncExceptions m = catch m handleExcept + where + handleExcept e = case fromException e of + Just asyncExcept -> throwIO (asyncExcept :: SomeAsyncException) + Nothing -> hPutStrLn stderr (displayException e) -- TODO: proper logging + +foreverWithBackoff :: IO () -> IO () +foreverWithBackoff m = loop 1 + where + loop t = m >> threadDelay t >> loop (min (t * 2) retryMaxWaitTime) + +catchIOExceptionAnd :: MonadCatch m => m () -> m () -> m () +catchIOExceptionAnd what m = catch m $ \(_ :: IOException) -> what + +catchSomeExceptionAnd :: MonadCatch m => m () -> m () -> m () +catchSomeExceptionAnd what m = catch m $ \(_ :: SomeException) -> what + +withJust :: Monad m => m (Maybe a) -> (a -> m ()) -> m () +withJust mma func = do + ma <- mma + case ma of + Nothing -> return () + Just a -> func a + +zeroMemory :: Ptr a -> CSize -> IO () +zeroMemory dest nbytes = memset dest 0 (fromIntegral nbytes) + +copyMemory :: Ptr a -> Ptr b -> CSize -> IO () +copyMemory dest src nbytes = memcpy dest src nbytes + +foreign import ccall unsafe "string.h" memset :: Ptr a -> CInt -> CSize -> IO () +foreign import ccall unsafe "string.h" memcpy :: Ptr a -> Ptr b -> CSize -> IO () -- cgit v1.2.3-59-g8ed1b