aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/ratelimiter.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/ratelimiter.c')
-rw-r--r--src/ratelimiter.c249
1 files changed, 153 insertions, 96 deletions
diff --git a/src/ratelimiter.c b/src/ratelimiter.c
index ab8f93d..2d2e758 100644
--- a/src/ratelimiter.c
+++ b/src/ratelimiter.c
@@ -1,138 +1,195 @@
/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
#include "ratelimiter.h"
-#include "peer.h"
-#include "device.h"
-
-#include <linux/module.h>
-#include <linux/netfilter/x_tables.h>
+#include <linux/siphash.h>
+#include <linux/vmalloc.h>
+#include <linux/slab.h>
+#include <linux/hashtable.h>
#include <net/ip.h>
-static struct xt_match *v4_match __read_mostly;
+static struct kmem_cache *entry_cache;
+static hsiphash_key_t key;
+static spinlock_t table_lock = __SPIN_LOCK_UNLOCKED("ratelimiter_table_lock");
+static atomic64_t refcnt = ATOMIC64_INIT(0);
+static atomic_t total_entries = ATOMIC_INIT(0);
+static unsigned int max_entries, table_size;
+static void gc_entries(struct work_struct *);
+static DECLARE_DEFERRABLE_WORK(gc_work, gc_entries);
+static struct hlist_head *table_v4;
#if IS_ENABLED(CONFIG_IPV6)
-static struct xt_match *v6_match __read_mostly;
+static struct hlist_head *table_v6;
#endif
+struct entry {
+ u64 last_time_ns, tokens;
+ void *net;
+ __be32 ip[3];
+ spinlock_t lock;
+ struct hlist_node hash;
+ struct rcu_head rcu;
+};
+
enum {
- RATELIMITER_PACKETS_PER_SECOND = 30,
- RATELIMITER_PACKETS_BURSTABLE = 5
+ PACKETS_PER_SECOND = 20,
+ PACKETS_BURSTABLE = 5,
+ PACKET_COST = NSEC_PER_SEC / PACKETS_PER_SECOND,
+ TOKEN_MAX = PACKET_COST * PACKETS_BURSTABLE
};
-static inline void cfg_init(struct hashlimit_cfg1 *cfg, int family)
+static void entry_free(struct rcu_head *rcu)
{
- memset(cfg, 0, sizeof(struct hashlimit_cfg1));
- if (family == NFPROTO_IPV4)
- cfg->srcmask = 32;
- else if (family == NFPROTO_IPV6)
- cfg->srcmask = 96;
- cfg->mode = XT_HASHLIMIT_HASH_SIP; /* source IP only -- we could also do source port by ORing this with XT_HASHLIMIT_HASH_SPT, but we don't really want to do that. It would also cause problems since we skb_pull early on, and hashlimit's nexthdr stuff isn't so nice. */
- cfg->avg = XT_HASHLIMIT_SCALE / RATELIMITER_PACKETS_PER_SECOND; /* 30 per second per IP */
- cfg->burst = RATELIMITER_PACKETS_BURSTABLE; /* Allow bursts of 5 at a time */
- cfg->gc_interval = 1000; /* same as expiration date */
- cfg->expire = 1000; /* Units of avg (seconds = 1) times 1000 */
- /* cfg->size and cfg->max are computed based on the memory size of left to zero */
+ kmem_cache_free(entry_cache, container_of(rcu, struct entry, rcu));
+ atomic_dec(&total_entries);
}
-int ratelimiter_init(struct ratelimiter *ratelimiter, struct wireguard_device *wg)
+static void entry_uninit(struct entry *entry)
{
- struct net_device *dev = netdev_pub(wg);
- struct xt_mtchk_param chk = { .net = wg->creating_net };
- int ret;
-
- memset(ratelimiter, 0, sizeof(struct ratelimiter));
-
- cfg_init(&ratelimiter->v4_info.cfg, NFPROTO_IPV4);
- memcpy(ratelimiter->v4_info.name, dev->name, IFNAMSIZ);
- chk.matchinfo = &ratelimiter->v4_info;
- chk.match = v4_match;
- chk.family = NFPROTO_IPV4;
- ret = v4_match->checkentry(&chk);
- if (ret < 0)
- return ret;
-
-#if IS_ENABLED(CONFIG_IPV6)
- cfg_init(&ratelimiter->v6_info.cfg, NFPROTO_IPV6);
- memcpy(ratelimiter->v6_info.name, dev->name, IFNAMSIZ);
- chk.matchinfo = &ratelimiter->v6_info;
- chk.match = v6_match;
- chk.family = NFPROTO_IPV6;
- ret = v6_match->checkentry(&chk);
- if (ret < 0) {
- struct xt_mtdtor_param dtor_v4 = {
- .net = wg->creating_net,
- .match = v4_match,
- .matchinfo = &ratelimiter->v4_info,
- .family = NFPROTO_IPV4
- };
- v4_match->destroy(&dtor_v4);
- return ret;
- }
-#endif
-
- ratelimiter->net = wg->creating_net;
- return 0;
+ hlist_del_rcu(&entry->hash);
+ call_rcu_bh(&entry->rcu, entry_free);
}
-void ratelimiter_uninit(struct ratelimiter *ratelimiter)
+/* Calling this function with a NULL work uninits all entries. */
+static void gc_entries(struct work_struct *work)
{
- struct xt_mtdtor_param dtor = { .net = ratelimiter->net };
-
- dtor.match = v4_match;
- dtor.matchinfo = &ratelimiter->v4_info;
- dtor.family = NFPROTO_IPV4;
- v4_match->destroy(&dtor);
-
+ unsigned int i;
+ struct entry *entry;
+ struct hlist_node *temp;
+ const u64 now = ktime_get_ns();
+
+ for (i = 0; i < table_size; ++i) {
+ spin_lock(&table_lock);
+ hlist_for_each_entry_safe (entry, temp, &table_v4[i], hash) {
+ if (unlikely(!work) || now - entry->last_time_ns > NSEC_PER_SEC)
+ entry_uninit(entry);
+ }
#if IS_ENABLED(CONFIG_IPV6)
- dtor.match = v6_match;
- dtor.matchinfo = &ratelimiter->v6_info;
- dtor.family = NFPROTO_IPV6;
- v6_match->destroy(&dtor);
+ hlist_for_each_entry_safe (entry, temp, &table_v6[i], hash) {
+ if (unlikely(!work) || now - entry->last_time_ns > NSEC_PER_SEC)
+ entry_uninit(entry);
+ }
#endif
+ spin_unlock(&table_lock);
+ if (likely(work))
+ cond_resched();
+ }
+ if (likely(work))
+ queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
}
-bool ratelimiter_allow(struct ratelimiter *ratelimiter, struct sk_buff *skb)
+bool ratelimiter_allow(struct sk_buff *skb, struct net *net)
{
- struct xt_action_param action = { { NULL } };
- if (unlikely(skb->len < sizeof(struct iphdr)))
- return false;
- if (ip_hdr(skb)->version == 4) {
- action.match = v4_match;
- action.matchinfo = &ratelimiter->v4_info;
- action.thoff = ip_hdrlen(skb);
+ struct entry *entry;
+ struct hlist_head *bucket;
+ struct { u32 net; __be32 ip[3]; } data = { .net = (unsigned long)net & 0xffffffff };
+
+ if (skb->len >= sizeof(struct iphdr) && ip_hdr(skb)->version == 4) {
+ data.ip[0] = ip_hdr(skb)->saddr;
+ bucket = &table_v4[hsiphash(&data, sizeof(u32) * 2, &key) & (table_size - 1)];
}
#if IS_ENABLED(CONFIG_IPV6)
- else if (ip_hdr(skb)->version == 6) {
- action.match = v6_match;
- action.matchinfo = &ratelimiter->v6_info;
+ else if (skb->len >= sizeof(struct ipv6hdr) && ip_hdr(skb)->version == 6) {
+ memcpy(data.ip, &ipv6_hdr(skb)->saddr, sizeof(u32) * 3); /* Only 96 bits */
+ bucket = &table_v6[hsiphash(&data, sizeof(u32) * 4, &key) & (table_size - 1)];
}
#endif
else
return false;
- return action.match->match(skb, &action);
+ rcu_read_lock();
+ hlist_for_each_entry_rcu (entry, bucket, hash) {
+ if (entry->net == net && !memcmp(entry->ip, data.ip, sizeof(data.ip))) {
+ u64 now, tokens;
+ bool ret;
+ /* Inspired by nft_limit.c, but this is actually a slightly different
+ * algorithm. Namely, we incorporate the burst as part of the maximum
+ * tokens, rather than as part of the rate. */
+ spin_lock(&entry->lock);
+ now = ktime_get_ns();
+ tokens = min_t(u64, TOKEN_MAX, entry->tokens + now - entry->last_time_ns);
+ entry->last_time_ns = now;
+ ret = tokens >= PACKET_COST;
+ entry->tokens = ret ? tokens - PACKET_COST : tokens;
+ spin_unlock(&entry->lock);
+ rcu_read_unlock();
+ return ret;
+ }
+ }
+ rcu_read_unlock();
+
+ if (atomic_inc_return(&total_entries) > max_entries)
+ goto err_oom;
+
+ entry = kmem_cache_alloc(entry_cache, GFP_KERNEL);
+ if (!entry)
+ goto err_oom;
+
+ entry->net = net;
+ memcpy(entry->ip, data.ip, sizeof(data.ip));
+ INIT_HLIST_NODE(&entry->hash);
+ spin_lock_init(&entry->lock);
+ entry->last_time_ns = ktime_get_ns();
+ entry->tokens = TOKEN_MAX - PACKET_COST;
+ spin_lock(&table_lock);
+ hlist_add_head_rcu(&entry->hash, bucket);
+ spin_unlock(&table_lock);
+ return true;
+
+err_oom:
+ atomic_dec(&total_entries);
+ return false;
}
-int ratelimiter_module_init(void)
+int ratelimiter_init(void)
{
- v4_match = xt_request_find_match(NFPROTO_IPV4, "hashlimit", 1);
- if (IS_ERR(v4_match)) {
- pr_err("The xt_hashlimit module for IPv4 is required\n");
- return PTR_ERR(v4_match);
- }
+ if (atomic64_inc_return(&refcnt) != 1)
+ return 0;
+
+ entry_cache = kmem_cache_create("wireguard_ratelimiter", sizeof(struct entry), 0, 0, NULL);
+ if (!entry_cache)
+ goto err;
+
+ /* xt_hashlimit.c uses a slightly different algorithm for ratelimiting,
+ * but what it shares in common is that it uses a massive hashtable. So,
+ * we borrow their wisdom about good table sizes on different systems
+ * dependent on RAM. This calculation here comes from there. */
+ table_size = (totalram_pages > (1 << 30) / PAGE_SIZE) ? 8192 : max_t(unsigned long, 16, roundup_pow_of_two((totalram_pages << PAGE_SHIFT) / (1 << 14) / sizeof(struct hlist_head)));
+ max_entries = table_size * 8;
+
+ table_v4 = vmalloc(table_size * sizeof(struct hlist_head));
+ if (!table_v4)
+ goto err_kmemcache;
+ __hash_init(table_v4, table_size);
+
#if IS_ENABLED(CONFIG_IPV6)
- v6_match = xt_request_find_match(NFPROTO_IPV6, "hashlimit", 1);
- if (IS_ERR(v6_match)) {
- pr_err("The xt_hashlimit module for IPv6 is required\n");
- module_put(v4_match->me);
- return PTR_ERR(v6_match);
+ table_v6 = vmalloc(table_size * sizeof(struct hlist_head));
+ if (!table_v6) {
+ vfree(table_v4);
+ goto err_kmemcache;
}
+ __hash_init(table_v6, table_size);
#endif
+
+ queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
+ get_random_bytes(&key, sizeof(key));
return 0;
+
+err_kmemcache:
+ kmem_cache_destroy(entry_cache);
+err:
+ atomic64_dec(&refcnt);
+ return -ENOMEM;
}
-void ratelimiter_module_deinit(void)
+void ratelimiter_uninit(void)
{
- module_put(v4_match->me);
+ if (atomic64_dec_return(&refcnt))
+ return;
+
+ cancel_delayed_work_sync(&gc_work);
+ gc_entries(NULL);
+ synchronize_rcu();
+ vfree(table_v4);
#if IS_ENABLED(CONFIG_IPV6)
- module_put(v6_match->me);
+ vfree(table_v6);
#endif
+ kmem_cache_destroy(entry_cache);
}