aboutsummaryrefslogtreecommitdiffstats
path: root/src/wg_cookie.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/wg_cookie.c')
-rw-r--r--src/wg_cookie.c68
1 files changed, 26 insertions, 42 deletions
diff --git a/src/wg_cookie.c b/src/wg_cookie.c
index e68d662..58d0d8d 100644
--- a/src/wg_cookie.c
+++ b/src/wg_cookie.c
@@ -31,22 +31,20 @@
#define IPV4_MASK_SIZE 4 /* Use all 4 bytes of IPv4 address */
#define IPV6_MASK_SIZE 8 /* Use top 8 bytes (/64) of IPv6 address */
+struct ratelimit_key {
+ struct vnet *vnet;
+ uint8_t ip[IPV6_MASK_SIZE];
+};
+
struct ratelimit_entry {
LIST_ENTRY(ratelimit_entry) r_entry;
- sa_family_t r_af;
- union {
- struct in_addr r_in;
-#ifdef INET6
- struct in6_addr r_in6;
-#endif
- };
+ struct ratelimit_key r_key;
sbintime_t r_last_time; /* sbinuptime */
uint64_t r_tokens;
};
struct ratelimit {
uint8_t rl_secret[SIPHASH_KEY_LENGTH];
-
struct rwlock rl_lock;
struct callout rl_gc;
LIST_HEAD(, ratelimit_entry) rl_table[RATELIMIT_SIZE];
@@ -67,7 +65,7 @@ static void ratelimit_deinit(struct ratelimit *);
static void ratelimit_gc_callout(void *);
static void ratelimit_gc_schedule(struct ratelimit *);
static void ratelimit_gc(struct ratelimit *, bool);
-static int ratelimit_allow(struct ratelimit *, struct sockaddr *);
+static int ratelimit_allow(struct ratelimit *, struct sockaddr *, struct vnet *);
static uint64_t siphash13(const uint8_t [SIPHASH_KEY_LENGTH], const void *, size_t);
static struct ratelimit ratelimit_v4;
@@ -86,9 +84,9 @@ cookie_init(void)
ratelimit_init(&ratelimit_v4);
#ifdef INET6
- ratelimit_init(&ratelimit_v6)
+ ratelimit_init(&ratelimit_v6);
#endif
- return 0;
+ return (0);
}
void
@@ -207,7 +205,7 @@ cookie_maker_mac(struct cookie_maker *cm, struct cookie_macs *macs, void *buf,
int
cookie_checker_validate_macs(struct cookie_checker *cc, struct cookie_macs *macs,
- void *buf, size_t len, bool check_cookie, struct sockaddr *sa)
+ void *buf, size_t len, bool check_cookie, struct sockaddr *sa, struct vnet *vnet)
{
struct cookie_macs our_macs;
uint8_t cookie[COOKIE_COOKIE_SIZE];
@@ -234,10 +232,10 @@ cookie_checker_validate_macs(struct cookie_checker *cc, struct cookie_macs *macs
* implying there is no ratelimiting, or we should ratelimit
* (refuse) respectively. */
if (sa->sa_family == AF_INET)
- return ratelimit_allow(&ratelimit_v4, sa);
+ return ratelimit_allow(&ratelimit_v4, sa, vnet);
#ifdef INET6
else if (sa->sa_family == AF_INET6)
- return ratelimit_allow(&ratelimit_v6, sa);
+ return ratelimit_allow(&ratelimit_v6, sa, vnet);
#endif
else
return EAFNOSUPPORT;
@@ -391,40 +389,33 @@ ratelimit_gc(struct ratelimit *rl, bool force)
}
static int
-ratelimit_allow(struct ratelimit *rl, struct sockaddr *sa)
+ratelimit_allow(struct ratelimit *rl, struct sockaddr *sa, struct vnet *vnet)
{
- uint64_t key, tokens;
+ uint64_t bucket, tokens;
sbintime_t diff, now;
struct ratelimit_entry *r;
int ret = ECONNREFUSED;
+ struct ratelimit_key key = { .vnet = vnet };
+ size_t len = sizeof(key);
- if (sa->sa_family == AF_INET)
- key = siphash13(rl->rl_secret, &satosin(sa)->sin_addr,
- IPV4_MASK_SIZE);
+ if (sa->sa_family == AF_INET) {
+ memcpy(key.ip, &satosin(sa)->sin_addr, IPV4_MASK_SIZE);
+ len -= IPV6_MASK_SIZE - IPV4_MASK_SIZE;
+ }
#ifdef INET6
else if (sa->sa_family == AF_INET6)
- key = siphash13(rl->rl_secret, &satosin6(sa)->sin6_addr,
- IPV6_MASK_SIZE);
+ memcpy(key.ip, &satosin6(sa)->sin6_addr, IPV6_MASK_SIZE);
#endif
else
return ret;
+ bucket = siphash13(rl->rl_secret, &key, len) & RATELIMIT_MASK;
rw_wlock(&rl->rl_lock);
- LIST_FOREACH(r, &rl->rl_table[key & RATELIMIT_MASK], r_entry) {
- if (r->r_af != sa->sa_family)
+ LIST_FOREACH(r, &rl->rl_table[bucket], r_entry) {
+ if (bcmp(&r->r_key, &key, len) != 0)
continue;
- if (r->r_af == AF_INET && bcmp(&r->r_in,
- &satosin(sa)->sin_addr, IPV4_MASK_SIZE) != 0)
- continue;
-
-#ifdef INET6
- if (r->r_af == AF_INET6 && bcmp(&r->r_in6,
- &satosin6(sa)->sin6_addr, IPV6_MASK_SIZE) != 0)
- continue;
-#endif
-
/* If we get to here, we've found an entry for the endpoint.
* We apply standard token bucket, by calculating the time
* lapsed since our last_time, adding that, ensuring that we
@@ -462,15 +453,8 @@ ratelimit_allow(struct ratelimit *rl, struct sockaddr *sa)
rl->rl_table_num++;
/* Insert entry into the hashtable and ensure it's initialised */
- LIST_INSERT_HEAD(&rl->rl_table[key & RATELIMIT_MASK], r, r_entry);
- r->r_af = sa->sa_family;
- if (r->r_af == AF_INET)
- memcpy(&r->r_in, &satosin(sa)->sin_addr, IPV4_MASK_SIZE);
-#ifdef INET6
- else if (r->r_af == AF_INET6)
- memcpy(&r->r_in6, &satosin6(sa)->sin6_addr, IPV6_MASK_SIZE);
-#endif
-
+ LIST_INSERT_HEAD(&rl->rl_table[bucket], r, r_entry);
+ r->r_key = key;
r->r_last_time = getsbinuptime();
r->r_tokens = TOKEN_MAX - INITIATION_COST;