aboutsummaryrefslogblamecommitdiffstatshomepage
path: root/manager/tunneltracker.go
blob: 32605b368b6609e441a882a15f31d1842c31593e (plain) (tree)
1
2
3
4
5
6
7
8
9




                                                         
               

        
             
             
                 
              
                 
              



                                          
 
                                                 
                                                     

 









                                            



                                                             








                                                          
                                                                                                    


                
                                                 
















                                                  















                                                                    
                                                                  

                               
                                                                              
           

                                 

                                                          

                      

                                                  





                                                  


                                                                                                                                                                                                                         
                                                          
         
 

                                                        
                                                                                                                   













                                                                                                      
                              
                                      
                                  
             
                                                                                                        

                            
                             
                                                                                                                                




                                                              
                                                             


                                                                  
                                                                                   
                              
                                                                 

                                


                                                                  
                                                                                                                                                         
                                                 

                              
 
                                                                                           
                                     
                                           
                                                                                                                 

                                                                                                          

                                                              
                                
                                                                             
                                                                                                           
                                        
                                                                                                         
                                 
                         
                 
                                       


                                                          


                                                                                   

         
/* SPDX-License-Identifier: MIT
 *
 * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
 */

package manager

import (
	"fmt"
	"log"
	"runtime"
	"sync"
	"syscall"
	"time"

	"golang.org/x/sys/windows"
	"golang.org/x/sys/windows/svc"
	"golang.org/x/sys/windows/svc/mgr"

	"golang.zx2c4.com/wireguard/windows/conf"
	"golang.zx2c4.com/wireguard/windows/services"
)

func trackExistingTunnels() error {
	m, err := serviceManager()
	if err != nil {
		return err
	}
	names, err := conf.ListConfigNames()
	if err != nil {
		return err
	}
	for _, name := range names {
		serviceName, err := ServiceNameOfTunnel(name)
		if err != nil {
			continue
		}
		service, err := m.OpenService(serviceName)
		if err != nil {
			continue
		}
		go trackTunnelService(name, service)
	}
	return nil
}

var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *windows.SERVICE_NOTIFY) uintptr {
	return 0
})

var trackedTunnels = make(map[string]TunnelState)
var trackedTunnelsLock = sync.Mutex{}

func svcStateToTunState(s svc.State) TunnelState {
	switch s {
	case svc.StartPending:
		return TunnelStarting
	case svc.Running:
		return TunnelStarted
	case svc.StopPending:
		return TunnelStopping
	case svc.Stopped:
		return TunnelStopped
	default:
		return TunnelUnknown
	}
}

func trackedTunnelsGlobalState() (state TunnelState) {
	state = TunnelStopped
	trackedTunnelsLock.Lock()
	defer trackedTunnelsLock.Unlock()
	for _, s := range trackedTunnels {
		if s == TunnelStarting {
			return TunnelStarting
		} else if s == TunnelStopping {
			return TunnelStopping
		} else if s == TunnelStarted || s == TunnelUnknown {
			state = TunnelStarted
		}
	}
	return
}

func trackTunnelService(tunnelName string, service *mgr.Service) {
	defer func() {
		service.Close()
		log.Printf("[%s] Tunnel service tracker finished", tunnelName)
	}()

	trackedTunnelsLock.Lock()
	if _, found := trackedTunnels[tunnelName]; found {
		trackedTunnelsLock.Unlock()
		return
	}
	trackedTunnels[tunnelName] = TunnelUnknown
	trackedTunnelsLock.Unlock()
	defer func() {
		trackedTunnelsLock.Lock()
		delete(trackedTunnels, tunnelName)
		trackedTunnelsLock.Unlock()
	}()

	const serviceNotifications = windows.SERVICE_NOTIFY_RUNNING | windows.SERVICE_NOTIFY_START_PENDING | windows.SERVICE_NOTIFY_STOP_PENDING | windows.SERVICE_NOTIFY_STOPPED | windows.SERVICE_NOTIFY_DELETE_PENDING
	notifier := &windows.SERVICE_NOTIFY{
		Version:        windows.SERVICE_NOTIFY_STATUS_CHANGE,
		NotifyCallback: serviceTrackerCallbackPtr,
	}

	checkForDisabled := func() (shouldReturn bool) {
		config, err := service.Config()
		if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || config.StartType == windows.SERVICE_DISABLED {
			log.Printf("[%s] Found disabled service via timeout, so deleting", tunnelName)
			service.Delete()
			trackedTunnelsLock.Lock()
			trackedTunnels[tunnelName] = TunnelStopped
			trackedTunnelsLock.Unlock()
			IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil)
			return true
		}
		return false
	}
	if checkForDisabled() {
		return
	}

	runtime.LockOSThread()
	defer runtime.UnlockOSThread()
	lastState := TunnelUnknown
	for {
		err := windows.NotifyServiceStatusChange(service.Handle, serviceNotifications, notifier)
		switch err {
		case nil:
			for {
				if windows.SleepEx(uint32(time.Second*3/time.Millisecond), true) == windows.WAIT_IO_COMPLETION {
					break
				} else if checkForDisabled() {
					return
				}
			}
		case windows.ERROR_SERVICE_MARKED_FOR_DELETE:
			trackedTunnelsLock.Lock()
			trackedTunnels[tunnelName] = TunnelStopped
			trackedTunnelsLock.Unlock()
			IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil)
			return
		case windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING:
			continue
		default:
			trackedTunnelsLock.Lock()
			trackedTunnels[tunnelName] = TunnelStopped
			trackedTunnelsLock.Unlock()
			IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring service, so stopping: %v", err))
			service.Control(svc.Stop)
			return
		}

		state := svcStateToTunState(svc.State(notifier.ServiceStatus.CurrentState))
		var tunnelError error
		if state == TunnelStopped {
			if notifier.ServiceStatus.Win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) {
				maybeErr := services.Error(notifier.ServiceStatus.ServiceSpecificExitCode)
				if maybeErr != services.ErrorSuccess {
					tunnelError = maybeErr
				}
			} else {
				switch notifier.ServiceStatus.Win32ExitCode {
				case uint32(windows.NO_ERROR), uint32(windows.ERROR_SERVICE_NEVER_STARTED):
				default:
					tunnelError = syscall.Errno(notifier.ServiceStatus.Win32ExitCode)
				}
			}
		}
		if state != lastState {
			trackedTunnelsLock.Lock()
			trackedTunnels[tunnelName] = state
			trackedTunnelsLock.Unlock()
			IPCServerNotifyTunnelChange(tunnelName, state, tunnelError)
			lastState = state
		}
	}
}