aboutsummaryrefslogtreecommitdiffstats
path: root/device/noise-protocol.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/noise-protocol.go')
-rw-r--r--device/noise-protocol.go251
1 files changed, 135 insertions, 116 deletions
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index 88c6aae..e8f6145 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -1,29 +1,50 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"errors"
+ "fmt"
"sync"
"time"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305"
+
"golang.zx2c4.com/wireguard/tai64n"
)
+type handshakeState int
+
const (
- HandshakeZeroed = iota
- HandshakeInitiationCreated
- HandshakeInitiationConsumed
- HandshakeResponseCreated
- HandshakeResponseConsumed
+ handshakeZeroed = handshakeState(iota)
+ handshakeInitiationCreated
+ handshakeInitiationConsumed
+ handshakeResponseCreated
+ handshakeResponseConsumed
)
+func (hs handshakeState) String() string {
+ switch hs {
+ case handshakeZeroed:
+ return "handshakeZeroed"
+ case handshakeInitiationCreated:
+ return "handshakeInitiationCreated"
+ case handshakeInitiationConsumed:
+ return "handshakeInitiationConsumed"
+ case handshakeResponseCreated:
+ return "handshakeResponseCreated"
+ case handshakeResponseConsumed:
+ return "handshakeResponseConsumed"
+ default:
+ return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
+ }
+}
+
const (
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
@@ -39,13 +60,13 @@ const (
)
const (
- MessageInitiationSize = 148 // size of handshake initation message
+ MessageInitiationSize = 148 // size of handshake initiation message
MessageResponseSize = 92 // size of response message
MessageCookieReplySize = 64 // size of cookie reply message
- MessageTransportHeaderSize = 16 // size of data preceeding content in transport message
+ MessageTransportHeaderSize = 16 // size of data preceding content in transport message
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
MessageKeepaliveSize = MessageTransportSize // size of keepalive
- MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message
+ MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message
)
const (
@@ -95,11 +116,11 @@ type MessageCookieReply struct {
}
type Handshake struct {
- state int
+ state handshakeState
mutex sync.RWMutex
hash [blake2s.Size]byte // hash value
chainKey [blake2s.Size]byte // chain key
- presharedKey NoiseSymmetricKey // psk
+ presharedKey NoisePresharedKey // psk
localEphemeral NoisePrivateKey // ephemeral secret key
localIndex uint32 // used to clear hash-table
remoteIndex uint32 // index for sending
@@ -117,11 +138,11 @@ var (
ZeroNonce [chacha20poly1305.NonceSize]byte
)
-func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) {
+func mixKey(dst, c *[blake2s.Size]byte, data []byte) {
KDF1(dst, c[:], data)
}
-func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) {
+func mixHash(dst, h *[blake2s.Size]byte, data []byte) {
hash, _ := blake2s.New256(nil)
hash.Write(h[:])
hash.Write(data)
@@ -135,7 +156,7 @@ func (h *Handshake) Clear() {
setZero(h.chainKey[:])
setZero(h.hash[:])
h.localIndex = 0
- h.state = HandshakeZeroed
+ h.state = handshakeZeroed
}
func (h *Handshake) mixHash(data []byte) {
@@ -154,7 +175,6 @@ func init() {
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
-
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -162,12 +182,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
- if isZero(handshake.precomputedStaticStatic[:]) {
- return nil, errors.New("static shared secret is zero")
- }
-
// create ephemeral key
-
var err error
handshake.hash = InitialHash
handshake.chainKey = InitialChainKey
@@ -176,59 +191,56 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
return nil, err
}
- // assign index
-
- device.indexTable.Delete(handshake.localIndex)
- handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
-
- if err != nil {
- return nil, err
- }
-
handshake.mixHash(handshake.remoteStatic[:])
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
- Sender: handshake.localIndex,
}
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
// encrypt static key
-
- func() {
- var key [chacha20poly1305.KeySize]byte
- ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
- KDF2(
- &handshake.chainKey,
- &key,
- handshake.chainKey[:],
- ss[:],
- )
- aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
- }()
+ ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ if err != nil {
+ return nil, err
+ }
+ var key [chacha20poly1305.KeySize]byte
+ KDF2(
+ &handshake.chainKey,
+ &key,
+ handshake.chainKey[:],
+ ss[:],
+ )
+ aead, _ := chacha20poly1305.New(key[:])
+ aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
handshake.mixHash(msg.Static[:])
// encrypt timestamp
-
+ if isZero(handshake.precomputedStaticStatic[:]) {
+ return nil, errInvalidPublicKey
+ }
+ KDF2(
+ &handshake.chainKey,
+ &key,
+ handshake.chainKey[:],
+ handshake.precomputedStaticStatic[:],
+ )
timestamp := tai64n.Now()
- func() {
- var key [chacha20poly1305.KeySize]byte
- KDF2(
- &handshake.chainKey,
- &key,
- handshake.chainKey[:],
- handshake.precomputedStaticStatic[:],
- )
- aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
- }()
+ aead, _ = chacha20poly1305.New(key[:])
+ aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
+
+ // assign index
+ device.indexTable.Delete(handshake.localIndex)
+ msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
+ if err != nil {
+ return nil, err
+ }
+ handshake.localIndex = msg.Sender
handshake.mixHash(msg.Timestamp[:])
- handshake.state = HandshakeInitiationCreated
+ handshake.state = handshakeInitiationCreated
return &msg, nil
}
@@ -250,16 +262,15 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key
-
- var err error
var peerPK NoisePublicKey
- func() {
- var key [chacha20poly1305.KeySize]byte
- ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
- KDF2(&chainKey, &key, chainKey[:], ss[:])
- aead, _ := chacha20poly1305.New(key[:])
- _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
- }()
+ var key [chacha20poly1305.KeySize]byte
+ ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ if err != nil {
+ return nil
+ }
+ KDF2(&chainKey, &key, chainKey[:], ss[:])
+ aead, _ := chacha20poly1305.New(key[:])
+ _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
if err != nil {
return nil
}
@@ -268,28 +279,29 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
// lookup peer
peer := device.LookupPeer(peerPK)
- if peer == nil {
+ if peer == nil || !peer.isRunning.Load() {
return nil
}
handshake := &peer.handshake
- if isZero(handshake.precomputedStaticStatic[:]) {
- return nil
- }
// verify identity
var timestamp tai64n.Timestamp
- var key [chacha20poly1305.KeySize]byte
handshake.mutex.RLock()
+
+ if isZero(handshake.precomputedStaticStatic[:]) {
+ handshake.mutex.RUnlock()
+ return nil
+ }
KDF2(
&chainKey,
&key,
chainKey[:],
handshake.precomputedStaticStatic[:],
)
- aead, _ := chacha20poly1305.New(key[:])
+ aead, _ = chacha20poly1305.New(key[:])
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
if err != nil {
handshake.mutex.RUnlock()
@@ -299,11 +311,15 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
// protect against replay & flood
- var ok bool
- ok = timestamp.After(handshake.lastTimestamp)
- ok = ok && time.Since(handshake.lastInitiationConsumption) > HandshakeInitationRate
+ replay := !timestamp.After(handshake.lastTimestamp)
+ flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate
handshake.mutex.RUnlock()
- if !ok {
+ if replay {
+ device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp)
+ return nil
+ }
+ if flood {
+ device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer)
return nil
}
@@ -322,7 +338,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
if now.After(handshake.lastInitiationConsumption) {
handshake.lastInitiationConsumption = now
}
- handshake.state = HandshakeInitiationConsumed
+ handshake.state = handshakeInitiationConsumed
handshake.mutex.Unlock()
@@ -337,7 +353,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
- if handshake.state != HandshakeInitiationConsumed {
+ if handshake.state != handshakeInitiationConsumed {
return nil, errors.New("handshake initiation must be consumed first")
}
@@ -365,12 +381,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(msg.Ephemeral[:])
handshake.mixKey(msg.Ephemeral[:])
- func() {
- ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
- handshake.mixKey(ss[:])
- ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
- handshake.mixKey(ss[:])
- }()
+ ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
+ if err != nil {
+ return nil, err
+ }
+ handshake.mixKey(ss[:])
+ ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ if err != nil {
+ return nil, err
+ }
+ handshake.mixKey(ss[:])
// add preshared key
@@ -387,13 +407,11 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(tau[:])
- func() {
- aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
- handshake.mixHash(msg.Empty[:])
- }()
+ aead, _ := chacha20poly1305.New(key[:])
+ aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
+ handshake.mixHash(msg.Empty[:])
- handshake.state = HandshakeResponseCreated
+ handshake.state = handshakeResponseCreated
return &msg, nil
}
@@ -417,13 +435,12 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
)
ok := func() bool {
-
// lock handshake state
handshake.mutex.RLock()
defer handshake.mutex.RUnlock()
- if handshake.state != HandshakeInitiationCreated {
+ if handshake.state != handshakeInitiationCreated {
return false
}
@@ -437,17 +454,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
- func() {
- ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
- mixKey(&chainKey, &chainKey, ss[:])
- setZero(ss[:])
- }()
+ ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
+ if err != nil {
+ return false
+ }
+ mixKey(&chainKey, &chainKey, ss[:])
+ setZero(ss[:])
- func() {
- ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
- mixKey(&chainKey, &chainKey, ss[:])
- setZero(ss[:])
- }()
+ ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ if err != nil {
+ return false
+ }
+ mixKey(&chainKey, &chainKey, ss[:])
+ setZero(ss[:])
// add preshared key (psk)
@@ -465,7 +484,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
// authenticate transcript
aead, _ := chacha20poly1305.New(key[:])
- _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
+ _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil {
return false
}
@@ -484,7 +503,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
handshake.hash = hash
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
- handshake.state = HandshakeResponseConsumed
+ handshake.state = handshakeResponseConsumed
handshake.mutex.Unlock()
@@ -509,7 +528,7 @@ func (peer *Peer) BeginSymmetricSession() error {
var sendKey [chacha20poly1305.KeySize]byte
var recvKey [chacha20poly1305.KeySize]byte
- if handshake.state == HandshakeResponseConsumed {
+ if handshake.state == handshakeResponseConsumed {
KDF2(
&sendKey,
&recvKey,
@@ -517,7 +536,7 @@ func (peer *Peer) BeginSymmetricSession() error {
nil,
)
isInitiator = true
- } else if handshake.state == HandshakeResponseCreated {
+ } else if handshake.state == handshakeResponseCreated {
KDF2(
&recvKey,
&sendKey,
@@ -526,7 +545,7 @@ func (peer *Peer) BeginSymmetricSession() error {
)
isInitiator = false
} else {
- return errors.New("invalid state for keypair derivation")
+ return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
}
// zero handshake
@@ -534,7 +553,7 @@ func (peer *Peer) BeginSymmetricSession() error {
setZero(handshake.chainKey[:])
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
setZero(handshake.localEphemeral[:])
- peer.handshake.state = HandshakeZeroed
+ peer.handshake.state = handshakeZeroed
// create AEAD instances
@@ -546,8 +565,7 @@ func (peer *Peer) BeginSymmetricSession() error {
setZero(recvKey[:])
keypair.created = time.Now()
- keypair.sendNonce = 0
- keypair.replayFilter.Init()
+ keypair.replayFilter.Reset()
keypair.isInitiator = isInitiator
keypair.localIndex = peer.handshake.localIndex
keypair.remoteIndex = peer.handshake.remoteIndex
@@ -564,12 +582,12 @@ func (peer *Peer) BeginSymmetricSession() error {
defer keypairs.Unlock()
previous := keypairs.previous
- next := keypairs.next
+ next := keypairs.next.Load()
current := keypairs.current
if isInitiator {
if next != nil {
- keypairs.next = nil
+ keypairs.next.Store(nil)
keypairs.previous = next
device.DeleteKeypair(current)
} else {
@@ -578,7 +596,7 @@ func (peer *Peer) BeginSymmetricSession() error {
device.DeleteKeypair(previous)
keypairs.current = keypair
} else {
- keypairs.next = keypair
+ keypairs.next.Store(keypair)
device.DeleteKeypair(next)
keypairs.previous = nil
device.DeleteKeypair(previous)
@@ -589,18 +607,19 @@ func (peer *Peer) BeginSymmetricSession() error {
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs
- if keypairs.next != receivedKeypair {
+
+ if keypairs.next.Load() != receivedKeypair {
return false
}
keypairs.Lock()
defer keypairs.Unlock()
- if keypairs.next != receivedKeypair {
+ if keypairs.next.Load() != receivedKeypair {
return false
}
old := keypairs.previous
keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old)
- keypairs.current = keypairs.next
- keypairs.next = nil
+ keypairs.current = keypairs.next.Load()
+ keypairs.next.Store(nil)
return true
}