aboutsummaryrefslogtreecommitdiffstats
path: root/ratelimiter/ratelimiter.go
diff options
context:
space:
mode:
Diffstat (limited to 'ratelimiter/ratelimiter.go')
-rw-r--r--ratelimiter/ratelimiter.go59
1 files changed, 12 insertions, 47 deletions
diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go
index 2f7aa2a..f7d05ef 100644
--- a/ratelimiter/ratelimiter.go
+++ b/ratelimiter/ratelimiter.go
@@ -1,12 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ratelimiter
import (
- "net"
+ "net/netip"
"sync"
"time"
)
@@ -30,8 +30,7 @@ type Ratelimiter struct {
timeNow func() time.Time
stopReset chan struct{} // send to reset, close to stop
- tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
- tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
+ table map[netip.Addr]*RatelimiterEntry
}
func (rate *Ratelimiter) Close() {
@@ -57,8 +56,7 @@ func (rate *Ratelimiter) Init() {
}
rate.stopReset = make(chan struct{})
- rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
- rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
+ rate.table = make(map[netip.Addr]*RatelimiterEntry)
stopReset := rate.stopReset // store in case Init is called again.
@@ -87,71 +85,39 @@ func (rate *Ratelimiter) cleanup() (empty bool) {
rate.mu.Lock()
defer rate.mu.Unlock()
- for key, entry := range rate.tableIPv4 {
+ for key, entry := range rate.table {
entry.mu.Lock()
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
- delete(rate.tableIPv4, key)
+ delete(rate.table, key)
}
entry.mu.Unlock()
}
- for key, entry := range rate.tableIPv6 {
- entry.mu.Lock()
- if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
- delete(rate.tableIPv6, key)
- }
- entry.mu.Unlock()
- }
-
- return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
+ return len(rate.table) == 0
}
-func (rate *Ratelimiter) Allow(ip net.IP) bool {
+func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
var entry *RatelimiterEntry
- var keyIPv4 [net.IPv4len]byte
- var keyIPv6 [net.IPv6len]byte
-
// lookup entry
-
- IPv4 := ip.To4()
- IPv6 := ip.To16()
-
rate.mu.RLock()
-
- if IPv4 != nil {
- copy(keyIPv4[:], IPv4)
- entry = rate.tableIPv4[keyIPv4]
- } else {
- copy(keyIPv6[:], IPv6)
- entry = rate.tableIPv6[keyIPv6]
- }
-
+ entry = rate.table[ip]
rate.mu.RUnlock()
// make new entry if not found
-
if entry == nil {
entry = new(RatelimiterEntry)
entry.tokens = maxTokens - packetCost
entry.lastTime = rate.timeNow()
rate.mu.Lock()
- if IPv4 != nil {
- rate.tableIPv4[keyIPv4] = entry
- if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
- rate.stopReset <- struct{}{}
- }
- } else {
- rate.tableIPv6[keyIPv6] = entry
- if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
- rate.stopReset <- struct{}{}
- }
+ rate.table[ip] = entry
+ if len(rate.table) == 1 {
+ rate.stopReset <- struct{}{}
}
rate.mu.Unlock()
return true
}
// add tokens to entry
-
entry.mu.Lock()
now := rate.timeNow()
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
@@ -161,7 +127,6 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
}
// subtract cost of packet
-
if entry.tokens > packetCost {
entry.tokens -= packetCost
entry.mu.Unlock()