aboutsummaryrefslogtreecommitdiffstats
path: root/internal/ratelimiter
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--internal/ratelimiter/ratelimiter.go74
-rw-r--r--internal/ratelimiter/ratelimiter_test.go6
2 files changed, 38 insertions, 42 deletions
diff --git a/internal/ratelimiter/ratelimiter.go b/internal/ratelimiter/ratelimiter.go
index f9fc673..006900a 100644
--- a/internal/ratelimiter/ratelimiter.go
+++ b/internal/ratelimiter/ratelimiter.go
@@ -2,8 +2,7 @@ package ratelimiter
/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
-/* This file contains a port of the ratelimited from the linux kernel version
- */
+/* This file contains a port of the rate-limiter from the linux kernel version */
import (
"net"
@@ -12,11 +11,11 @@ import (
)
const (
- RatelimiterPacketsPerSecond = 20
- RatelimiterPacketsBurstable = 5
- RatelimiterGarbageCollectTime = time.Second
- RatelimiterPacketCost = 1000000000 / RatelimiterPacketsPerSecond
- RatelimiterMaxTokens = RatelimiterPacketCost * RatelimiterPacketsBurstable
+ packetsPerSecond = 20
+ packetsBurstable = 5
+ garbageCollectTime = time.Second
+ packetCost = 1000000000 / packetsPerSecond
+ maxTokens = packetCost * packetsBurstable
)
type RatelimiterEntry struct {
@@ -45,6 +44,8 @@ func (rate *Ratelimiter) Init() {
rate.mutex.Lock()
defer rate.mutex.Unlock()
+ // stop any ongoing garbage collection routine
+
if rate.stop != nil {
close(rate.stop)
}
@@ -53,6 +54,8 @@ func (rate *Ratelimiter) Init() {
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
+ // start garbage collection routine
+
go func() {
timer := time.NewTimer(time.Second)
for {
@@ -60,39 +63,32 @@ func (rate *Ratelimiter) Init() {
case <-rate.stop:
return
case <-timer.C:
- rate.garbageCollectEntries()
+ func() {
+ rate.mutex.Lock()
+ defer rate.mutex.Unlock()
+
+ for key, entry := range rate.tableIPv4 {
+ entry.mutex.Lock()
+ if time.Now().Sub(entry.lastTime) > garbageCollectTime {
+ delete(rate.tableIPv4, key)
+ }
+ entry.mutex.Unlock()
+ }
+
+ for key, entry := range rate.tableIPv6 {
+ entry.mutex.Lock()
+ if time.Now().Sub(entry.lastTime) > garbageCollectTime {
+ delete(rate.tableIPv6, key)
+ }
+ entry.mutex.Unlock()
+ }
+ }()
timer.Reset(time.Second)
}
}
}()
}
-func (rate *Ratelimiter) garbageCollectEntries() {
- rate.mutex.Lock()
-
- // remove unused IPv4 entries
-
- for key, entry := range rate.tableIPv4 {
- entry.mutex.Lock()
- if time.Now().Sub(entry.lastTime) > RatelimiterGarbageCollectTime {
- delete(rate.tableIPv4, key)
- }
- entry.mutex.Unlock()
- }
-
- // remove unused IPv6 entries
-
- for key, entry := range rate.tableIPv6 {
- entry.mutex.Lock()
- if time.Now().Sub(entry.lastTime) > RatelimiterGarbageCollectTime {
- delete(rate.tableIPv6, key)
- }
- entry.mutex.Unlock()
- }
-
- rate.mutex.Unlock()
-}
-
func (rate *Ratelimiter) Allow(ip net.IP) bool {
var entry *RatelimiterEntry
var KeyIPv4 [net.IPv4len]byte
@@ -120,7 +116,7 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
if entry == nil {
rate.mutex.Lock()
entry = new(RatelimiterEntry)
- entry.tokens = RatelimiterMaxTokens - RatelimiterPacketCost
+ entry.tokens = maxTokens - packetCost
entry.lastTime = time.Now()
if IPv4 != nil {
rate.tableIPv4[KeyIPv4] = entry
@@ -137,14 +133,14 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
now := time.Now()
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
entry.lastTime = now
- if entry.tokens > RatelimiterMaxTokens {
- entry.tokens = RatelimiterMaxTokens
+ if entry.tokens > maxTokens {
+ entry.tokens = maxTokens
}
// subtract cost of packet
- if entry.tokens > RatelimiterPacketCost {
- entry.tokens -= RatelimiterPacketCost
+ if entry.tokens > packetCost {
+ entry.tokens -= packetCost
entry.mutex.Unlock()
return true
}
diff --git a/internal/ratelimiter/ratelimiter_test.go b/internal/ratelimiter/ratelimiter_test.go
index a6f618b..37339ee 100644
--- a/internal/ratelimiter/ratelimiter_test.go
+++ b/internal/ratelimiter/ratelimiter_test.go
@@ -28,7 +28,7 @@ func TestRatelimiter(t *testing.T) {
)
}
- for i := 0; i < RatelimiterPacketsBurstable; i++ {
+ for i := 0; i < packetsBurstable; i++ {
Add(RatelimiterResult{
allowed: true,
text: "inital burst",
@@ -42,7 +42,7 @@ func TestRatelimiter(t *testing.T) {
Add(RatelimiterResult{
allowed: true,
- wait: Nano(time.Second.Nanoseconds() / RatelimiterPacketsPerSecond),
+ wait: Nano(time.Second.Nanoseconds() / packetsPerSecond),
text: "filling tokens for single packet",
})
@@ -53,7 +53,7 @@ func TestRatelimiter(t *testing.T) {
Add(RatelimiterResult{
allowed: true,
- wait: 2 * Nano(time.Second.Nanoseconds()/RatelimiterPacketsPerSecond),
+ wait: 2 * (Nano(time.Second.Nanoseconds() / packetsPerSecond)),
text: "filling tokens for two packet burst",
})