aboutsummaryrefslogtreecommitdiffstats
path: root/device
diff options
context:
space:
mode:
Diffstat (limited to 'device')
-rw-r--r--device/device.go3
-rw-r--r--device/noise-protocol.go100
-rw-r--r--device/peer.go9
3 files changed, 49 insertions, 63 deletions
diff --git a/device/device.go b/device/device.go
index 0b909a7..8c08f1c 100644
--- a/device/device.go
+++ b/device/device.go
@@ -240,9 +240,6 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
for _, peer := range device.peers.keyMap {
handshake := &peer.handshake
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
- if isZero(handshake.precomputedStaticStatic[:]) {
- panic("an invalid peer public key made it into the configuration")
- }
expiredPeers = append(expiredPeers, peer)
}
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index 1c08e0a..ee327d2 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -154,6 +154,7 @@ func init() {
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
+ var errZeroECDHResult = errors.New("ECDH returned all zeros")
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -162,12 +163,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
- if isZero(handshake.precomputedStaticStatic[:]) {
- return nil, errors.New("static shared secret is zero")
- }
-
// create ephemeral key
-
var err error
handshake.hash = InitialHash
handshake.chainKey = InitialChainKey
@@ -176,56 +172,53 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
return nil, err
}
- // assign index
-
- device.indexTable.Delete(handshake.localIndex)
- handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
-
- if err != nil {
- return nil, err
- }
-
handshake.mixHash(handshake.remoteStatic[:])
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
- Sender: handshake.localIndex,
}
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
// encrypt static key
-
- func() {
- var key [chacha20poly1305.KeySize]byte
- ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
- KDF2(
- &handshake.chainKey,
- &key,
- handshake.chainKey[:],
- ss[:],
- )
- aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
- }()
+ ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ if isZero(ss[:]) {
+ return nil, errZeroECDHResult
+ }
+ var key [chacha20poly1305.KeySize]byte
+ KDF2(
+ &handshake.chainKey,
+ &key,
+ handshake.chainKey[:],
+ ss[:],
+ )
+ aead, _ := chacha20poly1305.New(key[:])
+ aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
handshake.mixHash(msg.Static[:])
// encrypt timestamp
-
+ if isZero(handshake.precomputedStaticStatic[:]) {
+ return nil, errZeroECDHResult
+ }
+ KDF2(
+ &handshake.chainKey,
+ &key,
+ handshake.chainKey[:],
+ handshake.precomputedStaticStatic[:],
+ )
timestamp := tai64n.Now()
- func() {
- var key [chacha20poly1305.KeySize]byte
- KDF2(
- &handshake.chainKey,
- &key,
- handshake.chainKey[:],
- handshake.precomputedStaticStatic[:],
- )
- aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
- }()
+ aead, _ = chacha20poly1305.New(key[:])
+ aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
+
+ // assign index
+ device.indexTable.Delete(handshake.localIndex)
+ msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
+ if err != nil {
+ return nil, err
+ }
+ handshake.localIndex = msg.Sender
handshake.mixHash(msg.Timestamp[:])
handshake.state = HandshakeInitiationCreated
@@ -250,16 +243,16 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key
-
var err error
var peerPK NoisePublicKey
- func() {
- var key [chacha20poly1305.KeySize]byte
- ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
- KDF2(&chainKey, &key, chainKey[:], ss[:])
- aead, _ := chacha20poly1305.New(key[:])
- _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
- }()
+ var key [chacha20poly1305.KeySize]byte
+ ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ if isZero(ss[:]) {
+ return nil
+ }
+ KDF2(&chainKey, &key, chainKey[:], ss[:])
+ aead, _ := chacha20poly1305.New(key[:])
+ _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
if err != nil {
return nil
}
@@ -273,23 +266,24 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
}
handshake := &peer.handshake
- if isZero(handshake.precomputedStaticStatic[:]) {
- return nil
- }
// verify identity
var timestamp tai64n.Timestamp
- var key [chacha20poly1305.KeySize]byte
handshake.mutex.RLock()
+
+ if isZero(handshake.precomputedStaticStatic[:]) {
+ handshake.mutex.RUnlock()
+ return nil
+ }
KDF2(
&chainKey,
&key,
chainKey[:],
handshake.precomputedStaticStatic[:],
)
- aead, _ := chacha20poly1305.New(key[:])
+ aead, _ = chacha20poly1305.New(key[:])
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
if err != nil {
handshake.mutex.RUnlock()
diff --git a/device/peer.go b/device/peer.go
index 91d975a..8a8224c 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -108,7 +108,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake := &peer.handshake
handshake.mutex.Lock()
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
- ssIsZero := isZero(handshake.precomputedStaticStatic[:])
handshake.remoteStatic = pk
handshake.mutex.Unlock()
@@ -116,13 +115,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.endpoint = nil
- // conditionally add
+ // add
- if !ssIsZero {
- device.peers.keyMap[pk] = peer
- } else {
- return nil, nil
- }
+ device.peers.keyMap[pk] = peer
// start peer