From 846d721dfd0cde953f2e9304d6ef50110de050eb Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Wed, 16 May 2018 22:20:15 +0200 Subject: Finer-grained start-stop synchronization --- conn.go | 6 ++++++ device.go | 12 +++++++++++- peer.go | 9 +++++---- receive.go | 4 ++++ send.go | 3 +++ tun.go | 4 ++++ 6 files changed, 33 insertions(+), 5 deletions(-) diff --git a/conn.go b/conn.go index 92f4cfe..d3919ca 100644 --- a/conn.go +++ b/conn.go @@ -12,6 +12,10 @@ import ( "net" ) +const ( + ConnRoutineNumber = 2 +) + /* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic */ type Bind interface { @@ -153,6 +157,8 @@ func (device *Device) BindUpdate() error { // start receiving routines + device.state.starting.Add(ConnRoutineNumber) + device.state.stopping.Add(ConnRoutineNumber) go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) diff --git a/device.go b/device.go index e91ca72..6e1bc94 100644 --- a/device.go +++ b/device.go @@ -15,6 +15,7 @@ import ( const ( DeviceRoutineNumberPerCPU = 3 + DeviceRoutineNumberAdditional = 2 ) type Device struct { @@ -25,6 +26,7 @@ type Device struct { // synchronized resources (locks acquired in order) state struct { + starting sync.WaitGroup stopping sync.WaitGroup mutex sync.Mutex changing AtomicBool @@ -297,7 +299,10 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device { // start workers cpus := runtime.NumCPU() - device.state.stopping.Add(DeviceRoutineNumberPerCPU * cpus) + device.state.starting.Wait() + device.state.stopping.Wait() + device.state.stopping.Add(DeviceRoutineNumberPerCPU * cpus + DeviceRoutineNumberAdditional) + device.state.starting.Add(DeviceRoutineNumberPerCPU * cpus + DeviceRoutineNumberAdditional) for i := 0; i < cpus; i += 1 { go device.RoutineEncryption() go device.RoutineDecryption() @@ -307,6 +312,8 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device { go device.RoutineReadFromTUN() go device.RoutineTUNEventReader() + device.state.starting.Wait() + return device } @@ -363,6 +370,9 @@ func (device *Device) Close() { if device.isClosed.Swap(true) { return } + + device.state.starting.Wait() + device.log.Info.Println("Device closing") device.state.changing.Set(true) device.state.mutex.Lock() diff --git a/peer.go b/peer.go index 4bc1ada..3808ad6 100644 --- a/peer.go +++ b/peer.go @@ -231,20 +231,21 @@ func (peer *Peer) Stop() { // prevent simultaneous start/stop operations - peer.routines.mutex.Lock() - defer peer.routines.mutex.Unlock() - if !peer.isRunning.Swap(false) { return } + peer.routines.starting.Wait() + + peer.routines.mutex.Lock() + defer peer.routines.mutex.Unlock() + peer.device.log.Debug.Println(peer, ": Stopping...") peer.timersStop() // stop & wait for ongoing peer routines - peer.routines.starting.Wait() close(peer.routines.stop) peer.routines.stopping.Wait() diff --git a/receive.go b/receive.go index 77062fa..aa96057 100644 --- a/receive.go +++ b/receive.go @@ -124,9 +124,11 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { logDebug := device.log.Debug defer func() { logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped") + device.state.stopping.Done() }() logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - starting") + device.state.starting.Done() // receive datagrams until conn is closed @@ -257,6 +259,7 @@ func (device *Device) RoutineDecryption() { device.state.stopping.Done() }() logDebug.Println("Routine: decryption worker - started") + device.state.starting.Done() for { select { @@ -324,6 +327,7 @@ func (device *Device) RoutineHandshake() { }() logDebug.Println("Routine: handshake worker - started") + device.state.starting.Done() var elem QueueHandshakeElement var ok bool diff --git a/send.go b/send.go index 9a59abd..5605ad1 100644 --- a/send.go +++ b/send.go @@ -247,9 +247,11 @@ func (device *Device) RoutineReadFromTUN() { defer func() { logDebug.Println("Routine: TUN reader - stopped") + device.state.stopping.Done() }() logDebug.Println("Routine: TUN reader - started") + device.state.starting.Done() for { @@ -424,6 +426,7 @@ func (device *Device) RoutineEncryption() { }() logDebug.Println("Routine: encryption worker - started") + device.state.starting.Done() for { diff --git a/tun.go b/tun.go index ec3ab47..ef80625 100644 --- a/tun.go +++ b/tun.go @@ -35,6 +35,8 @@ func (device *Device) RoutineTUNEventReader() { logInfo := device.log.Info logError := device.log.Error + device.state.starting.Done() + for event := range device.tun.device.Events() { if event&TUNEventMTUUpdate != 0 { mtu, err := device.tun.device.MTU() @@ -63,4 +65,6 @@ func (device *Device) RoutineTUNEventReader() { device.Down() } } + + device.state.stopping.Done() } -- cgit v1.2.3-59-g8ed1b