diff options
author | 2021-04-04 00:15:47 +1100 | |
---|---|---|
committer | 2021-04-13 15:47:30 +1000 | |
commit | e447e5e6911f4dc85b71b8683614807a4bf342e1 (patch) | |
tree | dacae0f734ffeedbeb18355d40ef0ae806c93632 /sys | |
parent | Check iter != NULL (diff) | |
download | wireguard-openbsd-e447e5e6911f4dc85b71b8683614807a4bf342e1.tar.xz wireguard-openbsd-e447e5e6911f4dc85b71b8683614807a4bf342e1.zip |
Add refcnt_take_if_gt()
This function (or of similar nature) is required to safely use a refcnt
and smr_entry together. Such functions exist on other platforms as
kref_get_unless_zero (on Linux) and refcount_acquire_if_gt (on FreeBSD).
The following diagram details the following situation with and without
refcnt_take_if_gt in 3 cases, with the first showing the "invalid" use
of refcnt_take.
Situation:
Thread #1 is removing the global referenc (o).
Thread #2 wants to reference an object (r), using a thread pointer (t).
Case:
1) refcnt_take after Thread #1 has released "o"
2) refcnt_take_if_gt before Thread #1 has released "o"
3) refcnt_take_if_gt after Thread #1 has released "o"
Data:
struct obj {
struct smr_entry smr;
struct refcnt refcnt;
} *o, *r, *t1, *t2;
Thread #1 | Thread #2
---------------------------------+------------------------------------
| r = NULL;
rw_enter_write(&lock); | smr_read_enter();
|
t1 = SMR_PTR_GET_LOCKED(&o); | t2 = SMR_PTR_GET(&o);
SMR_PTR_SET_LOCKED(&o, NULL); |
|
if (refcnt_rele(&t1->refcnt) |
smr_call(&t1->smr, free, t1); |
| if (t2 != NULL) {
| refcnt_take(&t2->refcnt);
| r = t2;
| }
rw_exit_write(&lock); | smr_read_exit();
.....
// called by smr_thread |
free(t1); |
.....
| // use after free
| *r
---------------------------------+------------------------------------
| r = NULL;
rw_enter_write(&lock); | smr_read_enter();
|
t1 = SMR_PTR_GET_LOCKED(&o); | t2 = SMR_PTR_GET(&o);
SMR_PTR_SET_LOCKED(&o, NULL); |
|
if (refcnt_rele(&t1->refcnt) |
smr_call(&t1->smr, free, t1); |
| if (t2 != NULL &&
| refcnt_take_if_gt(&t2->refcnt, 0))
| r = t2;
rw_exit_write(&lock); | smr_read_exit();
.....
// called by smr_thread | // we don't have a valid reference
free(t1); | assert(r == NULL);
---------------------------------+------------------------------------
| r = NULL;
rw_enter_write(&lock); | smr_read_enter();
|
t1 = SMR_PTR_GET_LOCKED(&o); | t2 = SMR_PTR_GET(&o);
SMR_PTR_SET_LOCKED(&o, NULL); |
| if (t2 != NULL &&
| refcnt_take_if_gt(&t2->refcnt, 0))
| r = t2;
if (refcnt_rele(&t1->refcnt) |
smr_call(&t1->smr, free, t1); |
rw_exit_write(&lock); | smr_read_exit();
.....
| // we need to put our reference
| if (refcnt_rele(&t2->refcnt))
| smr_call(&t2->smr, free, t2);
.....
// called by smr_thread |
free(t1); |
---------------------------------+------------------------------------
Currently it uses atomic_add_int_nv to atomically read the refcnt,
but I'm open to suggestions for better ways.
The atomic_cas_uint is used to ensure that refcnt hasn't been modified
since reading `old`.
Diffstat (limited to 'sys')
-rw-r--r-- | sys/kern/kern_synch.c | 13 | ||||
-rw-r--r-- | sys/sys/refcnt.h | 1 |
2 files changed, 14 insertions, 0 deletions
diff --git a/sys/kern/kern_synch.c b/sys/kern/kern_synch.c index b476a6b4253..fe616fdfbfe 100644 --- a/sys/kern/kern_synch.c +++ b/sys/kern/kern_synch.c @@ -818,6 +818,19 @@ refcnt_take(struct refcnt *r) } int +refcnt_take_if_gt(struct refcnt *r, u_int n) +{ + u_int old; + while (1) { + old = READ_ONCE(r->refs); + if (old <= n) + return 0; + if (atomic_cas_uint(&r->refs, old, old + 1) == old) + return 1; + } +} + +int refcnt_rele(struct refcnt *r) { u_int refcnt; diff --git a/sys/sys/refcnt.h b/sys/sys/refcnt.h index 85e84cfdc2d..847e51ce124 100644 --- a/sys/sys/refcnt.h +++ b/sys/sys/refcnt.h @@ -29,6 +29,7 @@ struct refcnt { void refcnt_init(struct refcnt *); void refcnt_take(struct refcnt *); +int refcnt_take_if_gt(struct refcnt *, unsigned int); int refcnt_rele(struct refcnt *); void refcnt_rele_wake(struct refcnt *); void refcnt_finalize(struct refcnt *, const char *); |