aboutsummaryrefslogtreecommitdiffstats
path: root/driver/ratelimiter.c
blob: 8d72afed6d32e50cf969cd3c79a2cdbbca1df8dd (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
/* SPDX-License-Identifier: GPL-2.0
 *
 * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
 */

#include "containers.h"
#include "ratelimiter.h"
#include "crypto.h"
#include "logging.h"
#include "timers.h"

#define TABLE_SIZE 8192
#define MAX_ENTRIES (TABLE_SIZE * 8)

static LOOKASIDE_ALIGN LOOKASIDE_LIST_EX EntryCache;
static HSIPHASH_KEY Key;
static KSPIN_LOCK TableLock;
static LONG TotalEntries = 0;
static struct
{
    KEVENT Terminate;
    PKTHREAD Thread;
} RatelimiterGcEntriesThread;
static HLIST_HEAD TableV4[TABLE_SIZE] = { 0 }, TableV6[TABLE_SIZE] = { 0 };

typedef struct _RATELIMITER_ENTRY
{
    UINT64 LastTime, Tokens, Ip;
    KSPIN_LOCK Lock;
    HLIST_NODE Hash;
    RCU_CALLBACK Rcu;
} RATELIMITER_ENTRY;

enum
{
    PACKETS_PER_SECOND = 20,
    PACKETS_BURSTABLE = 5,
    PACKET_COST = SYS_TIME_UNITS_PER_SEC / PACKETS_PER_SECOND,
    TOKEN_MAX = PACKET_COST * PACKETS_BURSTABLE
};

static RCU_CALLBACK_FN EntryFree;
_Use_decl_annotations_
static VOID
EntryFree(RCU_CALLBACK *Rcu)
{
    ExFreeToLookasideListEx(&EntryCache, CONTAINING_RECORD(Rcu, RATELIMITER_ENTRY, Rcu));
    InterlockedDecrement(&TotalEntries);
}

static VOID
EntryUninit(_Inout_ RATELIMITER_ENTRY *Entry)
{
    HlistDelRcu(&Entry->Hash);
    RcuCall(&Entry->Rcu, EntryFree);
}

/* Calling this function with a NULL work uninits all entries. */
_IRQL_requires_max_(PASSIVE_LEVEL)
_Function_class_(KSTART_ROUTINE)
static VOID
RatelimiterGcEntries(_In_opt_ PVOID StartContext)
{
    for (;;)
    {
        CONST UINT64 Now = KeQueryInterruptTime();
        RATELIMITER_ENTRY *Entry;
        HLIST_NODE *Temp;
        ULONG i;
        KIRQL Irql;

        for (i = 0; i < TABLE_SIZE; ++i)
        {
            KeAcquireSpinLock(&TableLock, &Irql);
            HLIST_FOR_EACH_ENTRY_SAFE (Entry, Temp, &TableV4[i], RATELIMITER_ENTRY, Hash)
            {
                if (!StartContext || Now - Entry->LastTime > SYS_TIME_UNITS_PER_SEC)
                    EntryUninit(Entry);
            }
            HLIST_FOR_EACH_ENTRY_SAFE (Entry, Temp, &TableV6[i], RATELIMITER_ENTRY, Hash)
            {
                if (!StartContext || Now - Entry->LastTime > SYS_TIME_UNITS_PER_SEC)
                    EntryUninit(Entry);
            }
            KeReleaseSpinLock(&TableLock, Irql);
        }
        if (!StartContext)
            break;
        if (KeWaitForSingleObject(
                &RatelimiterGcEntriesThread.Terminate,
                Executive,
                KernelMode,
                FALSE,
                &(LARGE_INTEGER){ .QuadPart = -SYS_TIME_UNITS_PER_SEC }) == STATUS_SUCCESS)
            break;
    }
}

_Use_decl_annotations_
BOOLEAN
RatelimiterAllow(CONST SOCKADDR *Src)
{
    RATELIMITER_ENTRY *Entry;
    HLIST_HEAD *Bucket;
    UINT64 Ip;
    KIRQL Irql;

    if (Src->sa_family == AF_INET)
    {
        Ip = (UINT64)((SOCKADDR_IN *)Src)->sin_addr.s_addr;
        Bucket = &TableV4[Hsiphash1u32((UINT32)Ip, &Key) & (TABLE_SIZE - 1)];
    }
    else if (Src->sa_family == AF_INET6)
    {
        /* Only use 64 bits, so as to ratelimit the whole /64. */
        RtlCopyMemory(&Ip, &((SOCKADDR_IN6 *)Src)->sin6_addr, sizeof(Ip));
        Bucket = &TableV6[Hsiphash2u32((UINT32)(Ip >> 32), (UINT32)Ip, &Key) & (TABLE_SIZE - 1)];
    }
    else
        return FALSE;
    Irql = RcuReadLock();
    HLIST_FOR_EACH_ENTRY_RCU (Entry, Bucket, RATELIMITER_ENTRY, Hash)
    {
        if (Entry->Ip == Ip)
        {
            UINT64 Now, Tokens;
            BOOLEAN Ret;
            /* Quasi-inspired by nft_limit.c, but this is actually a
             * slightly different algorithm. Namely, we incorporate
             * the burst as part of the maximum tokens, rather than
             * as part of the rate.
             */
            KeAcquireSpinLockAtDpcLevel(&Entry->Lock);
            Now = KeQueryInterruptTime();
            Tokens = min(TOKEN_MAX, Entry->Tokens + Now - Entry->LastTime);
            Entry->LastTime = Now;
            Ret = Tokens >= PACKET_COST;
            Entry->Tokens = Ret ? Tokens - PACKET_COST : Tokens;
            KeReleaseSpinLockFromDpcLevel(&Entry->Lock);
            RcuReadUnlock(Irql);
            return Ret;
        }
    }
    RcuReadUnlock(Irql);

    if ((ULONG)InterlockedIncrement(&TotalEntries) > MAX_ENTRIES)
        goto cleanupOom;

    Entry = ExAllocateFromLookasideListEx(&EntryCache);
    if (!Entry)
        goto cleanupOom;

    Entry->Ip = Ip;
    HlistInit(&Entry->Hash);
    KeInitializeSpinLock(&Entry->Lock);
    Entry->LastTime = KeQueryInterruptTime();
    Entry->Tokens = TOKEN_MAX - PACKET_COST;
    KeAcquireSpinLock(&TableLock, &Irql);
    HlistAddHeadRcu(&Entry->Hash, Bucket);
    KeReleaseSpinLock(&TableLock, Irql);
    return TRUE;

cleanupOom:
    InterlockedDecrement(&TotalEntries);
    return FALSE;
}

#ifdef ALLOC_PRAGMA
#    pragma alloc_text(INIT, RatelimiterDriverEntry)
#endif
_Use_decl_annotations_
NTSTATUS
RatelimiterDriverEntry(VOID)
{
    NTSTATUS Status =
        ExInitializeLookasideListEx(&EntryCache, NULL, NULL, NonPagedPool, 0, sizeof(RATELIMITER_ENTRY), MEMORY_TAG, 0);
    if (!NT_SUCCESS(Status))
        return Status;
    KeInitializeSpinLock(&TableLock);
    KeInitializeEvent(&RatelimiterGcEntriesThread.Terminate, NotificationEvent, FALSE);
    OBJECT_ATTRIBUTES ObjectAttributes;
    InitializeObjectAttributes(&ObjectAttributes, NULL, OBJ_KERNEL_HANDLE, NULL, NULL);
    HANDLE Handle;
    Status = PsCreateSystemThread(
        &Handle, THREAD_ALL_ACCESS, &ObjectAttributes, NULL, NULL, RatelimiterGcEntries, (PVOID)TRUE);
    if (!NT_SUCCESS(Status))
        goto cleanupEntryCache;
    ObReferenceObjectByHandle(Handle, SYNCHRONIZE, NULL, KernelMode, &RatelimiterGcEntriesThread.Thread, NULL);
    ZwClose(Handle);
    CryptoRandom(&Key, sizeof(Key));
    return STATUS_SUCCESS;
cleanupEntryCache:
    ExDeleteLookasideListEx(&EntryCache);
    return Status;
}

_Use_decl_annotations_
VOID RatelimiterUnload(VOID)
{
#pragma warning(suppress : 28160) /* Acknowledge caution about Wait parameter. */
    KeSetEvent(&RatelimiterGcEntriesThread.Terminate, IO_NO_INCREMENT, TRUE);
    KeWaitForSingleObject(RatelimiterGcEntriesThread.Thread, Executive, KernelMode, FALSE, NULL);
    ObDereferenceObject(RatelimiterGcEntriesThread.Thread);
    RatelimiterGcEntries(NULL);
    RcuBarrier();
    ExDeleteLookasideListEx(&EntryCache);
}

#ifdef DBG
#    include "selftest/ratelimiter.c"
#endif