diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2020-04-05 00:28:47 -0600 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2020-04-05 00:39:49 -0600 |
commit | 34e3a000c561fd2c8ef65bdb5a296f0b03095a99 (patch) | |
tree | 92557d70ea2433808636e9aafed27bcd5c12bcc2 | |
parent | syntax: insist on 256-bit keys, not 257-bit or 258-bit (diff) | |
download | wireguard-windows-34e3a000c561fd2c8ef65bdb5a296f0b03095a99.tar.xz wireguard-windows-34e3a000c561fd2c8ef65bdb5a296f0b03095a99.zip |
updater: allow updating from the command line
The administrator user may run `wireguard.exe /update`, which will check
for updates and install it if available. A log file may be written using
`wireguard.exe /update path\to\log\file.txt`.
Requested-by: Elliot Saba <staticfloat@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r-- | elevate/doas.go | 39 | ||||
-rw-r--r-- | main.go | 39 | ||||
-rw-r--r-- | manager/service.go | 2 | ||||
-rw-r--r-- | tunnel/defaultroutemonitor.go | 2 | ||||
-rw-r--r-- | updater/downloader.go | 27 |
5 files changed, 101 insertions, 8 deletions
diff --git a/elevate/doas.go b/elevate/doas.go index ceedb78b..8322a5c9 100644 --- a/elevate/doas.go +++ b/elevate/doas.go @@ -7,6 +7,7 @@ package elevate import ( "errors" + "os" "runtime" "strings" "unsafe" @@ -15,6 +16,17 @@ import ( "golang.org/x/sys/windows/svc/mgr" ) +func setAllEnv(env []string) { + windows.Clearenv() + for _, e := range env { + kv := strings.SplitN(e, "=", 2) + if len(kv) != 2 { + continue + } + windows.Setenv(kv[0], kv[1]) + } +} + func DoAsSystem(f func() error) error { runtime.LockOSThread() defer func() { @@ -58,12 +70,14 @@ func DoAsSystem(f func() error) error { return err } processEntry := windows.ProcessEntry32{Size: uint32(unsafe.Sizeof(windows.ProcessEntry32{}))} + var impersonationError error for err = windows.Process32First(processes, &processEntry); err == nil; err = windows.Process32Next(processes, &processEntry) { if strings.ToLower(windows.UTF16ToString(processEntry.ExeFile[:])) != "winlogon.exe" { continue } winlogonProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION, false, processEntry.ProcessID) if err != nil { + impersonationError = err continue } var winlogonToken windows.Token @@ -85,14 +99,26 @@ func DoAsSystem(f func() error) error { if err != nil { return err } + newEnv, err := duplicatedToken.Environ(false) + if err != nil { + duplicatedToken.Close() + return err + } + currentEnv := os.Environ() err = windows.SetThreadToken(nil, duplicatedToken) duplicatedToken.Close() if err != nil { return err } - return f() + setAllEnv(newEnv) + err = f() + setAllEnv(currentEnv) + return err } windows.CloseHandle(processes) + if impersonationError != nil { + return impersonationError + } return errors.New("unable to find winlogon.exe process") } @@ -131,11 +157,20 @@ func DoAsService(serviceName string, f func() error) error { if err != nil { return err } + newEnv, err := duplicatedToken.Environ(false) + if err != nil { + duplicatedToken.Close() + return err + } + currentEnv := os.Environ() err = windows.SetThreadToken(nil, duplicatedToken) duplicatedToken.Close() if err != nil { return err } - return f() + setAllEnv(newEnv) + err = f() + setAllEnv(currentEnv) + return err }) } @@ -7,6 +7,7 @@ package main import ( "fmt" + "log" "os" "strconv" "strings" @@ -20,6 +21,7 @@ import ( "golang.zx2c4.com/wireguard/windows/ringlogger" "golang.zx2c4.com/wireguard/windows/tunnel" "golang.zx2c4.com/wireguard/windows/ui" + "golang.zx2c4.com/wireguard/windows/updater" ) func fatal(v ...interface{}) { @@ -46,6 +48,7 @@ func usage() { "/tunnelservice CONFIG_PATH", "/ui CMD_READ_HANDLE CMD_WRITE_HANDLE CMD_EVENT_HANDLE LOG_MAPPING_HANDLE", "/dumplog OUTPUT_PATH", + "/update [LOG_FILE]", } builder := strings.Builder{} for _, flag := range flags { @@ -224,6 +227,42 @@ func main() { fatal(err) } return + case "/update": + if len(os.Args) != 2 && len(os.Args) != 3 { + usage() + } + var f *os.File + var err error + if len(os.Args) == 2 { + f = os.Stdout + } else { + f, err = os.Create(os.Args[2]) + if err != nil { + fatal(err) + } + defer f.Close() + } + l := log.New(f, "", log.LstdFlags) + for progress := range updater.DownloadVerifyAndExecute(0) { + if len(progress.Activity) > 0 { + if progress.BytesTotal > 0 || progress.BytesDownloaded > 0 { + var percent float64 + if progress.BytesTotal > 0 { + percent = float64(progress.BytesDownloaded) / float64(progress.BytesTotal) * 100.0 + } + l.Printf("%s: %d/%d (%.2f%%)\n", progress.Activity, progress.BytesDownloaded, progress.BytesTotal, percent) + } else { + l.Println(progress.Activity) + } + } + if progress.Error != nil { + l.Printf("Error: %v\n", progress.Error) + } + if progress.Complete || progress.Error != nil { + return + } + } + return } usage() } diff --git a/manager/service.go b/manager/service.go index dacb7864..6c3b039b 100644 --- a/manager/service.go +++ b/manager/service.go @@ -189,7 +189,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest Token: syscall.Token(elevatedToken), }, Files: []*os.File{devNull, devNull, devNull}, - Dir: userProfileDirectory, + Dir: userProfileDirectory, } procsLock.Lock() var proc *os.Process diff --git a/tunnel/defaultroutemonitor.go b/tunnel/defaultroutemonitor.go index 68853fd1..e71aad0d 100644 --- a/tunnel/defaultroutemonitor.go +++ b/tunnel/defaultroutemonitor.go @@ -39,7 +39,7 @@ func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLU continue } - if r[i].Metric + iface.Metric < lowestMetric { + if r[i].Metric+iface.Metric < lowestMetric { lowestMetric = r[i].Metric + iface.Metric index = r[i].InterfaceIndex luid = r[i].InterfaceLUID diff --git a/updater/downloader.go b/updater/downloader.go index a12b5037..fad9534e 100644 --- a/updater/downloader.go +++ b/updater/downloader.go @@ -16,6 +16,8 @@ import ( "sync/atomic" "golang.org/x/crypto/blake2b" + + "golang.zx2c4.com/wireguard/windows/elevate" "golang.zx2c4.com/wireguard/windows/version" ) @@ -47,6 +49,9 @@ type UpdateFound struct { } func CheckForUpdate() (*UpdateFound, error) { + if !version.IsRunningOfficialVersion() { + return nil, errors.New("Build is not official, so updates are disabled") + } request, err := http.NewRequest(http.MethodGet, latestVersionURL, nil) if err != nil { return nil, err @@ -80,17 +85,17 @@ func DownloadVerifyAndExecute(userToken uintptr) (progress chan DownloadProgress return } - go func() { + doIt := func() { defer atomic.StoreUint32(&updateInProgress, 0) - progress <- DownloadProgress{Activity: "Rechecking for update"} + progress <- DownloadProgress{Activity: "Checking for update"} update, err := CheckForUpdate() if err != nil { progress <- DownloadProgress{Error: err} return } if update == nil { - progress <- DownloadProgress{Error: errors.New("No update was found when re-checking for updates")} + progress <- DownloadProgress{Error: errors.New("No update was found")} return } @@ -100,6 +105,7 @@ func DownloadVerifyAndExecute(userToken uintptr) (progress chan DownloadProgress progress <- DownloadProgress{Error: err} return } + progress <- DownloadProgress{Activity: fmt.Sprintf("Msi destination is %#q", file.Name())} defer func() { if file != nil { name := file.Name() @@ -167,7 +173,20 @@ func DownloadVerifyAndExecute(userToken uintptr) (progress chan DownloadProgress } progress <- DownloadProgress{Complete: true} - }() + } + if userToken == 0 { + go func() { + err := elevate.DoAsSystem(func() error { + doIt() + return nil + }) + if err != nil { + progress <- DownloadProgress{Error: err} + } + }() + } else { + go doIt() + } return progress } |