1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
|
{-# LANGUAGE OverloadedStrings #-}
module Network.WireGuard.Internal.Noise
( NoiseStateWG
, newNoiseState
, sendFirstMessage
, recvFirstMessageAndReply
, recvSecondMessage
, encryptMessage
, decryptMessage
) where
import Control.Exception (SomeException)
import Control.Lens ((&), (.~), (^.))
import Control.Monad (unless)
import Control.Monad.Catch (throwM)
import qualified Crypto.Cipher.ChaChaPoly1305 as CCP
import Crypto.Error (throwCryptoError)
import Crypto.Noise.Cipher (cipherSymToBytes)
import Crypto.Noise.Cipher.ChaChaPoly1305 (ChaChaPoly1305)
import Crypto.Noise.DH.Curve25519 (Curve25519)
import Crypto.Noise.HandshakePatterns (noiseIK)
import Crypto.Noise.Hash.BLAKE2s (BLAKE2s)
import Crypto.Noise.Internal.CipherState (csk)
import Crypto.Noise.Internal.NoiseState (nsReceivingCipherState,
nsSendingCipherState)
import Data.ByteArray (ScrubbedBytes, convert)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.Maybe (fromJust)
import Data.Serialize (putWord64le, runPut)
import Crypto.Noise
import Network.WireGuard.Internal.Types
type NoiseStateWG = NoiseState ChaChaPoly1305 Curve25519 BLAKE2s
newNoiseState :: KeyPair -> Maybe PresharedKey -> KeyPair -> Maybe PublicKey -> HandshakeRole -> NoiseStateWG
newNoiseState staticKey presharedKey ephemeralKey remotePub role =
noiseState $ defaultHandshakeOpts noiseIK role
& hoPrologue .~ "WireGuard v0 zx2c4 Jason@zx2c4.com"
& hoLocalStatic .~ Just staticKey
& hoPreSharedKey .~ presharedKey
& hoRemoteStatic .~ remotePub
& hoLocalEphemeral .~ Just ephemeralKey
sendFirstMessage :: NoiseStateWG -> ScrubbedBytes
-> Either SomeException (ByteString, NoiseStateWG)
sendFirstMessage state0 plaintext1 = writeMessage state0 plaintext1
recvFirstMessageAndReply :: NoiseStateWG -> ByteString -> ScrubbedBytes
-> Either SomeException (ByteString, ScrubbedBytes, PublicKey, SessionKey)
recvFirstMessageAndReply state0 ciphertext1 plaintext2 = do
(plaintext1, state1) <- readMessage state0 ciphertext1
(ciphertext2, state2) <- writeMessage state1 plaintext2
unless (handshakeComplete state2) internalError
case remoteStaticKey state2 of
Nothing -> internalError
Just rpub -> return (ciphertext2, plaintext1, rpub, extractSessionKey state2)
recvSecondMessage :: NoiseStateWG -> ByteString
-> Either SomeException (ScrubbedBytes, PublicKey, SessionKey)
recvSecondMessage state1 ciphertext2 = do
(plaintext2, state2) <- readMessage state1 ciphertext2
unless (handshakeComplete state2) internalError
case remoteStaticKey state2 of
Nothing -> internalError
Just rpub -> return (plaintext2, rpub, extractSessionKey state2)
encryptMessage :: SessionKey -> Counter -> ScrubbedBytes -> (EncryptedPayload, AuthTag)
encryptMessage key counter plaintext = (ciphertext, convert authtag)
where
st0 = throwCryptoError (CCP.initialize (sendKey key) (getNonce counter))
(ciphertext, st) = CCP.encrypt (convert plaintext) st0
authtag = CCP.finalize st
decryptMessage :: SessionKey -> Counter -> (EncryptedPayload, AuthTag) -> Maybe ScrubbedBytes
decryptMessage key counter (ciphertext, authtag)
| authtag == authtagExpected = Just (convert plaintext)
| otherwise = Nothing
where
st0 = throwCryptoError (CCP.initialize (recvKey key) (getNonce counter))
(plaintext, st) = CCP.decrypt ciphertext st0
authtagExpected = convert $ CCP.finalize st
getNonce :: Counter -> CCP.Nonce
getNonce counter = throwCryptoError (CCP.nonce8 constant iv)
where
constant = BS.replicate 4 0
iv = runPut (putWord64le counter)
extractSessionKey :: NoiseStateWG -> SessionKey
extractSessionKey ns =
SessionKey (cipherSymToBytes $ fromJust (ns ^. nsSendingCipherState) ^. csk)
(cipherSymToBytes $ fromJust (ns ^. nsReceivingCipherState) ^. csk)
internalError :: Either SomeException a
internalError = throwM (InvalidHandshakeOptions "internal error")
|