From e493f911269a2dabab7b05ec28726cdaeffb660e Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Mon, 20 May 2019 14:18:01 +0200 Subject: service: split into tunnel and manager --- manager/service.go | 331 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 331 insertions(+) create mode 100644 manager/service.go (limited to 'manager/service.go') diff --git a/manager/service.go b/manager/service.go new file mode 100644 index 00000000..ba7208d8 --- /dev/null +++ b/manager/service.go @@ -0,0 +1,331 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package manager + +import ( + "errors" + "fmt" + "log" + "os" + "runtime" + "runtime/debug" + "strings" + "sync" + "syscall" + "time" + "unsafe" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" + + "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/ringlogger" + "golang.zx2c4.com/wireguard/windows/services" + "golang.zx2c4.com/wireguard/windows/version" +) + +type managerService struct{} + +func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) { + changes <- svc.Status{State: svc.StartPending} + + var err error + serviceError := services.ErrorSuccess + + defer func() { + svcSpecificEC, exitCode = services.DetermineErrorCode(err, serviceError) + logErr := services.CombineErrors(err, serviceError) + if logErr != nil { + log.Print(logErr) + } + changes <- svc.Status{State: svc.StopPending} + }() + + err = ringlogger.InitGlobalLogger("MGR") + if err != nil { + serviceError = services.ErrorRingloggerOpen + return + } + defer func() { + if x := recover(); x != nil { + for _, line := range append([]string{fmt.Sprint(x)}, strings.Split(string(debug.Stack()), "\n")...) { + if len(strings.TrimSpace(line)) > 0 { + log.Println(line) + } + } + panic(x) + } + }() + + log.Println("Starting", version.UserAgent()) + + path, err := os.Executable() + if err != nil { + serviceError = services.ErrorDetermineExecutablePath + return + } + + devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0) + if err != nil { + serviceError = services.ErrorOpenNULFile + return + } + + err = trackExistingTunnels() + if err != nil { + serviceError = services.ErrorTrackTunnels + return + } + + conf.RegisterStoreChangeCallback(func() { conf.MigrateUnencryptedConfigs() }) // Ignore return value for now, but could be useful later. + conf.RegisterStoreChangeCallback(IPCServerNotifyTunnelsChange) + + procs := make(map[uint32]*os.Process) + aliveSessions := make(map[uint32]bool) + procsLock := sync.Mutex{} + var startProcess func(session uint32) + stoppingManager := false + + startProcess = func(session uint32) { + defer func() { + runtime.UnlockOSThread() + procsLock.Lock() + delete(aliveSessions, session) + procsLock.Unlock() + }() + + var userToken windows.Token + err := windows.WTSQueryUserToken(session, &userToken) + if err != nil { + return + } + if !services.TokenIsMemberOfBuiltInAdministrator(userToken) { + userToken.Close() + return + } + user, err := userToken.GetTokenUser() + if err != nil { + log.Printf("Unable to lookup user from token: %v", err) + userToken.Close() + return + } + username, domain, accType, err := user.User.Sid.LookupAccount("") + if err != nil { + log.Printf("Unable to lookup username from sid: %v", err) + userToken.Close() + return + } + if accType != windows.SidTypeUser { + userToken.Close() + return + } + var elevatedToken windows.Token + if userToken.IsElevated() { + elevatedToken = userToken + } else { + elevatedToken, err = userToken.GetLinkedToken() + userToken.Close() + if err != nil { + log.Printf("Unable to elevate token: %v", err) + return + } + if !elevatedToken.IsElevated() { + elevatedToken.Close() + log.Println("Linked token is not elevated") + return + } + } + defer elevatedToken.Close() + userToken = 0 + first := true + for { + if stoppingManager { + return + } + + procsLock.Lock() + if alive := aliveSessions[session]; !alive { + procsLock.Unlock() + return + } + procsLock.Unlock() + + if !first { + time.Sleep(time.Second) + } else { + first = false + } + + //TODO: we lock the OS thread so that these inheritable handles don't escape into other processes that + // might be running in parallel Go routines. But the Go runtime is strange and who knows what's really + // happening with these or what is inherited. We need to do some analysis to be certain of what's going on. + runtime.LockOSThread() + ourReader, theirReader, theirReaderStr, ourWriter, theirWriter, theirWriterStr, err := inheritableSocketpairEmulation() + if err != nil { + log.Printf("Unable to create two inheritable pipes: %v", err) + return + } + ourEvents, theirEvents, theirEventStr, err := inheritableEvents() + err = IPCServerListen(ourReader, ourWriter, ourEvents, elevatedToken) + if err != nil { + log.Printf("Unable to listen on IPC pipes: %v", err) + return + } + theirLogMapping, theirLogMappingHandle, err := ringlogger.Global.ExportInheritableMappingHandleStr() + if err != nil { + log.Printf("Unable to export inheritable mapping handle for logging: %v", err) + return + } + + log.Printf("Starting UI process for user '%s@%s' for session %d", username, domain, session) + attr := &os.ProcAttr{ + Sys: &syscall.SysProcAttr{ + Token: syscall.Token(elevatedToken), + }, + Files: []*os.File{devNull, devNull, devNull}, + } + procsLock.Lock() + var proc *os.Process + if alive := aliveSessions[session]; alive { + proc, err = os.StartProcess(path, []string{path, "/ui", theirReaderStr, theirWriterStr, theirEventStr, theirLogMapping}, attr) + } else { + err = errors.New("Session has logged out") + } + procsLock.Unlock() + theirReader.Close() + theirWriter.Close() + theirEvents.Close() + windows.Close(theirLogMappingHandle) + runtime.UnlockOSThread() + if err != nil { + ourReader.Close() + ourWriter.Close() + ourEvents.Close() + log.Printf("Unable to start manager UI process for user '%s@%s' for session %d: %v", username, domain, session, err) + return + } + + procsLock.Lock() + procs[session] = proc + procsLock.Unlock() + + sessionIsDead := false + processStatus, err := proc.Wait() + if err == nil { + exitCode := processStatus.Sys().(syscall.WaitStatus).ExitCode + log.Printf("Exited UI process for user '%s@%s' for session %d with status %x", username, domain, session, exitCode) + const STATUS_DLL_INIT_FAILED_LOGOFF = 0xC000026B + sessionIsDead = exitCode == STATUS_DLL_INIT_FAILED_LOGOFF + } else { + log.Printf("Unable to wait for UI process for user '%s@%s' for session %d: %v", username, domain, session, err) + } + + procsLock.Lock() + delete(procs, session) + procsLock.Unlock() + ourReader.Close() + ourWriter.Close() + ourEvents.Close() + + if sessionIsDead { + return + } + } + } + + go checkForUpdates() + + var sessionsPointer *windows.WTS_SESSION_INFO + var count uint32 + err = windows.WTSEnumerateSessions(0, 0, 1, &sessionsPointer, &count) + if err != nil { + serviceError = services.ErrorEnumerateSessions + return + } + sessions := *(*[]windows.WTS_SESSION_INFO)(unsafe.Pointer(&struct { + addr *windows.WTS_SESSION_INFO + len int + cap int + }{sessionsPointer, int(count), int(count)})) + for _, session := range sessions { + if session.State != windows.WTSActive && session.State != windows.WTSDisconnected { + continue + } + procsLock.Lock() + if alive := aliveSessions[session.SessionID]; !alive { + aliveSessions[session.SessionID] = true + if _, ok := procs[session.SessionID]; !ok { + go startProcess(session.SessionID) + } + } + procsLock.Unlock() + } + windows.WTSFreeMemory(uintptr(unsafe.Pointer(sessionsPointer))) + + changes <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptSessionChange} + + uninstall := false +loop: + for { + select { + case <-quitManagersChan: + uninstall = true + break loop + case c := <-r: + switch c.Cmd { + case svc.Stop: + break loop + case svc.Interrogate: + changes <- c.CurrentStatus + case svc.SessionChange: + if c.EventType != windows.WTS_SESSION_LOGON && c.EventType != windows.WTS_SESSION_LOGOFF { + continue + } + sessionNotification := (*windows.WTSSESSION_NOTIFICATION)(unsafe.Pointer(c.EventData)) + if uintptr(sessionNotification.Size) != unsafe.Sizeof(*sessionNotification) { + log.Printf("Unexpected size of WTSSESSION_NOTIFICATION: %d", sessionNotification.Size) + continue + } + if c.EventType == windows.WTS_SESSION_LOGOFF { + procsLock.Lock() + delete(aliveSessions, sessionNotification.SessionID) + if proc, ok := procs[sessionNotification.SessionID]; ok { + proc.Kill() + } + procsLock.Unlock() + } else if c.EventType == windows.WTS_SESSION_LOGON { + procsLock.Lock() + if alive := aliveSessions[sessionNotification.SessionID]; !alive { + aliveSessions[sessionNotification.SessionID] = true + if _, ok := procs[sessionNotification.SessionID]; !ok { + go startProcess(sessionNotification.SessionID) + } + } + procsLock.Unlock() + } + + default: + log.Printf("Unexpected service control request #%d", c) + } + } + } + + changes <- svc.Status{State: svc.StopPending} + procsLock.Lock() + stoppingManager = true + IPCServerNotifyManagerStopping() + for _, proc := range procs { + proc.Kill() + } + procsLock.Unlock() + if uninstall { + err = UninstallManager() + if err != nil { + log.Printf("Unable to uninstaller manager when quitting: %v", err) + } + } + return +} -- cgit v1.2.3-59-g8ed1b