aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/main.go
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--main.go137
1 files changed, 99 insertions, 38 deletions
diff --git a/main.go b/main.go
index 79dfcdfc..62a0b559 100644
--- a/main.go
+++ b/main.go
@@ -1,12 +1,15 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package main
import (
+ "debug/pe"
+ "errors"
"fmt"
+ "io"
"log"
"os"
"strconv"
@@ -15,6 +18,8 @@ import (
"golang.org/x/sys/windows"
+ "golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/driver"
"golang.zx2c4.com/wireguard/windows/elevate"
"golang.zx2c4.com/wireguard/windows/l18n"
"golang.zx2c4.com/wireguard/windows/manager"
@@ -24,21 +29,41 @@ import (
"golang.zx2c4.com/wireguard/windows/updater"
)
-func fatal(v ...interface{}) {
- windows.MessageBox(0, windows.StringToUTF16Ptr(fmt.Sprint(v...)), windows.StringToUTF16Ptr(l18n.Sprintf("Error")), windows.MB_ICONERROR)
- os.Exit(1)
+func setLogFile() {
+ logHandle, err := windows.GetStdHandle(windows.STD_ERROR_HANDLE)
+ if logHandle == 0 || err != nil {
+ logHandle, err = windows.GetStdHandle(windows.STD_OUTPUT_HANDLE)
+ }
+ if logHandle == 0 || err != nil {
+ log.SetOutput(io.Discard)
+ } else {
+ log.SetOutput(os.NewFile(uintptr(logHandle), "stderr"))
+ }
+}
+
+func fatal(v ...any) {
+ if log.Writer() == io.Discard {
+ windows.MessageBox(0, windows.StringToUTF16Ptr(fmt.Sprint(v...)), windows.StringToUTF16Ptr(l18n.Sprintf("Error")), windows.MB_ICONERROR)
+ os.Exit(1)
+ } else {
+ log.Fatal(append([]any{l18n.Sprintf("Error: ")}, v...))
+ }
}
-func fatalf(format string, v ...interface{}) {
+func fatalf(format string, v ...any) {
fatal(l18n.Sprintf(format, v...))
}
-func info(title string, format string, v ...interface{}) {
- windows.MessageBox(0, windows.StringToUTF16Ptr(l18n.Sprintf(format, v...)), windows.StringToUTF16Ptr(title), windows.MB_ICONINFORMATION)
+func info(title, format string, v ...any) {
+ if log.Writer() == io.Discard {
+ windows.MessageBox(0, windows.StringToUTF16Ptr(l18n.Sprintf(format, v...)), windows.StringToUTF16Ptr(title), windows.MB_ICONINFORMATION)
+ } else {
+ log.Printf(title+":\n"+format, v...)
+ }
}
func usage() {
- var flags = [...]string{
+ flags := [...]string{
l18n.Sprintf("(no argument): elevate and install manager service"),
"/installmanagerservice",
"/installtunnelservice CONFIG_PATH",
@@ -47,8 +72,9 @@ func usage() {
"/managerservice",
"/tunnelservice CONFIG_PATH",
"/ui CMD_READ_HANDLE CMD_WRITE_HANDLE CMD_EVENT_HANDLE LOG_MAPPING_HANDLE",
- "/dumplog OUTPUT_PATH",
- "/update [LOG_FILE]",
+ "/dumplog [/tail]",
+ "/update",
+ "/removedriver",
}
builder := strings.Builder{}
for _, flag := range flags {
@@ -59,13 +85,27 @@ func usage() {
}
func checkForWow64() {
- var b bool
- err := windows.IsWow64Process(windows.CurrentProcess(), &b)
+ b, err := func() (bool, error) {
+ var processMachine, nativeMachine uint16
+ err := windows.IsWow64Process2(windows.CurrentProcess(), &processMachine, &nativeMachine)
+ if err == nil {
+ return processMachine != pe.IMAGE_FILE_MACHINE_UNKNOWN, nil
+ }
+ if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) {
+ return false, err
+ }
+ var b bool
+ err = windows.IsWow64Process(windows.CurrentProcess(), &b)
+ if err != nil {
+ return false, err
+ }
+ return b, nil
+ }()
if err != nil {
fatalf("Unable to determine whether the process is running under WOW64: %v", err)
}
if b {
- fatalf("You must use the 64-bit version of WireGuard on this computer.")
+ fatalf("You must use the native version of WireGuard on this computer.")
}
}
@@ -95,11 +135,11 @@ func execElevatedManagerServiceInstaller() error {
return err
}
err = elevate.ShellExecute(path, "/installmanagerservice", "", windows.SW_SHOW)
- if err != nil {
+ if err != nil && err != windows.ERROR_CANCELLED {
return err
}
os.Exit(0)
- return windows.ERROR_ACCESS_DENIED // Not reached
+ return windows.ERROR_UNHANDLED_EXCEPTION // Not reached
}
func pipeFromHandleArgument(handleStr string) (*os.File, error) {
@@ -111,13 +151,18 @@ func pipeFromHandleArgument(handleStr string) (*os.File, error) {
}
func main() {
+ if windows.SetDllDirectory("") != nil || windows.SetDefaultDllDirectories(windows.LOAD_LIBRARY_SEARCH_SYSTEM32) != nil {
+ panic("failed to restrict dll search path")
+ }
+
+ setLogFile()
checkForWow64()
if len(os.Args) <= 1 {
- checkForAdminGroup()
if ui.RaiseUI() {
return
}
+ checkForAdminGroup()
err := execElevatedManagerServiceInstaller()
if err != nil {
fatal(err)
@@ -190,9 +235,18 @@ func main() {
if len(os.Args) != 6 {
usage()
}
- err := elevate.DropAllPrivileges(false)
- if err != nil {
- fatal(err)
+ var processToken windows.Token
+ isAdmin := false
+ err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY|windows.TOKEN_DUPLICATE, &processToken)
+ if err == nil {
+ isAdmin = elevate.TokenIsElevatedOrElevatable(processToken)
+ processToken.Close()
+ }
+ if isAdmin {
+ err := elevate.DropAllPrivileges(false)
+ if err != nil {
+ fatal(err)
+ }
}
readPipe, err := pipeFromHandleArgument(os.Args[2])
if err != nil {
@@ -211,38 +265,35 @@ func main() {
fatal(err)
}
manager.InitializeIPCClient(readPipe, writePipe, eventPipe)
+ ui.IsAdmin = isAdmin
ui.RunUI()
return
case "/dumplog":
- if len(os.Args) != 3 {
+ if len(os.Args) != 2 && len(os.Args) != 3 {
usage()
}
- file, err := os.Create(os.Args[2])
+ outputHandle, err := windows.GetStdHandle(windows.STD_OUTPUT_HANDLE)
if err != nil {
fatal(err)
}
+ if outputHandle == 0 {
+ fatal("Stdout must be set")
+ }
+ file := os.NewFile(uintptr(outputHandle), "stdout")
defer file.Close()
- err = ringlogger.DumpTo(file, true)
+ logPath, err := conf.LogFile(false)
+ if err != nil {
+ fatal(err)
+ }
+ err = ringlogger.DumpTo(logPath, file, len(os.Args) == 3 && os.Args[2] == "/tail")
if err != nil {
fatal(err)
}
return
case "/update":
- if len(os.Args) != 2 && len(os.Args) != 3 {
+ if len(os.Args) != 2 {
usage()
}
- var f *os.File
- var err error
- if len(os.Args) == 2 {
- f = os.Stdout
- } else {
- f, err = os.Create(os.Args[2])
- if err != nil {
- fatal(err)
- }
- defer f.Close()
- }
- l := log.New(f, "", log.LstdFlags)
for progress := range updater.DownloadVerifyAndExecute(0) {
if len(progress.Activity) > 0 {
if progress.BytesTotal > 0 || progress.BytesDownloaded > 0 {
@@ -250,19 +301,29 @@ func main() {
if progress.BytesTotal > 0 {
percent = float64(progress.BytesDownloaded) / float64(progress.BytesTotal) * 100.0
}
- l.Printf("%s: %d/%d (%.2f%%)\n", progress.Activity, progress.BytesDownloaded, progress.BytesTotal, percent)
+ log.Printf("%s: %d/%d (%.2f%%)\n", progress.Activity, progress.BytesDownloaded, progress.BytesTotal, percent)
} else {
- l.Println(progress.Activity)
+ log.Println(progress.Activity)
}
}
if progress.Error != nil {
- l.Printf("Error: %v\n", progress.Error)
+ log.Printf("Error: %v\n", progress.Error)
}
if progress.Complete || progress.Error != nil {
return
}
}
return
+ case "/removedriver":
+ if len(os.Args) != 2 {
+ usage()
+ }
+ _ = driver.UninstallLegacyWintun() // Best effort
+ err := driver.Uninstall()
+ if err != nil {
+ fatal(err)
+ }
+ return
}
usage()
}