summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMatt Dunwoodie <ncon@noconroy.net>2021-04-04 00:15:47 +1100
committerMatt Dunwoodie <ncon@noconroy.net>2021-04-13 15:47:30 +1000
commite447e5e6911f4dc85b71b8683614807a4bf342e1 (patch)
treedacae0f734ffeedbeb18355d40ef0ae806c93632
parentCheck iter != NULL (diff)
downloadwireguard-openbsd-e447e5e6911f4dc85b71b8683614807a4bf342e1.tar.xz
wireguard-openbsd-e447e5e6911f4dc85b71b8683614807a4bf342e1.zip
Add refcnt_take_if_gt()
This function (or of similar nature) is required to safely use a refcnt and smr_entry together. Such functions exist on other platforms as kref_get_unless_zero (on Linux) and refcount_acquire_if_gt (on FreeBSD). The following diagram details the following situation with and without refcnt_take_if_gt in 3 cases, with the first showing the "invalid" use of refcnt_take. Situation: Thread #1 is removing the global referenc (o). Thread #2 wants to reference an object (r), using a thread pointer (t). Case: 1) refcnt_take after Thread #1 has released "o" 2) refcnt_take_if_gt before Thread #1 has released "o" 3) refcnt_take_if_gt after Thread #1 has released "o" Data: struct obj { struct smr_entry smr; struct refcnt refcnt; } *o, *r, *t1, *t2; Thread #1 | Thread #2 ---------------------------------+------------------------------------ | r = NULL; rw_enter_write(&lock); | smr_read_enter(); | t1 = SMR_PTR_GET_LOCKED(&o); | t2 = SMR_PTR_GET(&o); SMR_PTR_SET_LOCKED(&o, NULL); | | if (refcnt_rele(&t1->refcnt) | smr_call(&t1->smr, free, t1); | | if (t2 != NULL) { | refcnt_take(&t2->refcnt); | r = t2; | } rw_exit_write(&lock); | smr_read_exit(); ..... // called by smr_thread | free(t1); | ..... | // use after free | *r ---------------------------------+------------------------------------ | r = NULL; rw_enter_write(&lock); | smr_read_enter(); | t1 = SMR_PTR_GET_LOCKED(&o); | t2 = SMR_PTR_GET(&o); SMR_PTR_SET_LOCKED(&o, NULL); | | if (refcnt_rele(&t1->refcnt) | smr_call(&t1->smr, free, t1); | | if (t2 != NULL && | refcnt_take_if_gt(&t2->refcnt, 0)) | r = t2; rw_exit_write(&lock); | smr_read_exit(); ..... // called by smr_thread | // we don't have a valid reference free(t1); | assert(r == NULL); ---------------------------------+------------------------------------ | r = NULL; rw_enter_write(&lock); | smr_read_enter(); | t1 = SMR_PTR_GET_LOCKED(&o); | t2 = SMR_PTR_GET(&o); SMR_PTR_SET_LOCKED(&o, NULL); | | if (t2 != NULL && | refcnt_take_if_gt(&t2->refcnt, 0)) | r = t2; if (refcnt_rele(&t1->refcnt) | smr_call(&t1->smr, free, t1); | rw_exit_write(&lock); | smr_read_exit(); ..... | // we need to put our reference | if (refcnt_rele(&t2->refcnt)) | smr_call(&t2->smr, free, t2); ..... // called by smr_thread | free(t1); | ---------------------------------+------------------------------------ Currently it uses atomic_add_int_nv to atomically read the refcnt, but I'm open to suggestions for better ways. The atomic_cas_uint is used to ensure that refcnt hasn't been modified since reading `old`.
-rw-r--r--share/man/man9/refcnt_init.916
-rw-r--r--sys/kern/kern_synch.c13
-rw-r--r--sys/sys/refcnt.h1
3 files changed, 30 insertions, 0 deletions
diff --git a/share/man/man9/refcnt_init.9 b/share/man/man9/refcnt_init.9
index 1b19d3b653a..8f45785d45b 100644
--- a/share/man/man9/refcnt_init.9
+++ b/share/man/man9/refcnt_init.9
@@ -20,6 +20,7 @@
.Sh NAME
.Nm refcnt_init ,
.Nm refcnt_take ,
+.Nm refcnt_take_if_gt ,
.Nm refcnt_rele ,
.Nm refcnt_rele_wake ,
.Nm refcnt_finalize ,
@@ -32,6 +33,8 @@
.Ft void
.Fn "refcnt_take" "struct refcnt *r"
.Ft int
+.Fn "refcnt_take_if_gt" "struct refcnt *r" "unsigned int n"
+.Ft int
.Fn "refcnt_rele" "struct refcnt *r"
.Ft void
.Fn "refcnt_rele_wake" "struct refcnt *r"
@@ -51,6 +54,14 @@ is used to acquire a new reference.
It is the responsibility of the caller to guarantee that it holds
a valid reference before taking a new reference.
.Pp
+.Fn refcnt_take_if_gt
+is used to conditionally acquire a new reference.
+If the count is greater than
+.Fa n ,
+a reference is taken.
+This allows the caller to safely reference a SMR-protected object in an SMR
+read-side critical section.
+.Pp
.Fn refcnt_rele
is used to release an existing reference.
.Pp
@@ -73,6 +84,7 @@ initialises a declaration of a refcnt to 1.
.Sh CONTEXT
.Fn refcnt_init ,
.Fn refcnt_take ,
+.Fn refcnt_take_if_gt ,
.Fn refcnt_rele ,
and
.Fn refcnt_rele_wake
@@ -82,6 +94,10 @@ context.
.Fn refcnt_finalize
can be called from process context.
.Sh RETURN VALUES
+.Fn refcnt_take_if_gt
+returns a non-zero value if a reference has been taken,
+otherwise 0.
+.Pp
.Fn refcnt_rele
returns a non-zero value if the last reference has been released,
otherwise 0.
diff --git a/sys/kern/kern_synch.c b/sys/kern/kern_synch.c
index b476a6b4253..fe616fdfbfe 100644
--- a/sys/kern/kern_synch.c
+++ b/sys/kern/kern_synch.c
@@ -818,6 +818,19 @@ refcnt_take(struct refcnt *r)
}
int
+refcnt_take_if_gt(struct refcnt *r, u_int n)
+{
+ u_int old;
+ while (1) {
+ old = READ_ONCE(r->refs);
+ if (old <= n)
+ return 0;
+ if (atomic_cas_uint(&r->refs, old, old + 1) == old)
+ return 1;
+ }
+}
+
+int
refcnt_rele(struct refcnt *r)
{
u_int refcnt;
diff --git a/sys/sys/refcnt.h b/sys/sys/refcnt.h
index 85e84cfdc2d..847e51ce124 100644
--- a/sys/sys/refcnt.h
+++ b/sys/sys/refcnt.h
@@ -29,6 +29,7 @@ struct refcnt {
void refcnt_init(struct refcnt *);
void refcnt_take(struct refcnt *);
+int refcnt_take_if_gt(struct refcnt *, unsigned int);
int refcnt_rele(struct refcnt *);
void refcnt_rele_wake(struct refcnt *);
void refcnt_finalize(struct refcnt *, const char *);