aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-08-09 02:07:14 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2021-08-09 13:12:47 +0200
commite60a3b02b8e551c169aebe764f0add6c5facc5fe (patch)
tree6d83a4693d4e9262be2511c45effa4f9ebc34a39
parentdriver: break encapsulation and pass timestamp to ringlogger (diff)
downloadwireguard-windows-e60a3b02b8e551c169aebe764f0add6c5facc5fe.tar.xz
wireguard-windows-e60a3b02b8e551c169aebe764f0add6c5facc5fe.zip
manager: track externally created tunnels
Requested-by: Bruno UT1 <bandry@ut1.org> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r--manager/service.go2
-rw-r--r--manager/tunneltracker.go121
2 files changed, 94 insertions, 29 deletions
diff --git a/manager/service.go b/manager/service.go
index 5c44cfd8..6f0df505 100644
--- a/manager/service.go
+++ b/manager/service.go
@@ -64,7 +64,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
moveConfigsFromLegacyStore()
- err = trackExistingTunnels()
+ err = watchNewTunnelServices()
if err != nil {
serviceError = services.ErrorTrackTunnels
return
diff --git a/manager/tunneltracker.go b/manager/tunneltracker.go
index b32450a3..6b5b1a02 100644
--- a/manager/tunneltracker.go
+++ b/manager/tunneltracker.go
@@ -24,29 +24,6 @@ import (
"golang.zx2c4.com/wireguard/windows/services"
)
-func trackExistingTunnels() error {
- m, err := serviceManager()
- if err != nil {
- return err
- }
- names, err := conf.ListConfigNames()
- if err != nil {
- return err
- }
- for _, name := range names {
- serviceName, err := services.ServiceNameOfTunnel(name)
- if err != nil {
- continue
- }
- service, err := m.OpenService(serviceName)
- if err != nil {
- continue
- }
- go trackTunnelService(name, service)
- }
- return nil
-}
-
var trackedTunnels = make(map[string]TunnelState)
var trackedTunnelsLock = sync.Mutex{}
@@ -196,16 +173,17 @@ func trackService(service *mgr.Service, callback func(status uint32) bool) error
}
func trackTunnelService(tunnelName string, service *mgr.Service) {
- defer func() {
- service.Close()
- log.Printf("[%s] Tunnel service tracker finished", tunnelName)
- }()
-
trackedTunnelsLock.Lock()
if _, found := trackedTunnels[tunnelName]; found {
trackedTunnelsLock.Unlock()
+ service.Close()
return
}
+
+ defer func() {
+ service.Close()
+ log.Printf("[%s] Tunnel service tracker finished", tunnelName)
+ }()
trackedTunnels[tunnelName] = TunnelUnknown
trackedTunnelsLock.Unlock()
defer func() {
@@ -214,6 +192,15 @@ func trackTunnelService(tunnelName string, service *mgr.Service) {
trackedTunnelsLock.Unlock()
}()
+ for i := 0; i < 20; i++ {
+ if i > 0 {
+ time.Sleep(time.Second / 5)
+ }
+ if status, err := service.Query(); err != nil || status.State != svc.Stopped {
+ break
+ }
+ }
+
checkForDisabled := func() (shouldReturn bool) {
config, err := service.Config()
if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || (err != nil && config.StartType == windows.SERVICE_DISABLED) {
@@ -275,3 +262,81 @@ func trackTunnelService(tunnelName string, service *mgr.Service) {
}
disconnectTunnelServicePipe(tunnelName)
}
+
+func trackExistingTunnels() error {
+ m, err := serviceManager()
+ if err != nil {
+ return err
+ }
+ names, err := conf.ListConfigNames()
+ if err != nil {
+ return err
+ }
+ for _, name := range names {
+ trackedTunnelsLock.Lock()
+ if _, found := trackedTunnels[name]; found {
+ trackedTunnelsLock.Unlock()
+ continue
+ }
+ trackedTunnelsLock.Unlock()
+ serviceName, err := services.ServiceNameOfTunnel(name)
+ if err != nil {
+ continue
+ }
+ service, err := m.OpenService(serviceName)
+ if err != nil {
+ continue
+ }
+ go trackTunnelService(name, service)
+ }
+ return nil
+}
+
+var servicesSubscriptionWatcherCallbackPtr = windows.NewCallback(func(notification uint32, context uintptr) uintptr {
+ trackExistingTunnels()
+ return 0
+})
+
+func watchNewTunnelServices() error {
+ m, err := serviceManager()
+ if err != nil {
+ return err
+ }
+ var subscription uintptr
+ err = windows.SubscribeServiceChangeNotifications(m.Handle, windows.SC_EVENT_DATABASE_CHANGE, servicesSubscriptionWatcherCallbackPtr, 0, &subscription)
+ if err == nil {
+ // We probably could do:
+ // defer windows.UnsubscribeServiceChangeNotifications(subscription)
+ // and then terminate after some point, but instead we just let this go forever; it's process-lived.
+ return trackExistingTunnels()
+ }
+ if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) {
+ return err
+ }
+
+ // TODO: Below this line is Windows 7 compatibility code, which hopefully we can delete at some point.
+ go func() {
+ runtime.LockOSThread()
+ notifier := &windows.SERVICE_NOTIFY{
+ Version: windows.SERVICE_NOTIFY_STATUS_CHANGE,
+ NotifyCallback: serviceTrackerCallbackPtr,
+ }
+ for {
+ err := windows.NotifyServiceStatusChange(m.Handle, windows.SERVICE_NOTIFY_CREATED, notifier)
+ if err == nil {
+ windows.SleepEx(windows.INFINITE, true)
+ if notifier.ServiceNames != nil {
+ windows.LocalFree(windows.Handle(unsafe.Pointer(notifier.ServiceNames)))
+ notifier.ServiceNames = nil
+ }
+ trackExistingTunnels()
+ } else if err == windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING {
+ continue
+ } else {
+ time.Sleep(time.Second * 3)
+ trackExistingTunnels()
+ }
+ }
+ }()
+ return trackExistingTunnels()
+}