aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--replay/replay.go97
-rw-r--r--replay/replay_test.go24
2 files changed, 50 insertions, 71 deletions
diff --git a/replay/replay.go b/replay/replay.go
index 85647f5..8685712 100644
--- a/replay/replay.go
+++ b/replay/replay.go
@@ -3,81 +3,60 @@
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
+// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
package replay
-/* Implementation of RFC6479
- * https://tools.ietf.org/html/rfc6479
- *
- * The implementation is not safe for concurrent use!
- */
-
-const (
- // See: https://golang.org/src/math/big/arith.go
- _Wordm = ^uintptr(0)
- _WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1
- _WordSize = 1 << _WordLogSize
-)
+type block uint64
const (
- CounterRedundantBitsLog = _WordLogSize + 3
- CounterRedundantBits = _WordSize * 8
- CounterBitsTotal = 8192
- CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits)
+ blockBitLog = 6 // 1<<6 == 64 bits
+ blockBits = 1 << blockBitLog // must be power of 2
+ ringBlocks = 1 << 7 // must be power of 2
+ windowSize = (ringBlocks - 1) * blockBits
+ blockMask = ringBlocks - 1
+ bitMask = blockBits - 1
)
-const (
- BacktrackWords = CounterBitsTotal / 8 / _WordSize
-)
-
-func minUint64(a uint64, b uint64) uint64 {
- if a > b {
- return b
- }
- return a
-}
-
+// A ReplayFilter rejects replayed messages by checking if message counter value is
+// within a sliding window of previously received messages.
+// The zero value for ReplayFilter is an empty filter ready to use.
+// Filters are unsafe for concurrent use.
type ReplayFilter struct {
- counter uint64
- backtrack [BacktrackWords]uintptr
+ last uint64
+ ring [ringBlocks]block
}
-func (filter *ReplayFilter) Init() {
- filter.counter = 0
- filter.backtrack[0] = 0
+// Init resets the filter to empty state.
+func (f *ReplayFilter) Init() {
+ f.last = 0
+ f.ring[0] = 0
}
-func (filter *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool {
+// ValidateCounter checks if the counter should be accepted.
+// Overlimit counters (>= limit) are always rejected.
+func (f *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool {
if counter >= limit {
return false
}
-
- indexWord := counter >> CounterRedundantBitsLog
-
- if counter > filter.counter {
-
- // move window forward
-
- current := filter.counter >> CounterRedundantBitsLog
- diff := minUint64(indexWord-current, BacktrackWords)
- for i := uint64(1); i <= diff; i++ {
- filter.backtrack[(current+i)%BacktrackWords] = 0
+ indexBlock := counter >> blockBitLog
+ if counter > f.last { // move window forward
+ current := f.last >> blockBitLog
+ diff := indexBlock - current
+ if diff > ringBlocks {
+ diff = ringBlocks // cap diff to clear the whole ring
}
- filter.counter = counter
-
- } else if filter.counter-counter > CounterWindowSize {
-
- // behind current window
-
+ for i := current + 1; i <= current+diff; i++ {
+ f.ring[i&blockMask] = 0
+ }
+ f.last = counter
+ } else if f.last-counter > windowSize { // behind current window
return false
}
-
- indexWord %= BacktrackWords
- indexBit := counter & uint64(CounterRedundantBits-1)
-
// check and set bit
-
- oldValue := filter.backtrack[indexWord]
- newValue := oldValue | (1 << indexBit)
- filter.backtrack[indexWord] = newValue
- return oldValue != newValue
+ indexBlock &= blockMask
+ indexBit := counter & bitMask
+ old := f.ring[indexBlock]
+ new := old | 1<<indexBit
+ f.ring[indexBlock] = new
+ return old != new
}
diff --git a/replay/replay_test.go b/replay/replay_test.go
index ceae2f3..5af66ff 100644
--- a/replay/replay_test.go
+++ b/replay/replay_test.go
@@ -19,13 +19,13 @@ const RejectAfterMessages = (1 << 64) - (1 << 4) - 1
func TestReplay(t *testing.T) {
var filter ReplayFilter
- T_LIM := CounterWindowSize + 1
+ const T_LIM = windowSize + 1
testNumber := 0
- T := func(n uint64, v bool) {
+ T := func(n uint64, expected bool) {
testNumber++
- if filter.ValidateCounter(n, RejectAfterMessages) != v {
- t.Fatal("Test", testNumber, "failed", n, v)
+ if filter.ValidateCounter(n, RejectAfterMessages) != expected {
+ t.Fatal("Test", testNumber, "failed", n, expected)
}
}
@@ -69,7 +69,7 @@ func TestReplay(t *testing.T) {
t.Log("Bulk test 1")
filter.Init()
testNumber = 0
- for i := uint64(1); i <= CounterWindowSize; i++ {
+ for i := uint64(1); i <= windowSize; i++ {
T(i, true)
}
T(0, true)
@@ -78,7 +78,7 @@ func TestReplay(t *testing.T) {
t.Log("Bulk test 2")
filter.Init()
testNumber = 0
- for i := uint64(2); i <= CounterWindowSize+1; i++ {
+ for i := uint64(2); i <= windowSize+1; i++ {
T(i, true)
}
T(1, true)
@@ -87,14 +87,14 @@ func TestReplay(t *testing.T) {
t.Log("Bulk test 3")
filter.Init()
testNumber = 0
- for i := CounterWindowSize + 1; i > 0; i-- {
+ for i := uint64(windowSize + 1); i > 0; i-- {
T(i, true)
}
t.Log("Bulk test 4")
filter.Init()
testNumber = 0
- for i := CounterWindowSize + 2; i > 1; i-- {
+ for i := uint64(windowSize + 2); i > 1; i-- {
T(i, true)
}
T(0, false)
@@ -102,18 +102,18 @@ func TestReplay(t *testing.T) {
t.Log("Bulk test 5")
filter.Init()
testNumber = 0
- for i := CounterWindowSize; i > 0; i-- {
+ for i := uint64(windowSize); i > 0; i-- {
T(i, true)
}
- T(CounterWindowSize+1, true)
+ T(windowSize+1, true)
T(0, false)
t.Log("Bulk test 6")
filter.Init()
testNumber = 0
- for i := CounterWindowSize; i > 0; i-- {
+ for i := uint64(windowSize); i > 0; i-- {
T(i, true)
}
T(0, true)
- T(CounterWindowSize+1, true)
+ T(windowSize+1, true)
}