From f3c3bd215731f55cf042e0d1c9eae4aa880f6257 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Thu, 28 Feb 2019 07:19:06 +0100 Subject: service: track tunnel service status --- service/errors.go | 1 + service/install.go | 4 +- service/ipc_client.go | 11 +++-- service/ipc_server.go | 10 ++-- service/mksyscall.go | 2 +- service/service_manager.go | 8 +++ service/tunneltracker.go | 117 ++++++++++++++++++++++++++++++++++++++++++++ service/zsyscall_windows.go | 42 ++++++++++++++-- 8 files changed, 180 insertions(+), 15 deletions(-) create mode 100644 service/tunneltracker.go (limited to 'service') diff --git a/service/errors.go b/service/errors.go index b6566d00..04f0638c 100644 --- a/service/errors.go +++ b/service/errors.go @@ -16,4 +16,5 @@ const ( ERROR_FILE_NOT_FOUND uint32 = 0x00000002 ERROR_SERVER_SID_MISMATCH uint32 = 0x00000274 ERROR_NETWORK_BUSY uint32 = 0x00000036 + ERROR_NO_TRACKING_SERVICE uint32 = 0x00000494 ) diff --git a/service/install.go b/service/install.go index 32131f94..9a57504f 100644 --- a/service/install.go +++ b/service/install.go @@ -160,8 +160,8 @@ func InstallTunnel(configPath string) error { if err != nil { return err } - service.Start() - return service.Close() + go trackTunnelService(name, service) + return service.Start() } func UninstallTunnel(name string) error { diff --git a/service/ipc_client.go b/service/ipc_client.go index c3d08897..f2ae2b22 100644 --- a/service/ipc_client.go +++ b/service/ipc_client.go @@ -37,7 +37,7 @@ const ( var rpcClient *rpc.Client type tunnelChangeCallback struct { - cb func(tunnel string) + cb func(tunnel string, state TunnelState) } var tunnelChangeCallbacks = make(map[*tunnelChangeCallback]bool) @@ -65,8 +65,13 @@ func InitializeIPCClient(reader *os.File, writer *os.File, events *os.File) { if err != nil || len(tunnel) == 0 { continue } + var state TunnelState + err = decoder.Decode(&state) + if err != nil || state == TunnelUnknown { + continue + } for cb := range tunnelChangeCallbacks { - cb.cb(tunnel) + cb.cb(tunnel, state) } case TunnelsChangeNotificationType: for cb := range tunnelsChangeCallbacks { @@ -122,7 +127,7 @@ func IPCClientQuit(stopTunnelsOnQuit bool) (bool, error) { return alreadyQuit, rpcClient.Call("ManagerService.Quit", stopTunnelsOnQuit, &alreadyQuit) } -func IPCClientRegisterTunnelChange(cb func(tunnel string)) *tunnelChangeCallback { +func IPCClientRegisterTunnelChange(cb func(tunnel string, state TunnelState)) *tunnelChangeCallback { s := &tunnelChangeCallback{cb} tunnelChangeCallbacks[s] = true return s diff --git a/service/ipc_server.go b/service/ipc_server.go index 4388bb00..3e4c7fd3 100644 --- a/service/ipc_server.go +++ b/service/ipc_server.go @@ -150,7 +150,7 @@ func IPCServerListen(reader *os.File, writer *os.File, events *os.File) error { return nil } -func notifyAll(notificationType NotificationType, iface interface{}) { +func notifyAll(notificationType NotificationType, ifaces ...interface{}) { if len(managerServices) == 0 { return } @@ -161,7 +161,7 @@ func notifyAll(notificationType NotificationType, iface interface{}) { if err != nil { return } - if iface != nil { + for _, iface := range ifaces { err = encoder.Encode(iface) if err != nil { return @@ -178,10 +178,10 @@ func notifyAll(notificationType NotificationType, iface interface{}) { managerServicesLock.RUnlock() } -func IPCServerNotifyTunnelChange(name string) { - notifyAll(TunnelChangeNotificationType, name) +func IPCServerNotifyTunnelChange(name string, state TunnelState) { + notifyAll(TunnelChangeNotificationType, name, state) } func IPCServerNotifyTunnelsChange() { - notifyAll(TunnelsChangeNotificationType, nil) + notifyAll(TunnelsChangeNotificationType) } diff --git a/service/mksyscall.go b/service/mksyscall.go index e96640aa..ccc4053e 100644 --- a/service/mksyscall.go +++ b/service/mksyscall.go @@ -5,4 +5,4 @@ package service -//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go service_manager.go +//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go service_manager.go tunneltracker.go diff --git a/service/service_manager.go b/service/service_manager.go index b903929b..c049465d 100644 --- a/service/service_manager.go +++ b/service/service_manager.go @@ -129,6 +129,14 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest return } + err = trackExistingTunnels() + if err != nil { + elog.Error(1, "Unable to track existing tunnels: "+err.Error()) + changes <- svc.Status{State: svc.StopPending} + exitCode = ERROR_NO_TRACKING_SERVICE + return + } + conf.RegisterStoreChangeCallback(func() { conf.MigrateUnencryptedConfigs() }) // Ignore return value for now, but could be useful later. conf.RegisterStoreChangeCallback(IPCServerNotifyTunnelsChange) diff --git a/service/tunneltracker.go b/service/tunneltracker.go new file mode 100644 index 00000000..2545930d --- /dev/null +++ b/service/tunneltracker.go @@ -0,0 +1,117 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package service + +import ( + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc/mgr" + "golang.zx2c4.com/wireguard/windows/conf" + "runtime" + "unsafe" +) + +//sys notifyServiceStatusChange(service windows.Handle, notifyMask uint32, notifyBuffer uintptr) (err error) [failretval!=0] = advapi32.NotifyServiceStatusChangeW +//sys sleepEx(milliseconds uint32, alertable bool) (ret uint32, err error) = kernel32.SleepEx + +const ( + serviceNotify_CREATED uint32 = 0x00000080 + serviceNotify_CONTINUE_PENDING = 0x00000010 + serviceNotify_DELETE_PENDING = 0x00000200 + serviceNotify_DELETED = 0x00000100 + serviceNotify_PAUSE_PENDING = 0x00000020 + serviceNotify_PAUSED = 0x00000040 + serviceNotify_RUNNING = 0x00000008 + serviceNotify_START_PENDING = 0x00000002 + serviceNotify_STOP_PENDING = 0x00000004 + serviceNotify_STOPPED = 0x00000001 +) +const serviceNotify_STATUS_CHANGE uint32 = 2 +const errorServiceMARKED_FOR_DELETE uint32 = 1072 + +type serviceNotify struct { + version uint32 + notifyCallback uintptr + context uintptr + notificationStatus uint32 + serviceType uint32 + currentState uint32 + controlsAccepted uint32 + win32ExitCode uint32 + serviceSpecificExitCode uint32 + checkPoint uint32 + waitHint uint32 + processId uint32 + serviceFlags uint32 + notificationTriggered uint32 + serviceNames *uint16 +} + +func serviceTrackerCallback(notifier *serviceNotify) uintptr { + return 0 +} + +var serviceTrackerCallbackPtr uintptr + +func init() { + serviceTrackerCallbackPtr = windows.NewCallback(serviceTrackerCallback) +} + +func trackExistingTunnels() error { + m, err := serviceManager() + if err != nil { + return err + } + names, err := conf.ListConfigNames() + if err != nil { + return err + } + for _, name := range names { + serviceName := "WireGuard Tunnel: " + name + service, err := m.OpenService(serviceName) + if err != nil { + continue + } + go trackTunnelService(name, service) + } + return nil +} + +func trackTunnelService(tunnelName string, svc *mgr.Service) { + runtime.LockOSThread() + const serviceNotifications = serviceNotify_RUNNING | serviceNotify_START_PENDING | serviceNotify_STOP_PENDING | serviceNotify_STOPPED | serviceNotify_DELETE_PENDING + notifier := &serviceNotify{ + version: serviceNotify_STATUS_CHANGE, + notifyCallback: serviceTrackerCallbackPtr, + } + defer svc.Close() + for { + notifier.context = 0 + err := notifyServiceStatusChange(svc.Handle, serviceNotifications, uintptr(unsafe.Pointer(notifier))) + if err != nil { + return + } + sleepEx(windows.INFINITE, true) + if notifier.notificationStatus != 0 { + return + } + state := TunnelUnknown + if notifier.notificationTriggered&serviceNotify_DELETE_PENDING != 0 { + state = TunnelDeleting + } else if notifier.notificationTriggered&serviceNotify_STOPPED != 0 { + state = TunnelStopped + } else if notifier.notificationTriggered&serviceNotify_STOP_PENDING != 0 { + state = TunnelStopping + } else if notifier.notificationTriggered&serviceNotify_RUNNING != 0 { + state = TunnelStarted + } else if notifier.notificationTriggered&serviceNotify_START_PENDING != 0 { + state = TunnelStarting + } + IPCServerNotifyTunnelChange(tunnelName, state) + if state == TunnelDeleting { + return + } + } +} diff --git a/service/zsyscall_windows.go b/service/zsyscall_windows.go index 649e3581..faf6d780 100644 --- a/service/zsyscall_windows.go +++ b/service/zsyscall_windows.go @@ -39,11 +39,14 @@ func errnoErr(e syscall.Errno) error { var ( modwtsapi32 = windows.NewLazySystemDLL("wtsapi32.dll") modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") - procWTSQueryUserToken = modwtsapi32.NewProc("WTSQueryUserToken") - procWTSEnumerateSessionsW = modwtsapi32.NewProc("WTSEnumerateSessionsW") - procWTSFreeMemory = modwtsapi32.NewProc("WTSFreeMemory") - procCreateWellKnownSid = modadvapi32.NewProc("CreateWellKnownSid") + procWTSQueryUserToken = modwtsapi32.NewProc("WTSQueryUserToken") + procWTSEnumerateSessionsW = modwtsapi32.NewProc("WTSEnumerateSessionsW") + procWTSFreeMemory = modwtsapi32.NewProc("WTSFreeMemory") + procCreateWellKnownSid = modadvapi32.NewProc("CreateWellKnownSid") + procNotifyServiceStatusChangeW = modadvapi32.NewProc("NotifyServiceStatusChangeW") + procSleepEx = modkernel32.NewProc("SleepEx") ) func wtfQueryUserToken(session uint32, token *windows.Token) (err error) { @@ -86,3 +89,34 @@ func createWellKnownSid(sidType wellKnownSidType, domainSid *windows.SID, sid *w } return } + +func notifyServiceStatusChange(service windows.Handle, notifyMask uint32, notifyBuffer uintptr) (err error) { + r1, _, e1 := syscall.Syscall(procNotifyServiceStatusChangeW.Addr(), 3, uintptr(service), uintptr(notifyMask), uintptr(notifyBuffer)) + if r1 != 0 { + if e1 != 0 { + err = errnoErr(e1) + } else { + err = syscall.EINVAL + } + } + return +} + +func sleepEx(milliseconds uint32, alertable bool) (ret uint32, err error) { + var _p0 uint32 + if alertable { + _p0 = 1 + } else { + _p0 = 0 + } + r0, _, e1 := syscall.Syscall(procSleepEx.Addr(), 2, uintptr(milliseconds), uintptr(_p0), 0) + ret = uint32(r0) + if ret == 0 { + if e1 != 0 { + err = errnoErr(e1) + } else { + err = syscall.EINVAL + } + } + return +} -- cgit v1.2.3-59-g8ed1b