diff options
Diffstat (limited to 'ratelimiter/ratelimiter.go')
-rw-r--r-- | ratelimiter/ratelimiter.go | 134 |
1 files changed, 53 insertions, 81 deletions
diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go index 772c45a..f7d05ef 100644 --- a/ratelimiter/ratelimiter.go +++ b/ratelimiter/ratelimiter.go @@ -1,12 +1,12 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package ratelimiter import ( - "net" + "net/netip" "sync" "time" ) @@ -20,21 +20,22 @@ const ( ) type RatelimiterEntry struct { - sync.Mutex + mu sync.Mutex lastTime time.Time tokens int64 } type Ratelimiter struct { - sync.RWMutex - stopReset chan struct{} - tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry - tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry + mu sync.RWMutex + timeNow func() time.Time + + stopReset chan struct{} // send to reset, close to stop + table map[netip.Addr]*RatelimiterEntry } func (rate *Ratelimiter) Close() { - rate.Lock() - defer rate.Unlock() + rate.mu.Lock() + defer rate.mu.Unlock() if rate.stopReset != nil { close(rate.stopReset) @@ -42,111 +43,83 @@ func (rate *Ratelimiter) Close() { } func (rate *Ratelimiter) Init() { - rate.Lock() - defer rate.Unlock() + rate.mu.Lock() + defer rate.mu.Unlock() - // stop any ongoing garbage collection routine + if rate.timeNow == nil { + rate.timeNow = time.Now + } + // stop any ongoing garbage collection routine if rate.stopReset != nil { close(rate.stopReset) } rate.stopReset = make(chan struct{}) - rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) - rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry) + rate.table = make(map[netip.Addr]*RatelimiterEntry) - // start garbage collection routine + stopReset := rate.stopReset // store in case Init is called again. + // Start garbage collection routine. go func() { ticker := time.NewTicker(time.Second) ticker.Stop() for { select { - case _, ok := <-rate.stopReset: + case _, ok := <-stopReset: ticker.Stop() - if ok { - ticker = time.NewTicker(time.Second) - } else { + if !ok { return } + ticker = time.NewTicker(time.Second) case <-ticker.C: - func() { - rate.Lock() - defer rate.Unlock() - - for key, entry := range rate.tableIPv4 { - entry.Lock() - if time.Since(entry.lastTime) > garbageCollectTime { - delete(rate.tableIPv4, key) - } - entry.Unlock() - } - - for key, entry := range rate.tableIPv6 { - entry.Lock() - if time.Since(entry.lastTime) > garbageCollectTime { - delete(rate.tableIPv6, key) - } - entry.Unlock() - } - - if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 { - ticker.Stop() - } - }() + if rate.cleanup() { + ticker.Stop() + } } } }() } -func (rate *Ratelimiter) Allow(ip net.IP) bool { - var entry *RatelimiterEntry - var keyIPv4 [net.IPv4len]byte - var keyIPv6 [net.IPv6len]byte - - // lookup entry - - IPv4 := ip.To4() - IPv6 := ip.To16() - - rate.RLock() +func (rate *Ratelimiter) cleanup() (empty bool) { + rate.mu.Lock() + defer rate.mu.Unlock() - if IPv4 != nil { - copy(keyIPv4[:], IPv4) - entry = rate.tableIPv4[keyIPv4] - } else { - copy(keyIPv6[:], IPv6) - entry = rate.tableIPv6[keyIPv6] + for key, entry := range rate.table { + entry.mu.Lock() + if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { + delete(rate.table, key) + } + entry.mu.Unlock() } - rate.RUnlock() + return len(rate.table) == 0 +} - // make new entry if not found +func (rate *Ratelimiter) Allow(ip netip.Addr) bool { + var entry *RatelimiterEntry + // lookup entry + rate.mu.RLock() + entry = rate.table[ip] + rate.mu.RUnlock() + // make new entry if not found if entry == nil { entry = new(RatelimiterEntry) entry.tokens = maxTokens - packetCost - entry.lastTime = time.Now() - rate.Lock() - if IPv4 != nil { - rate.tableIPv4[keyIPv4] = entry - if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 { - rate.stopReset <- struct{}{} - } - } else { - rate.tableIPv6[keyIPv6] = entry - if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 { - rate.stopReset <- struct{}{} - } + entry.lastTime = rate.timeNow() + rate.mu.Lock() + rate.table[ip] = entry + if len(rate.table) == 1 { + rate.stopReset <- struct{}{} } - rate.Unlock() + rate.mu.Unlock() return true } // add tokens to entry - - entry.Lock() - now := time.Now() + entry.mu.Lock() + now := rate.timeNow() entry.tokens += now.Sub(entry.lastTime).Nanoseconds() entry.lastTime = now if entry.tokens > maxTokens { @@ -154,12 +127,11 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool { } // subtract cost of packet - if entry.tokens > packetCost { entry.tokens -= packetCost - entry.Unlock() + entry.mu.Unlock() return true } - entry.Unlock() + entry.mu.Unlock() return false } |