aboutsummaryrefslogtreecommitdiffstats
path: root/device
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-02-10 00:12:23 +0100
committerJason A. Donenfeld <Jason@zx2c4.com>2021-02-10 00:12:23 +0100
commit587a2b2a2028430893f14f9ac49e1efa5e3f8509 (patch)
tree85294dc6224347a855f640f3770f1ff0d6a854f0 /device
parentrwcancel: add an explicit close call (diff)
downloadwireguard-go-587a2b2a2028430893f14f9ac49e1efa5e3f8509.tar.xz
wireguard-go-587a2b2a2028430893f14f9ac49e1efa5e3f8509.zip
device: return error from Up() and Down()
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'device')
-rw-r--r--device/device.go32
-rw-r--r--device/device_test.go13
2 files changed, 27 insertions, 18 deletions
diff --git a/device/device.go b/device/device.go
index 7f96a1e..432549d 100644
--- a/device/device.go
+++ b/device/device.go
@@ -139,37 +139,42 @@ func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) {
}
// changeState attempts to change the device state to match want.
-func (device *Device) changeState(want deviceState) {
+func (device *Device) changeState(want deviceState) (err error) {
device.state.Lock()
defer device.state.Unlock()
old := device.deviceState()
if old == deviceStateClosed {
// once closed, always closed
device.log.Verbosef("Interface closed, ignored requested state %s", want)
- return
+ return nil
}
switch want {
case old:
- return
+ return nil
case deviceStateUp:
atomic.StoreUint32(&device.state.state, uint32(deviceStateUp))
- if ok := device.upLocked(); ok {
+ err = device.upLocked()
+ if err == nil {
break
}
fallthrough // up failed; bring the device all the way back down
case deviceStateDown:
atomic.StoreUint32(&device.state.state, uint32(deviceStateDown))
- device.downLocked()
+ errDown := device.downLocked()
+ if err == nil {
+ err = errDown
+ }
}
device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState())
+ return
}
// upLocked attempts to bring the device up and reports whether it succeeded.
// The caller must hold device.state.mu and is responsible for updating device.state.state.
-func (device *Device) upLocked() bool {
+func (device *Device) upLocked() error {
if err := device.BindUpdate(); err != nil {
device.log.Errorf("Unable to update bind: %v", err)
- return false
+ return err
}
device.peers.RLock()
@@ -180,12 +185,12 @@ func (device *Device) upLocked() bool {
}
}
device.peers.RUnlock()
- return true
+ return nil
}
// downLocked attempts to bring the device down.
// The caller must hold device.state.mu and is responsible for updating device.state.state.
-func (device *Device) downLocked() {
+func (device *Device) downLocked() error {
err := device.BindClose()
if err != nil {
device.log.Errorf("Bind close failed: %v", err)
@@ -196,14 +201,15 @@ func (device *Device) downLocked() {
peer.Stop()
}
device.peers.RUnlock()
+ return err
}
-func (device *Device) Up() {
- device.changeState(deviceStateUp)
+func (device *Device) Up() error {
+ return device.changeState(deviceStateUp)
}
-func (device *Device) Down() {
- device.changeState(deviceStateDown)
+func (device *Device) Down() error {
+ return device.changeState(deviceStateDown)
}
func (device *Device) IsUnderLoad() bool {
diff --git a/device/device_test.go b/device/device_test.go
index ce1ba9b..c17b350 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -157,14 +157,13 @@ func genTestPair(tb testing.TB) (pair testPair) {
level = LogLevelError
}
p.dev = NewDevice(p.tun.TUN(), NewLogger(level, fmt.Sprintf("dev%d: ", i)))
- p.dev.Up()
if err := p.dev.IpcSet(cfg[i]); err != nil {
tb.Errorf("failed to configure device %d: %v", i, err)
p.dev.Close()
continue
}
- if !p.dev.isUp() {
- tb.Errorf("device %d did not come up", i)
+ if err := p.dev.Up(); err != nil {
+ tb.Errorf("failed to bring up device %d: %v", i, err)
p.dev.Close()
continue
}
@@ -212,9 +211,13 @@ func TestUpDown(t *testing.T) {
go func(d *Device) {
defer wg.Done()
for i := 0; i < itrials; i++ {
- d.Up()
+ if err := d.Up(); err != nil {
+ t.Errorf("failed up bring up device: %v", err)
+ }
time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
- d.Down()
+ if err := d.Down(); err != nil {
+ t.Errorf("failed to bring down device: %v", err)
+ }
time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
}
}(pair[i].dev)