aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2017-09-01 14:21:53 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2017-09-01 14:21:53 +0200
commit0294a5c0dd753786996e62236b7d8d524201ace4 (patch)
tree6e4623154072100ff402b45c2ac26fcff30da0fd /src
parentRenamed config.go to follow general naming pattern (diff)
downloadwireguard-go-0294a5c0dd753786996e62236b7d8d524201ace4.tar.xz
wireguard-go-0294a5c0dd753786996e62236b7d8d524201ace4.zip
Improved handling of key-material
Diffstat (limited to 'src')
-rw-r--r--src/keypair.go33
-rw-r--r--src/noise_helpers.go48
-rw-r--r--src/noise_protocol.go136
-rw-r--r--src/receive.go24
-rw-r--r--src/send.go19
-rw-r--r--src/timers.go32
-rw-r--r--src/tun_linux.go2
7 files changed, 203 insertions, 91 deletions
diff --git a/src/keypair.go b/src/keypair.go
index ba9c437..644d040 100644
--- a/src/keypair.go
+++ b/src/keypair.go
@@ -2,14 +2,39 @@ package main
import (
"crypto/cipher"
+ "golang.org/x/crypto/chacha20poly1305"
+ "reflect"
"sync"
"time"
)
+type safeAEAD struct {
+ mutex sync.RWMutex
+ aead cipher.AEAD
+}
+
+func (con *safeAEAD) clear() {
+ // TODO: improve handling of key material
+ con.mutex.Lock()
+ if con.aead != nil {
+ val := reflect.ValueOf(con.aead)
+ elm := val.Elem()
+ typ := elm.Type()
+ elm.Set(reflect.Zero(typ))
+ con.aead = nil
+ }
+ con.mutex.Unlock()
+}
+
+func (con *safeAEAD) setKey(key *[chacha20poly1305.KeySize]byte) {
+ // TODO: improve handling of key material
+ con.aead, _ = chacha20poly1305.New(key[:])
+}
+
type KeyPair struct {
- receive cipher.AEAD
+ send safeAEAD
+ receive safeAEAD
replayFilter ReplayFilter
- send cipher.AEAD
sendNonce uint64
isInitiator bool
created time.Time
@@ -31,7 +56,7 @@ func (kp *KeyPairs) Current() *KeyPair {
}
func (device *Device) DeleteKeyPair(key *KeyPair) {
- key.send = nil
- key.receive = nil
+ key.send.clear()
+ key.receive.clear()
device.indices.Delete(key.localIndex)
}
diff --git a/src/noise_helpers.go b/src/noise_helpers.go
index 105f78f..24302c0 100644
--- a/src/noise_helpers.go
+++ b/src/noise_helpers.go
@@ -13,37 +13,47 @@ import (
* https://tools.ietf.org/html/rfc5869
*/
-func HMAC(sum *[blake2s.Size]byte, key []byte, input []byte) {
+func HMAC1(sum *[blake2s.Size]byte, key, in0 []byte) {
mac := hmac.New(func() hash.Hash {
h, _ := blake2s.New256(nil)
return h
}, key)
- mac.Write(input)
+ mac.Write(in0)
mac.Sum(sum[:0])
}
-func KDF1(key []byte, input []byte) (t0 [blake2s.Size]byte) {
- HMAC(&t0, key, input)
- HMAC(&t0, t0[:], []byte{0x1})
+func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) {
+ mac := hmac.New(func() hash.Hash {
+ h, _ := blake2s.New256(nil)
+ return h
+ }, key)
+ mac.Write(in0)
+ mac.Write(in1)
+ mac.Sum(sum[:0])
+}
+
+func KDF1(t0 *[blake2s.Size]byte, key, input []byte) {
+ HMAC1(t0, key, input)
+ HMAC1(t0, t0[:], []byte{0x1})
return
}
-func KDF2(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byte) {
+func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) {
var prk [blake2s.Size]byte
- HMAC(&prk, key, input)
- HMAC(&t0, prk[:], []byte{0x1})
- HMAC(&t1, prk[:], append(t0[:], 0x2))
- prk = [blake2s.Size]byte{}
+ HMAC1(&prk, key, input)
+ HMAC1(t0, prk[:], []byte{0x1})
+ HMAC2(t1, prk[:], t0[:], []byte{0x2})
+ setZero(prk[:])
return
}
-func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byte, t2 [blake2s.Size]byte) {
+func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
var prk [blake2s.Size]byte
- HMAC(&prk, key, input)
- HMAC(&t0, prk[:], []byte{0x1})
- HMAC(&t1, prk[:], append(t0[:], 0x2))
- HMAC(&t2, prk[:], append(t1[:], 0x3))
- prk = [blake2s.Size]byte{}
+ HMAC1(&prk, key, input)
+ HMAC1(t0, prk[:], []byte{0x1})
+ HMAC2(t1, prk[:], t0[:], []byte{0x2})
+ HMAC2(t2, prk[:], t1[:], []byte{0x3})
+ setZero(prk[:])
return
}
@@ -55,6 +65,12 @@ func isZero(val []byte) bool {
return acc == 0
}
+func setZero(arr []byte) {
+ for i := range arr {
+ arr[i] = 0
+ }
+}
+
/* curve25519 wrappers */
func newPrivateKey() (sk NoisePrivateKey, err error) {
diff --git a/src/noise_protocol.go b/src/noise_protocol.go
index 1f1301e..a50e3dc 100644
--- a/src/noise_protocol.go
+++ b/src/noise_protocol.go
@@ -109,27 +109,31 @@ var (
ZeroNonce [chacha20poly1305.NonceSize]byte
)
-func mixKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
- return KDF1(c[:], data)
+func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) {
+ KDF1(dst, c[:], data)
}
-func mixHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
- return blake2s.Sum256(append(h[:], data...))
+func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) {
+ hsh, _ := blake2s.New256(nil)
+ hsh.Write(h[:])
+ hsh.Write(data)
+ hsh.Sum(dst[:0])
+ hsh.Reset()
}
func (h *Handshake) mixHash(data []byte) {
- h.hash = mixHash(h.hash, data)
+ mixHash(&h.hash, &h.hash, data)
}
func (h *Handshake) mixKey(data []byte) {
- h.chainKey = mixKey(h.chainKey, data)
+ mixKey(&h.chainKey, &h.chainKey, data)
}
/* Do basic precomputations
*/
func init() {
InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction))
- InitialHash = mixHash(InitialChainKey, []byte(WGIdentifier))
+ mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier))
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
@@ -176,7 +180,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
func() {
var key [chacha20poly1305.KeySize]byte
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
- handshake.chainKey, key = KDF2(handshake.chainKey[:], ss[:])
+ KDF2(
+ &handshake.chainKey,
+ &key,
+ handshake.chainKey[:],
+ ss[:],
+ )
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:])
}()
@@ -187,7 +196,9 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
timestamp := Timestamp()
func() {
var key [chacha20poly1305.KeySize]byte
- handshake.chainKey, key = KDF2(
+ KDF2(
+ &handshake.chainKey,
+ &key,
handshake.chainKey[:],
handshake.precomputedStaticStatic[:],
)
@@ -197,7 +208,6 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(msg.Timestamp[:])
handshake.state = HandshakeInitiationCreated
-
return &msg, nil
}
@@ -206,9 +216,14 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
return nil
}
- hash := mixHash(InitialHash, device.publicKey[:])
- hash = mixHash(hash, msg.Ephemeral[:])
- chainKey := mixKey(InitialChainKey, msg.Ephemeral[:])
+ var (
+ hash [blake2s.Size]byte
+ chainKey [blake2s.Size]byte
+ )
+
+ mixHash(&hash, &InitialHash, device.publicKey[:])
+ mixHash(&hash, &hash, msg.Ephemeral[:])
+ mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key
@@ -217,14 +232,14 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
func() {
var key [chacha20poly1305.KeySize]byte
ss := device.privateKey.sharedSecret(msg.Ephemeral)
- chainKey, key = KDF2(chainKey[:], ss[:])
+ KDF2(&chainKey, &key, chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
}()
if err != nil {
return nil
}
- hash = mixHash(hash, msg.Static[:])
+ mixHash(&hash, &hash, msg.Static[:])
// lookup peer
@@ -244,7 +259,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
var key [chacha20poly1305.KeySize]byte
handshake.mutex.RLock()
- chainKey, key = KDF2(
+ KDF2(
+ &chainKey,
+ &key,
chainKey[:],
handshake.precomputedStaticStatic[:],
)
@@ -254,7 +271,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
handshake.mutex.RUnlock()
return nil
}
- hash = mixHash(hash, msg.Timestamp[:])
+ mixHash(&hash, &hash, msg.Timestamp[:])
// protect against replay & flood
@@ -327,7 +344,15 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
var tau [blake2s.Size]byte
var key [chacha20poly1305.KeySize]byte
- handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:])
+
+ KDF3(
+ &handshake.chainKey,
+ &tau,
+ &key,
+ handshake.chainKey[:],
+ handshake.presharedKey[:],
+ )
+
handshake.mixHash(tau[:])
func() {
@@ -337,6 +362,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}()
handshake.state = HandshakeResponseCreated
+
return &msg, nil
}
@@ -371,22 +397,33 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
// finish 3-way DH
- hash = mixHash(handshake.hash, msg.Ephemeral[:])
- chainKey = mixKey(handshake.chainKey, msg.Ephemeral[:])
+ mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
+ mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
func() {
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
- chainKey = mixKey(chainKey, ss[:])
- ss = device.privateKey.sharedSecret(msg.Ephemeral)
- chainKey = mixKey(chainKey, ss[:])
+ mixKey(&chainKey, &chainKey, ss[:])
+ setZero(ss[:])
+ }()
+
+ func() {
+ ss := device.privateKey.sharedSecret(msg.Ephemeral)
+ mixKey(&chainKey, &chainKey, ss[:])
+ setZero(ss[:])
}()
// add preshared key (psk)
var tau [blake2s.Size]byte
var key [chacha20poly1305.KeySize]byte
- chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
- hash = mixHash(hash, tau[:])
+ KDF3(
+ &chainKey,
+ &tau,
+ &key,
+ chainKey[:],
+ handshake.presharedKey[:],
+ )
+ mixHash(&hash, &hash, tau[:])
// authenticate
@@ -396,7 +433,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
device.log.Debug.Println("failed to open")
return false
}
- hash = mixHash(hash, msg.Empty[:])
+ mixHash(&hash, &hash, msg.Empty[:])
return true
}()
@@ -415,6 +452,9 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
handshake.mutex.Unlock()
+ setZero(hash[:])
+ setZero(chainKey[:])
+
return lookup.peer
}
@@ -422,6 +462,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
*
*/
func (peer *Peer) NewKeyPair() *KeyPair {
+ device := peer.device
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
@@ -433,10 +474,20 @@ func (peer *Peer) NewKeyPair() *KeyPair {
var recvKey [chacha20poly1305.KeySize]byte
if handshake.state == HandshakeResponseConsumed {
- sendKey, recvKey = KDF2(handshake.chainKey[:], nil)
+ KDF2(
+ &sendKey,
+ &recvKey,
+ handshake.chainKey[:],
+ nil,
+ )
isInitiator = true
} else if handshake.state == HandshakeResponseCreated {
- recvKey, sendKey = KDF2(handshake.chainKey[:], nil)
+ KDF2(
+ &recvKey,
+ &sendKey,
+ handshake.chainKey[:],
+ nil,
+ )
isInitiator = false
} else {
return nil
@@ -444,16 +495,20 @@ func (peer *Peer) NewKeyPair() *KeyPair {
// zero handshake
- handshake.chainKey = [blake2s.Size]byte{}
- handshake.localEphemeral = NoisePrivateKey{}
+ setZero(handshake.chainKey[:])
+ setZero(handshake.localEphemeral[:])
peer.handshake.state = HandshakeZeroed
// create AEAD instances
keyPair := new(KeyPair)
+ keyPair.send.setKey(&sendKey)
+ keyPair.receive.setKey(&recvKey)
+
+ setZero(sendKey[:])
+ setZero(recvKey[:])
+
keyPair.created = time.Now()
- keyPair.send, _ = chacha20poly1305.New(sendKey[:])
- keyPair.receive, _ = chacha20poly1305.New(recvKey[:])
keyPair.sendNonce = 0
keyPair.replayFilter.Init()
keyPair.isInitiator = isInitiator
@@ -462,12 +517,14 @@ func (peer *Peer) NewKeyPair() *KeyPair {
// remap index
- indices := &peer.device.indices
- indices.Insert(handshake.localIndex, IndexTableEntry{
- peer: peer,
- keyPair: keyPair,
- handshake: nil,
- })
+ device.indices.Insert(
+ handshake.localIndex,
+ IndexTableEntry{
+ peer: peer,
+ keyPair: keyPair,
+ handshake: nil,
+ },
+ )
handshake.localIndex = 0
// rotate key pairs
@@ -479,7 +536,8 @@ func (peer *Peer) NewKeyPair() *KeyPair {
// TODO: Adapt kernel behavior noise.c:161
if isInitiator {
if kp.previous != nil {
- indices.Delete(kp.previous.localIndex)
+ device.DeleteKeyPair(kp.previous)
+ kp.previous = nil
}
if kp.next != nil {
diff --git a/src/receive.go b/src/receive.go
index ca7bb6e..97646d8 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -251,15 +251,22 @@ func (device *Device) RoutineDecryption() {
var err error
copy(nonce[4:], counter)
elem.counter = binary.LittleEndian.Uint64(counter)
- elem.packet, err = elem.keyPair.receive.Open(
- elem.buffer[:0],
- nonce[:],
- content,
- nil,
- )
- if err != nil {
+ elem.keyPair.receive.mutex.RLock()
+ if elem.keyPair.receive.aead == nil {
+ // very unlikely (the key was deleted during queuing)
elem.Drop()
+ } else {
+ elem.packet, err = elem.keyPair.receive.aead.Open(
+ elem.buffer[:0],
+ nonce[:],
+ content,
+ nil,
+ )
+ if err != nil {
+ elem.Drop()
+ }
}
+ elem.keyPair.receive.mutex.RUnlock()
elem.mutex.Unlock()
}
}
@@ -507,6 +514,9 @@ func (peer *Peer) RoutineSequentialReceiver() {
kp.mutex.Lock()
if kp.next == elem.keyPair {
peer.TimerHandshakeComplete()
+ if kp.previous != nil {
+ device.DeleteKeyPair(kp.previous)
+ }
kp.previous = kp.current
kp.current = kp.next
kp.next = nil
diff --git a/src/send.go b/src/send.go
index 7d4014a..c598ad4 100644
--- a/src/send.go
+++ b/src/send.go
@@ -349,12 +349,19 @@ func (device *Device) RoutineEncryption() {
// encrypt content (append to header)
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
- elem.packet = elem.keyPair.send.Seal(
- header,
- nonce[:],
- elem.packet,
- nil,
- )
+ elem.keyPair.send.mutex.RLock()
+ if elem.keyPair.send.aead == nil {
+ // very unlikely (the key was deleted during queuing)
+ elem.Drop()
+ } else {
+ elem.packet = elem.keyPair.send.aead.Seal(
+ header,
+ nonce[:],
+ elem.packet,
+ nil,
+ )
+ }
+ elem.keyPair.send.mutex.RUnlock()
elem.mutex.Unlock()
// refresh key if necessary
diff --git a/src/timers.go b/src/timers.go
index de54a96..ad8866f 100644
--- a/src/timers.go
+++ b/src/timers.go
@@ -3,7 +3,6 @@ package main
import (
"bytes"
"encoding/binary"
- "golang.org/x/crypto/blake2s"
"math/rand"
"sync/atomic"
"time"
@@ -134,7 +133,6 @@ func (peer *Peer) TimerEphemeralKeyCreated() {
func (peer *Peer) RoutineTimerHandler() {
device := peer.device
- indices := &device.indices
logDebug := device.log.Debug
logDebug.Println("Routine, timer handler, started for peer", peer.String())
@@ -186,35 +184,31 @@ func (peer *Peer) RoutineTimerHandler() {
kp := &peer.keyPairs
kp.mutex.Lock()
- // unmap indecies
+ // remove key-pairs
- indices.mutex.Lock()
if kp.previous != nil {
- delete(indices.table, kp.previous.localIndex)
+ device.DeleteKeyPair(kp.previous)
+ kp.previous = nil
}
if kp.current != nil {
- delete(indices.table, kp.current.localIndex)
+ device.DeleteKeyPair(kp.current)
+ kp.current = nil
}
if kp.next != nil {
- delete(indices.table, kp.next.localIndex)
+ device.DeleteKeyPair(kp.next)
+ kp.next = nil
}
- delete(indices.table, hs.localIndex)
- indices.mutex.Unlock()
-
- // zero out key pairs (TODO: better than wait for GC)
-
- kp.current = nil
- kp.previous = nil
- kp.next = nil
kp.mutex.Unlock()
// zero out handshake
+ device.indices.Delete(hs.localIndex)
+
hs.localIndex = 0
- hs.localEphemeral = NoisePrivateKey{}
- hs.remoteEphemeral = NoisePublicKey{}
- hs.chainKey = [blake2s.Size]byte{}
- hs.hash = [blake2s.Size]byte{}
+ setZero(hs.localEphemeral[:])
+ setZero(hs.remoteEphemeral[:])
+ setZero(hs.chainKey[:])
+ setZero(hs.hash[:])
hs.mutex.Unlock()
}
}
diff --git a/src/tun_linux.go b/src/tun_linux.go
index b9541c9..58a762a 100644
--- a/src/tun_linux.go
+++ b/src/tun_linux.go
@@ -63,6 +63,8 @@ func (tun *NativeTun) RoutineNetlinkListener() {
return
}
+ tun.events <- TUNEventUp // TODO: Fix network namespace problem
+
for msg := make([]byte, 1<<16); ; {
msgn, _, _, _, err := unix.Recvmsg(sock, msg[:], nil, 0)