aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2020-04-05 00:28:47 -0600
committerJason A. Donenfeld <Jason@zx2c4.com>2020-04-05 00:39:49 -0600
commit34e3a000c561fd2c8ef65bdb5a296f0b03095a99 (patch)
tree92557d70ea2433808636e9aafed27bcd5c12bcc2
parentsyntax: insist on 256-bit keys, not 257-bit or 258-bit (diff)
downloadwireguard-windows-34e3a000c561fd2c8ef65bdb5a296f0b03095a99.tar.xz
wireguard-windows-34e3a000c561fd2c8ef65bdb5a296f0b03095a99.zip
updater: allow updating from the command line
The administrator user may run `wireguard.exe /update`, which will check for updates and install it if available. A log file may be written using `wireguard.exe /update path\to\log\file.txt`. Requested-by: Elliot Saba <staticfloat@gmail.com> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r--elevate/doas.go39
-rw-r--r--main.go39
-rw-r--r--manager/service.go2
-rw-r--r--tunnel/defaultroutemonitor.go2
-rw-r--r--updater/downloader.go27
5 files changed, 101 insertions, 8 deletions
diff --git a/elevate/doas.go b/elevate/doas.go
index ceedb78b..8322a5c9 100644
--- a/elevate/doas.go
+++ b/elevate/doas.go
@@ -7,6 +7,7 @@ package elevate
import (
"errors"
+ "os"
"runtime"
"strings"
"unsafe"
@@ -15,6 +16,17 @@ import (
"golang.org/x/sys/windows/svc/mgr"
)
+func setAllEnv(env []string) {
+ windows.Clearenv()
+ for _, e := range env {
+ kv := strings.SplitN(e, "=", 2)
+ if len(kv) != 2 {
+ continue
+ }
+ windows.Setenv(kv[0], kv[1])
+ }
+}
+
func DoAsSystem(f func() error) error {
runtime.LockOSThread()
defer func() {
@@ -58,12 +70,14 @@ func DoAsSystem(f func() error) error {
return err
}
processEntry := windows.ProcessEntry32{Size: uint32(unsafe.Sizeof(windows.ProcessEntry32{}))}
+ var impersonationError error
for err = windows.Process32First(processes, &processEntry); err == nil; err = windows.Process32Next(processes, &processEntry) {
if strings.ToLower(windows.UTF16ToString(processEntry.ExeFile[:])) != "winlogon.exe" {
continue
}
winlogonProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION, false, processEntry.ProcessID)
if err != nil {
+ impersonationError = err
continue
}
var winlogonToken windows.Token
@@ -85,14 +99,26 @@ func DoAsSystem(f func() error) error {
if err != nil {
return err
}
+ newEnv, err := duplicatedToken.Environ(false)
+ if err != nil {
+ duplicatedToken.Close()
+ return err
+ }
+ currentEnv := os.Environ()
err = windows.SetThreadToken(nil, duplicatedToken)
duplicatedToken.Close()
if err != nil {
return err
}
- return f()
+ setAllEnv(newEnv)
+ err = f()
+ setAllEnv(currentEnv)
+ return err
}
windows.CloseHandle(processes)
+ if impersonationError != nil {
+ return impersonationError
+ }
return errors.New("unable to find winlogon.exe process")
}
@@ -131,11 +157,20 @@ func DoAsService(serviceName string, f func() error) error {
if err != nil {
return err
}
+ newEnv, err := duplicatedToken.Environ(false)
+ if err != nil {
+ duplicatedToken.Close()
+ return err
+ }
+ currentEnv := os.Environ()
err = windows.SetThreadToken(nil, duplicatedToken)
duplicatedToken.Close()
if err != nil {
return err
}
- return f()
+ setAllEnv(newEnv)
+ err = f()
+ setAllEnv(currentEnv)
+ return err
})
}
diff --git a/main.go b/main.go
index 132753cd..79dfcdfc 100644
--- a/main.go
+++ b/main.go
@@ -7,6 +7,7 @@ package main
import (
"fmt"
+ "log"
"os"
"strconv"
"strings"
@@ -20,6 +21,7 @@ import (
"golang.zx2c4.com/wireguard/windows/ringlogger"
"golang.zx2c4.com/wireguard/windows/tunnel"
"golang.zx2c4.com/wireguard/windows/ui"
+ "golang.zx2c4.com/wireguard/windows/updater"
)
func fatal(v ...interface{}) {
@@ -46,6 +48,7 @@ func usage() {
"/tunnelservice CONFIG_PATH",
"/ui CMD_READ_HANDLE CMD_WRITE_HANDLE CMD_EVENT_HANDLE LOG_MAPPING_HANDLE",
"/dumplog OUTPUT_PATH",
+ "/update [LOG_FILE]",
}
builder := strings.Builder{}
for _, flag := range flags {
@@ -224,6 +227,42 @@ func main() {
fatal(err)
}
return
+ case "/update":
+ if len(os.Args) != 2 && len(os.Args) != 3 {
+ usage()
+ }
+ var f *os.File
+ var err error
+ if len(os.Args) == 2 {
+ f = os.Stdout
+ } else {
+ f, err = os.Create(os.Args[2])
+ if err != nil {
+ fatal(err)
+ }
+ defer f.Close()
+ }
+ l := log.New(f, "", log.LstdFlags)
+ for progress := range updater.DownloadVerifyAndExecute(0) {
+ if len(progress.Activity) > 0 {
+ if progress.BytesTotal > 0 || progress.BytesDownloaded > 0 {
+ var percent float64
+ if progress.BytesTotal > 0 {
+ percent = float64(progress.BytesDownloaded) / float64(progress.BytesTotal) * 100.0
+ }
+ l.Printf("%s: %d/%d (%.2f%%)\n", progress.Activity, progress.BytesDownloaded, progress.BytesTotal, percent)
+ } else {
+ l.Println(progress.Activity)
+ }
+ }
+ if progress.Error != nil {
+ l.Printf("Error: %v\n", progress.Error)
+ }
+ if progress.Complete || progress.Error != nil {
+ return
+ }
+ }
+ return
}
usage()
}
diff --git a/manager/service.go b/manager/service.go
index dacb7864..6c3b039b 100644
--- a/manager/service.go
+++ b/manager/service.go
@@ -189,7 +189,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
Token: syscall.Token(elevatedToken),
},
Files: []*os.File{devNull, devNull, devNull},
- Dir: userProfileDirectory,
+ Dir: userProfileDirectory,
}
procsLock.Lock()
var proc *os.Process
diff --git a/tunnel/defaultroutemonitor.go b/tunnel/defaultroutemonitor.go
index 68853fd1..e71aad0d 100644
--- a/tunnel/defaultroutemonitor.go
+++ b/tunnel/defaultroutemonitor.go
@@ -39,7 +39,7 @@ func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLU
continue
}
- if r[i].Metric + iface.Metric < lowestMetric {
+ if r[i].Metric+iface.Metric < lowestMetric {
lowestMetric = r[i].Metric + iface.Metric
index = r[i].InterfaceIndex
luid = r[i].InterfaceLUID
diff --git a/updater/downloader.go b/updater/downloader.go
index a12b5037..fad9534e 100644
--- a/updater/downloader.go
+++ b/updater/downloader.go
@@ -16,6 +16,8 @@ import (
"sync/atomic"
"golang.org/x/crypto/blake2b"
+
+ "golang.zx2c4.com/wireguard/windows/elevate"
"golang.zx2c4.com/wireguard/windows/version"
)
@@ -47,6 +49,9 @@ type UpdateFound struct {
}
func CheckForUpdate() (*UpdateFound, error) {
+ if !version.IsRunningOfficialVersion() {
+ return nil, errors.New("Build is not official, so updates are disabled")
+ }
request, err := http.NewRequest(http.MethodGet, latestVersionURL, nil)
if err != nil {
return nil, err
@@ -80,17 +85,17 @@ func DownloadVerifyAndExecute(userToken uintptr) (progress chan DownloadProgress
return
}
- go func() {
+ doIt := func() {
defer atomic.StoreUint32(&updateInProgress, 0)
- progress <- DownloadProgress{Activity: "Rechecking for update"}
+ progress <- DownloadProgress{Activity: "Checking for update"}
update, err := CheckForUpdate()
if err != nil {
progress <- DownloadProgress{Error: err}
return
}
if update == nil {
- progress <- DownloadProgress{Error: errors.New("No update was found when re-checking for updates")}
+ progress <- DownloadProgress{Error: errors.New("No update was found")}
return
}
@@ -100,6 +105,7 @@ func DownloadVerifyAndExecute(userToken uintptr) (progress chan DownloadProgress
progress <- DownloadProgress{Error: err}
return
}
+ progress <- DownloadProgress{Activity: fmt.Sprintf("Msi destination is %#q", file.Name())}
defer func() {
if file != nil {
name := file.Name()
@@ -167,7 +173,20 @@ func DownloadVerifyAndExecute(userToken uintptr) (progress chan DownloadProgress
}
progress <- DownloadProgress{Complete: true}
- }()
+ }
+ if userToken == 0 {
+ go func() {
+ err := elevate.DoAsSystem(func() error {
+ doIt()
+ return nil
+ })
+ if err != nil {
+ progress <- DownloadProgress{Error: err}
+ }
+ }()
+ } else {
+ go doIt()
+ }
return progress
}