aboutsummaryrefslogtreecommitdiffstats
path: root/device/device.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/device.go')
-rw-r--r--device/device.go135
1 files changed, 132 insertions, 3 deletions
diff --git a/device/device.go b/device/device.go
index 569c5a8..3d18ddd 100644
--- a/device/device.go
+++ b/device/device.go
@@ -11,7 +11,11 @@ import (
"sync/atomic"
"time"
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+ "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ratelimiter"
+ "golang.zx2c4.com/wireguard/rwcancel"
"golang.zx2c4.com/wireguard/tun"
)
@@ -39,9 +43,10 @@ type Device struct {
starting sync.WaitGroup
stopping sync.WaitGroup
sync.RWMutex
- bind Bind // bind interface
- port uint16 // listening port
- fwmark uint32 // mark value (0 = disabled)
+ bind conn.Bind // bind interface
+ netlinkCancel *rwcancel.RWCancel
+ port uint16 // listening port
+ fwmark uint32 // mark value (0 = disabled)
}
staticIdentity struct {
@@ -425,3 +430,127 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
}
device.peers.RUnlock()
}
+
+func unsafeCloseBind(device *Device) error {
+ var err error
+ netc := &device.net
+ if netc.netlinkCancel != nil {
+ netc.netlinkCancel.Cancel()
+ }
+ if netc.bind != nil {
+ err = netc.bind.Close()
+ netc.bind = nil
+ }
+ netc.stopping.Wait()
+ return err
+}
+
+func (device *Device) BindSetMark(mark uint32) error {
+
+ device.net.Lock()
+ defer device.net.Unlock()
+
+ // check if modified
+
+ if device.net.fwmark == mark {
+ return nil
+ }
+
+ // update fwmark on existing bind
+
+ device.net.fwmark = mark
+ if device.isUp.Get() && device.net.bind != nil {
+ if err := device.net.bind.SetMark(mark); err != nil {
+ return err
+ }
+ }
+
+ // clear cached source addresses
+
+ device.peers.RLock()
+ for _, peer := range device.peers.keyMap {
+ peer.Lock()
+ defer peer.Unlock()
+ if peer.endpoint != nil {
+ peer.endpoint.ClearSrc()
+ }
+ }
+ device.peers.RUnlock()
+
+ return nil
+}
+
+func (device *Device) BindUpdate() error {
+
+ device.net.Lock()
+ defer device.net.Unlock()
+
+ // close existing sockets
+
+ if err := unsafeCloseBind(device); err != nil {
+ return err
+ }
+
+ // open new sockets
+
+ if device.isUp.Get() {
+
+ // bind to new port
+
+ var err error
+ netc := &device.net
+ netc.bind, netc.port, err = conn.CreateBind(netc.port, device)
+ if err != nil {
+ netc.bind = nil
+ netc.port = 0
+ return err
+ }
+ netc.netlinkCancel, err = device.startRouteListener(netc.bind)
+ if err != nil {
+ netc.bind.Close()
+ netc.bind = nil
+ netc.port = 0
+ return err
+ }
+
+ // set fwmark
+
+ if netc.fwmark != 0 {
+ err = netc.bind.SetMark(netc.fwmark)
+ if err != nil {
+ return err
+ }
+ }
+
+ // clear cached source addresses
+
+ device.peers.RLock()
+ for _, peer := range device.peers.keyMap {
+ peer.Lock()
+ defer peer.Unlock()
+ if peer.endpoint != nil {
+ peer.endpoint.ClearSrc()
+ }
+ }
+ device.peers.RUnlock()
+
+ // start receiving routines
+
+ device.net.starting.Add(conn.ConnRoutineNumber)
+ device.net.stopping.Add(conn.ConnRoutineNumber)
+ go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
+ go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
+ device.net.starting.Wait()
+
+ device.log.Debug.Println("UDP bind has been updated")
+ }
+
+ return nil
+}
+
+func (device *Device) BindClose() error {
+ device.net.Lock()
+ err := unsafeCloseBind(device)
+ device.net.Unlock()
+ return err
+}