diff options
Diffstat (limited to 'tunnel/interfacewatcher.go')
-rw-r--r-- | tunnel/interfacewatcher.go | 45 |
1 files changed, 26 insertions, 19 deletions
diff --git a/tunnel/interfacewatcher.go b/tunnel/interfacewatcher.go index e12e5929..32132e93 100644 --- a/tunnel/interfacewatcher.go +++ b/tunnel/interfacewatcher.go @@ -12,8 +12,6 @@ import ( "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/windows/conf" "golang.zx2c4.com/wireguard/windows/services" "golang.zx2c4.com/wireguard/windows/tunnel/firewall" @@ -31,9 +29,10 @@ type interfaceWatcherEvent struct { type interfaceWatcher struct { errors chan interfaceWatcherError - binder conn.BindSocketToInterface - conf *conf.Config - tun *tun.NativeTun + binder conn.BindSocketToInterface + clamper mtuClamper + conf *conf.Config + luid winipcfg.LUID setupMutex sync.Mutex interfaceChangeCallback winipcfg.ChangeCallback @@ -100,15 +99,24 @@ func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) { } var err error - log.Printf("Monitoring default %s routes", ipversion) - *changeCallbacks, err = monitorDefaultRoutes(family, iw.binder, iw.conf.Interface.MTU == 0, hasDefaultRoute(family, iw.conf.Peers), iw.tun) - if err != nil { - iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err} - return + if iw.binder != nil && iw.clamper != nil { + log.Printf("Monitoring default %s routes", ipversion) + *changeCallbacks, err = monitorDefaultRoutes(family, iw.binder, iw.conf.Interface.MTU == 0, hasDefaultRoute(family, iw.conf.Peers), iw.clamper, iw.luid) + if err != nil { + iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err} + return + } + } else if iw.conf.Interface.MTU == 0 { + log.Printf("Monitoring MTU of default %s routes", ipversion) + *changeCallbacks, err = monitorMTU(family, iw.luid) + if err != nil { + iw.errors <- interfaceWatcherError{services.ErrorMonitorMTUChanges, err} + return + } } log.Printf("Setting device %s addresses", ipversion) - err = configureInterface(family, iw.conf, iw.tun) + err = configureInterface(family, iw.conf, iw.luid, iw.clamper) if err != nil { iw.errors <- interfaceWatcherError{services.ErrorSetNetConfig, err} return @@ -127,11 +135,11 @@ func watchInterface() (*interfaceWatcher, error) { if notificationType != winipcfg.MibAddInstance { return } - if iw.tun == nil { + if iw.luid == 0 { iw.storedEvents = append(iw.storedEvents, interfaceWatcherEvent{iface.InterfaceLUID, iface.Family}) return } - if iface.InterfaceLUID != winipcfg.LUID(iw.tun.LUID()) { + if iface.InterfaceLUID != iw.luid { return } iw.setup(iface.Family) @@ -142,13 +150,13 @@ func watchInterface() (*interfaceWatcher, error) { return iw, nil } -func (iw *interfaceWatcher) Configure(binder conn.BindSocketToInterface, conf *conf.Config, tun *tun.NativeTun) { +func (iw *interfaceWatcher) Configure(binder conn.BindSocketToInterface, clamper mtuClamper, conf *conf.Config, luid winipcfg.LUID) { iw.setupMutex.Lock() defer iw.setupMutex.Unlock() - iw.binder, iw.conf, iw.tun = binder, conf, tun + iw.binder, iw.clamper, iw.conf, iw.luid = binder, clamper, conf, luid for _, event := range iw.storedEvents { - if event.luid == winipcfg.LUID(iw.tun.LUID()) { + if event.luid == luid { iw.setup(event.family) } } @@ -160,7 +168,7 @@ func (iw *interfaceWatcher) Destroy() { changeCallbacks4 := iw.changeCallbacks4 changeCallbacks6 := iw.changeCallbacks6 interfaceChangeCallback := iw.interfaceChangeCallback - tun := iw.tun + luid := iw.luid iw.setupMutex.Unlock() if interfaceChangeCallback != nil { @@ -186,10 +194,9 @@ func (iw *interfaceWatcher) Destroy() { changeCallbacks6 = changeCallbacks6[1:] } firewall.DisableFirewall() - if tun != nil && iw.tun == tun { + if luid != 0 && iw.luid == luid { // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active // routes, so to be certain, just remove everything before destroying. - luid := winipcfg.LUID(tun.LUID()) luid.FlushRoutes(windows.AF_INET) luid.FlushIPAddresses(windows.AF_INET) luid.FlushDNS(windows.AF_INET) |