aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2018-06-17 18:25:15 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2018-06-17 18:25:15 +0200
commit6ed04ca462c9273acb990144b79c01ebf0c8d102 (patch)
tree3799ef44ac520763729ee70c3a754becb76e48f3
parentmpmc_ptr_ring: Fix a word (diff)
downloadwireguard-monolithic-historical-6ed04ca462c9273acb990144b79c01ebf0c8d102.tar.xz
wireguard-monolithic-historical-6ed04ca462c9273acb990144b79c01ebf0c8d102.zip
selftest/mpmc_ring: use completion and switch to ptr_ring
-rw-r--r--src/selftest/mpmc_ring.h40
1 files changed, 23 insertions, 17 deletions
diff --git a/src/selftest/mpmc_ring.h b/src/selftest/mpmc_ring.h
index 54250d7..1dc92fd 100644
--- a/src/selftest/mpmc_ring.h
+++ b/src/selftest/mpmc_ring.h
@@ -8,6 +8,7 @@
#include "../mpmc_ptr_ring.h"
#include <linux/kthread.h>
+#include <linux/ptr_ring.h>
#define THREADS_PRODUCER 20
#define THREADS_CONSUMER 20
@@ -20,12 +21,14 @@
#define THREADS_TOTAL (THREADS_PRODUCER + THREADS_CONSUMER)
struct worker_producer {
- struct mpmc_ptr_ring *ring;
+ struct ptr_ring *ring;
+ struct completion completion;
int thread_num;
};
struct worker_consumer {
- struct mpmc_ptr_ring *ring;
+ struct ptr_ring *ring;
+ struct completion completion;
uint64_t total;
uint64_t count;
};
@@ -36,11 +39,12 @@ static __init int producer_function(void *data)
uint64_t i;
for (i = td->thread_num * PER_PRODUCER + 1; i <= (td->thread_num + 1) * PER_PRODUCER; ++i) {
- while (mpmc_ptr_ring_produce(td->ring, (void *)i)) {
+ while (ptr_ring_produce(td->ring, (void *)i)) {
if (need_resched())
schedule();
}
}
+ complete_all(&td->completion);
return 0;
}
@@ -51,7 +55,7 @@ static __init int consumer_function(void *data)
for (i = 0; i < PER_CONSUMER; ++i) {
uintptr_t value;
- while (!(value = (uintptr_t)mpmc_ptr_ring_consume(td->ring))) {
+ while (!(value = (uintptr_t)ptr_ring_consume(td->ring))) {
if (need_resched())
schedule();
}
@@ -59,6 +63,7 @@ static __init int consumer_function(void *data)
td->total += value;
++td->count;
}
+ complete_all(&td->completion);
return 0;
}
@@ -66,36 +71,38 @@ bool __init mpmc_ring_selftest(void)
{
struct worker_producer *producers;
struct worker_consumer *consumers;
- struct task_struct **threads;
- struct mpmc_ptr_ring ring;
+ struct ptr_ring ring;
int64_t total = 0, count = 0;
- int i, j = 0;
+ int i;
producers = kmalloc_array(THREADS_PRODUCER, sizeof(*producers), GFP_KERNEL);
consumers = kmalloc_array(THREADS_CONSUMER, sizeof(*consumers), GFP_KERNEL);
- threads = kmalloc_array(THREADS_CONSUMER + THREADS_PRODUCER, sizeof(*threads), GFP_KERNEL);
- BUG_ON(!producers || !consumers || !threads);
- BUG_ON(mpmc_ptr_ring_init(&ring, QUEUE_SIZE, GFP_KERNEL));
+ BUG_ON(!producers || !consumers);
+ BUG_ON(ptr_ring_init(&ring, QUEUE_SIZE, GFP_KERNEL));
for (i = 0; i < THREADS_PRODUCER; ++i) {
producers[i].ring = &ring;
producers[i].thread_num = i;
- threads[j++] = kthread_run(producer_function, &producers[i], "producer %d", i);
+ init_completion(&producers[i].completion);
+ kthread_run(producer_function, &producers[i], "producer %d", i);
}
for (i = 0; i < THREADS_CONSUMER; ++i) {
consumers[i].ring = &ring;
consumers[i].total = 0;
consumers[i].count = 0;
- threads[j++] = kthread_run(consumer_function, &consumers[i], "consumer %d", i);
+ init_completion(&consumers[i].completion);
+ kthread_run(consumer_function, &consumers[i], "consumer %d", i);
}
- for (j = 0; j < THREADS_CONSUMER + THREADS_PRODUCER; ++j)
- kthread_stop(threads[j]);
+ for (i = 0; i < THREADS_PRODUCER; ++i)
+ wait_for_completion(&producers[i].completion);
+ for (i = 0; i < THREADS_CONSUMER; ++i)
+ wait_for_completion(&consumers[i].completion);
- BUG_ON(!mpmc_ptr_ring_empty(&ring));
- mpmc_ptr_ring_cleanup(&ring, NULL);
+ BUG_ON(!ptr_ring_empty(&ring));
+ ptr_ring_cleanup(&ring, NULL);
for (i = 0; i < THREADS_CONSUMER; ++i) {
total += consumers[i].total;
@@ -104,7 +111,6 @@ bool __init mpmc_ring_selftest(void)
kfree(producers);
kfree(consumers);
- kfree(threads);
if (count == ELEMENT_COUNT && total == EXPECTED_TOTAL) {
pr_info("mpmc_ring self-tests: pass");