From 568528c747afe3ae991b0340d15cd6a897fd5c9d Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Mon, 6 May 2019 09:46:10 +0200 Subject: updater: move into manager --- updater/downloader.go | 57 +++++++++++++---------------------- updater/msirunner_linux.go | 8 +++-- updater/msirunner_windows.go | 71 ++++++++++++++++++++++++++++++++++++++------ updater/updater_test.go | 2 +- 4 files changed, 88 insertions(+), 50 deletions(-) (limited to 'updater') diff --git a/updater/downloader.go b/updater/downloader.go index 382d284b..2f83b9b2 100644 --- a/updater/downloader.go +++ b/updater/downloader.go @@ -15,7 +15,6 @@ import ( "io" "net/http" "os" - "path" "sync/atomic" ) @@ -71,7 +70,7 @@ func CheckForUpdate() (*UpdateFound, error) { var updateInProgress = uint32(0) -func DownloadVerifyAndExecute() (progress chan DownloadProgress) { +func DownloadVerifyAndExecute(userToken uintptr, userEnvironment []string) (progress chan DownloadProgress) { progress = make(chan DownloadProgress, 128) progress <- DownloadProgress{Activity: "Initializing"} @@ -94,33 +93,19 @@ func DownloadVerifyAndExecute() (progress chan DownloadProgress) { return } - progress <- DownloadProgress{Activity: "Creating update file"} - updateDir, err := msiSaveDirectory() - if err != nil { - progress <- DownloadProgress{Error: err} - return - } - // Clean up old updates the brutal way: - os.RemoveAll(updateDir) - - err = os.MkdirAll(updateDir, 0700) - if err != nil { - progress <- DownloadProgress{Error: err} - return - } - destinationFilename := path.Join(updateDir, update.name) - unverifiedDestinationFilename := destinationFilename + ".unverified" - out, err := os.Create(unverifiedDestinationFilename) + progress <- DownloadProgress{Activity: "Creating temporary file"} + file, err := msiTempFile() if err != nil { progress <- DownloadProgress{Error: err} return } defer func() { - if out != nil { - out.Seek(0, io.SeekStart) - out.Truncate(0) - out.Close() - os.Remove(unverifiedDestinationFilename) + if file != nil { + name := file.Name() + file.Seek(0, io.SeekStart) + file.Truncate(0) + file.Close() + os.Remove(name) //TODO: Do we have any sort of TOCTOU here? } }() @@ -149,7 +134,7 @@ func DownloadVerifyAndExecute() (progress chan DownloadProgress) { return } pm := &progressHashWatcher{&dp, progress, hasher} - _, err = io.Copy(out, io.TeeReader(io.LimitReader(response.Body, 1024*1024*100 /* 100 MiB */), pm)) + _, err = io.Copy(file, io.TeeReader(io.LimitReader(response.Body, 1024*1024*100 /* 100 MiB */), pm)) if err != nil { progress <- DownloadProgress{Error: err} return @@ -158,25 +143,23 @@ func DownloadVerifyAndExecute() (progress chan DownloadProgress) { progress <- DownloadProgress{Error: errors.New("The downloaded update has the wrong hash")} return } - out.Close() - out = nil + + //TODO: it would be nice to rename in place from "file.msi.unverified" to "file.msi", but Windows TOCTOU stuff + // is hard, so we'll come back to this later. + name := file.Name() + file.Close() + file = nil progress <- DownloadProgress{Activity: "Verifying authenticode signature"} - if !version.VerifyAuthenticode(unverifiedDestinationFilename) { - os.Remove(unverifiedDestinationFilename) + if !version.VerifyAuthenticode(name) { + os.Remove(name) //TODO: Do we have any sort of TOCTOU here? progress <- DownloadProgress{Error: errors.New("The downloaded update does not have an authentic authenticode signature")} return } progress <- DownloadProgress{Activity: "Installing update"} - err = os.Rename(unverifiedDestinationFilename, destinationFilename) - if err != nil { - os.Remove(unverifiedDestinationFilename) - progress <- DownloadProgress{Error: err} - return - } - err = runMsi(destinationFilename) - os.Remove(unverifiedDestinationFilename) + err = runMsi(name, userToken, userEnvironment) + os.Remove(name) //TODO: Do we have any sort of TOCTOU here? if err != nil { progress <- DownloadProgress{Error: err} return diff --git a/updater/msirunner_linux.go b/updater/msirunner_linux.go index cbb52cf6..6550025c 100644 --- a/updater/msirunner_linux.go +++ b/updater/msirunner_linux.go @@ -7,15 +7,17 @@ package updater import ( "fmt" + "io/ioutil" + "os" "os/exec" ) // This isn't a Linux program, yes, but having the updater package work across platforms is quite helpful for testing. -func runMsi(msiPath string) error { +func runMsi(msiPath string, userToken uintptr, env []string) error { return exec.Command("qarma", "--info", "--text", fmt.Sprintf("It seems to be working! Were we on Windows, ā€˜%sā€™ would be executed.", msiPath)).Run() } -func msiSaveDirectory() (string, error) { - return "/tmp/wireguard-update-test-msi-directory", nil +func msiTempFile() (*os.File, error) { + return ioutil.TempFile(os.TempDir(), "") } diff --git a/updater/msirunner_windows.go b/updater/msirunner_windows.go index dfa921ee..de3fb58e 100644 --- a/updater/msirunner_windows.go +++ b/updater/msirunner_windows.go @@ -6,26 +6,79 @@ package updater import ( + "crypto/rand" + "encoding/hex" + "errors" + "github.com/Microsoft/go-winio" "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/windows/conf" + "os" "os/exec" "path" + "runtime" + "syscall" + "unsafe" ) -func runMsi(msiPath string) error { +func runMsi(msiPath string, userToken uintptr, env []string) error { system32, err := windows.GetSystemDirectory() if err != nil { return err } - cmd := exec.Command(path.Join(system32, "msiexec.exe"), "/qb!-", "/i", path.Base(msiPath)) - cmd.Dir = path.Dir(msiPath) - return cmd.Run() + devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0) + if err != nil { + return err + } + defer devNull.Close() + attr := &os.ProcAttr{ + Sys: &syscall.SysProcAttr{ + Token: syscall.Token(userToken), + }, + Files: []*os.File{devNull, devNull, devNull}, + Env: env, + Dir: path.Dir(msiPath), + } + msiexec := path.Join(system32, "msiexec.exe") + proc, err := os.StartProcess(msiexec, []string{msiexec, "/qb!-", "/i", path.Base(msiPath)}, attr) + if err != nil { + return err + } + state, err := proc.Wait() + if err != nil { + return err + } + if !state.Success() { + return &exec.ExitError{ProcessState: state} + } + return nil } -func msiSaveDirectory() (string, error) { - configRootDir, err := conf.RootDirectory() +func msiTempFile() (*os.File, error) { + var randBytes [32]byte + n, err := rand.Read(randBytes[:]) + if err != nil { + return nil, err + } + if n != int(len(randBytes)) { + return nil, errors.New("Unable to generate random bytes") + } + sd, err := winio.SddlToSecurityDescriptor("O:SYD:PAI(A;;FA;;;SY)(A;;FR;;;BA)") + if err != nil { + return nil, err + } + sa := &windows.SecurityAttributes{ + Length: uint32(len(sd)), + SecurityDescriptor: uintptr(unsafe.Pointer(&sd[0])), + } + //TODO: os.TempDir() returns C:\windows\temp when calling from this context. Supposedly this is mostly secure + // against TOCTOU, but who knows! Look into this! + name := path.Join(os.TempDir(), hex.EncodeToString(randBytes[:])) + name16 := windows.StringToUTF16Ptr(name) + //TODO: it would be nice to specify delete_on_close, but msiexec.exe doesn't open its files with read sharing. + fileHandle, err := windows.CreateFile(name16, windows.GENERIC_WRITE, windows.FILE_SHARE_READ, sa, windows.CREATE_NEW, windows.FILE_ATTRIBUTE_NORMAL, 0) + runtime.KeepAlive(sd) if err != nil { - return "", err + return nil, err } - return path.Join(configRootDir, "Updates"), nil + windows.MoveFileEx(name16, nil, windows.MOVEFILE_DELAY_UNTIL_REBOOT) + return os.NewFile(uintptr(fileHandle), name), nil } diff --git a/updater/updater_test.go b/updater/updater_test.go index 7bc4df8e..fbd1080d 100644 --- a/updater/updater_test.go +++ b/updater/updater_test.go @@ -20,7 +20,7 @@ func TestUpdate(t *testing.T) { return } t.Log("Found update") - progress := DownloadVerifyAndExecute() + progress := DownloadVerifyAndExecute(0, nil) for { dp := <-progress if dp.Error != nil { -- cgit v1.2.3-59-g8ed1b