aboutsummaryrefslogtreecommitdiffstats
path: root/device
diff options
context:
space:
mode:
authorJosh Bleecher Snyder <josh@tailscale.com>2020-12-14 15:28:52 -0800
committerJason A. Donenfeld <Jason@zx2c4.com>2021-01-07 14:49:44 +0100
commit63066ce4062a85224821ce302e3eb8c34e95a658 (patch)
treef6945216e1b48bc5404f612cdcc5564b12e93aae /device
parentdevice: use channel close to shut down and drain encryption channel (diff)
downloadwireguard-go-63066ce4062a85224821ce302e3eb8c34e95a658.tar.xz
wireguard-go-63066ce4062a85224821ce302e3eb8c34e95a658.zip
device: fix persistent_keepalive_interval data races
Co-authored-by: David Anderson <danderson@tailscale.com> Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
Diffstat (limited to 'device')
-rw-r--r--device/device.go2
-rw-r--r--device/device_test.go15
-rw-r--r--device/peer.go2
-rw-r--r--device/timers.go7
-rw-r--r--device/uapi.go5
5 files changed, 22 insertions, 9 deletions
diff --git a/device/device.go b/device/device.go
index d9367e5..99f5e60 100644
--- a/device/device.go
+++ b/device/device.go
@@ -163,7 +163,7 @@ func deviceUpdateState(device *Device) {
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Start()
- if peer.persistentKeepaliveInterval > 0 {
+ if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
peer.SendKeepalive()
}
}
diff --git a/device/device_test.go b/device/device_test.go
index 65942ec..e143914 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -215,7 +215,20 @@ func TestConcurrencySafety(t *testing.T) {
}()
warmup.Wait()
- // coming soon: more things here...
+ // Change persistent_keepalive_interval concurrently with tunnel use.
+ t.Run("persistentKeepaliveInterval", func(t *testing.T) {
+ cfg := uapiCfg(
+ "public_key", "f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725",
+ "persistent_keepalive_interval", "1",
+ )
+ for i := 0; i < 1000; i++ {
+ cfg.Seek(0, io.SeekStart)
+ err := pair[0].dev.IpcSetOperation(cfg)
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+ })
close(done)
}
diff --git a/device/peer.go b/device/peer.go
index c2397cc..31b75c7 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -27,7 +27,7 @@ type Peer struct {
handshake Handshake
device *Device
endpoint conn.Endpoint
- persistentKeepaliveInterval uint16
+ persistentKeepaliveInterval uint32 // accessed atomically
disableRoaming bool
// These fields are accessed with atomic operations, which must be
diff --git a/device/timers.go b/device/timers.go
index 48cef94..e94da36 100644
--- a/device/timers.go
+++ b/device/timers.go
@@ -138,7 +138,7 @@ func expiredZeroKeyMaterial(peer *Peer) {
}
func expiredPersistentKeepalive(peer *Peer) {
- if peer.persistentKeepaliveInterval > 0 {
+ if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
peer.SendKeepalive()
}
}
@@ -201,8 +201,9 @@ func (peer *Peer) timersSessionDerived() {
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
- if peer.persistentKeepaliveInterval > 0 && peer.timersActive() {
- peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second)
+ keepalive := atomic.LoadUint32(&peer.persistentKeepaliveInterval)
+ if keepalive > 0 && peer.timersActive() {
+ peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
}
}
diff --git a/device/uapi.go b/device/uapi.go
index c0e522b..3f26607 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -86,7 +86,7 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) error {
send(fmt.Sprintf("last_handshake_time_nsec=%d", nano))
send(fmt.Sprintf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes)))
send(fmt.Sprintf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)))
- send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
+ send(fmt.Sprintf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval)))
for _, ip := range device.allowedips.EntriesForPeer(peer) {
send("allowed_ip=" + ip.String())
@@ -333,8 +333,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
return &IPCError{ipc.IpcErrorInvalid}
}
- old := peer.persistentKeepaliveInterval
- peer.persistentKeepaliveInterval = uint16(secs)
+ old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs))
// send immediate keepalive if we're turning it on and before it wasn't on