diff options
Diffstat (limited to 'main.go')
-rw-r--r-- | main.go | 137 |
1 files changed, 99 insertions, 38 deletions
@@ -1,12 +1,15 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package main import ( + "debug/pe" + "errors" "fmt" + "io" "log" "os" "strconv" @@ -15,6 +18,8 @@ import ( "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/driver" "golang.zx2c4.com/wireguard/windows/elevate" "golang.zx2c4.com/wireguard/windows/l18n" "golang.zx2c4.com/wireguard/windows/manager" @@ -24,21 +29,41 @@ import ( "golang.zx2c4.com/wireguard/windows/updater" ) -func fatal(v ...interface{}) { - windows.MessageBox(0, windows.StringToUTF16Ptr(fmt.Sprint(v...)), windows.StringToUTF16Ptr(l18n.Sprintf("Error")), windows.MB_ICONERROR) - os.Exit(1) +func setLogFile() { + logHandle, err := windows.GetStdHandle(windows.STD_ERROR_HANDLE) + if logHandle == 0 || err != nil { + logHandle, err = windows.GetStdHandle(windows.STD_OUTPUT_HANDLE) + } + if logHandle == 0 || err != nil { + log.SetOutput(io.Discard) + } else { + log.SetOutput(os.NewFile(uintptr(logHandle), "stderr")) + } +} + +func fatal(v ...any) { + if log.Writer() == io.Discard { + windows.MessageBox(0, windows.StringToUTF16Ptr(fmt.Sprint(v...)), windows.StringToUTF16Ptr(l18n.Sprintf("Error")), windows.MB_ICONERROR) + os.Exit(1) + } else { + log.Fatal(append([]any{l18n.Sprintf("Error: ")}, v...)) + } } -func fatalf(format string, v ...interface{}) { +func fatalf(format string, v ...any) { fatal(l18n.Sprintf(format, v...)) } -func info(title string, format string, v ...interface{}) { - windows.MessageBox(0, windows.StringToUTF16Ptr(l18n.Sprintf(format, v...)), windows.StringToUTF16Ptr(title), windows.MB_ICONINFORMATION) +func info(title, format string, v ...any) { + if log.Writer() == io.Discard { + windows.MessageBox(0, windows.StringToUTF16Ptr(l18n.Sprintf(format, v...)), windows.StringToUTF16Ptr(title), windows.MB_ICONINFORMATION) + } else { + log.Printf(title+":\n"+format, v...) + } } func usage() { - var flags = [...]string{ + flags := [...]string{ l18n.Sprintf("(no argument): elevate and install manager service"), "/installmanagerservice", "/installtunnelservice CONFIG_PATH", @@ -47,8 +72,9 @@ func usage() { "/managerservice", "/tunnelservice CONFIG_PATH", "/ui CMD_READ_HANDLE CMD_WRITE_HANDLE CMD_EVENT_HANDLE LOG_MAPPING_HANDLE", - "/dumplog OUTPUT_PATH", - "/update [LOG_FILE]", + "/dumplog [/tail]", + "/update", + "/removedriver", } builder := strings.Builder{} for _, flag := range flags { @@ -59,13 +85,27 @@ func usage() { } func checkForWow64() { - var b bool - err := windows.IsWow64Process(windows.CurrentProcess(), &b) + b, err := func() (bool, error) { + var processMachine, nativeMachine uint16 + err := windows.IsWow64Process2(windows.CurrentProcess(), &processMachine, &nativeMachine) + if err == nil { + return processMachine != pe.IMAGE_FILE_MACHINE_UNKNOWN, nil + } + if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) { + return false, err + } + var b bool + err = windows.IsWow64Process(windows.CurrentProcess(), &b) + if err != nil { + return false, err + } + return b, nil + }() if err != nil { fatalf("Unable to determine whether the process is running under WOW64: %v", err) } if b { - fatalf("You must use the 64-bit version of WireGuard on this computer.") + fatalf("You must use the native version of WireGuard on this computer.") } } @@ -95,11 +135,11 @@ func execElevatedManagerServiceInstaller() error { return err } err = elevate.ShellExecute(path, "/installmanagerservice", "", windows.SW_SHOW) - if err != nil { + if err != nil && err != windows.ERROR_CANCELLED { return err } os.Exit(0) - return windows.ERROR_ACCESS_DENIED // Not reached + return windows.ERROR_UNHANDLED_EXCEPTION // Not reached } func pipeFromHandleArgument(handleStr string) (*os.File, error) { @@ -111,13 +151,18 @@ func pipeFromHandleArgument(handleStr string) (*os.File, error) { } func main() { + if windows.SetDllDirectory("") != nil || windows.SetDefaultDllDirectories(windows.LOAD_LIBRARY_SEARCH_SYSTEM32) != nil { + panic("failed to restrict dll search path") + } + + setLogFile() checkForWow64() if len(os.Args) <= 1 { - checkForAdminGroup() if ui.RaiseUI() { return } + checkForAdminGroup() err := execElevatedManagerServiceInstaller() if err != nil { fatal(err) @@ -190,9 +235,18 @@ func main() { if len(os.Args) != 6 { usage() } - err := elevate.DropAllPrivileges(false) - if err != nil { - fatal(err) + var processToken windows.Token + isAdmin := false + err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY|windows.TOKEN_DUPLICATE, &processToken) + if err == nil { + isAdmin = elevate.TokenIsElevatedOrElevatable(processToken) + processToken.Close() + } + if isAdmin { + err := elevate.DropAllPrivileges(false) + if err != nil { + fatal(err) + } } readPipe, err := pipeFromHandleArgument(os.Args[2]) if err != nil { @@ -211,38 +265,35 @@ func main() { fatal(err) } manager.InitializeIPCClient(readPipe, writePipe, eventPipe) + ui.IsAdmin = isAdmin ui.RunUI() return case "/dumplog": - if len(os.Args) != 3 { + if len(os.Args) != 2 && len(os.Args) != 3 { usage() } - file, err := os.Create(os.Args[2]) + outputHandle, err := windows.GetStdHandle(windows.STD_OUTPUT_HANDLE) if err != nil { fatal(err) } + if outputHandle == 0 { + fatal("Stdout must be set") + } + file := os.NewFile(uintptr(outputHandle), "stdout") defer file.Close() - err = ringlogger.DumpTo(file, true) + logPath, err := conf.LogFile(false) + if err != nil { + fatal(err) + } + err = ringlogger.DumpTo(logPath, file, len(os.Args) == 3 && os.Args[2] == "/tail") if err != nil { fatal(err) } return case "/update": - if len(os.Args) != 2 && len(os.Args) != 3 { + if len(os.Args) != 2 { usage() } - var f *os.File - var err error - if len(os.Args) == 2 { - f = os.Stdout - } else { - f, err = os.Create(os.Args[2]) - if err != nil { - fatal(err) - } - defer f.Close() - } - l := log.New(f, "", log.LstdFlags) for progress := range updater.DownloadVerifyAndExecute(0) { if len(progress.Activity) > 0 { if progress.BytesTotal > 0 || progress.BytesDownloaded > 0 { @@ -250,19 +301,29 @@ func main() { if progress.BytesTotal > 0 { percent = float64(progress.BytesDownloaded) / float64(progress.BytesTotal) * 100.0 } - l.Printf("%s: %d/%d (%.2f%%)\n", progress.Activity, progress.BytesDownloaded, progress.BytesTotal, percent) + log.Printf("%s: %d/%d (%.2f%%)\n", progress.Activity, progress.BytesDownloaded, progress.BytesTotal, percent) } else { - l.Println(progress.Activity) + log.Println(progress.Activity) } } if progress.Error != nil { - l.Printf("Error: %v\n", progress.Error) + log.Printf("Error: %v\n", progress.Error) } if progress.Complete || progress.Error != nil { return } } return + case "/removedriver": + if len(os.Args) != 2 { + usage() + } + _ = driver.UninstallLegacyWintun() // Best effort + err := driver.Uninstall() + if err != nil { + fatal(err) + } + return } usage() } |