diff options
Diffstat (limited to 'device/noise-protocol.go')
-rw-r--r-- | device/noise-protocol.go | 251 |
1 files changed, 135 insertions, 116 deletions
diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 88c6aae..e8f6145 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -1,29 +1,50 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "errors" + "fmt" "sync" "time" "golang.org/x/crypto/blake2s" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/poly1305" + "golang.zx2c4.com/wireguard/tai64n" ) +type handshakeState int + const ( - HandshakeZeroed = iota - HandshakeInitiationCreated - HandshakeInitiationConsumed - HandshakeResponseCreated - HandshakeResponseConsumed + handshakeZeroed = handshakeState(iota) + handshakeInitiationCreated + handshakeInitiationConsumed + handshakeResponseCreated + handshakeResponseConsumed ) +func (hs handshakeState) String() string { + switch hs { + case handshakeZeroed: + return "handshakeZeroed" + case handshakeInitiationCreated: + return "handshakeInitiationCreated" + case handshakeInitiationConsumed: + return "handshakeInitiationConsumed" + case handshakeResponseCreated: + return "handshakeResponseCreated" + case handshakeResponseConsumed: + return "handshakeResponseConsumed" + default: + return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs)) + } +} + const ( NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" @@ -39,13 +60,13 @@ const ( ) const ( - MessageInitiationSize = 148 // size of handshake initation message + MessageInitiationSize = 148 // size of handshake initiation message MessageResponseSize = 92 // size of response message MessageCookieReplySize = 64 // size of cookie reply message - MessageTransportHeaderSize = 16 // size of data preceeding content in transport message + MessageTransportHeaderSize = 16 // size of data preceding content in transport message MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport MessageKeepaliveSize = MessageTransportSize // size of keepalive - MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message + MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message ) const ( @@ -95,11 +116,11 @@ type MessageCookieReply struct { } type Handshake struct { - state int + state handshakeState mutex sync.RWMutex hash [blake2s.Size]byte // hash value chainKey [blake2s.Size]byte // chain key - presharedKey NoiseSymmetricKey // psk + presharedKey NoisePresharedKey // psk localEphemeral NoisePrivateKey // ephemeral secret key localIndex uint32 // used to clear hash-table remoteIndex uint32 // index for sending @@ -117,11 +138,11 @@ var ( ZeroNonce [chacha20poly1305.NonceSize]byte ) -func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) { +func mixKey(dst, c *[blake2s.Size]byte, data []byte) { KDF1(dst, c[:], data) } -func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { +func mixHash(dst, h *[blake2s.Size]byte, data []byte) { hash, _ := blake2s.New256(nil) hash.Write(h[:]) hash.Write(data) @@ -135,7 +156,7 @@ func (h *Handshake) Clear() { setZero(h.chainKey[:]) setZero(h.hash[:]) h.localIndex = 0 - h.state = HandshakeZeroed + h.state = handshakeZeroed } func (h *Handshake) mixHash(data []byte) { @@ -154,7 +175,6 @@ func init() { } func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { - device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -162,12 +182,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mutex.Lock() defer handshake.mutex.Unlock() - if isZero(handshake.precomputedStaticStatic[:]) { - return nil, errors.New("static shared secret is zero") - } - // create ephemeral key - var err error handshake.hash = InitialHash handshake.chainKey = InitialChainKey @@ -176,59 +191,56 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e return nil, err } - // assign index - - device.indexTable.Delete(handshake.localIndex) - handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake) - - if err != nil { - return nil, err - } - handshake.mixHash(handshake.remoteStatic[:]) msg := MessageInitiation{ Type: MessageInitiationType, Ephemeral: handshake.localEphemeral.publicKey(), - Sender: handshake.localIndex, } handshake.mixKey(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:]) // encrypt static key - - func() { - var key [chacha20poly1305.KeySize]byte - ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) - KDF2( - &handshake.chainKey, - &key, - handshake.chainKey[:], - ss[:], - ) - aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) - }() + ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + if err != nil { + return nil, err + } + var key [chacha20poly1305.KeySize]byte + KDF2( + &handshake.chainKey, + &key, + handshake.chainKey[:], + ss[:], + ) + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) handshake.mixHash(msg.Static[:]) // encrypt timestamp - + if isZero(handshake.precomputedStaticStatic[:]) { + return nil, errInvalidPublicKey + } + KDF2( + &handshake.chainKey, + &key, + handshake.chainKey[:], + handshake.precomputedStaticStatic[:], + ) timestamp := tai64n.Now() - func() { - var key [chacha20poly1305.KeySize]byte - KDF2( - &handshake.chainKey, - &key, - handshake.chainKey[:], - handshake.precomputedStaticStatic[:], - ) - aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) - }() + aead, _ = chacha20poly1305.New(key[:]) + aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) + + // assign index + device.indexTable.Delete(handshake.localIndex) + msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake) + if err != nil { + return nil, err + } + handshake.localIndex = msg.Sender handshake.mixHash(msg.Timestamp[:]) - handshake.state = HandshakeInitiationCreated + handshake.state = handshakeInitiationCreated return &msg, nil } @@ -250,16 +262,15 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) // decrypt static key - - var err error var peerPK NoisePublicKey - func() { - var key [chacha20poly1305.KeySize]byte - ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) - KDF2(&chainKey, &key, chainKey[:], ss[:]) - aead, _ := chacha20poly1305.New(key[:]) - _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) - }() + var key [chacha20poly1305.KeySize]byte + ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) + if err != nil { + return nil + } + KDF2(&chainKey, &key, chainKey[:], ss[:]) + aead, _ := chacha20poly1305.New(key[:]) + _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) if err != nil { return nil } @@ -268,28 +279,29 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { // lookup peer peer := device.LookupPeer(peerPK) - if peer == nil { + if peer == nil || !peer.isRunning.Load() { return nil } handshake := &peer.handshake - if isZero(handshake.precomputedStaticStatic[:]) { - return nil - } // verify identity var timestamp tai64n.Timestamp - var key [chacha20poly1305.KeySize]byte handshake.mutex.RLock() + + if isZero(handshake.precomputedStaticStatic[:]) { + handshake.mutex.RUnlock() + return nil + } KDF2( &chainKey, &key, chainKey[:], handshake.precomputedStaticStatic[:], ) - aead, _ := chacha20poly1305.New(key[:]) + aead, _ = chacha20poly1305.New(key[:]) _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) if err != nil { handshake.mutex.RUnlock() @@ -299,11 +311,15 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { // protect against replay & flood - var ok bool - ok = timestamp.After(handshake.lastTimestamp) - ok = ok && time.Since(handshake.lastInitiationConsumption) > HandshakeInitationRate + replay := !timestamp.After(handshake.lastTimestamp) + flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate handshake.mutex.RUnlock() - if !ok { + if replay { + device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp) + return nil + } + if flood { + device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer) return nil } @@ -322,7 +338,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { if now.After(handshake.lastInitiationConsumption) { handshake.lastInitiationConsumption = now } - handshake.state = HandshakeInitiationConsumed + handshake.state = handshakeInitiationConsumed handshake.mutex.Unlock() @@ -337,7 +353,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error handshake.mutex.Lock() defer handshake.mutex.Unlock() - if handshake.state != HandshakeInitiationConsumed { + if handshake.state != handshakeInitiationConsumed { return nil, errors.New("handshake initiation must be consumed first") } @@ -365,12 +381,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error handshake.mixHash(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:]) - func() { - ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) - handshake.mixKey(ss[:]) - ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) - handshake.mixKey(ss[:]) - }() + ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) + if err != nil { + return nil, err + } + handshake.mixKey(ss[:]) + ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + if err != nil { + return nil, err + } + handshake.mixKey(ss[:]) // add preshared key @@ -387,13 +407,11 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error handshake.mixHash(tau[:]) - func() { - aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) - handshake.mixHash(msg.Empty[:]) - }() + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) + handshake.mixHash(msg.Empty[:]) - handshake.state = HandshakeResponseCreated + handshake.state = handshakeResponseCreated return &msg, nil } @@ -417,13 +435,12 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { ) ok := func() bool { - // lock handshake state handshake.mutex.RLock() defer handshake.mutex.RUnlock() - if handshake.state != HandshakeInitiationCreated { + if handshake.state != handshakeInitiationCreated { return false } @@ -437,17 +454,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) - func() { - ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) - mixKey(&chainKey, &chainKey, ss[:]) - setZero(ss[:]) - }() + ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral) + if err != nil { + return false + } + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) - func() { - ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) - mixKey(&chainKey, &chainKey, ss[:]) - setZero(ss[:]) - }() + ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) + if err != nil { + return false + } + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) // add preshared key (psk) @@ -465,7 +484,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { // authenticate transcript aead, _ := chacha20poly1305.New(key[:]) - _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) + _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) if err != nil { return false } @@ -484,7 +503,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { handshake.hash = hash handshake.chainKey = chainKey handshake.remoteIndex = msg.Sender - handshake.state = HandshakeResponseConsumed + handshake.state = handshakeResponseConsumed handshake.mutex.Unlock() @@ -509,7 +528,7 @@ func (peer *Peer) BeginSymmetricSession() error { var sendKey [chacha20poly1305.KeySize]byte var recvKey [chacha20poly1305.KeySize]byte - if handshake.state == HandshakeResponseConsumed { + if handshake.state == handshakeResponseConsumed { KDF2( &sendKey, &recvKey, @@ -517,7 +536,7 @@ func (peer *Peer) BeginSymmetricSession() error { nil, ) isInitiator = true - } else if handshake.state == HandshakeResponseCreated { + } else if handshake.state == handshakeResponseCreated { KDF2( &recvKey, &sendKey, @@ -526,7 +545,7 @@ func (peer *Peer) BeginSymmetricSession() error { ) isInitiator = false } else { - return errors.New("invalid state for keypair derivation") + return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state) } // zero handshake @@ -534,7 +553,7 @@ func (peer *Peer) BeginSymmetricSession() error { setZero(handshake.chainKey[:]) setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line. setZero(handshake.localEphemeral[:]) - peer.handshake.state = HandshakeZeroed + peer.handshake.state = handshakeZeroed // create AEAD instances @@ -546,8 +565,7 @@ func (peer *Peer) BeginSymmetricSession() error { setZero(recvKey[:]) keypair.created = time.Now() - keypair.sendNonce = 0 - keypair.replayFilter.Init() + keypair.replayFilter.Reset() keypair.isInitiator = isInitiator keypair.localIndex = peer.handshake.localIndex keypair.remoteIndex = peer.handshake.remoteIndex @@ -564,12 +582,12 @@ func (peer *Peer) BeginSymmetricSession() error { defer keypairs.Unlock() previous := keypairs.previous - next := keypairs.next + next := keypairs.next.Load() current := keypairs.current if isInitiator { if next != nil { - keypairs.next = nil + keypairs.next.Store(nil) keypairs.previous = next device.DeleteKeypair(current) } else { @@ -578,7 +596,7 @@ func (peer *Peer) BeginSymmetricSession() error { device.DeleteKeypair(previous) keypairs.current = keypair } else { - keypairs.next = keypair + keypairs.next.Store(keypair) device.DeleteKeypair(next) keypairs.previous = nil device.DeleteKeypair(previous) @@ -589,18 +607,19 @@ func (peer *Peer) BeginSymmetricSession() error { func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { keypairs := &peer.keypairs - if keypairs.next != receivedKeypair { + + if keypairs.next.Load() != receivedKeypair { return false } keypairs.Lock() defer keypairs.Unlock() - if keypairs.next != receivedKeypair { + if keypairs.next.Load() != receivedKeypair { return false } old := keypairs.previous keypairs.previous = keypairs.current peer.device.DeleteKeypair(old) - keypairs.current = keypairs.next - keypairs.next = nil + keypairs.current = keypairs.next.Load() + keypairs.next.Store(nil) return true } |