aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/tunnel/interfacewatcher.go
diff options
context:
space:
mode:
Diffstat (limited to 'tunnel/interfacewatcher.go')
-rw-r--r--tunnel/interfacewatcher.go45
1 files changed, 26 insertions, 19 deletions
diff --git a/tunnel/interfacewatcher.go b/tunnel/interfacewatcher.go
index e12e5929..32132e93 100644
--- a/tunnel/interfacewatcher.go
+++ b/tunnel/interfacewatcher.go
@@ -12,8 +12,6 @@ import (
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/tun"
-
"golang.zx2c4.com/wireguard/windows/conf"
"golang.zx2c4.com/wireguard/windows/services"
"golang.zx2c4.com/wireguard/windows/tunnel/firewall"
@@ -31,9 +29,10 @@ type interfaceWatcherEvent struct {
type interfaceWatcher struct {
errors chan interfaceWatcherError
- binder conn.BindSocketToInterface
- conf *conf.Config
- tun *tun.NativeTun
+ binder conn.BindSocketToInterface
+ clamper mtuClamper
+ conf *conf.Config
+ luid winipcfg.LUID
setupMutex sync.Mutex
interfaceChangeCallback winipcfg.ChangeCallback
@@ -100,15 +99,24 @@ func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) {
}
var err error
- log.Printf("Monitoring default %s routes", ipversion)
- *changeCallbacks, err = monitorDefaultRoutes(family, iw.binder, iw.conf.Interface.MTU == 0, hasDefaultRoute(family, iw.conf.Peers), iw.tun)
- if err != nil {
- iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err}
- return
+ if iw.binder != nil && iw.clamper != nil {
+ log.Printf("Monitoring default %s routes", ipversion)
+ *changeCallbacks, err = monitorDefaultRoutes(family, iw.binder, iw.conf.Interface.MTU == 0, hasDefaultRoute(family, iw.conf.Peers), iw.clamper, iw.luid)
+ if err != nil {
+ iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err}
+ return
+ }
+ } else if iw.conf.Interface.MTU == 0 {
+ log.Printf("Monitoring MTU of default %s routes", ipversion)
+ *changeCallbacks, err = monitorMTU(family, iw.luid)
+ if err != nil {
+ iw.errors <- interfaceWatcherError{services.ErrorMonitorMTUChanges, err}
+ return
+ }
}
log.Printf("Setting device %s addresses", ipversion)
- err = configureInterface(family, iw.conf, iw.tun)
+ err = configureInterface(family, iw.conf, iw.luid, iw.clamper)
if err != nil {
iw.errors <- interfaceWatcherError{services.ErrorSetNetConfig, err}
return
@@ -127,11 +135,11 @@ func watchInterface() (*interfaceWatcher, error) {
if notificationType != winipcfg.MibAddInstance {
return
}
- if iw.tun == nil {
+ if iw.luid == 0 {
iw.storedEvents = append(iw.storedEvents, interfaceWatcherEvent{iface.InterfaceLUID, iface.Family})
return
}
- if iface.InterfaceLUID != winipcfg.LUID(iw.tun.LUID()) {
+ if iface.InterfaceLUID != iw.luid {
return
}
iw.setup(iface.Family)
@@ -142,13 +150,13 @@ func watchInterface() (*interfaceWatcher, error) {
return iw, nil
}
-func (iw *interfaceWatcher) Configure(binder conn.BindSocketToInterface, conf *conf.Config, tun *tun.NativeTun) {
+func (iw *interfaceWatcher) Configure(binder conn.BindSocketToInterface, clamper mtuClamper, conf *conf.Config, luid winipcfg.LUID) {
iw.setupMutex.Lock()
defer iw.setupMutex.Unlock()
- iw.binder, iw.conf, iw.tun = binder, conf, tun
+ iw.binder, iw.clamper, iw.conf, iw.luid = binder, clamper, conf, luid
for _, event := range iw.storedEvents {
- if event.luid == winipcfg.LUID(iw.tun.LUID()) {
+ if event.luid == luid {
iw.setup(event.family)
}
}
@@ -160,7 +168,7 @@ func (iw *interfaceWatcher) Destroy() {
changeCallbacks4 := iw.changeCallbacks4
changeCallbacks6 := iw.changeCallbacks6
interfaceChangeCallback := iw.interfaceChangeCallback
- tun := iw.tun
+ luid := iw.luid
iw.setupMutex.Unlock()
if interfaceChangeCallback != nil {
@@ -186,10 +194,9 @@ func (iw *interfaceWatcher) Destroy() {
changeCallbacks6 = changeCallbacks6[1:]
}
firewall.DisableFirewall()
- if tun != nil && iw.tun == tun {
+ if luid != 0 && iw.luid == luid {
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active
// routes, so to be certain, just remove everything before destroying.
- luid := winipcfg.LUID(tun.LUID())
luid.FlushRoutes(windows.AF_INET)
luid.FlushIPAddresses(windows.AF_INET)
luid.FlushDNS(windows.AF_INET)