aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/service
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2019-02-28 07:19:06 +0100
committerJason A. Donenfeld <Jason@zx2c4.com>2019-02-28 08:05:02 +0100
commitf3c3bd215731f55cf042e0d1c9eae4aa880f6257 (patch)
tree243e106096b2df0f27074234d6acd56d518fb407 /service
parentmanager: wire up config migrator (diff)
downloadwireguard-windows-f3c3bd215731f55cf042e0d1c9eae4aa880f6257.tar.xz
wireguard-windows-f3c3bd215731f55cf042e0d1c9eae4aa880f6257.zip
service: track tunnel service status
Diffstat (limited to 'service')
-rw-r--r--service/errors.go1
-rw-r--r--service/install.go4
-rw-r--r--service/ipc_client.go11
-rw-r--r--service/ipc_server.go10
-rw-r--r--service/mksyscall.go2
-rw-r--r--service/service_manager.go8
-rw-r--r--service/tunneltracker.go117
-rw-r--r--service/zsyscall_windows.go42
8 files changed, 180 insertions, 15 deletions
diff --git a/service/errors.go b/service/errors.go
index b6566d00..04f0638c 100644
--- a/service/errors.go
+++ b/service/errors.go
@@ -16,4 +16,5 @@ const (
ERROR_FILE_NOT_FOUND uint32 = 0x00000002
ERROR_SERVER_SID_MISMATCH uint32 = 0x00000274
ERROR_NETWORK_BUSY uint32 = 0x00000036
+ ERROR_NO_TRACKING_SERVICE uint32 = 0x00000494
)
diff --git a/service/install.go b/service/install.go
index 32131f94..9a57504f 100644
--- a/service/install.go
+++ b/service/install.go
@@ -160,8 +160,8 @@ func InstallTunnel(configPath string) error {
if err != nil {
return err
}
- service.Start()
- return service.Close()
+ go trackTunnelService(name, service)
+ return service.Start()
}
func UninstallTunnel(name string) error {
diff --git a/service/ipc_client.go b/service/ipc_client.go
index c3d08897..f2ae2b22 100644
--- a/service/ipc_client.go
+++ b/service/ipc_client.go
@@ -37,7 +37,7 @@ const (
var rpcClient *rpc.Client
type tunnelChangeCallback struct {
- cb func(tunnel string)
+ cb func(tunnel string, state TunnelState)
}
var tunnelChangeCallbacks = make(map[*tunnelChangeCallback]bool)
@@ -65,8 +65,13 @@ func InitializeIPCClient(reader *os.File, writer *os.File, events *os.File) {
if err != nil || len(tunnel) == 0 {
continue
}
+ var state TunnelState
+ err = decoder.Decode(&state)
+ if err != nil || state == TunnelUnknown {
+ continue
+ }
for cb := range tunnelChangeCallbacks {
- cb.cb(tunnel)
+ cb.cb(tunnel, state)
}
case TunnelsChangeNotificationType:
for cb := range tunnelsChangeCallbacks {
@@ -122,7 +127,7 @@ func IPCClientQuit(stopTunnelsOnQuit bool) (bool, error) {
return alreadyQuit, rpcClient.Call("ManagerService.Quit", stopTunnelsOnQuit, &alreadyQuit)
}
-func IPCClientRegisterTunnelChange(cb func(tunnel string)) *tunnelChangeCallback {
+func IPCClientRegisterTunnelChange(cb func(tunnel string, state TunnelState)) *tunnelChangeCallback {
s := &tunnelChangeCallback{cb}
tunnelChangeCallbacks[s] = true
return s
diff --git a/service/ipc_server.go b/service/ipc_server.go
index 4388bb00..3e4c7fd3 100644
--- a/service/ipc_server.go
+++ b/service/ipc_server.go
@@ -150,7 +150,7 @@ func IPCServerListen(reader *os.File, writer *os.File, events *os.File) error {
return nil
}
-func notifyAll(notificationType NotificationType, iface interface{}) {
+func notifyAll(notificationType NotificationType, ifaces ...interface{}) {
if len(managerServices) == 0 {
return
}
@@ -161,7 +161,7 @@ func notifyAll(notificationType NotificationType, iface interface{}) {
if err != nil {
return
}
- if iface != nil {
+ for _, iface := range ifaces {
err = encoder.Encode(iface)
if err != nil {
return
@@ -178,10 +178,10 @@ func notifyAll(notificationType NotificationType, iface interface{}) {
managerServicesLock.RUnlock()
}
-func IPCServerNotifyTunnelChange(name string) {
- notifyAll(TunnelChangeNotificationType, name)
+func IPCServerNotifyTunnelChange(name string, state TunnelState) {
+ notifyAll(TunnelChangeNotificationType, name, state)
}
func IPCServerNotifyTunnelsChange() {
- notifyAll(TunnelsChangeNotificationType, nil)
+ notifyAll(TunnelsChangeNotificationType)
}
diff --git a/service/mksyscall.go b/service/mksyscall.go
index e96640aa..ccc4053e 100644
--- a/service/mksyscall.go
+++ b/service/mksyscall.go
@@ -5,4 +5,4 @@
package service
-//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go service_manager.go
+//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go service_manager.go tunneltracker.go
diff --git a/service/service_manager.go b/service/service_manager.go
index b903929b..c049465d 100644
--- a/service/service_manager.go
+++ b/service/service_manager.go
@@ -129,6 +129,14 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
return
}
+ err = trackExistingTunnels()
+ if err != nil {
+ elog.Error(1, "Unable to track existing tunnels: "+err.Error())
+ changes <- svc.Status{State: svc.StopPending}
+ exitCode = ERROR_NO_TRACKING_SERVICE
+ return
+ }
+
conf.RegisterStoreChangeCallback(func() { conf.MigrateUnencryptedConfigs() }) // Ignore return value for now, but could be useful later.
conf.RegisterStoreChangeCallback(IPCServerNotifyTunnelsChange)
diff --git a/service/tunneltracker.go b/service/tunneltracker.go
new file mode 100644
index 00000000..2545930d
--- /dev/null
+++ b/service/tunneltracker.go
@@ -0,0 +1,117 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package service
+
+import (
+ "golang.org/x/sys/windows"
+ "golang.org/x/sys/windows/svc/mgr"
+ "golang.zx2c4.com/wireguard/windows/conf"
+ "runtime"
+ "unsafe"
+)
+
+//sys notifyServiceStatusChange(service windows.Handle, notifyMask uint32, notifyBuffer uintptr) (err error) [failretval!=0] = advapi32.NotifyServiceStatusChangeW
+//sys sleepEx(milliseconds uint32, alertable bool) (ret uint32, err error) = kernel32.SleepEx
+
+const (
+ serviceNotify_CREATED uint32 = 0x00000080
+ serviceNotify_CONTINUE_PENDING = 0x00000010
+ serviceNotify_DELETE_PENDING = 0x00000200
+ serviceNotify_DELETED = 0x00000100
+ serviceNotify_PAUSE_PENDING = 0x00000020
+ serviceNotify_PAUSED = 0x00000040
+ serviceNotify_RUNNING = 0x00000008
+ serviceNotify_START_PENDING = 0x00000002
+ serviceNotify_STOP_PENDING = 0x00000004
+ serviceNotify_STOPPED = 0x00000001
+)
+const serviceNotify_STATUS_CHANGE uint32 = 2
+const errorServiceMARKED_FOR_DELETE uint32 = 1072
+
+type serviceNotify struct {
+ version uint32
+ notifyCallback uintptr
+ context uintptr
+ notificationStatus uint32
+ serviceType uint32
+ currentState uint32
+ controlsAccepted uint32
+ win32ExitCode uint32
+ serviceSpecificExitCode uint32
+ checkPoint uint32
+ waitHint uint32
+ processId uint32
+ serviceFlags uint32
+ notificationTriggered uint32
+ serviceNames *uint16
+}
+
+func serviceTrackerCallback(notifier *serviceNotify) uintptr {
+ return 0
+}
+
+var serviceTrackerCallbackPtr uintptr
+
+func init() {
+ serviceTrackerCallbackPtr = windows.NewCallback(serviceTrackerCallback)
+}
+
+func trackExistingTunnels() error {
+ m, err := serviceManager()
+ if err != nil {
+ return err
+ }
+ names, err := conf.ListConfigNames()
+ if err != nil {
+ return err
+ }
+ for _, name := range names {
+ serviceName := "WireGuard Tunnel: " + name
+ service, err := m.OpenService(serviceName)
+ if err != nil {
+ continue
+ }
+ go trackTunnelService(name, service)
+ }
+ return nil
+}
+
+func trackTunnelService(tunnelName string, svc *mgr.Service) {
+ runtime.LockOSThread()
+ const serviceNotifications = serviceNotify_RUNNING | serviceNotify_START_PENDING | serviceNotify_STOP_PENDING | serviceNotify_STOPPED | serviceNotify_DELETE_PENDING
+ notifier := &serviceNotify{
+ version: serviceNotify_STATUS_CHANGE,
+ notifyCallback: serviceTrackerCallbackPtr,
+ }
+ defer svc.Close()
+ for {
+ notifier.context = 0
+ err := notifyServiceStatusChange(svc.Handle, serviceNotifications, uintptr(unsafe.Pointer(notifier)))
+ if err != nil {
+ return
+ }
+ sleepEx(windows.INFINITE, true)
+ if notifier.notificationStatus != 0 {
+ return
+ }
+ state := TunnelUnknown
+ if notifier.notificationTriggered&serviceNotify_DELETE_PENDING != 0 {
+ state = TunnelDeleting
+ } else if notifier.notificationTriggered&serviceNotify_STOPPED != 0 {
+ state = TunnelStopped
+ } else if notifier.notificationTriggered&serviceNotify_STOP_PENDING != 0 {
+ state = TunnelStopping
+ } else if notifier.notificationTriggered&serviceNotify_RUNNING != 0 {
+ state = TunnelStarted
+ } else if notifier.notificationTriggered&serviceNotify_START_PENDING != 0 {
+ state = TunnelStarting
+ }
+ IPCServerNotifyTunnelChange(tunnelName, state)
+ if state == TunnelDeleting {
+ return
+ }
+ }
+}
diff --git a/service/zsyscall_windows.go b/service/zsyscall_windows.go
index 649e3581..faf6d780 100644
--- a/service/zsyscall_windows.go
+++ b/service/zsyscall_windows.go
@@ -39,11 +39,14 @@ func errnoErr(e syscall.Errno) error {
var (
modwtsapi32 = windows.NewLazySystemDLL("wtsapi32.dll")
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
+ modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
- procWTSQueryUserToken = modwtsapi32.NewProc("WTSQueryUserToken")
- procWTSEnumerateSessionsW = modwtsapi32.NewProc("WTSEnumerateSessionsW")
- procWTSFreeMemory = modwtsapi32.NewProc("WTSFreeMemory")
- procCreateWellKnownSid = modadvapi32.NewProc("CreateWellKnownSid")
+ procWTSQueryUserToken = modwtsapi32.NewProc("WTSQueryUserToken")
+ procWTSEnumerateSessionsW = modwtsapi32.NewProc("WTSEnumerateSessionsW")
+ procWTSFreeMemory = modwtsapi32.NewProc("WTSFreeMemory")
+ procCreateWellKnownSid = modadvapi32.NewProc("CreateWellKnownSid")
+ procNotifyServiceStatusChangeW = modadvapi32.NewProc("NotifyServiceStatusChangeW")
+ procSleepEx = modkernel32.NewProc("SleepEx")
)
func wtfQueryUserToken(session uint32, token *windows.Token) (err error) {
@@ -86,3 +89,34 @@ func createWellKnownSid(sidType wellKnownSidType, domainSid *windows.SID, sid *w
}
return
}
+
+func notifyServiceStatusChange(service windows.Handle, notifyMask uint32, notifyBuffer uintptr) (err error) {
+ r1, _, e1 := syscall.Syscall(procNotifyServiceStatusChangeW.Addr(), 3, uintptr(service), uintptr(notifyMask), uintptr(notifyBuffer))
+ if r1 != 0 {
+ if e1 != 0 {
+ err = errnoErr(e1)
+ } else {
+ err = syscall.EINVAL
+ }
+ }
+ return
+}
+
+func sleepEx(milliseconds uint32, alertable bool) (ret uint32, err error) {
+ var _p0 uint32
+ if alertable {
+ _p0 = 1
+ } else {
+ _p0 = 0
+ }
+ r0, _, e1 := syscall.Syscall(procSleepEx.Addr(), 2, uintptr(milliseconds), uintptr(_p0), 0)
+ ret = uint32(r0)
+ if ret == 0 {
+ if e1 != 0 {
+ err = errnoErr(e1)
+ } else {
+ err = syscall.EINVAL
+ }
+ }
+ return
+}