aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2017-11-30 23:22:40 +0100
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2017-11-30 23:22:40 +0100
commit02ce67294cd28bde9d61924fe6d0365638cc924e (patch)
treea30445a4e15e6ab651525e9dd40ba6399e00d1d8
parentFixed typos (diff)
downloadwireguard-go-02ce67294cd28bde9d61924fe6d0365638cc924e.tar.xz
wireguard-go-02ce67294cd28bde9d61924fe6d0365638cc924e.zip
Refactor timers.go
Diffstat (limited to '')
-rw-r--r--src/noise_protocol.go3
-rw-r--r--src/peer.go59
-rw-r--r--src/receive.go3
-rw-r--r--src/send.go19
-rw-r--r--src/signal.go45
-rw-r--r--src/timer.go65
-rw-r--r--src/timers.go214
-rw-r--r--src/uapi.go4
8 files changed, 249 insertions, 163 deletions
diff --git a/src/noise_protocol.go b/src/noise_protocol.go
index 9e5fdd8..2f9e1d5 100644
--- a/src/noise_protocol.go
+++ b/src/noise_protocol.go
@@ -532,7 +532,6 @@ func (peer *Peer) NewKeyPair() *KeyPair {
kp := &peer.keyPairs
kp.mutex.Lock()
- // TODO: Adapt kernel behavior noise.c:161
if isInitiator {
if kp.previous != nil {
device.DeleteKeyPair(kp.previous)
@@ -545,7 +544,7 @@ func (peer *Peer) NewKeyPair() *KeyPair {
} else {
kp.previous = kp.current
kp.current = keyPair
- signalSend(peer.signal.newKeyPair) // TODO: This more places (after confirming the key)
+ peer.signal.newKeyPair.Send()
}
} else {
diff --git a/src/peer.go b/src/peer.go
index f3eb6c2..f582556 100644
--- a/src/peer.go
+++ b/src/peer.go
@@ -28,30 +28,26 @@ type Peer struct {
nextKeepalive time.Time
}
signal struct {
- newKeyPair chan struct{} // (size 1) : a new key pair was generated
- handshakeBegin chan struct{} // (size 1) : request that a new handshake be started ("queue handshake")
- handshakeCompleted chan struct{} // (size 1) : handshake completed
- handshakeReset chan struct{} // (size 1) : reset handshake negotiation state
- flushNonceQueue chan struct{} // (size 1) : empty queued packets
- messageSend chan struct{} // (size 1) : a message was send to the peer
- messageReceived chan struct{} // (size 1) : an authenticated message was received
- stop chan struct{} // (size 0) : close to stop all goroutines for peer
+ newKeyPair Signal // size 1, new key pair was generated
+ handshakeCompleted Signal // size 1, handshake completed
+ handshakeBegin Signal // size 1, begin new handshake begin
+ flushNonceQueue Signal // size 1, empty queued packets
+ messageSend Signal // size 1, message was send to peer
+ messageReceived Signal // size 1, authenticated message recv
+ stop Signal // size 0, stop all goroutines
}
timer struct {
// state related to WireGuard timers
- keepalivePersistent *time.Timer // set for persistent keepalives
- keepalivePassive *time.Timer // set upon recieving messages
- newHandshake *time.Timer // begin a new handshake (after Keepalive + RekeyTimeout)
- zeroAllKeys *time.Timer // zero all key material (after RejectAfterTime*3)
- handshakeDeadline *time.Timer // Current handshake must be completed
+ keepalivePersistent Timer // set for persistent keepalives
+ keepalivePassive Timer // set upon recieving messages
+ newHandshake Timer // begin a new handshake (stale)
+ zeroAllKeys Timer // zero all key material
+ handshakeDeadline Timer // complete handshake timeout
+ handshakeTimeout Timer // current handshake message timeout
- pendingKeepalivePassive bool
- pendingNewHandshake bool
- pendingZeroAllKeys bool
-
- needAnotherKeepalive bool
sendLastMinuteHandshake bool
+ needAnotherKeepalive bool
}
queue struct {
nonce chan *QueueOutboundElement // nonce / pre-handshake queue
@@ -71,10 +67,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.mac.Init(pk)
peer.device = device
- peer.timer.keepalivePersistent = NewStoppedTimer()
- peer.timer.keepalivePassive = NewStoppedTimer()
- peer.timer.newHandshake = NewStoppedTimer()
- peer.timer.zeroAllKeys = NewStoppedTimer()
+ peer.timer.keepalivePersistent = NewTimer()
+ peer.timer.keepalivePassive = NewTimer()
+ peer.timer.newHandshake = NewTimer()
+ peer.timer.zeroAllKeys = NewTimer()
+ peer.timer.handshakeDeadline = NewTimer()
+ peer.timer.handshakeTimeout = NewTimer()
// assign id for debugging
@@ -102,7 +100,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake := &peer.handshake
handshake.mutex.Lock()
handshake.remoteStatic = pk
- handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
+ handshake.precomputedStaticStatic =
+ device.privateKey.sharedSecret(handshake.remoteStatic)
handshake.mutex.Unlock()
// reset endpoint
@@ -117,16 +116,14 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// prepare signaling & routines
- peer.signal.stop = make(chan struct{})
- peer.signal.newKeyPair = make(chan struct{}, 1)
- peer.signal.handshakeBegin = make(chan struct{}, 1)
- peer.signal.handshakeReset = make(chan struct{}, 1)
- peer.signal.handshakeCompleted = make(chan struct{}, 1)
- peer.signal.flushNonceQueue = make(chan struct{}, 1)
+ peer.signal.stop = NewSignal()
+ peer.signal.newKeyPair = NewSignal()
+ peer.signal.handshakeBegin = NewSignal()
+ peer.signal.handshakeCompleted = NewSignal()
+ peer.signal.flushNonceQueue = NewSignal()
go peer.RoutineNonce()
go peer.RoutineTimerHandler()
- go peer.RoutineHandshakeInitiator()
go peer.RoutineSequentialSender()
go peer.RoutineSequentialReceiver()
@@ -163,5 +160,5 @@ func (peer *Peer) String() string {
}
func (peer *Peer) Close() {
- close(peer.signal.stop)
+ peer.signal.stop.Broadcast()
}
diff --git a/src/receive.go b/src/receive.go
index 0b0efbf..7d493b0 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -482,7 +482,8 @@ func (peer *Peer) RoutineSequentialReceiver() {
for {
select {
- case <-peer.signal.stop:
+
+ case <-peer.signal.stop.Wait():
logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id)
return
diff --git a/src/send.go b/src/send.go
index 52872f6..35a4a6e 100644
--- a/src/send.go
+++ b/src/send.go
@@ -164,7 +164,7 @@ func (device *Device) RoutineReadFromTUN() {
// insert into nonce/pre-handshake queue
- signalSend(peer.signal.handshakeReset)
+ peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
addToOutboundQueue(peer.queue.nonce, elem)
elem = device.NewOutboundElement()
}
@@ -186,7 +186,7 @@ func (peer *Peer) RoutineNonce() {
for {
NextPacket:
select {
- case <-peer.signal.stop:
+ case <-peer.signal.stop.Wait():
return
case elem := <-peer.queue.nonce:
@@ -201,16 +201,17 @@ func (peer *Peer) RoutineNonce() {
}
}
- signalSend(peer.signal.handshakeBegin)
+ peer.signal.handshakeBegin.Send()
+
logDebug.Println("Awaiting key-pair for", peer.String())
select {
- case <-peer.signal.newKeyPair:
- case <-peer.signal.flushNonceQueue:
+ case <-peer.signal.newKeyPair.Wait():
+ case <-peer.signal.flushNonceQueue.Wait():
logDebug.Println("Clearing queue for", peer.String())
peer.FlushNonceQueue()
goto NextPacket
- case <-peer.signal.stop:
+ case <-peer.signal.stop.Wait():
return
}
}
@@ -309,8 +310,10 @@ func (peer *Peer) RoutineSequentialSender() {
for {
select {
- case <-peer.signal.stop:
- logDebug.Println("Routine, sequential sender, stopped for", peer.String())
+
+ case <-peer.signal.stop.Wait():
+ logDebug.Println(
+ "Routine, sequential sender, stopped for", peer.String())
return
case elem := <-peer.queue.outbound:
diff --git a/src/signal.go b/src/signal.go
new file mode 100644
index 0000000..96b21bb
--- /dev/null
+++ b/src/signal.go
@@ -0,0 +1,45 @@
+package main
+
+type Signal struct {
+ enabled AtomicBool
+ C chan struct{}
+}
+
+func NewSignal() (s Signal) {
+ s.C = make(chan struct{}, 1)
+ s.Enable()
+ return
+}
+
+func (s *Signal) Disable() {
+ s.enabled.Set(false)
+ s.Clear()
+}
+
+func (s *Signal) Enable() {
+ s.enabled.Set(true)
+}
+
+func (s *Signal) Send() {
+ if s.enabled.Get() {
+ select {
+ case s.C <- struct{}{}:
+ default:
+ }
+ }
+}
+
+func (s Signal) Clear() {
+ select {
+ case <-s.C:
+ default:
+ }
+}
+
+func (s Signal) Broadcast() {
+ close(s.C) // unblocks all selectors
+}
+
+func (s Signal) Wait() chan struct{} {
+ return s.C
+}
diff --git a/src/timer.go b/src/timer.go
new file mode 100644
index 0000000..3def253
--- /dev/null
+++ b/src/timer.go
@@ -0,0 +1,65 @@
+package main
+
+import (
+ "time"
+)
+
+type Timer struct {
+ pending AtomicBool
+ timer *time.Timer
+}
+
+/* Starts the timer if not already pending
+ */
+func (t *Timer) Start(dur time.Duration) bool {
+ set := t.pending.Swap(true)
+ if !set {
+ t.timer.Reset(dur)
+ return true
+ }
+ return false
+}
+
+/* Stops the timer
+ */
+func (t *Timer) Stop() {
+ set := t.pending.Swap(true)
+ if set {
+ t.timer.Stop()
+ select {
+ case <-t.timer.C:
+ default:
+ }
+ }
+ t.pending.Set(false)
+}
+
+func (t *Timer) Pending() bool {
+ return t.pending.Get()
+}
+
+func (t *Timer) Reset(dur time.Duration) {
+ t.pending.Set(false)
+ t.Start(dur)
+}
+
+func (t *Timer) Push(dur time.Duration) {
+ if t.pending.Get() {
+ t.Reset(dur)
+ }
+}
+
+func (t *Timer) Wait() <-chan time.Time {
+ return t.timer.C
+}
+
+func NewTimer() (t Timer) {
+ t.pending.Set(false)
+ t.timer = time.NewTimer(0)
+ t.timer.Stop()
+ select {
+ case <-t.timer.C:
+ default:
+ }
+ return
+}
diff --git a/src/timers.go b/src/timers.go
index 5848b2a..64aeca8 100644
--- a/src/timers.go
+++ b/src/timers.go
@@ -18,10 +18,10 @@ func (peer *Peer) KeepKeyFreshSending() {
}
nonce := atomic.LoadUint64(&kp.sendNonce)
if nonce > RekeyAfterMessages {
- signalSend(peer.signal.handshakeBegin)
+ peer.signal.handshakeBegin.Send()
}
if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime {
- signalSend(peer.signal.handshakeBegin)
+ peer.signal.handshakeBegin.Send()
}
}
@@ -44,7 +44,7 @@ func (peer *Peer) KeepKeyFreshReceiving() {
send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
if send {
// do a last minute attempt at initiating a new handshake
- signalSend(peer.signal.handshakeBegin)
+ peer.signal.handshakeBegin.Send()
peer.timer.sendLastMinuteHandshake = true
}
}
@@ -69,34 +69,36 @@ func (peer *Peer) SendKeepAlive() bool {
* Sent non-empty (authenticated) transport message
*/
func (peer *Peer) TimerDataSent() {
- timerStop(peer.timer.keepalivePassive)
- if !peer.timer.pendingNewHandshake {
- peer.timer.pendingNewHandshake = true
+ peer.timer.keepalivePassive.Stop()
+ if peer.timer.newHandshake.Pending() {
peer.timer.newHandshake.Reset(NewHandshakeTime)
}
}
/* Event:
* Received non-empty (authenticated) transport message
+ *
+ * Action:
+ * Set a timer to confirm the message using a keep-alive (if not already set)
*/
func (peer *Peer) TimerDataReceived() {
- if peer.timer.pendingKeepalivePassive {
+ if !peer.timer.keepalivePassive.Start(KeepaliveTimeout) {
peer.timer.needAnotherKeepalive = true
- return
}
- peer.timer.pendingKeepalivePassive = false
- peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
}
/* Event:
* Any (authenticated) packet received
*/
func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {
- timerStop(peer.timer.newHandshake)
+ peer.timer.newHandshake.Stop()
}
/* Event:
* Any authenticated packet send / received.
+ *
+ * Action:
+ * Push persistent keep-alive into the future
*/
func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
@@ -117,7 +119,7 @@ func (peer *Peer) TimerHandshakeComplete() {
&peer.stats.lastHandshakeNano,
time.Now().UnixNano(),
)
- signalSend(peer.signal.handshakeCompleted)
+ peer.signal.handshakeCompleted.Send()
peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
}
@@ -129,7 +131,8 @@ func (peer *Peer) TimerHandshakeComplete() {
* CreateMessageInitiation
* CreateMessageResponse
*
- * Schedules the deletion of all key material
+ * Action:
+ * Schedule the deletion of all key material
* upon failure to complete a handshake
*/
func (peer *Peer) TimerEphemeralKeyCreated() {
@@ -139,18 +142,18 @@ func (peer *Peer) TimerEphemeralKeyCreated() {
func (peer *Peer) RoutineTimerHandler() {
device := peer.device
+ logInfo := device.log.Info
logDebug := device.log.Debug
logDebug.Println("Routine, timer handler, started for peer", peer.String())
for {
select {
- case <-peer.signal.stop:
- return
+ /* timers */
- // keep-alives
+ // keep-alive
- case <-peer.timer.keepalivePersistent.C:
+ case <-peer.timer.keepalivePersistent.Wait():
interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
if interval > 0 {
@@ -158,7 +161,7 @@ func (peer *Peer) RoutineTimerHandler() {
peer.SendKeepAlive()
}
- case <-peer.timer.keepalivePassive.C:
+ case <-peer.timer.keepalivePassive.Wait():
logDebug.Println("Sending keep-alive to", peer.String())
@@ -169,17 +172,9 @@ func (peer *Peer) RoutineTimerHandler() {
peer.timer.needAnotherKeepalive = false
}
- // unresponsive session
+ // clear key material timer
- case <-peer.timer.newHandshake.C:
-
- logDebug.Println("Retrying handshake with", peer.String(), "due to lack of reply")
-
- signalSend(peer.signal.handshakeBegin)
-
- // clear key material
-
- case <-peer.timer.zeroAllKeys.C:
+ case <-peer.timer.zeroAllKeys.Wait():
logDebug.Println("Clearing all key material for", peer.String())
@@ -215,125 +210,106 @@ func (peer *Peer) RoutineTimerHandler() {
setZero(hs.chainKey[:])
setZero(hs.hash[:])
hs.mutex.Unlock()
- }
- }
-}
-/* This is the state machine for handshake initiation
- *
- * Associated with this routine is the signal "handshakeBegin"
- * The routine will read from the "handshakeBegin" channel
- * at most every RekeyTimeout seconds
- */
-func (peer *Peer) RoutineHandshakeInitiator() {
- device := peer.device
+ // handshake timers
- logInfo := device.log.Info
- logError := device.log.Error
- logDebug := device.log.Debug
- logDebug.Println("Routine, handshake initiator, started for", peer.String())
+ case <-peer.timer.newHandshake.Wait():
+ logInfo.Println("Retrying handshake with", peer.String())
+ peer.signal.handshakeBegin.Send()
- var temp [256]byte
+ case <-peer.timer.handshakeTimeout.Wait():
- for {
+ // clear source (in case this is causing problems)
- // wait for signal
+ peer.mutex.Lock()
+ if peer.endpoint != nil {
+ peer.endpoint.ClearSrc()
+ }
+ peer.mutex.Unlock()
- select {
- case <-peer.signal.handshakeBegin:
- case <-peer.signal.stop:
- return
- }
+ // send new handshake
- // set deadline
+ err := peer.sendNewHandshake()
+ if err != nil {
+ logInfo.Println(
+ "Failed to send handshake to peer:", peer.String())
+ }
- BeginHandshakes:
+ case <-peer.timer.handshakeDeadline.Wait():
- signalClear(peer.signal.handshakeReset)
- deadline := time.NewTimer(RekeyAttemptTime)
+ // clear all queued packets and stop keep-alive
- AttemptHandshakes:
+ logInfo.Println(
+ "Handshake negotiation timed out for:", peer.String())
- for attempts := uint(1); ; attempts++ {
+ peer.signal.flushNonceQueue.Send()
+ peer.timer.keepalivePersistent.Stop()
+ peer.signal.handshakeBegin.Enable()
- // check if deadline reached
+ /* signals */
- select {
- case <-deadline.C:
- logInfo.Println("Handshake negotiation timed out for:", peer.String())
- signalSend(peer.signal.flushNonceQueue)
- timerStop(peer.timer.keepalivePersistent)
- break
- case <-peer.signal.stop:
- return
- default:
- }
+ case <-peer.signal.stop.Wait():
+ return
- signalClear(peer.signal.handshakeCompleted)
+ case <-peer.signal.handshakeBegin.Wait():
- // create initiation message
+ peer.signal.handshakeBegin.Disable()
- msg, err := peer.device.CreateMessageInitiation(peer)
+ err := peer.sendNewHandshake()
if err != nil {
- logError.Println("Failed to create handshake initiation message:", err)
- break AttemptHandshakes
+ logInfo.Println(
+ "Failed to send handshake to peer:", peer.String())
}
- // marshal handshake message
-
- writer := bytes.NewBuffer(temp[:0])
- binary.Write(writer, binary.LittleEndian, msg)
- packet := writer.Bytes()
- peer.mac.AddMacs(packet)
-
- // send to endpoint
-
- err = peer.SendBuffer(packet)
- jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
- timeout := time.NewTimer(RekeyTimeout + jitter)
- if err == nil {
- peer.TimerAnyAuthenticatedPacketTraversal()
- logDebug.Println(
- "Handshake initiation attempt",
- attempts, "sent to", peer.String(),
- )
- } else {
- logError.Println(
- "Failed to send handshake initiation message to",
- peer.String(), ":", err,
- )
- }
+ peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
- // wait for handshake or timeout
+ case <-peer.signal.handshakeCompleted.Wait():
- select {
+ logInfo.Println(
+ "Handshake completed for:", peer.String())
- case <-peer.signal.stop:
- return
+ peer.timer.handshakeTimeout.Stop()
+ peer.timer.handshakeDeadline.Stop()
+ peer.signal.handshakeBegin.Enable()
+ }
+ }
+}
- case <-peer.signal.handshakeCompleted:
- <-timeout.C
- peer.timer.sendLastMinuteHandshake = false
- break AttemptHandshakes
+/* Sends a new handshake initiation message to the peer (endpoint)
+ */
+func (peer *Peer) sendNewHandshake() error {
- case <-peer.signal.handshakeReset:
- <-timeout.C
- goto BeginHandshakes
+ // temporarily disable the handshake complete signal
- case <-timeout.C:
+ peer.signal.handshakeCompleted.Disable()
- // clear source address of peer
+ // create initiation message
- peer.mutex.Lock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
- peer.mutex.Unlock()
- }
- }
+ msg, err := peer.device.CreateMessageInitiation(peer)
+ if err != nil {
+ return err
+ }
+
+ // marshal handshake message
- // clear signal set in the meantime
+ var buff [MessageInitiationSize]byte
+ writer := bytes.NewBuffer(buff[:0])
+ binary.Write(writer, binary.LittleEndian, msg)
+ packet := writer.Bytes()
+ peer.mac.AddMacs(packet)
- signalClear(peer.signal.handshakeBegin)
+ // send to endpoint
+
+ err = peer.SendBuffer(packet)
+ if err == nil {
+ peer.TimerAnyAuthenticatedPacketTraversal()
+ peer.signal.handshakeCompleted.Enable()
}
+
+ // set timeout
+
+ jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
+ peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter)
+
+ return err
}
diff --git a/src/uapi.go b/src/uapi.go
index 7ab3c4a..155f483 100644
--- a/src/uapi.go
+++ b/src/uapi.go
@@ -221,7 +221,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalid}
}
}
- signalSend(peer.signal.handshakeReset)
+ peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
dummy = false
}
@@ -265,7 +265,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return err
}
peer.endpoint = endpoint
- signalSend(peer.signal.handshakeReset)
+ peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
return nil
}()