aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/updater/msirunner.go
diff options
context:
space:
mode:
Diffstat (limited to 'updater/msirunner.go')
-rw-r--r--updater/msirunner.go116
1 files changed, 116 insertions, 0 deletions
diff --git a/updater/msirunner.go b/updater/msirunner.go
new file mode 100644
index 00000000..d7631706
--- /dev/null
+++ b/updater/msirunner.go
@@ -0,0 +1,116 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2020 WireGuard LLC. All Rights Reserved.
+ */
+
+package updater
+
+import (
+ "crypto/rand"
+ "encoding/hex"
+ "errors"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+type tempFile struct {
+ *os.File
+ originalHandle windows.Handle
+}
+
+func (t *tempFile) ExclusivePath() string {
+ if t.originalHandle != 0 {
+ t.Close() // TODO: sort of a toctou, but msi requires unshared file
+ t.originalHandle = 0
+ }
+ return t.Name()
+}
+
+func (t *tempFile) Delete() error {
+ if t.originalHandle == 0 {
+ name16, err := windows.UTF16PtrFromString(t.Name())
+ if err != nil {
+ return err
+ }
+ return windows.DeleteFile(name16) //TODO: how does this deal with reparse points?
+ }
+ disposition := byte(1)
+ err := windows.SetFileInformationByHandle(t.originalHandle, windows.FileDispositionInfo, &disposition, 1)
+ t.originalHandle = 0
+ t.Close()
+ return err
+}
+
+func runMsi(msi *tempFile, userToken uintptr) error {
+ system32, err := windows.GetSystemDirectory()
+ if err != nil {
+ return err
+ }
+ devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0)
+ if err != nil {
+ return err
+ }
+ defer devNull.Close()
+ msiPath := msi.ExclusivePath()
+ attr := &os.ProcAttr{
+ Sys: &syscall.SysProcAttr{
+ Token: syscall.Token(userToken),
+ },
+ Files: []*os.File{devNull, devNull, devNull},
+ Dir: filepath.Dir(msiPath),
+ }
+ msiexec := filepath.Join(system32, "msiexec.exe")
+ proc, err := os.StartProcess(msiexec, []string{msiexec, "/qb!-", "/i", filepath.Base(msiPath)}, attr)
+ if err != nil {
+ return err
+ }
+ state, err := proc.Wait()
+ if err != nil {
+ return err
+ }
+ if !state.Success() {
+ return &exec.ExitError{ProcessState: state}
+ }
+ return nil
+}
+
+func msiTempFile() (*tempFile, error) {
+ var randBytes [32]byte
+ n, err := rand.Read(randBytes[:])
+ if err != nil {
+ return nil, err
+ }
+ if n != int(len(randBytes)) {
+ return nil, errors.New("Unable to generate random bytes")
+ }
+ sd, err := windows.SecurityDescriptorFromString("O:SYD:PAI(A;;FA;;;SY)(A;;FR;;;BA)")
+ if err != nil {
+ return nil, err
+ }
+ sa := &windows.SecurityAttributes{
+ Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})),
+ SecurityDescriptor: sd,
+ }
+ windir, err := windows.GetWindowsDirectory()
+ if err != nil {
+ return nil, err
+ }
+ name := filepath.Join(windir, "Temp", hex.EncodeToString(randBytes[:]))
+ name16 := windows.StringToUTF16Ptr(name)
+ fileHandle, err := windows.CreateFile(name16, windows.GENERIC_WRITE|windows.DELETE, 0, sa, windows.CREATE_NEW, windows.FILE_ATTRIBUTE_TEMPORARY, 0)
+ runtime.KeepAlive(sd)
+ if err != nil {
+ return nil, err
+ }
+ windows.MoveFileEx(name16, nil, windows.MOVEFILE_DELAY_UNTIL_REBOOT)
+ return &tempFile{
+ File: os.NewFile(uintptr(fileHandle), name),
+ originalHandle: fileHandle,
+ }, nil
+}