aboutsummaryrefslogtreecommitdiffstats
path: root/device/noise-protocol.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/noise-protocol.go')
-rw-r--r--device/noise-protocol.go189
1 files changed, 140 insertions, 49 deletions
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index be92b4b..5cf1702 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -1,11 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
import (
+ "encoding/binary"
"errors"
"fmt"
"sync"
@@ -20,7 +21,6 @@ import (
type handshakeState int
-// TODO(crawshaw): add commentary describing each state and the transitions
const (
handshakeZeroed = handshakeState(iota)
handshakeInitiationCreated
@@ -116,12 +116,104 @@ type MessageCookieReply struct {
Cookie [blake2s.Size128 + poly1305.TagSize]byte
}
+var errMessageLengthMismatch = errors.New("message length mismatch")
+
+func (msg *MessageInitiation) unmarshal(b []byte) error {
+ if len(b) != MessageInitiationSize {
+ return errMessageLengthMismatch
+ }
+
+ msg.Type = binary.LittleEndian.Uint32(b)
+ msg.Sender = binary.LittleEndian.Uint32(b[4:])
+ copy(msg.Ephemeral[:], b[8:])
+ copy(msg.Static[:], b[8+len(msg.Ephemeral):])
+ copy(msg.Timestamp[:], b[8+len(msg.Ephemeral)+len(msg.Static):])
+ copy(msg.MAC1[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):])
+ copy(msg.MAC2[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp)+len(msg.MAC1):])
+
+ return nil
+}
+
+func (msg *MessageInitiation) marshal(b []byte) error {
+ if len(b) != MessageInitiationSize {
+ return errMessageLengthMismatch
+ }
+
+ binary.LittleEndian.PutUint32(b, msg.Type)
+ binary.LittleEndian.PutUint32(b[4:], msg.Sender)
+ copy(b[8:], msg.Ephemeral[:])
+ copy(b[8+len(msg.Ephemeral):], msg.Static[:])
+ copy(b[8+len(msg.Ephemeral)+len(msg.Static):], msg.Timestamp[:])
+ copy(b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):], msg.MAC1[:])
+ copy(b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp)+len(msg.MAC1):], msg.MAC2[:])
+
+ return nil
+}
+
+func (msg *MessageResponse) unmarshal(b []byte) error {
+ if len(b) != MessageResponseSize {
+ return errMessageLengthMismatch
+ }
+
+ msg.Type = binary.LittleEndian.Uint32(b)
+ msg.Sender = binary.LittleEndian.Uint32(b[4:])
+ msg.Receiver = binary.LittleEndian.Uint32(b[8:])
+ copy(msg.Ephemeral[:], b[12:])
+ copy(msg.Empty[:], b[12+len(msg.Ephemeral):])
+ copy(msg.MAC1[:], b[12+len(msg.Ephemeral)+len(msg.Empty):])
+ copy(msg.MAC2[:], b[12+len(msg.Ephemeral)+len(msg.Empty)+len(msg.MAC1):])
+
+ return nil
+}
+
+func (msg *MessageResponse) marshal(b []byte) error {
+ if len(b) != MessageResponseSize {
+ return errMessageLengthMismatch
+ }
+
+ binary.LittleEndian.PutUint32(b, msg.Type)
+ binary.LittleEndian.PutUint32(b[4:], msg.Sender)
+ binary.LittleEndian.PutUint32(b[8:], msg.Receiver)
+ copy(b[12:], msg.Ephemeral[:])
+ copy(b[12+len(msg.Ephemeral):], msg.Empty[:])
+ copy(b[12+len(msg.Ephemeral)+len(msg.Empty):], msg.MAC1[:])
+ copy(b[12+len(msg.Ephemeral)+len(msg.Empty)+len(msg.MAC1):], msg.MAC2[:])
+
+ return nil
+}
+
+func (msg *MessageCookieReply) unmarshal(b []byte) error {
+ if len(b) != MessageCookieReplySize {
+ return errMessageLengthMismatch
+ }
+
+ msg.Type = binary.LittleEndian.Uint32(b)
+ msg.Receiver = binary.LittleEndian.Uint32(b[4:])
+ copy(msg.Nonce[:], b[8:])
+ copy(msg.Cookie[:], b[8+len(msg.Nonce):])
+
+ return nil
+}
+
+func (msg *MessageCookieReply) marshal(b []byte) error {
+ if len(b) != MessageCookieReplySize {
+ return errMessageLengthMismatch
+ }
+
+ binary.LittleEndian.PutUint32(b, msg.Type)
+ binary.LittleEndian.PutUint32(b[4:], msg.Receiver)
+ copy(b[8:], msg.Nonce[:])
+ copy(b[8+len(msg.Nonce):], msg.Cookie[:])
+
+ return nil
+}
+
type Handshake struct {
state handshakeState
mutex sync.RWMutex
hash [blake2s.Size]byte // hash value
chainKey [blake2s.Size]byte // chain key
- presharedKey NoiseSymmetricKey // psk
+ presharedKey NoisePresharedKey // psk
localEphemeral NoisePrivateKey // ephemeral secret key
localIndex uint32 // used to clear hash-table
remoteIndex uint32 // index for sending
@@ -139,11 +231,11 @@ var (
ZeroNonce [chacha20poly1305.NonceSize]byte
)
-func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) {
+func mixKey(dst, c *[blake2s.Size]byte, data []byte) {
KDF1(dst, c[:], data)
}
-func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) {
+func mixHash(dst, h *[blake2s.Size]byte, data []byte) {
hash, _ := blake2s.New256(nil)
hash.Write(h[:])
hash.Write(data)
@@ -176,8 +268,6 @@ func init() {
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
- var errZeroECDHResult = errors.New("ECDH returned all zeros")
-
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -205,9 +295,9 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(msg.Ephemeral[:])
// encrypt static key
- ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
- if isZero(ss[:]) {
- return nil, errZeroECDHResult
+ ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ if err != nil {
+ return nil, err
}
var key [chacha20poly1305.KeySize]byte
KDF2(
@@ -222,7 +312,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
// encrypt timestamp
if isZero(handshake.precomputedStaticStatic[:]) {
- return nil, errZeroECDHResult
+ return nil, errInvalidPublicKey
}
KDF2(
&handshake.chainKey,
@@ -265,11 +355,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key
- var err error
var peerPK NoisePublicKey
var key [chacha20poly1305.KeySize]byte
- ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
- if isZero(ss[:]) {
+ ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ if err != nil {
return nil
}
KDF2(&chainKey, &key, chainKey[:], ss[:])
@@ -283,7 +372,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
// lookup peer
peer := device.LookupPeer(peerPK)
- if peer == nil {
+ if peer == nil || !peer.isRunning.Load() {
return nil
}
@@ -319,11 +408,11 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate
handshake.mutex.RUnlock()
if replay {
- device.log.Debug.Printf("%v - ConsumeMessageInitiation: handshake replay @ %v\n", peer, timestamp)
+ device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp)
return nil
}
if flood {
- device.log.Debug.Printf("%v - ConsumeMessageInitiation: handshake flood\n", peer)
+ device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer)
return nil
}
@@ -385,12 +474,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(msg.Ephemeral[:])
handshake.mixKey(msg.Ephemeral[:])
- func() {
- ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
- handshake.mixKey(ss[:])
- ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
- handshake.mixKey(ss[:])
- }()
+ ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
+ if err != nil {
+ return nil, err
+ }
+ handshake.mixKey(ss[:])
+ ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ if err != nil {
+ return nil, err
+ }
+ handshake.mixKey(ss[:])
// add preshared key
@@ -407,11 +500,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(tau[:])
- func() {
- aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
- handshake.mixHash(msg.Empty[:])
- }()
+ aead, _ := chacha20poly1305.New(key[:])
+ aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
+ handshake.mixHash(msg.Empty[:])
handshake.state = handshakeResponseCreated
@@ -437,7 +528,6 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
)
ok := func() bool {
-
// lock handshake state
handshake.mutex.RLock()
@@ -457,17 +547,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
- func() {
- ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
- mixKey(&chainKey, &chainKey, ss[:])
- setZero(ss[:])
- }()
+ ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
+ if err != nil {
+ return false
+ }
+ mixKey(&chainKey, &chainKey, ss[:])
+ setZero(ss[:])
- func() {
- ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
- mixKey(&chainKey, &chainKey, ss[:])
- setZero(ss[:])
- }()
+ ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ if err != nil {
+ return false
+ }
+ mixKey(&chainKey, &chainKey, ss[:])
+ setZero(ss[:])
// add preshared key (psk)
@@ -485,7 +577,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
// authenticate transcript
aead, _ := chacha20poly1305.New(key[:])
- _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
+ _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil {
return false
}
@@ -566,8 +658,7 @@ func (peer *Peer) BeginSymmetricSession() error {
setZero(recvKey[:])
keypair.created = time.Now()
- keypair.sendNonce = 0
- keypair.replayFilter.Init()
+ keypair.replayFilter.Reset()
keypair.isInitiator = isInitiator
keypair.localIndex = peer.handshake.localIndex
keypair.remoteIndex = peer.handshake.remoteIndex
@@ -584,12 +675,12 @@ func (peer *Peer) BeginSymmetricSession() error {
defer keypairs.Unlock()
previous := keypairs.previous
- next := keypairs.loadNext()
+ next := keypairs.next.Load()
current := keypairs.current
if isInitiator {
if next != nil {
- keypairs.storeNext(nil)
+ keypairs.next.Store(nil)
keypairs.previous = next
device.DeleteKeypair(current)
} else {
@@ -598,7 +689,7 @@ func (peer *Peer) BeginSymmetricSession() error {
device.DeleteKeypair(previous)
keypairs.current = keypair
} else {
- keypairs.storeNext(keypair)
+ keypairs.next.Store(keypair)
device.DeleteKeypair(next)
keypairs.previous = nil
device.DeleteKeypair(previous)
@@ -610,18 +701,18 @@ func (peer *Peer) BeginSymmetricSession() error {
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs
- if keypairs.loadNext() != receivedKeypair {
+ if keypairs.next.Load() != receivedKeypair {
return false
}
keypairs.Lock()
defer keypairs.Unlock()
- if keypairs.loadNext() != receivedKeypair {
+ if keypairs.next.Load() != receivedKeypair {
return false
}
old := keypairs.previous
keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old)
- keypairs.current = keypairs.loadNext()
- keypairs.storeNext(nil)
+ keypairs.current = keypairs.next.Load()
+ keypairs.next.Store(nil)
return true
}