aboutsummaryrefslogtreecommitdiffstats
path: root/ratelimiter/ratelimiter.go
diff options
context:
space:
mode:
Diffstat (limited to 'ratelimiter/ratelimiter.go')
-rw-r--r--ratelimiter/ratelimiter.go134
1 files changed, 53 insertions, 81 deletions
diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go
index 772c45a..f7d05ef 100644
--- a/ratelimiter/ratelimiter.go
+++ b/ratelimiter/ratelimiter.go
@@ -1,12 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ratelimiter
import (
- "net"
+ "net/netip"
"sync"
"time"
)
@@ -20,21 +20,22 @@ const (
)
type RatelimiterEntry struct {
- sync.Mutex
+ mu sync.Mutex
lastTime time.Time
tokens int64
}
type Ratelimiter struct {
- sync.RWMutex
- stopReset chan struct{}
- tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
- tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
+ mu sync.RWMutex
+ timeNow func() time.Time
+
+ stopReset chan struct{} // send to reset, close to stop
+ table map[netip.Addr]*RatelimiterEntry
}
func (rate *Ratelimiter) Close() {
- rate.Lock()
- defer rate.Unlock()
+ rate.mu.Lock()
+ defer rate.mu.Unlock()
if rate.stopReset != nil {
close(rate.stopReset)
@@ -42,111 +43,83 @@ func (rate *Ratelimiter) Close() {
}
func (rate *Ratelimiter) Init() {
- rate.Lock()
- defer rate.Unlock()
+ rate.mu.Lock()
+ defer rate.mu.Unlock()
- // stop any ongoing garbage collection routine
+ if rate.timeNow == nil {
+ rate.timeNow = time.Now
+ }
+ // stop any ongoing garbage collection routine
if rate.stopReset != nil {
close(rate.stopReset)
}
rate.stopReset = make(chan struct{})
- rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
- rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
+ rate.table = make(map[netip.Addr]*RatelimiterEntry)
- // start garbage collection routine
+ stopReset := rate.stopReset // store in case Init is called again.
+ // Start garbage collection routine.
go func() {
ticker := time.NewTicker(time.Second)
ticker.Stop()
for {
select {
- case _, ok := <-rate.stopReset:
+ case _, ok := <-stopReset:
ticker.Stop()
- if ok {
- ticker = time.NewTicker(time.Second)
- } else {
+ if !ok {
return
}
+ ticker = time.NewTicker(time.Second)
case <-ticker.C:
- func() {
- rate.Lock()
- defer rate.Unlock()
-
- for key, entry := range rate.tableIPv4 {
- entry.Lock()
- if time.Since(entry.lastTime) > garbageCollectTime {
- delete(rate.tableIPv4, key)
- }
- entry.Unlock()
- }
-
- for key, entry := range rate.tableIPv6 {
- entry.Lock()
- if time.Since(entry.lastTime) > garbageCollectTime {
- delete(rate.tableIPv6, key)
- }
- entry.Unlock()
- }
-
- if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 {
- ticker.Stop()
- }
- }()
+ if rate.cleanup() {
+ ticker.Stop()
+ }
}
}
}()
}
-func (rate *Ratelimiter) Allow(ip net.IP) bool {
- var entry *RatelimiterEntry
- var keyIPv4 [net.IPv4len]byte
- var keyIPv6 [net.IPv6len]byte
-
- // lookup entry
-
- IPv4 := ip.To4()
- IPv6 := ip.To16()
-
- rate.RLock()
+func (rate *Ratelimiter) cleanup() (empty bool) {
+ rate.mu.Lock()
+ defer rate.mu.Unlock()
- if IPv4 != nil {
- copy(keyIPv4[:], IPv4)
- entry = rate.tableIPv4[keyIPv4]
- } else {
- copy(keyIPv6[:], IPv6)
- entry = rate.tableIPv6[keyIPv6]
+ for key, entry := range rate.table {
+ entry.mu.Lock()
+ if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
+ delete(rate.table, key)
+ }
+ entry.mu.Unlock()
}
- rate.RUnlock()
+ return len(rate.table) == 0
+}
- // make new entry if not found
+func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
+ var entry *RatelimiterEntry
+ // lookup entry
+ rate.mu.RLock()
+ entry = rate.table[ip]
+ rate.mu.RUnlock()
+ // make new entry if not found
if entry == nil {
entry = new(RatelimiterEntry)
entry.tokens = maxTokens - packetCost
- entry.lastTime = time.Now()
- rate.Lock()
- if IPv4 != nil {
- rate.tableIPv4[keyIPv4] = entry
- if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
- rate.stopReset <- struct{}{}
- }
- } else {
- rate.tableIPv6[keyIPv6] = entry
- if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
- rate.stopReset <- struct{}{}
- }
+ entry.lastTime = rate.timeNow()
+ rate.mu.Lock()
+ rate.table[ip] = entry
+ if len(rate.table) == 1 {
+ rate.stopReset <- struct{}{}
}
- rate.Unlock()
+ rate.mu.Unlock()
return true
}
// add tokens to entry
-
- entry.Lock()
- now := time.Now()
+ entry.mu.Lock()
+ now := rate.timeNow()
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
entry.lastTime = now
if entry.tokens > maxTokens {
@@ -154,12 +127,11 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
}
// subtract cost of packet
-
if entry.tokens > packetCost {
entry.tokens -= packetCost
- entry.Unlock()
+ entry.mu.Unlock()
return true
}
- entry.Unlock()
+ entry.mu.Unlock()
return false
}