diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-08-09 02:07:14 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-08-09 13:12:47 +0200 |
commit | e60a3b02b8e551c169aebe764f0add6c5facc5fe (patch) | |
tree | 6d83a4693d4e9262be2511c45effa4f9ebc34a39 /manager | |
parent | driver: break encapsulation and pass timestamp to ringlogger (diff) | |
download | wireguard-windows-e60a3b02b8e551c169aebe764f0add6c5facc5fe.tar.xz wireguard-windows-e60a3b02b8e551c169aebe764f0add6c5facc5fe.zip |
manager: track externally created tunnels
Requested-by: Bruno UT1 <bandry@ut1.org>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'manager')
-rw-r--r-- | manager/service.go | 2 | ||||
-rw-r--r-- | manager/tunneltracker.go | 121 |
2 files changed, 94 insertions, 29 deletions
diff --git a/manager/service.go b/manager/service.go index 5c44cfd8..6f0df505 100644 --- a/manager/service.go +++ b/manager/service.go @@ -64,7 +64,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest moveConfigsFromLegacyStore() - err = trackExistingTunnels() + err = watchNewTunnelServices() if err != nil { serviceError = services.ErrorTrackTunnels return diff --git a/manager/tunneltracker.go b/manager/tunneltracker.go index b32450a3..6b5b1a02 100644 --- a/manager/tunneltracker.go +++ b/manager/tunneltracker.go @@ -24,29 +24,6 @@ import ( "golang.zx2c4.com/wireguard/windows/services" ) -func trackExistingTunnels() error { - m, err := serviceManager() - if err != nil { - return err - } - names, err := conf.ListConfigNames() - if err != nil { - return err - } - for _, name := range names { - serviceName, err := services.ServiceNameOfTunnel(name) - if err != nil { - continue - } - service, err := m.OpenService(serviceName) - if err != nil { - continue - } - go trackTunnelService(name, service) - } - return nil -} - var trackedTunnels = make(map[string]TunnelState) var trackedTunnelsLock = sync.Mutex{} @@ -196,16 +173,17 @@ func trackService(service *mgr.Service, callback func(status uint32) bool) error } func trackTunnelService(tunnelName string, service *mgr.Service) { - defer func() { - service.Close() - log.Printf("[%s] Tunnel service tracker finished", tunnelName) - }() - trackedTunnelsLock.Lock() if _, found := trackedTunnels[tunnelName]; found { trackedTunnelsLock.Unlock() + service.Close() return } + + defer func() { + service.Close() + log.Printf("[%s] Tunnel service tracker finished", tunnelName) + }() trackedTunnels[tunnelName] = TunnelUnknown trackedTunnelsLock.Unlock() defer func() { @@ -214,6 +192,15 @@ func trackTunnelService(tunnelName string, service *mgr.Service) { trackedTunnelsLock.Unlock() }() + for i := 0; i < 20; i++ { + if i > 0 { + time.Sleep(time.Second / 5) + } + if status, err := service.Query(); err != nil || status.State != svc.Stopped { + break + } + } + checkForDisabled := func() (shouldReturn bool) { config, err := service.Config() if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || (err != nil && config.StartType == windows.SERVICE_DISABLED) { @@ -275,3 +262,81 @@ func trackTunnelService(tunnelName string, service *mgr.Service) { } disconnectTunnelServicePipe(tunnelName) } + +func trackExistingTunnels() error { + m, err := serviceManager() + if err != nil { + return err + } + names, err := conf.ListConfigNames() + if err != nil { + return err + } + for _, name := range names { + trackedTunnelsLock.Lock() + if _, found := trackedTunnels[name]; found { + trackedTunnelsLock.Unlock() + continue + } + trackedTunnelsLock.Unlock() + serviceName, err := services.ServiceNameOfTunnel(name) + if err != nil { + continue + } + service, err := m.OpenService(serviceName) + if err != nil { + continue + } + go trackTunnelService(name, service) + } + return nil +} + +var servicesSubscriptionWatcherCallbackPtr = windows.NewCallback(func(notification uint32, context uintptr) uintptr { + trackExistingTunnels() + return 0 +}) + +func watchNewTunnelServices() error { + m, err := serviceManager() + if err != nil { + return err + } + var subscription uintptr + err = windows.SubscribeServiceChangeNotifications(m.Handle, windows.SC_EVENT_DATABASE_CHANGE, servicesSubscriptionWatcherCallbackPtr, 0, &subscription) + if err == nil { + // We probably could do: + // defer windows.UnsubscribeServiceChangeNotifications(subscription) + // and then terminate after some point, but instead we just let this go forever; it's process-lived. + return trackExistingTunnels() + } + if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) { + return err + } + + // TODO: Below this line is Windows 7 compatibility code, which hopefully we can delete at some point. + go func() { + runtime.LockOSThread() + notifier := &windows.SERVICE_NOTIFY{ + Version: windows.SERVICE_NOTIFY_STATUS_CHANGE, + NotifyCallback: serviceTrackerCallbackPtr, + } + for { + err := windows.NotifyServiceStatusChange(m.Handle, windows.SERVICE_NOTIFY_CREATED, notifier) + if err == nil { + windows.SleepEx(windows.INFINITE, true) + if notifier.ServiceNames != nil { + windows.LocalFree(windows.Handle(unsafe.Pointer(notifier.ServiceNames))) + notifier.ServiceNames = nil + } + trackExistingTunnels() + } else if err == windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING { + continue + } else { + time.Sleep(time.Second * 3) + trackExistingTunnels() + } + } + }() + return trackExistingTunnels() +} |