diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-05-25 16:14:32 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-06-18 11:28:50 +0200 |
commit | 1636a8c94b345c1af61d6bf786b33722c2ec2b07 (patch) | |
tree | ec6a3020c23bc0f9beb3913512c19acdbc88eb63 | |
parent | main: log CLI to stderr/stdout (diff) | |
download | wireguard-windows-1636a8c94b345c1af61d6bf786b33722c2ec2b07.tar.xz wireguard-windows-1636a8c94b345c1af61d6bf786b33722c2ec2b07.zip |
manager: manually use CreateProcess for launching UI process
Go's standard library for this is buggy (PID races, handle races) and
requires passing NUL, which we don't really care about for Windows.
Simplify and speed up process creation by only passing exactly what we
need.
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r-- | manager/service.go | 35 | ||||
-rw-r--r-- | manager/uiprocess.go | 99 | ||||
-rw-r--r-- | services/errors.go | 3 |
3 files changed, 108 insertions, 29 deletions
diff --git a/manager/service.go b/manager/service.go index ac428c86..da6ff497 100644 --- a/manager/service.go +++ b/manager/service.go @@ -12,7 +12,6 @@ import ( "runtime" "strconv" "sync" - "syscall" "time" "unsafe" @@ -57,12 +56,6 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest return } - devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0) - if err != nil { - serviceError = services.ErrorOpenNULFile - return - } - moveConfigsFromLegacyStore() err = trackExistingTunnels() @@ -74,7 +67,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest conf.RegisterStoreChangeCallback(func() { conf.MigrateUnencryptedConfigs(changeTunnelServiceConfigFilePath) }) conf.RegisterStoreChangeCallback(IPCServerNotifyTunnelsChange) - procs := make(map[uint32]*os.Process) + procs := make(map[uint32]*uiProcess) aliveSessions := make(map[uint32]bool) procsLock := sync.Mutex{} stoppingManager := false @@ -196,29 +189,21 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest } log.Printf("Starting UI process for user ā%s@%sā for session %d", username, domain, session) - attr := &os.ProcAttr{ - Sys: &syscall.SysProcAttr{ - Token: syscall.Token(runToken), - AdditionalInheritedHandles: []syscall.Handle{ - syscall.Handle(theirReader.Fd()), - syscall.Handle(theirWriter.Fd()), - syscall.Handle(theirEvents.Fd()), - syscall.Handle(theirLogMapping)}, - }, - Files: []*os.File{devNull, devNull, devNull}, - Dir: userProfileDirectory, - } procsLock.Lock() - var proc *os.Process + var proc *uiProcess if alive := aliveSessions[session]; alive { - proc, err = os.StartProcess(path, []string{ + proc, err = launchUIProcess(path, []string{ path, "/ui", strconv.FormatUint(uint64(theirReader.Fd()), 10), strconv.FormatUint(uint64(theirWriter.Fd()), 10), strconv.FormatUint(uint64(theirEvents.Fd()), 10), strconv.FormatUint(uint64(theirLogMapping), 10), - }, attr) + }, userProfileDirectory, []windows.Handle{ + windows.Handle(theirReader.Fd()), + windows.Handle(theirWriter.Fd()), + windows.Handle(theirEvents.Fd()), + theirLogMapping}, runToken) } else { err = errors.New("Session has logged out") } @@ -240,9 +225,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest procsLock.Unlock() sessionIsDead := false - processStatus, err := proc.Wait() - if err == nil { - exitCode := processStatus.Sys().(syscall.WaitStatus).ExitCode + if exitCode, err := proc.Wait(); err == nil { log.Printf("Exited UI process for user '%s@%s' for session %d with status %x", username, domain, session, exitCode) const STATUS_DLL_INIT_FAILED_LOGOFF = 0xC000026B sessionIsDead = exitCode == STATUS_DLL_INIT_FAILED_LOGOFF diff --git a/manager/uiprocess.go b/manager/uiprocess.go new file mode 100644 index 00000000..80ac8b30 --- /dev/null +++ b/manager/uiprocess.go @@ -0,0 +1,99 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + */ + +package manager + +import ( + "errors" + "runtime" + "sync/atomic" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +type uiProcess struct { + handle uintptr +} + +func launchUIProcess(executable string, args []string, workingDirectory string, handles []windows.Handle, token windows.Token) (*uiProcess, error) { + executable16, err := windows.UTF16PtrFromString(executable) + if err != nil { + return nil, err + } + args16, err := windows.UTF16PtrFromString(windows.ComposeCommandLine(args)) + if err != nil { + return nil, err + } + workingDirectory16, err := windows.UTF16PtrFromString(workingDirectory) + if err != nil { + return nil, err + } + var environmentBlock *uint16 + err = windows.CreateEnvironmentBlock(&environmentBlock, token, false) + if err != nil { + return nil, err + } + defer windows.DestroyEnvironmentBlock(environmentBlock) + attributeList, err := windows.NewProcThreadAttributeList(1) + if err != nil { + return nil, err + } + defer attributeList.Delete() + si := &windows.StartupInfoEx{ + StartupInfo: windows.StartupInfo{Cb: uint32(unsafe.Sizeof(windows.StartupInfoEx{}))}, + ProcThreadAttributeList: attributeList.List(), + } + if len(handles) == 0 { + handles = []windows.Handle{0} + } + attributeList.Update(windows.PROC_THREAD_ATTRIBUTE_HANDLE_LIST, unsafe.Pointer(&handles[0]), uintptr(len(handles))*unsafe.Sizeof(handles[0])) + pi := new(windows.ProcessInformation) + err = windows.CreateProcessAsUser(token, executable16, args16, nil, nil, true, windows.CREATE_DEFAULT_ERROR_MODE|windows.CREATE_UNICODE_ENVIRONMENT|windows.EXTENDED_STARTUPINFO_PRESENT, environmentBlock, workingDirectory16, &si.StartupInfo, pi) + if err != nil { + return nil, err + } + windows.CloseHandle(pi.Thread) + uiProc := &uiProcess{handle: uintptr(pi.Process)} + runtime.SetFinalizer(uiProc, (*uiProcess).release) + return uiProc, nil +} + +func (p *uiProcess) release() error { + handle := windows.Handle(atomic.SwapUintptr(&p.handle, uintptr(windows.InvalidHandle))) + if handle == windows.InvalidHandle { + return nil + } + err := windows.CloseHandle(handle) + if err != nil { + return err + } + runtime.SetFinalizer(p, nil) + return nil +} + +func (p *uiProcess) Wait() (uint32, error) { + handle := windows.Handle(atomic.LoadUintptr(&p.handle)) + s, err := windows.WaitForSingleObject(handle, syscall.INFINITE) + switch s { + case windows.WAIT_OBJECT_0: + case windows.WAIT_FAILED: + return 0, err + default: + return 0, errors.New("unexpected result from WaitForSingleObject") + } + var exitCode uint32 + err = windows.GetExitCodeProcess(handle, &exitCode) + if err != nil { + return 0, err + } + p.release() + return exitCode, nil +} + +func (p *uiProcess) Kill() error { + return windows.TerminateProcess(windows.Handle(atomic.LoadUintptr(&p.handle)), 1) +} diff --git a/services/errors.go b/services/errors.go index e17cad88..569585a5 100644 --- a/services/errors.go +++ b/services/errors.go @@ -26,7 +26,6 @@ const ( ErrorBindSocketsToDefaultRoutes ErrorSetNetConfig ErrorDetermineExecutablePath - ErrorOpenNULFile ErrorTrackTunnels ErrorEnumerateSessions ErrorDropPrivileges @@ -58,8 +57,6 @@ func (e Error) Error() string { return "Unable to bind sockets to default route" case ErrorSetNetConfig: return "Unable to set interface addresses, routes, dns, and/or interface settings" - case ErrorOpenNULFile: - return "Unable to open NUL file" case ErrorTrackTunnels: return "Unable to track existing tunnels" case ErrorEnumerateSessions: |