/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package service
import (
"fmt"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/mgr"
"golang.zx2c4.com/wireguard/windows/conf"
"runtime"
"sync"
"syscall"
"unsafe"
)
//sys notifyServiceStatusChange(service windows.Handle, notifyMask uint32, notifyBuffer uintptr) (status uint32) = advapi32.NotifyServiceStatusChangeW
//sys sleepEx(milliseconds uint32, alertable bool) (ret uint32, err error) = kernel32.SleepEx
const (
serviceNotify_CREATED uint32 = 0x00000080
serviceNotify_CONTINUE_PENDING = 0x00000010
serviceNotify_DELETE_PENDING = 0x00000200
serviceNotify_DELETED = 0x00000100
serviceNotify_PAUSE_PENDING = 0x00000020
serviceNotify_PAUSED = 0x00000040
serviceNotify_RUNNING = 0x00000008
serviceNotify_START_PENDING = 0x00000002
serviceNotify_STOP_PENDING = 0x00000004
serviceNotify_STOPPED = 0x00000001
)
const serviceNotify_STATUS_CHANGE uint32 = 2
const errorServiceMARKED_FOR_DELETE uint32 = 1072
const errorServiceNOTIFY_CLIENT_LAGGING uint32 = 1294
type serviceStatus struct {
serviceType uint32
currentState uint32
controlsAccepted uint32
win32ExitCode uint32
serviceSpecificExitCode uint32
checkPoint uint32
waitHint uint32
processId uint32
serviceFlags uint32
}
type serviceNotify struct {
version uint32
notifyCallback uintptr
context uintptr
notificationStatus uint32
serviceStatus serviceStatus
notificationTriggered uint32
serviceNames *uint16
}
func trackExistingTunnels() error {
m, err := serviceManager()
if err != nil {
return err
}
names, err := conf.ListConfigNames()
if err != nil {
return err
}
for _, name := range names {
serviceName, err := ServiceNameOfTunnel(name)
if err != nil {
continue
}
service, err := m.OpenService(serviceName)
if err != nil {
continue
}
go trackTunnelService(name, service)
}
return nil
}
var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *serviceNotify) uintptr {
return 0
})
var trackedTunnels = make(map[string]TunnelState)
var trackedTunnelsLock = sync.Mutex{}
func svcStateToTunState(s svc.State) TunnelState {
switch s {
case svc.StartPending:
return TunnelStarting
case svc.Running:
return TunnelStarted
case svc.StopPending:
return TunnelStopping
case svc.Stopped:
return TunnelStopped
default:
return TunnelUnknown
}
}
func trackedTunnelsGlobalState() (state TunnelState) {
state = TunnelStopped
trackedTunnelsLock.Lock()
defer trackedTunnelsLock.Unlock()
for _, s := range trackedTunnels {
if s == TunnelStarting {
return TunnelStarting
} else if s == TunnelStopping {
return TunnelStopping
} else if s == TunnelStarted || s == TunnelUnknown {
state = TunnelStarted
}
}
return
}
func trackTunnelService(tunnelName string, service *mgr.Service) {
defer service.Close()
trackedTunnelsLock.Lock()
if _, found := trackedTunnels[tunnelName]; found {
trackedTunnelsLock.Unlock()
return
}
trackedTunnels[tunnelName] = TunnelUnknown
trackedTunnelsLock.Unlock()
defer func() {
trackedTunnelsLock.Lock()
delete(trackedTunnels, tunnelName)
trackedTunnelsLock.Unlock()
}()
const serviceNotifications = serviceNotify_RUNNING | serviceNotify_START_PENDING | serviceNotify_STOP_PENDING | serviceNotify_STOPPED | serviceNotify_DELETE_PENDING
notifier := &serviceNotify{
version: serviceNotify_STATUS_CHANGE,
notifyCallback: serviceTrackerCallbackPtr,
}
runtime.LockOSThread()
lastState := TunnelUnknown
for {
ret := notifyServiceStatusChange(service.Handle, serviceNotifications, uintptr(unsafe.Pointer(notifier)))
switch ret {
case 0:
sleepEx(windows.INFINITE, true)
case errorServiceMARKED_FOR_DELETE:
trackedTunnelsLock.Lock()
trackedTunnels[tunnelName] = TunnelStopped
trackedTunnelsLock.Unlock()
IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil)
return
case errorServiceNOTIFY_CLIENT_LAGGING:
continue
default:
trackedTunnelsLock.Lock()
trackedTunnels[tunnelName] = TunnelStopped
trackedTunnelsLock.Unlock()
IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring service, so stopping: %v", syscall.Errno(ret)))
service.Control(svc.Stop)
return
}
state := svcStateToTunState(svc.State(notifier.serviceStatus.currentState))
var tunnelError error
if state == TunnelStopped {
if notifier.serviceStatus.win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) {
maybeErr := Error(notifier.serviceStatus.serviceSpecificExitCode)
if maybeErr != ErrorSuccess {
tunnelError = maybeErr
}
} else {
switch notifier.serviceStatus.win32ExitCode {
case uint32(windows.NO_ERROR), serviceNEVER_STARTED:
default:
tunnelError = syscall.Errno(notifier.serviceStatus.win32ExitCode)
}
}
}
if state != lastState {
trackedTunnelsLock.Lock()
trackedTunnels[tunnelName] = state
trackedTunnelsLock.Unlock()
IPCServerNotifyTunnelChange(tunnelName, state, tunnelError)
lastState = state
}
}
}