diff options
Diffstat (limited to 'main.go')
-rw-r--r-- | main.go | 104 |
1 files changed, 55 insertions, 49 deletions
@@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package main @@ -9,6 +9,7 @@ import ( "debug/pe" "errors" "fmt" + "io" "log" "os" "strconv" @@ -16,8 +17,9 @@ import ( "time" "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/driver" "golang.zx2c4.com/wireguard/windows/elevate" "golang.zx2c4.com/wireguard/windows/l18n" "golang.zx2c4.com/wireguard/windows/manager" @@ -27,21 +29,41 @@ import ( "golang.zx2c4.com/wireguard/windows/updater" ) -func fatal(v ...interface{}) { - windows.MessageBox(0, windows.StringToUTF16Ptr(fmt.Sprint(v...)), windows.StringToUTF16Ptr(l18n.Sprintf("Error")), windows.MB_ICONERROR) - os.Exit(1) +func setLogFile() { + logHandle, err := windows.GetStdHandle(windows.STD_ERROR_HANDLE) + if logHandle == 0 || err != nil { + logHandle, err = windows.GetStdHandle(windows.STD_OUTPUT_HANDLE) + } + if logHandle == 0 || err != nil { + log.SetOutput(io.Discard) + } else { + log.SetOutput(os.NewFile(uintptr(logHandle), "stderr")) + } } -func fatalf(format string, v ...interface{}) { +func fatal(v ...any) { + if log.Writer() == io.Discard { + windows.MessageBox(0, windows.StringToUTF16Ptr(fmt.Sprint(v...)), windows.StringToUTF16Ptr(l18n.Sprintf("Error")), windows.MB_ICONERROR) + os.Exit(1) + } else { + log.Fatal(append([]any{l18n.Sprintf("Error: ")}, v...)) + } +} + +func fatalf(format string, v ...any) { fatal(l18n.Sprintf(format, v...)) } -func info(title string, format string, v ...interface{}) { - windows.MessageBox(0, windows.StringToUTF16Ptr(l18n.Sprintf(format, v...)), windows.StringToUTF16Ptr(title), windows.MB_ICONINFORMATION) +func info(title, format string, v ...any) { + if log.Writer() == io.Discard { + windows.MessageBox(0, windows.StringToUTF16Ptr(l18n.Sprintf(format, v...)), windows.StringToUTF16Ptr(title), windows.MB_ICONINFORMATION) + } else { + log.Printf(title+":\n"+format, v...) + } } func usage() { - var flags = [...]string{ + flags := [...]string{ l18n.Sprintf("(no argument): elevate and install manager service"), "/installmanagerservice", "/installtunnelservice CONFIG_PATH", @@ -50,9 +72,9 @@ func usage() { "/managerservice", "/tunnelservice CONFIG_PATH", "/ui CMD_READ_HANDLE CMD_WRITE_HANDLE CMD_EVENT_HANDLE LOG_MAPPING_HANDLE", - "/dumplog OUTPUT_PATH", - "/update [LOG_FILE]", - "/removealladapters [LOG_FILE]", + "/dumplog [/tail]", + "/update", + "/removedriver", } builder := strings.Builder{} for _, flag := range flags { @@ -133,6 +155,7 @@ func main() { panic("failed to restrict dll search path") } + setLogFile() checkForWow64() if len(os.Args) <= 1 { @@ -246,35 +269,31 @@ func main() { ui.RunUI() return case "/dumplog": - if len(os.Args) != 3 { + if len(os.Args) != 2 && len(os.Args) != 3 { usage() } - file, err := os.Create(os.Args[2]) + outputHandle, err := windows.GetStdHandle(windows.STD_OUTPUT_HANDLE) if err != nil { fatal(err) } + if outputHandle == 0 { + fatal("Stdout must be set") + } + file := os.NewFile(uintptr(outputHandle), "stdout") defer file.Close() - err = ringlogger.DumpTo(file, true) + logPath, err := conf.LogFile(false) + if err != nil { + fatal(err) + } + err = ringlogger.DumpTo(logPath, file, len(os.Args) == 3 && os.Args[2] == "/tail") if err != nil { fatal(err) } return case "/update": - if len(os.Args) != 2 && len(os.Args) != 3 { + if len(os.Args) != 2 { usage() } - var f *os.File - var err error - if len(os.Args) == 2 { - f = os.Stdout - } else { - f, err = os.Create(os.Args[2]) - if err != nil { - fatal(err) - } - defer f.Close() - } - l := log.New(f, "", log.LstdFlags) for progress := range updater.DownloadVerifyAndExecute(0) { if len(progress.Activity) > 0 { if progress.BytesTotal > 0 || progress.BytesDownloaded > 0 { @@ -282,40 +301,27 @@ func main() { if progress.BytesTotal > 0 { percent = float64(progress.BytesDownloaded) / float64(progress.BytesTotal) * 100.0 } - l.Printf("%s: %d/%d (%.2f%%)\n", progress.Activity, progress.BytesDownloaded, progress.BytesTotal, percent) + log.Printf("%s: %d/%d (%.2f%%)\n", progress.Activity, progress.BytesDownloaded, progress.BytesTotal, percent) } else { - l.Println(progress.Activity) + log.Println(progress.Activity) } } if progress.Error != nil { - l.Printf("Error: %v\n", progress.Error) + log.Printf("Error: %v\n", progress.Error) } if progress.Complete || progress.Error != nil { return } } return - case "/removealladapters": - if len(os.Args) != 2 && len(os.Args) != 3 { + case "/removedriver": + if len(os.Args) != 2 { usage() } - var f *os.File - var err error - if len(os.Args) == 2 { - f = os.Stdout - } else { - f, err = os.Create(os.Args[2]) - if err != nil { - fatal(err) - } - defer f.Close() - } - log.SetOutput(f) - rebootRequired, err := tun.WintunPool.DeleteDriver() + _ = driver.UninstallLegacyWintun() // Best effort + err := driver.Uninstall() if err != nil { - log.Printf("Error: %v\n", err) - } else if rebootRequired { - log.Println("A reboot may be required") + fatal(err) } return } |