From 22af3890f60d11472977f2fbdf50e6d5b406e356 Mon Sep 17 00:00:00 2001 From: Riobard Zhan Date: Thu, 10 Sep 2020 01:55:24 +0800 Subject: replay: clean up internals and better documentation Signed-off-by: Riobard Zhan Signed-off-by: Jason A. Donenfeld --- replay/replay.go | 97 ++++++++++++++++++++------------------------------- replay/replay_test.go | 24 ++++++------- 2 files changed, 50 insertions(+), 71 deletions(-) (limited to 'replay') diff --git a/replay/replay.go b/replay/replay.go index 85647f5..8685712 100644 --- a/replay/replay.go +++ b/replay/replay.go @@ -3,81 +3,60 @@ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. */ +// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479. package replay -/* Implementation of RFC6479 - * https://tools.ietf.org/html/rfc6479 - * - * The implementation is not safe for concurrent use! - */ - -const ( - // See: https://golang.org/src/math/big/arith.go - _Wordm = ^uintptr(0) - _WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1 - _WordSize = 1 << _WordLogSize -) +type block uint64 const ( - CounterRedundantBitsLog = _WordLogSize + 3 - CounterRedundantBits = _WordSize * 8 - CounterBitsTotal = 8192 - CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits) + blockBitLog = 6 // 1<<6 == 64 bits + blockBits = 1 << blockBitLog // must be power of 2 + ringBlocks = 1 << 7 // must be power of 2 + windowSize = (ringBlocks - 1) * blockBits + blockMask = ringBlocks - 1 + bitMask = blockBits - 1 ) -const ( - BacktrackWords = CounterBitsTotal / 8 / _WordSize -) - -func minUint64(a uint64, b uint64) uint64 { - if a > b { - return b - } - return a -} - +// A ReplayFilter rejects replayed messages by checking if message counter value is +// within a sliding window of previously received messages. +// The zero value for ReplayFilter is an empty filter ready to use. +// Filters are unsafe for concurrent use. type ReplayFilter struct { - counter uint64 - backtrack [BacktrackWords]uintptr + last uint64 + ring [ringBlocks]block } -func (filter *ReplayFilter) Init() { - filter.counter = 0 - filter.backtrack[0] = 0 +// Init resets the filter to empty state. +func (f *ReplayFilter) Init() { + f.last = 0 + f.ring[0] = 0 } -func (filter *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool { +// ValidateCounter checks if the counter should be accepted. +// Overlimit counters (>= limit) are always rejected. +func (f *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool { if counter >= limit { return false } - - indexWord := counter >> CounterRedundantBitsLog - - if counter > filter.counter { - - // move window forward - - current := filter.counter >> CounterRedundantBitsLog - diff := minUint64(indexWord-current, BacktrackWords) - for i := uint64(1); i <= diff; i++ { - filter.backtrack[(current+i)%BacktrackWords] = 0 + indexBlock := counter >> blockBitLog + if counter > f.last { // move window forward + current := f.last >> blockBitLog + diff := indexBlock - current + if diff > ringBlocks { + diff = ringBlocks // cap diff to clear the whole ring } - filter.counter = counter - - } else if filter.counter-counter > CounterWindowSize { - - // behind current window - + for i := current + 1; i <= current+diff; i++ { + f.ring[i&blockMask] = 0 + } + f.last = counter + } else if f.last-counter > windowSize { // behind current window return false } - - indexWord %= BacktrackWords - indexBit := counter & uint64(CounterRedundantBits-1) - // check and set bit - - oldValue := filter.backtrack[indexWord] - newValue := oldValue | (1 << indexBit) - filter.backtrack[indexWord] = newValue - return oldValue != newValue + indexBlock &= blockMask + indexBit := counter & bitMask + old := f.ring[indexBlock] + new := old | 1< 0; i-- { + for i := uint64(windowSize + 1); i > 0; i-- { T(i, true) } t.Log("Bulk test 4") filter.Init() testNumber = 0 - for i := CounterWindowSize + 2; i > 1; i-- { + for i := uint64(windowSize + 2); i > 1; i-- { T(i, true) } T(0, false) @@ -102,18 +102,18 @@ func TestReplay(t *testing.T) { t.Log("Bulk test 5") filter.Init() testNumber = 0 - for i := CounterWindowSize; i > 0; i-- { + for i := uint64(windowSize); i > 0; i-- { T(i, true) } - T(CounterWindowSize+1, true) + T(windowSize+1, true) T(0, false) t.Log("Bulk test 6") filter.Init() testNumber = 0 - for i := CounterWindowSize; i > 0; i-- { + for i := uint64(windowSize); i > 0; i-- { T(i, true) } T(0, true) - T(CounterWindowSize+1, true) + T(windowSize+1, true) } -- cgit v1.2.3-59-g8ed1b