diff options
Diffstat (limited to 'src/wg_cookie.c')
-rw-r--r-- | src/wg_cookie.c | 68 |
1 files changed, 26 insertions, 42 deletions
diff --git a/src/wg_cookie.c b/src/wg_cookie.c index e68d662..58d0d8d 100644 --- a/src/wg_cookie.c +++ b/src/wg_cookie.c @@ -31,22 +31,20 @@ #define IPV4_MASK_SIZE 4 /* Use all 4 bytes of IPv4 address */ #define IPV6_MASK_SIZE 8 /* Use top 8 bytes (/64) of IPv6 address */ +struct ratelimit_key { + struct vnet *vnet; + uint8_t ip[IPV6_MASK_SIZE]; +}; + struct ratelimit_entry { LIST_ENTRY(ratelimit_entry) r_entry; - sa_family_t r_af; - union { - struct in_addr r_in; -#ifdef INET6 - struct in6_addr r_in6; -#endif - }; + struct ratelimit_key r_key; sbintime_t r_last_time; /* sbinuptime */ uint64_t r_tokens; }; struct ratelimit { uint8_t rl_secret[SIPHASH_KEY_LENGTH]; - struct rwlock rl_lock; struct callout rl_gc; LIST_HEAD(, ratelimit_entry) rl_table[RATELIMIT_SIZE]; @@ -67,7 +65,7 @@ static void ratelimit_deinit(struct ratelimit *); static void ratelimit_gc_callout(void *); static void ratelimit_gc_schedule(struct ratelimit *); static void ratelimit_gc(struct ratelimit *, bool); -static int ratelimit_allow(struct ratelimit *, struct sockaddr *); +static int ratelimit_allow(struct ratelimit *, struct sockaddr *, struct vnet *); static uint64_t siphash13(const uint8_t [SIPHASH_KEY_LENGTH], const void *, size_t); static struct ratelimit ratelimit_v4; @@ -86,9 +84,9 @@ cookie_init(void) ratelimit_init(&ratelimit_v4); #ifdef INET6 - ratelimit_init(&ratelimit_v6) + ratelimit_init(&ratelimit_v6); #endif - return 0; + return (0); } void @@ -207,7 +205,7 @@ cookie_maker_mac(struct cookie_maker *cm, struct cookie_macs *macs, void *buf, int cookie_checker_validate_macs(struct cookie_checker *cc, struct cookie_macs *macs, - void *buf, size_t len, bool check_cookie, struct sockaddr *sa) + void *buf, size_t len, bool check_cookie, struct sockaddr *sa, struct vnet *vnet) { struct cookie_macs our_macs; uint8_t cookie[COOKIE_COOKIE_SIZE]; @@ -234,10 +232,10 @@ cookie_checker_validate_macs(struct cookie_checker *cc, struct cookie_macs *macs * implying there is no ratelimiting, or we should ratelimit * (refuse) respectively. */ if (sa->sa_family == AF_INET) - return ratelimit_allow(&ratelimit_v4, sa); + return ratelimit_allow(&ratelimit_v4, sa, vnet); #ifdef INET6 else if (sa->sa_family == AF_INET6) - return ratelimit_allow(&ratelimit_v6, sa); + return ratelimit_allow(&ratelimit_v6, sa, vnet); #endif else return EAFNOSUPPORT; @@ -391,40 +389,33 @@ ratelimit_gc(struct ratelimit *rl, bool force) } static int -ratelimit_allow(struct ratelimit *rl, struct sockaddr *sa) +ratelimit_allow(struct ratelimit *rl, struct sockaddr *sa, struct vnet *vnet) { - uint64_t key, tokens; + uint64_t bucket, tokens; sbintime_t diff, now; struct ratelimit_entry *r; int ret = ECONNREFUSED; + struct ratelimit_key key = { .vnet = vnet }; + size_t len = sizeof(key); - if (sa->sa_family == AF_INET) - key = siphash13(rl->rl_secret, &satosin(sa)->sin_addr, - IPV4_MASK_SIZE); + if (sa->sa_family == AF_INET) { + memcpy(key.ip, &satosin(sa)->sin_addr, IPV4_MASK_SIZE); + len -= IPV6_MASK_SIZE - IPV4_MASK_SIZE; + } #ifdef INET6 else if (sa->sa_family == AF_INET6) - key = siphash13(rl->rl_secret, &satosin6(sa)->sin6_addr, - IPV6_MASK_SIZE); + memcpy(key.ip, &satosin6(sa)->sin6_addr, IPV6_MASK_SIZE); #endif else return ret; + bucket = siphash13(rl->rl_secret, &key, len) & RATELIMIT_MASK; rw_wlock(&rl->rl_lock); - LIST_FOREACH(r, &rl->rl_table[key & RATELIMIT_MASK], r_entry) { - if (r->r_af != sa->sa_family) + LIST_FOREACH(r, &rl->rl_table[bucket], r_entry) { + if (bcmp(&r->r_key, &key, len) != 0) continue; - if (r->r_af == AF_INET && bcmp(&r->r_in, - &satosin(sa)->sin_addr, IPV4_MASK_SIZE) != 0) - continue; - -#ifdef INET6 - if (r->r_af == AF_INET6 && bcmp(&r->r_in6, - &satosin6(sa)->sin6_addr, IPV6_MASK_SIZE) != 0) - continue; -#endif - /* If we get to here, we've found an entry for the endpoint. * We apply standard token bucket, by calculating the time * lapsed since our last_time, adding that, ensuring that we @@ -462,15 +453,8 @@ ratelimit_allow(struct ratelimit *rl, struct sockaddr *sa) rl->rl_table_num++; /* Insert entry into the hashtable and ensure it's initialised */ - LIST_INSERT_HEAD(&rl->rl_table[key & RATELIMIT_MASK], r, r_entry); - r->r_af = sa->sa_family; - if (r->r_af == AF_INET) - memcpy(&r->r_in, &satosin(sa)->sin_addr, IPV4_MASK_SIZE); -#ifdef INET6 - else if (r->r_af == AF_INET6) - memcpy(&r->r_in6, &satosin6(sa)->sin6_addr, IPV6_MASK_SIZE); -#endif - + LIST_INSERT_HEAD(&rl->rl_table[bucket], r, r_entry); + r->r_key = key; r->r_last_time = getsbinuptime(); r->r_tokens = TOKEN_MAX - INITIATION_COST; |