aboutsummaryrefslogblamecommitdiffstats
path: root/ratelimiter/ratelimiter.go
blob: 772c45aed6636765665b8ac99418e09aca4af6b2 (plain) (tree)
1
2
3
4
5
6
                               
  
                                                              
   
 
                   







              




                                                          


                              
                  




                          
                    
                               




                                                         

                           
 

                                     
         


                                 

                           
 

                                                      

                                     

         
                                            

                                                                      
 

                                           
                   
                                                     
                             

                                
                                                       
                                             




                                                                            
                                        
                                        

                                                           

                                                                                
                                                            
                                                                                                    

                                                                                   
                                                              


                                                                                
                                                            
                                                                                                    

                                                                                   
                                                              
                                         



                                                                                                 
                                   


                         

 

                                                

                                     





                         
                    

                        

                                               
                

                                               

         
                      



                                      
                                             
                                                     
                                           
                           
                                
                                                       


                                                                                 
                        
                                                       


                                                                                 
                 
                             




                              
                    


                                                             

                                        



                                  

                                          
                              

                           
                      

                    
/* SPDX-License-Identifier: MIT
 *
 * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
 */

package ratelimiter

import (
	"net"
	"sync"
	"time"
)

const (
	packetsPerSecond   = 20
	packetsBurstable   = 5
	garbageCollectTime = time.Second
	packetCost         = 1000000000 / packetsPerSecond
	maxTokens          = packetCost * packetsBurstable
)

type RatelimiterEntry struct {
	sync.Mutex
	lastTime time.Time
	tokens   int64
}

type Ratelimiter struct {
	sync.RWMutex
	stopReset chan struct{}
	tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
	tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
}

func (rate *Ratelimiter) Close() {
	rate.Lock()
	defer rate.Unlock()

	if rate.stopReset != nil {
		close(rate.stopReset)
	}
}

func (rate *Ratelimiter) Init() {
	rate.Lock()
	defer rate.Unlock()

	// stop any ongoing garbage collection routine

	if rate.stopReset != nil {
		close(rate.stopReset)
	}

	rate.stopReset = make(chan struct{})
	rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
	rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)

	// start garbage collection routine

	go func() {
		ticker := time.NewTicker(time.Second)
		ticker.Stop()
		for {
			select {
			case _, ok := <-rate.stopReset:
				ticker.Stop()
				if ok {
					ticker = time.NewTicker(time.Second)
				} else {
					return
				}
			case <-ticker.C:
				func() {
					rate.Lock()
					defer rate.Unlock()

					for key, entry := range rate.tableIPv4 {
						entry.Lock()
						if time.Since(entry.lastTime) > garbageCollectTime {
							delete(rate.tableIPv4, key)
						}
						entry.Unlock()
					}

					for key, entry := range rate.tableIPv6 {
						entry.Lock()
						if time.Since(entry.lastTime) > garbageCollectTime {
							delete(rate.tableIPv6, key)
						}
						entry.Unlock()
					}

					if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 {
						ticker.Stop()
					}
				}()
			}
		}
	}()
}

func (rate *Ratelimiter) Allow(ip net.IP) bool {
	var entry *RatelimiterEntry
	var keyIPv4 [net.IPv4len]byte
	var keyIPv6 [net.IPv6len]byte

	// lookup entry

	IPv4 := ip.To4()
	IPv6 := ip.To16()

	rate.RLock()

	if IPv4 != nil {
		copy(keyIPv4[:], IPv4)
		entry = rate.tableIPv4[keyIPv4]
	} else {
		copy(keyIPv6[:], IPv6)
		entry = rate.tableIPv6[keyIPv6]
	}

	rate.RUnlock()

	// make new entry if not found

	if entry == nil {
		entry = new(RatelimiterEntry)
		entry.tokens = maxTokens - packetCost
		entry.lastTime = time.Now()
		rate.Lock()
		if IPv4 != nil {
			rate.tableIPv4[keyIPv4] = entry
			if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
				rate.stopReset <- struct{}{}
			}
		} else {
			rate.tableIPv6[keyIPv6] = entry
			if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
				rate.stopReset <- struct{}{}
			}
		}
		rate.Unlock()
		return true
	}

	// add tokens to entry

	entry.Lock()
	now := time.Now()
	entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
	entry.lastTime = now
	if entry.tokens > maxTokens {
		entry.tokens = maxTokens
	}

	// subtract cost of packet

	if entry.tokens > packetCost {
		entry.tokens -= packetCost
		entry.Unlock()
		return true
	}
	entry.Unlock()
	return false
}