summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2018-05-13 18:23:40 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2018-05-13 18:26:09 +0200
commit2c27ab205c992d3387574aa6d57780744d35d36f (patch)
treee95bb88db16bac5d2050d3db13cd28570f44cd72
parentRewrite timers and related state machines (diff)
downloadwireguard-go-2c27ab205c992d3387574aa6d57780744d35d36f.tar.xz
wireguard-go-2c27ab205c992d3387574aa6d57780744d35d36f.zip
Rework index hashtable
-rw-r--r--device.go6
-rw-r--r--indextable.go (renamed from index.go)47
-rw-r--r--keypair.go2
-rw-r--r--noise-protocol.go51
-rw-r--r--peer.go6
-rw-r--r--receive.go24
-rw-r--r--send.go20
-rw-r--r--timers.go4
8 files changed, 75 insertions, 85 deletions
diff --git a/device.go b/device.go
index e127b5b..3db3609 100644
--- a/device.go
+++ b/device.go
@@ -56,8 +56,8 @@ type Device struct {
// unprotected / "self-synchronising resources"
- indices IndexTable
- mac CookieChecker
+ indexTable IndexTable
+ mac CookieChecker
rate struct {
underLoadUntil atomic.Value
@@ -283,7 +283,7 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
// initialize noise & crypt-key routine
- device.indices.Init()
+ device.indexTable.Init()
device.routing.table.Reset()
// setup buffer pool
diff --git a/index.go b/indextable.go
index 4a78d55..2d947cd 100644
--- a/index.go
+++ b/indextable.go
@@ -7,18 +7,14 @@ package main
import (
"crypto/rand"
- "encoding/binary"
"sync"
+ "unsafe"
)
-/* Index=0 is reserved for unset indecies
- *
- */
-
type IndexTableEntry struct {
peer *Peer
handshake *Handshake
- keyPair *Keypair
+ keypair *Keypair
}
type IndexTable struct {
@@ -27,34 +23,38 @@ type IndexTable struct {
}
func randUint32() (uint32, error) {
- var buff [4]byte
- _, err := rand.Read(buff[:])
- value := binary.LittleEndian.Uint32(buff[:])
- return value, err
+ var integer [4]byte
+ _, err := rand.Read(integer[:])
+ return *(*uint32)(unsafe.Pointer(&integer[0])), err
}
func (table *IndexTable) Init() {
table.mutex.Lock()
+ defer table.mutex.Unlock()
table.table = make(map[uint32]IndexTableEntry)
- table.mutex.Unlock()
}
func (table *IndexTable) Delete(index uint32) {
- if index == 0 {
- return
- }
table.mutex.Lock()
+ defer table.mutex.Unlock()
delete(table.table, index)
- table.mutex.Unlock()
}
-func (table *IndexTable) Insert(key uint32, value IndexTableEntry) {
+func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) {
table.mutex.Lock()
- table.table[key] = value
- table.mutex.Unlock()
+ defer table.mutex.Unlock()
+ entry, ok := table.table[index]
+ if !ok {
+ return
+ }
+ table.table[index] = IndexTableEntry{
+ peer: entry.peer,
+ keypair: keypair,
+ handshake: nil,
+ }
}
-func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
+func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake) (uint32, error) {
for {
// generate random index
@@ -62,9 +62,6 @@ func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
if err != nil {
return index, err
}
- if index == 0 {
- continue
- }
// check if index used
@@ -75,7 +72,7 @@ func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
continue
}
- // map index to handshake
+ // check again while locked
table.mutex.Lock()
_, found := table.table[index]
@@ -85,8 +82,8 @@ func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
}
table.table[index] = IndexTableEntry{
peer: peer,
- handshake: &peer.handshake,
- keyPair: nil,
+ handshake: handshake,
+ keypair: nil,
}
table.mutex.Unlock()
return index, nil
diff --git a/keypair.go b/keypair.go
index 07a183d..6f6f7c0 100644
--- a/keypair.go
+++ b/keypair.go
@@ -44,6 +44,6 @@ func (kp *Keypairs) Current() *Keypair {
func (device *Device) DeleteKeypair(key *Keypair) {
if key != nil {
- device.indices.Delete(key.localIndex)
+ device.indexTable.Delete(key.localIndex)
}
}
diff --git a/noise-protocol.go b/noise-protocol.go
index 3abbe4b..82d553e 100644
--- a/noise-protocol.go
+++ b/noise-protocol.go
@@ -161,7 +161,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
defer handshake.mutex.Unlock()
if isZero(handshake.precomputedStaticStatic[:]) {
- return nil, errors.New("Static shared secret is zero")
+ return nil, errors.New("static shared secret is zero")
}
// create ephemeral key
@@ -176,8 +176,8 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
// assign index
- device.indices.Delete(handshake.localIndex)
- handshake.localIndex, err = device.indices.NewIndex(peer)
+ device.indexTable.Delete(handshake.localIndex)
+ handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
if err != nil {
return nil, err
@@ -328,14 +328,14 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
defer handshake.mutex.Unlock()
if handshake.state != HandshakeInitiationConsumed {
- return nil, errors.New("handshake initation must be consumed first")
+ return nil, errors.New("handshake initiation must be consumed first")
}
// assign index
var err error
- device.indices.Delete(handshake.localIndex)
- handshake.localIndex, err = device.indices.NewIndex(peer)
+ device.indexTable.Delete(handshake.localIndex)
+ handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
if err != nil {
return nil, err
}
@@ -393,9 +393,9 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
return nil
}
- // lookup handshake by reciever
+ // lookup handshake by receiver
- lookup := device.indices.Lookup(msg.Receiver)
+ lookup := device.indexTable.Lookup(msg.Receiver)
handshake := lookup.handshake
if handshake == nil {
return nil
@@ -528,35 +528,28 @@ func (peer *Peer) NewKeypair() *Keypair {
// create AEAD instances
- keyPair := new(Keypair)
- keyPair.send, _ = chacha20poly1305.New(sendKey[:])
- keyPair.receive, _ = chacha20poly1305.New(recvKey[:])
+ keypair := new(Keypair)
+ keypair.send, _ = chacha20poly1305.New(sendKey[:])
+ keypair.receive, _ = chacha20poly1305.New(recvKey[:])
setZero(sendKey[:])
setZero(recvKey[:])
- keyPair.created = time.Now()
- keyPair.sendNonce = 0
- keyPair.replayFilter.Init()
- keyPair.isInitiator = isInitiator
- keyPair.localIndex = peer.handshake.localIndex
- keyPair.remoteIndex = peer.handshake.remoteIndex
+ keypair.created = time.Now()
+ keypair.sendNonce = 0
+ keypair.replayFilter.Init()
+ keypair.isInitiator = isInitiator
+ keypair.localIndex = peer.handshake.localIndex
+ keypair.remoteIndex = peer.handshake.remoteIndex
// remap index
- device.indices.Insert(
- handshake.localIndex,
- IndexTableEntry{
- peer: peer,
- keyPair: keyPair,
- handshake: nil,
- },
- )
+ device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair)
handshake.localIndex = 0
// rotate key pairs
- kp := &peer.keyPairs
+ kp := &peer.keypairs
kp.mutex.Lock()
peer.timersSessionDerived()
@@ -574,14 +567,14 @@ func (peer *Peer) NewKeypair() *Keypair {
kp.previous = current
}
device.DeleteKeypair(previous)
- kp.current = keyPair
+ kp.current = keypair
} else {
- kp.next = keyPair
+ kp.next = keypair
device.DeleteKeypair(next)
kp.previous = nil
device.DeleteKeypair(previous)
}
kp.mutex.Unlock()
- return keyPair
+ return keypair
}
diff --git a/peer.go b/peer.go
index 242729e..f49f806 100644
--- a/peer.go
+++ b/peer.go
@@ -20,7 +20,7 @@ const (
type Peer struct {
isRunning AtomicBool
mutex sync.RWMutex
- keyPairs Keypairs
+ keypairs Keypairs
handshake Handshake
device *Device
endpoint Endpoint
@@ -234,7 +234,7 @@ func (peer *Peer) Stop() {
// clear key pairs
- kp := &peer.keyPairs
+ kp := &peer.keypairs
kp.mutex.Lock()
device.DeleteKeypair(kp.previous)
@@ -250,7 +250,7 @@ func (peer *Peer) Stop() {
hs := &peer.handshake
hs.mutex.Lock()
- device.indices.Delete(hs.localIndex)
+ device.indexTable.Delete(hs.localIndex)
hs.Clear()
hs.mutex.Unlock()
diff --git a/receive.go b/receive.go
index 0f22a3f..60a2510 100644
--- a/receive.go
+++ b/receive.go
@@ -31,7 +31,7 @@ type QueueInboundElement struct {
buffer *[MaxMessageSize]byte
packet []byte
counter uint64
- keyPair *Keypair
+ keypair *Keypair
endpoint Endpoint
}
@@ -107,7 +107,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
if peer.timers.sentLastMinuteHandshake {
return
}
- kp := peer.keyPairs.Current()
+ kp := peer.keypairs.Current()
if kp != nil && kp.isInitiator && time.Now().Sub(kp.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
peer.timers.sentLastMinuteHandshake = true
peer.SendHandshakeInitiation(false)
@@ -183,15 +183,15 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
receiver := binary.LittleEndian.Uint32(
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
)
- value := device.indices.Lookup(receiver)
- keyPair := value.keyPair
- if keyPair == nil {
+ value := device.indexTable.Lookup(receiver)
+ keypair := value.keypair
+ if keypair == nil {
continue
}
// check key-pair expiry
- if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
+ if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
continue
}
@@ -201,7 +201,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
elem := &QueueInboundElement{
packet: packet,
buffer: buffer,
- keyPair: keyPair,
+ keypair: keypair,
dropped: AtomicFalse,
endpoint: endpoint,
}
@@ -296,7 +296,7 @@ func (device *Device) RoutineDecryption() {
var err error
elem.counter = binary.LittleEndian.Uint64(counter)
- elem.packet, err = elem.keyPair.receive.Open(
+ elem.packet, err = elem.keypair.receive.Open(
content[:0],
nonce[:],
content,
@@ -358,7 +358,7 @@ func (device *Device) RoutineHandshake() {
// lookup peer from index
- entry := device.indices.Lookup(reply.Receiver)
+ entry := device.indexTable.Lookup(reply.Receiver)
if entry.peer == nil {
continue
@@ -587,7 +587,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
// check for replay
- if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
+ if !elem.keypair.replayFilter.ValidateCounter(elem.counter) {
continue
}
@@ -599,9 +599,9 @@ func (peer *Peer) RoutineSequentialReceiver() {
// check if using new key-pair
- kp := &peer.keyPairs
+ kp := &peer.keypairs
kp.mutex.Lock() //TODO: make this into an RW lock to reduce contention here for the equality check which is rarely true
- if kp.next == elem.keyPair {
+ if kp.next == elem.keypair {
old := kp.previous
kp.previous = kp.current
device.DeleteKeypair(old)
diff --git a/send.go b/send.go
index 1b35e27..35e0d00 100644
--- a/send.go
+++ b/send.go
@@ -47,7 +47,7 @@ type QueueOutboundElement struct {
buffer *[MaxMessageSize]byte // slice holding the packet data
packet []byte // slice of "buffer" (always!)
nonce uint64 // nonce for encryption
- keyPair *Keypair // key-pair for encryption
+ keypair *Keypair // key-pair for encryption
peer *Peer // related peer
}
@@ -161,7 +161,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
*
*/
func (peer *Peer) keepKeyFreshSending() {
- kp := peer.keyPairs.Current()
+ kp := peer.keypairs.Current()
if kp == nil {
return
}
@@ -260,7 +260,7 @@ func (peer *Peer) FlushNonceQueue() {
* Obs. A single instance per peer
*/
func (peer *Peer) RoutineNonce() {
- var keyPair *Keypair
+ var keypair *Keypair
device := peer.device
logDebug := device.log.Debug
@@ -291,9 +291,9 @@ func (peer *Peer) RoutineNonce() {
// wait for key pair
for {
- keyPair = peer.keyPairs.Current()
- if keyPair != nil && keyPair.sendNonce < RejectAfterMessages {
- if time.Now().Sub(keyPair.created) < RejectAfterTime {
+ keypair = peer.keypairs.Current()
+ if keypair != nil && keypair.sendNonce < RejectAfterMessages {
+ if time.Now().Sub(keypair.created) < RejectAfterTime {
break
}
}
@@ -328,12 +328,12 @@ func (peer *Peer) RoutineNonce() {
// populate work element
elem.peer = peer
- elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
+ elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1
// double check in case of race condition added by future code
if elem.nonce >= RejectAfterMessages {
goto NextPacket
}
- elem.keyPair = keyPair
+ elem.keypair = keypair
elem.dropped = AtomicFalse
elem.mutex.Lock()
@@ -392,7 +392,7 @@ func (device *Device) RoutineEncryption() {
fieldNonce := header[8:16]
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
- binary.LittleEndian.PutUint32(fieldReceiver, elem.keyPair.remoteIndex)
+ binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to multiple of 16
@@ -408,7 +408,7 @@ func (device *Device) RoutineEncryption() {
// encrypt content and release to consumer
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
- elem.packet = elem.keyPair.send.Seal(
+ elem.packet = elem.keypair.send.Seal(
header,
nonce[:],
elem.packet,
diff --git a/timers.go b/timers.go
index 5c72efd..9e633ee 100644
--- a/timers.go
+++ b/timers.go
@@ -108,7 +108,7 @@ func expiredZeroKeyMaterial(peer *Peer) {
hs := &peer.handshake
hs.mutex.Lock()
- kp := &peer.keyPairs
+ kp := &peer.keypairs
kp.mutex.Lock()
if kp.previous != nil {
@@ -125,7 +125,7 @@ func expiredZeroKeyMaterial(peer *Peer) {
}
kp.mutex.Unlock()
- peer.device.indices.Delete(hs.localIndex)
+ peer.device.indexTable.Delete(hs.localIndex)
hs.Clear()
hs.mutex.Unlock()
}