aboutsummaryrefslogtreecommitdiffstats
path: root/ratelimiter/ratelimiter_test.go
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--ratelimiter/ratelimiter_test.go54
1 files changed, 35 insertions, 19 deletions
diff --git a/ratelimiter/ratelimiter_test.go b/ratelimiter/ratelimiter_test.go
index 659bdfb..25d5d63 100644
--- a/ratelimiter/ratelimiter_test.go
+++ b/ratelimiter/ratelimiter_test.go
@@ -11,22 +11,21 @@ import (
"time"
)
-type RatelimiterResult struct {
+type result struct {
allowed bool
text string
wait time.Duration
}
func TestRatelimiter(t *testing.T) {
+ var rate Ratelimiter
+ var expectedResults []result
- var ratelimiter Ratelimiter
- var expectedResults []RatelimiterResult
-
- Nano := func(nano int64) time.Duration {
+ nano := func(nano int64) time.Duration {
return time.Nanosecond * time.Duration(nano)
}
- Add := func(res RatelimiterResult) {
+ add := func(res result) {
expectedResults = append(
expectedResults,
res,
@@ -34,40 +33,40 @@ func TestRatelimiter(t *testing.T) {
}
for i := 0; i < packetsBurstable; i++ {
- Add(RatelimiterResult{
+ add(result{
allowed: true,
text: "initial burst",
})
}
- Add(RatelimiterResult{
+ add(result{
allowed: false,
text: "after burst",
})
- Add(RatelimiterResult{
+ add(result{
allowed: true,
- wait: Nano(time.Second.Nanoseconds() / packetsPerSecond),
+ wait: nano(time.Second.Nanoseconds() / packetsPerSecond),
text: "filling tokens for single packet",
})
- Add(RatelimiterResult{
+ add(result{
allowed: false,
text: "not having refilled enough",
})
- Add(RatelimiterResult{
+ add(result{
allowed: true,
- wait: 2 * (Nano(time.Second.Nanoseconds() / packetsPerSecond)),
+ wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)),
text: "filling tokens for two packet burst",
})
- Add(RatelimiterResult{
+ add(result{
allowed: true,
text: "second packet in 2 packet burst",
})
- Add(RatelimiterResult{
+ add(result{
allowed: false,
text: "packet following 2 packet burst",
})
@@ -89,14 +88,31 @@ func TestRatelimiter(t *testing.T) {
net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
}
- ratelimiter.Init()
+ now := time.Now()
+ rate.timeNow = func() time.Time {
+ return now
+ }
+ defer func() {
+ // Lock to avoid data race with cleanup goroutine from Init.
+ rate.mu.Lock()
+ defer rate.mu.Unlock()
+
+ rate.timeNow = time.Now
+ }()
+ timeSleep := func(d time.Duration) {
+ now = now.Add(d + 1)
+ rate.cleanup()
+ }
+
+ rate.Init()
+ defer rate.Close()
for i, res := range expectedResults {
- time.Sleep(res.wait)
+ timeSleep(res.wait)
for _, ip := range ips {
- allowed := ratelimiter.Allow(ip)
+ allowed := rate.Allow(ip)
if allowed != res.allowed {
- t.Fatal("Test failed for", ip.String(), ", on:", i, "(", res.text, ")", "expected:", res.allowed, "got:", allowed)
+ t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed)
}
}
}