From 1df85964a33a4ff037aa60915f3d1f24140a2a62 Mon Sep 17 00:00:00 2001 From: Matt Dunwoodie Date: Sun, 4 Apr 2021 22:35:43 +1000 Subject: Replace timer lock with SMR The lock was not used to protect any data structures, it was purely to ensure race-free setting of t_disabled. That is, that no other thread was halfway through any wg_timers_run_* function. With smr_* we can ensure this is still the case by calling smr_barrier() after setting t_disabled. --- sys/net/if_wg.c | 67 ++++++++++++++++++++++++++------------------------------- 1 file changed, 31 insertions(+), 36 deletions(-) diff --git a/sys/net/if_wg.c b/sys/net/if_wg.c index d29d98218dc..ae29bf4d49c 100644 --- a/sys/net/if_wg.c +++ b/sys/net/if_wg.c @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -129,9 +130,6 @@ struct wg_endpoint { }; struct wg_timers { - /* t_lock is for blocking wg_timers_event_* when setting t_disabled. */ - struct rwlock t_lock; - int t_disabled; int t_need_another_keepalive; uint16_t t_persistent_keepalive_interval; @@ -787,7 +785,7 @@ wg_tag_get(struct mbuf *m) /* * The following section handles the timeout callbacks for a WireGuard session. - * These functions provide an "event based" model for controlling wg(8) session + * These functions provide an "event based" model for controlling wg(4) session * timers. All function calls occur after the specified event below. * * wg_timers_event_data_sent: @@ -815,7 +813,6 @@ void wg_timers_init(struct wg_timers *t) { bzero(t, sizeof(*t)); - rw_init(&t->t_lock, "wg_timers"); mtx_init(&t->t_handshake_mtx, IPL_NET); timeout_set_flags(&t->t_new_handshake, @@ -833,19 +830,16 @@ wg_timers_init(struct wg_timers *t) void wg_timers_enable(struct wg_timers *t) { - rw_enter_write(&t->t_lock); - t->t_disabled = 0; - rw_exit_write(&t->t_lock); + WRITE_ONCE(t->t_disabled, 0); wg_timers_run_persistent_keepalive(t); } void wg_timers_disable(struct wg_timers *t) { - rw_enter_write(&t->t_lock); - t->t_disabled = 1; - t->t_need_another_keepalive = 0; - rw_exit_write(&t->t_lock); + WRITE_ONCE(t->t_disabled, 1); + smr_barrier(); + WRITE_ONCE(t->t_need_another_keepalive, 0); timeout_del_barrier(&t->t_new_handshake); timeout_del_barrier(&t->t_send_keepalive); @@ -857,13 +851,13 @@ wg_timers_disable(struct wg_timers *t) void wg_timers_set_persistent_keepalive(struct wg_timers *t, uint16_t interval) { - rw_enter_read(&t->t_lock); if (interval != t->t_persistent_keepalive_interval) { - t->t_persistent_keepalive_interval = interval; + WRITE_ONCE(t->t_persistent_keepalive_interval, interval); + smr_read_enter(); if (!t->t_disabled) wg_timers_run_persistent_keepalive(t); + smr_read_leave(); } - rw_exit_read(&t->t_lock); } int @@ -887,24 +881,24 @@ wg_timers_event_data_sent(struct wg_timers *t) int msecs = NEW_HANDSHAKE_TIMEOUT * 1000; msecs += arc4random_uniform(REKEY_TIMEOUT_JITTER); - rw_enter_read(&t->t_lock); + smr_read_enter(); if (!t->t_disabled && !timeout_pending(&t->t_new_handshake)) timeout_add_msec(&t->t_new_handshake, msecs); - rw_exit_read(&t->t_lock); + smr_read_leave(); } void wg_timers_event_data_received(struct wg_timers *t) { - rw_enter_read(&t->t_lock); + smr_read_enter(); if (!t->t_disabled) { if (!timeout_pending(&t->t_send_keepalive)) timeout_add_sec(&t->t_send_keepalive, KEEPALIVE_TIMEOUT); else - t->t_need_another_keepalive = 1; + WRITE_ONCE(t->t_need_another_keepalive, 1); } - rw_exit_read(&t->t_lock); + smr_read_leave(); } void @@ -922,11 +916,12 @@ wg_timers_event_any_authenticated_packet_received(struct wg_timers *t) void wg_timers_event_any_authenticated_packet_traversal(struct wg_timers *t) { - rw_enter_read(&t->t_lock); - if (!t->t_disabled && t->t_persistent_keepalive_interval > 0) - timeout_add_sec(&t->t_persistent_keepalive, - t->t_persistent_keepalive_interval); - rw_exit_read(&t->t_lock); + uint16_t interval; + smr_read_enter(); + interval = READ_ONCE(t->t_persistent_keepalive_interval); + if (!t->t_disabled && interval > 0) + timeout_add_sec(&t->t_persistent_keepalive, interval); + smr_read_leave(); } void @@ -935,16 +930,16 @@ wg_timers_event_handshake_initiated(struct wg_timers *t) int msecs = REKEY_TIMEOUT * 1000; msecs += arc4random_uniform(REKEY_TIMEOUT_JITTER); - rw_enter_read(&t->t_lock); + smr_read_enter(); if (!t->t_disabled) timeout_add_msec(&t->t_retry_handshake, msecs); - rw_exit_read(&t->t_lock); + smr_read_leave(); } void wg_timers_event_handshake_complete(struct wg_timers *t) { - rw_enter_read(&t->t_lock); + smr_read_enter(); if (!t->t_disabled) { mtx_enter(&t->t_handshake_mtx); timeout_del(&t->t_retry_handshake); @@ -953,25 +948,25 @@ wg_timers_event_handshake_complete(struct wg_timers *t) mtx_leave(&t->t_handshake_mtx); wg_timers_run_send_keepalive(t); } - rw_exit_read(&t->t_lock); + smr_read_leave(); } void wg_timers_event_session_derived(struct wg_timers *t) { - rw_enter_read(&t->t_lock); + smr_read_enter(); if (!t->t_disabled) timeout_add_sec(&t->t_zero_key_material, REJECT_AFTER_TIME * 3); - rw_exit_read(&t->t_lock); + smr_read_leave(); } void wg_timers_event_want_initiation(struct wg_timers *t) { - rw_enter_read(&t->t_lock); + smr_read_enter(); if (!t->t_disabled) wg_timers_run_send_initiation(t, 0); - rw_exit_read(&t->t_lock); + smr_read_leave(); } void @@ -1023,8 +1018,8 @@ wg_timers_run_send_keepalive(void *_t) struct wg_peer *peer = CONTAINER_OF(t, struct wg_peer, p_timers); wg_send_keepalive(peer); - if (t->t_need_another_keepalive) { - t->t_need_another_keepalive = 0; + if (READ_ONCE(t->t_need_another_keepalive)) { + WRITE_ONCE(t->t_need_another_keepalive, 0); timeout_add_sec(&t->t_send_keepalive, KEEPALIVE_TIMEOUT); } } @@ -1061,7 +1056,7 @@ wg_timers_run_persistent_keepalive(void *_t) struct wg_timers *t = _t; struct wg_peer *peer = CONTAINER_OF(t, struct wg_peer, p_timers); - if (t->t_persistent_keepalive_interval > 0) + if (READ_ONCE(t->t_persistent_keepalive_interval) > 0) wg_send_keepalive(peer); } -- cgit v1.2.3-59-g8ed1b