diff options
Diffstat (limited to '')
-rw-r--r-- | manager/install.go | 59 | ||||
-rw-r--r-- | manager/interfacecleanup.go | 58 | ||||
-rw-r--r-- | manager/ipc_client.go | 17 | ||||
-rw-r--r-- | manager/ipc_driver.go | 61 | ||||
-rw-r--r-- | manager/ipc_pipe.go | 55 | ||||
-rw-r--r-- | manager/ipc_server.go | 134 | ||||
-rw-r--r-- | manager/service.go | 157 | ||||
-rw-r--r-- | manager/tunneltracker.go | 357 | ||||
-rw-r--r-- | manager/uiprocess.go | 103 | ||||
-rw-r--r-- | manager/updatestate.go | 36 |
10 files changed, 671 insertions, 366 deletions
diff --git a/manager/install.go b/manager/install.go index f84a96ae..44a744cf 100644 --- a/manager/install.go +++ b/manager/install.go @@ -1,13 +1,15 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package manager import ( "errors" + "log" "os" + "strings" "time" "golang.org/x/sys/windows" @@ -15,7 +17,6 @@ import ( "golang.org/x/sys/windows/svc/mgr" "golang.zx2c4.com/wireguard/windows/conf" - "golang.zx2c4.com/wireguard/windows/services" ) var cachedServiceManager *mgr.Mgr @@ -56,6 +57,12 @@ func InstallManager() error { } if status.State != svc.Stopped { service.Close() + if status.State == svc.StartPending { + // We were *just* started by something else, so return success here, assuming the other program + // starting this does the right thing. This can happen when, e.g., the updater relaunches the + // manager service and then invokes wireguard.exe to raise the UI. + return nil + } return ErrManagerAlreadyRunning } err = service.Delete() @@ -122,7 +129,7 @@ func InstallTunnel(configPath string) error { return err } - serviceName, err := services.ServiceNameOfTunnel(name) + serviceName, err := conf.ServiceNameOfTunnel(name) if err != nil { return err } @@ -156,7 +163,7 @@ func InstallTunnel(configPath string) error { ServiceType: windows.SERVICE_WIN32_OWN_PROCESS, StartType: mgr.StartAutomatic, ErrorControl: mgr.ErrorNormal, - Dependencies: []string{"Nsi"}, + Dependencies: []string{"Nsi", "TcpIp"}, DisplayName: "WireGuard Tunnel: " + name, SidType: windows.SERVICE_SID_TYPE_UNRESTRICTED, } @@ -175,7 +182,7 @@ func UninstallTunnel(name string) error { if err != nil { return err } - serviceName, err := services.ServiceNameOfTunnel(name) + serviceName, err := conf.ServiceNameOfTunnel(name) if err != nil { return err } @@ -191,3 +198,45 @@ func UninstallTunnel(name string) error { } return err2 } + +func changeTunnelServiceConfigFilePath(name, oldPath, newPath string) { + var err error + defer func() { + if err != nil { + log.Printf("Unable to change tunnel service command line argument from %#q to %#q: %v", oldPath, newPath, err) + } + }() + m, err := serviceManager() + if err != nil { + return + } + serviceName, err := conf.ServiceNameOfTunnel(name) + if err != nil { + return + } + service, err := m.OpenService(serviceName) + if err == windows.ERROR_SERVICE_DOES_NOT_EXIST { + err = nil + return + } else if err != nil { + return + } + defer service.Close() + config, err := service.Config() + if err != nil { + return + } + exePath, err := os.Executable() + if err != nil { + return + } + args, err := windows.DecomposeCommandLine(config.BinaryPathName) + if err != nil || len(args) != 3 || + !strings.EqualFold(args[0], exePath) || args[1] != "/tunnelservice" || !strings.EqualFold(args[2], oldPath) { + err = nil + return + } + args[2] = newPath + config.BinaryPathName = windows.ComposeCommandLine(args) + err = service.UpdateConfig(config) +} diff --git a/manager/interfacecleanup.go b/manager/interfacecleanup.go deleted file mode 100644 index f5d9ef48..00000000 --- a/manager/interfacecleanup.go +++ /dev/null @@ -1,58 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package manager - -import ( - "log" - - "golang.org/x/sys/windows" - "golang.org/x/sys/windows/svc" - "golang.org/x/sys/windows/svc/mgr" - "golang.zx2c4.com/wireguard/tun/wintun" - - "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/windows/services" -) - -func cleanupStaleWintunInterfaces() { - defer printPanic() - - m, err := mgr.Connect() - if err != nil { - return - } - defer m.Disconnect() - - tun.WintunPool.DeleteMatchingInterfaces(func(wintun *wintun.Interface) bool { - interfaceName, err := wintun.Name() - if err != nil { - log.Printf("Removing Wintun interface %s because determining interface name failed: %v", wintun.GUID().String(), err) - return true - } - serviceName, err := services.ServiceNameOfTunnel(interfaceName) - if err != nil { - log.Printf("Removing Wintun interface ‘%s’ because determining tunnel service name failed: %v", interfaceName, err) - return true - } - service, err := m.OpenService(serviceName) - if err == windows.ERROR_SERVICE_DOES_NOT_EXIST { - log.Printf("Removing Wintun interface ‘%s’ because no service for it exists", interfaceName) - return true - } else if err != nil { - return false - } - defer service.Close() - status, err := service.Query() - if err != nil { - return false - } - if status.State == svc.Stopped { - log.Printf("Removing Wintun interface ‘%s’ because its service is stopped", interfaceName) - return true - } - return false - }) -} diff --git a/manager/ipc_client.go b/manager/ipc_client.go index c8b2f852..8c9c4c04 100644 --- a/manager/ipc_client.go +++ b/manager/ipc_client.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package manager @@ -64,7 +64,7 @@ var ( ) type TunnelChangeCallback struct { - cb func(tunnel *Tunnel, state TunnelState, globalState TunnelState, err error) + cb func(tunnel *Tunnel, state, globalState TunnelState, err error) } var tunnelChangeCallbacks = make(map[*TunnelChangeCallback]bool) @@ -93,7 +93,7 @@ type UpdateProgressCallback struct { var updateProgressCallbacks = make(map[*UpdateProgressCallback]bool) -func InitializeIPCClient(reader *os.File, writer *os.File, events *os.File) { +func InitializeIPCClient(reader, writer, events *os.File) { rpcDecoder = gob.NewDecoder(reader) rpcEncoder = gob.NewEncoder(writer) go func() { @@ -431,43 +431,52 @@ func IPCClientUpdate() error { return rpcEncoder.Encode(UpdateMethodType) } -func IPCClientRegisterTunnelChange(cb func(tunnel *Tunnel, state TunnelState, globalState TunnelState, err error)) *TunnelChangeCallback { +func IPCClientRegisterTunnelChange(cb func(tunnel *Tunnel, state, globalState TunnelState, err error)) *TunnelChangeCallback { s := &TunnelChangeCallback{cb} tunnelChangeCallbacks[s] = true return s } + func (cb *TunnelChangeCallback) Unregister() { delete(tunnelChangeCallbacks, cb) } + func IPCClientRegisterTunnelsChange(cb func()) *TunnelsChangeCallback { s := &TunnelsChangeCallback{cb} tunnelsChangeCallbacks[s] = true return s } + func (cb *TunnelsChangeCallback) Unregister() { delete(tunnelsChangeCallbacks, cb) } + func IPCClientRegisterManagerStopping(cb func()) *ManagerStoppingCallback { s := &ManagerStoppingCallback{cb} managerStoppingCallbacks[s] = true return s } + func (cb *ManagerStoppingCallback) Unregister() { delete(managerStoppingCallbacks, cb) } + func IPCClientRegisterUpdateFound(cb func(updateState UpdateState)) *UpdateFoundCallback { s := &UpdateFoundCallback{cb} updateFoundCallbacks[s] = true return s } + func (cb *UpdateFoundCallback) Unregister() { delete(updateFoundCallbacks, cb) } + func IPCClientRegisterUpdateProgress(cb func(dp updater.DownloadProgress)) *UpdateProgressCallback { s := &UpdateProgressCallback{cb} updateProgressCallbacks[s] = true return s } + func (cb *UpdateProgressCallback) Unregister() { delete(updateProgressCallbacks, cb) } diff --git a/manager/ipc_driver.go b/manager/ipc_driver.go new file mode 100644 index 00000000..6cb43c38 --- /dev/null +++ b/manager/ipc_driver.go @@ -0,0 +1,61 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package manager + +import ( + "sync" + + "golang.zx2c4.com/wireguard/windows/driver" +) + +type lockedDriverAdapter struct { + *driver.Adapter + sync.Mutex +} + +var ( + driverAdapters = make(map[string]*lockedDriverAdapter) + driverAdaptersLock sync.RWMutex +) + +func findDriverAdapter(tunnelName string) (*lockedDriverAdapter, error) { + driverAdaptersLock.RLock() + driverAdapter, ok := driverAdapters[tunnelName] + if ok { + driverAdapter.Lock() + driverAdaptersLock.RUnlock() + return driverAdapter, nil + } + driverAdaptersLock.RUnlock() + driverAdaptersLock.Lock() + defer driverAdaptersLock.Unlock() + driverAdapter, ok = driverAdapters[tunnelName] + if ok { + driverAdapter.Lock() + return driverAdapter, nil + } + driverAdapter = &lockedDriverAdapter{} + var err error + driverAdapter.Adapter, err = driver.OpenAdapter(tunnelName) + if err != nil { + return nil, err + } + driverAdapters[tunnelName] = driverAdapter + driverAdapter.Lock() + return driverAdapter, nil +} + +func releaseDriverAdapter(tunnelName string) { + driverAdaptersLock.Lock() + defer driverAdaptersLock.Unlock() + driverAdapter, ok := driverAdapters[tunnelName] + if !ok { + return + } + driverAdapter.Lock() + delete(driverAdapters, tunnelName) + driverAdapter.Unlock() +} diff --git a/manager/ipc_pipe.go b/manager/ipc_pipe.go deleted file mode 100644 index d4214ac0..00000000 --- a/manager/ipc_pipe.go +++ /dev/null @@ -1,55 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package manager - -import ( - "os" - "strconv" - - "golang.org/x/sys/windows" -) - -func makeInheritableAndGetStr(f *os.File) (str string, err error) { - sc, err := f.SyscallConn() - if err != nil { - return - } - err2 := sc.Control(func(fd uintptr) { - err = windows.SetHandleInformation(windows.Handle(fd), windows.HANDLE_FLAG_INHERIT, windows.HANDLE_FLAG_INHERIT) - str = strconv.FormatUint(uint64(fd), 10) - }) - if err2 != nil { - err = err2 - } - return -} - -func inheritableEvents() (ourEvents *os.File, theirEvents *os.File, theirEventStr string, err error) { - theirEvents, ourEvents, err = os.Pipe() - if err != nil { - return - } - theirEventStr, err = makeInheritableAndGetStr(theirEvents) - return -} - -func inheritableSocketpairEmulation() (ourReader *os.File, theirReader *os.File, theirReaderStr string, ourWriter *os.File, theirWriter *os.File, theirWriterStr string, err error) { - ourReader, theirWriter, err = os.Pipe() - if err != nil { - return - } - theirWriterStr, err = makeInheritableAndGetStr(theirWriter) - if err != nil { - return - } - - theirReader, ourWriter, err = os.Pipe() - if err != nil { - return - } - theirReaderStr, err = makeInheritableAndGetStr(theirReader) - return -} diff --git a/manager/ipc_server.go b/manager/ipc_server.go index 1367c2e9..e21ffaf0 100644 --- a/manager/ipc_server.go +++ b/manager/ipc_server.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package manager @@ -10,7 +10,6 @@ import ( "encoding/gob" "fmt" "io" - "io/ioutil" "log" "os" "sync" @@ -20,64 +19,73 @@ import ( "golang.org/x/sys/windows" "golang.org/x/sys/windows/svc" - "golang.zx2c4.com/wireguard/ipc/winpipe" - "golang.zx2c4.com/wireguard/windows/conf" - "golang.zx2c4.com/wireguard/windows/services" "golang.zx2c4.com/wireguard/windows/updater" ) -var managerServices = make(map[*ManagerService]bool) -var managerServicesLock sync.RWMutex -var haveQuit uint32 -var quitManagersChan = make(chan struct{}, 1) +var ( + managerServices = make(map[*ManagerService]bool) + managerServicesLock sync.RWMutex + haveQuit uint32 + quitManagersChan = make(chan struct{}, 1) +) type ManagerService struct { events *os.File + eventLock sync.Mutex elevatedToken windows.Token } func (s *ManagerService) StoredConfig(tunnelName string) (*conf.Config, error) { - return conf.LoadFromName(tunnelName) -} - -func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) { - storedConfig, err := conf.LoadFromName(tunnelName) + conf, err := conf.LoadFromName(tunnelName) if err != nil { return nil, err } - pipePath, err := services.PipePathOfTunnel(storedConfig.Name) - if err != nil { - return nil, err + if s.elevatedToken == 0 { + conf.Redact() } - localSystem, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid) + return conf, nil +} + +func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) { + storedConfig, err := conf.LoadFromName(tunnelName) if err != nil { return nil, err } - pipe, err := winpipe.DialPipe(pipePath, nil, localSystem) + driverAdapter, err := findDriverAdapter(tunnelName) if err != nil { return nil, err } - defer pipe.Close() - pipe.SetWriteDeadline(time.Now().Add(time.Second * 2)) - _, err = pipe.Write([]byte("get=1\n\n")) + runtimeConfig, err := driverAdapter.Configuration() if err != nil { + driverAdapter.Unlock() + releaseDriverAdapter(tunnelName) return nil, err } - pipe.SetReadDeadline(time.Now().Add(time.Second * 2)) - resp, err := ioutil.ReadAll(pipe) - if err != nil { - return nil, err + conf := conf.FromDriverConfiguration(runtimeConfig, storedConfig) + driverAdapter.Unlock() + if s.elevatedToken == 0 { + conf.Redact() } - return conf.FromUAPI(string(resp), storedConfig) + return conf, nil } func (s *ManagerService) Start(tunnelName string) error { - // For now, enforce only one tunnel at a time. Later we'll remove this silly restriction. + c, err := conf.LoadFromName(tunnelName) + if err != nil { + return err + } + + // Figure out which tunnels have intersecting addresses/routes and stop those. trackedTunnelsLock.Lock() tt := make([]string, 0, len(trackedTunnels)) var inTransition string for t, state := range trackedTunnels { + c2, err := conf.LoadFromName(t) + if err != nil || !c.IntersectsWith(c2) { + // If we can't get the config, assume it doesn't intersect. + continue + } tt = append(tt, t) if len(t) > 0 && (state == TunnelStarting || state == TunnelUnknown) { inTransition = t @@ -88,6 +96,8 @@ func (s *ManagerService) Start(tunnelName string) error { if len(inTransition) != 0 { return fmt.Errorf("Please allow the tunnel ‘%s’ to finish activating", inTransition) } + + // Stop those intersecting tunnels asynchronously. go func() { for _, t := range tt { s.Stop(t) @@ -101,13 +111,7 @@ func (s *ManagerService) Start(tunnelName string) error { } } }() - time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces) - - // After that process is started -- it's somewhat asynchronous -- we install the new one. - c, err := conf.LoadFromName(tunnelName) - if err != nil { - return err - } + // After the stop process has begun, but before it's finished, we install the new one. path, err := c.Path() if err != nil { return err @@ -116,8 +120,6 @@ func (s *ManagerService) Start(tunnelName string) error { } func (s *ManagerService) Stop(tunnelName string) error { - time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces) - err := UninstallTunnel(tunnelName) if err == windows.ERROR_SERVICE_DOES_NOT_EXIST { _, notExistsError := conf.LoadFromName(tunnelName) @@ -129,7 +131,7 @@ func (s *ManagerService) Stop(tunnelName string) error { } func (s *ManagerService) WaitForStop(tunnelName string) error { - serviceName, err := services.ServiceNameOfTunnel(tunnelName) + serviceName, err := conf.ServiceNameOfTunnel(tunnelName) if err != nil { return err } @@ -149,6 +151,9 @@ func (s *ManagerService) WaitForStop(tunnelName string) error { } func (s *ManagerService) Delete(tunnelName string) error { + if s.elevatedToken == 0 { + return windows.ERROR_ACCESS_DENIED + } err := s.Stop(tunnelName) if err != nil { return err @@ -157,7 +162,7 @@ func (s *ManagerService) Delete(tunnelName string) error { } func (s *ManagerService) State(tunnelName string) (TunnelState, error) { - serviceName, err := services.ServiceNameOfTunnel(tunnelName) + serviceName, err := conf.ServiceNameOfTunnel(tunnelName) if err != nil { return 0, err } @@ -193,7 +198,10 @@ func (s *ManagerService) GlobalState() TunnelState { } func (s *ManagerService) Create(tunnelConfig *conf.Config) (*Tunnel, error) { - err := tunnelConfig.Save() + if s.elevatedToken == 0 { + return nil, windows.ERROR_ACCESS_DENIED + } + err := tunnelConfig.Save(true) if err != nil { return nil, err } @@ -209,19 +217,25 @@ func (s *ManagerService) Tunnels() ([]Tunnel, error) { } tunnels := make([]Tunnel, len(names)) for i := 0; i < len(tunnels); i++ { - (tunnels)[i].Name = names[i] + tunnels[i].Name = names[i] } return tunnels, nil // TODO: account for running ones that aren't in the configuration store somehow } func (s *ManagerService) Quit(stopTunnelsOnQuit bool) (alreadyQuit bool, err error) { + if s.elevatedToken == 0 { + return false, windows.ERROR_ACCESS_DENIED + } if !atomic.CompareAndSwapUint32(&haveQuit, 0, 1) { return true, nil } // Work around potential race condition of delivering messages to the wrong process by removing from notifications. managerServicesLock.Lock() + s.eventLock.Lock() + s.events = nil + s.eventLock.Unlock() delete(managerServices, s) managerServicesLock.Unlock() @@ -244,6 +258,9 @@ func (s *ManagerService) UpdateState() UpdateState { } func (s *ManagerService) Update() { + if s.elevatedToken == 0 { + return + } progress := updater.DownloadVerifyAndExecute(uintptr(s.elevatedToken)) go func() { for { @@ -374,6 +391,9 @@ func (s *ManagerService) ServeConn(reader io.Reader, writer io.Writer) { return } tunnel, retErr := s.Create(&config) + if tunnel == nil { + tunnel = &Tunnel{} + } err = encoder.Encode(tunnel) if err != nil { return @@ -421,26 +441,27 @@ func (s *ManagerService) ServeConn(reader io.Reader, writer io.Writer) { } } -func IPCServerListen(reader *os.File, writer *os.File, events *os.File, elevatedToken windows.Token) { +func IPCServerListen(reader, writer, events *os.File, elevatedToken windows.Token) { service := &ManagerService{ events: events, elevatedToken: elevatedToken, } go func() { - defer printPanic() managerServicesLock.Lock() managerServices[service] = true managerServicesLock.Unlock() service.ServeConn(reader, writer) managerServicesLock.Lock() + service.eventLock.Lock() + service.events = nil + service.eventLock.Unlock() delete(managerServices, service) managerServicesLock.Unlock() - }() } -func notifyAll(notificationType NotificationType, ifaces ...interface{}) { +func notifyAll(notificationType NotificationType, adminOnly bool, ifaces ...any) { if len(managerServices) == 0 { return } @@ -460,8 +481,17 @@ func notifyAll(notificationType NotificationType, ifaces ...interface{}) { managerServicesLock.RLock() for m := range managerServices { - m.events.SetWriteDeadline(time.Now().Add(time.Second)) - m.events.Write(buf.Bytes()) + if m.elevatedToken == 0 && adminOnly { + continue + } + go func(m *ManagerService) { + m.eventLock.Lock() + defer m.eventLock.Unlock() + if m.events != nil { + m.events.SetWriteDeadline(time.Now().Add(time.Second)) + m.events.Write(buf.Bytes()) + } + }(m) } managerServicesLock.RUnlock() } @@ -474,22 +504,22 @@ func errToString(err error) string { } func IPCServerNotifyTunnelChange(name string, state TunnelState, err error) { - notifyAll(TunnelChangeNotificationType, name, state, trackedTunnelsGlobalState(), errToString(err)) + notifyAll(TunnelChangeNotificationType, false, name, state, trackedTunnelsGlobalState(), errToString(err)) } func IPCServerNotifyTunnelsChange() { - notifyAll(TunnelsChangeNotificationType) + notifyAll(TunnelsChangeNotificationType, false) } func IPCServerNotifyUpdateFound(state UpdateState) { - notifyAll(UpdateFoundNotificationType, state) + notifyAll(UpdateFoundNotificationType, false, state) } func IPCServerNotifyUpdateProgress(dp updater.DownloadProgress) { - notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, errToString(dp.Error), dp.Complete) + notifyAll(UpdateProgressNotificationType, true, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, errToString(dp.Error), dp.Complete) } func IPCServerNotifyManagerStopping() { - notifyAll(ManagerStoppingNotificationType) + notifyAll(ManagerStoppingNotificationType, false) time.Sleep(time.Millisecond * 200) } diff --git a/manager/service.go b/manager/service.go index 6c3b039b..47e20d45 100644 --- a/manager/service.go +++ b/manager/service.go @@ -1,46 +1,32 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package manager import ( "errors" - "fmt" "log" "os" "runtime" - "runtime/debug" - "strings" + "strconv" "sync" - "syscall" "time" "unsafe" "golang.org/x/sys/windows" "golang.org/x/sys/windows/svc" + "golang.zx2c4.com/wireguard/windows/driver" "golang.zx2c4.com/wireguard/windows/conf" "golang.zx2c4.com/wireguard/windows/elevate" "golang.zx2c4.com/wireguard/windows/ringlogger" "golang.zx2c4.com/wireguard/windows/services" - "golang.zx2c4.com/wireguard/windows/version" ) type managerService struct{} -func printPanic() { - if x := recover(); x != nil { - for _, line := range append([]string{fmt.Sprint(x)}, strings.Split(string(debug.Stack()), "\n")...) { - if len(strings.TrimSpace(line)) > 0 { - log.Println(line) - } - } - panic(x) - } -} - func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) { changes <- svc.Status{State: svc.StartPending} @@ -56,40 +42,40 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest changes <- svc.Status{State: svc.StopPending} }() - err = ringlogger.InitGlobalLogger("MGR") + var logFile string + logFile, err = conf.LogFile(true) if err != nil { serviceError = services.ErrorRingloggerOpen return } - defer printPanic() - - log.Println("Starting", version.UserAgent()) - - path, err := os.Executable() + err = ringlogger.InitGlobalLogger(logFile, "MGR") if err != nil { - serviceError = services.ErrorDetermineExecutablePath + serviceError = services.ErrorRingloggerOpen return } - devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0) + services.PrintStarting() + + path, err := os.Executable() if err != nil { - serviceError = services.ErrorOpenNULFile + serviceError = services.ErrorDetermineExecutablePath return } - err = trackExistingTunnels() + err = watchNewTunnelServices() if err != nil { serviceError = services.ErrorTrackTunnels return } - conf.RegisterStoreChangeCallback(func() { conf.MigrateUnencryptedConfigs() }) // Ignore return value for now, but could be useful later. + conf.RegisterStoreChangeCallback(func() { conf.MigrateUnencryptedConfigs(changeTunnelServiceConfigFilePath) }) conf.RegisterStoreChangeCallback(IPCServerNotifyTunnelsChange) - procs := make(map[uint32]*os.Process) + procs := make(map[uint32]*uiProcess) aliveSessions := make(map[uint32]bool) procsLock := sync.Mutex{} stoppingManager := false + operatorGroupSid, _ := windows.CreateWellKnownSid(windows.WinBuiltinNetworkConfigurationOperatorsSid) startProcess := func(session uint32) { defer func() { @@ -104,7 +90,24 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest if err != nil { return } - if !elevate.TokenIsElevatedOrElevatable(userToken) { + isAdmin := elevate.TokenIsElevatedOrElevatable(userToken) + isOperator := false + if !isAdmin && conf.AdminBool("LimitedOperatorUI") && operatorGroupSid != nil { + linkedToken, err := userToken.GetLinkedToken() + var impersonationToken windows.Token + if err == nil { + err = windows.DuplicateTokenEx(linkedToken, windows.TOKEN_QUERY, nil, windows.SecurityImpersonation, windows.TokenImpersonation, &impersonationToken) + linkedToken.Close() + } else { + err = windows.DuplicateTokenEx(userToken, windows.TOKEN_QUERY, nil, windows.SecurityImpersonation, windows.TokenImpersonation, &impersonationToken) + } + if err == nil { + isOperator, err = impersonationToken.IsMember(operatorGroupSid) + isOperator = isOperator && err == nil + impersonationToken.Close() + } + } + if !isAdmin && !isOperator { userToken.Close() return } @@ -125,23 +128,28 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest return } userProfileDirectory, _ := userToken.GetUserProfileDirectory() - var elevatedToken windows.Token - if userToken.IsElevated() { - elevatedToken = userToken - } else { - elevatedToken, err = userToken.GetLinkedToken() - userToken.Close() - if err != nil { - log.Printf("Unable to elevate token: %v", err) - return - } - if !elevatedToken.IsElevated() { - elevatedToken.Close() - log.Println("Linked token is not elevated") - return + var elevatedToken, runToken windows.Token + if isAdmin { + if userToken.IsElevated() { + elevatedToken = userToken + } else { + elevatedToken, err = userToken.GetLinkedToken() + userToken.Close() + if err != nil { + log.Printf("Unable to elevate token: %v", err) + return + } + if !elevatedToken.IsElevated() { + elevatedToken.Close() + log.Println("Linked token is not elevated") + return + } } + runToken = elevatedToken + } else { + runToken = userToken } - defer elevatedToken.Close() + defer runToken.Close() userToken = 0 first := true for { @@ -162,39 +170,45 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest first = false } - // TODO: we lock the OS thread so that these inheritable handles don't escape into other processes that - // might be running in parallel Go routines. But the Go runtime is strange and who knows what's really - // happening with these or what is inherited. We need to do some analysis to be certain of what's going on. - runtime.LockOSThread() - ourReader, theirReader, theirReaderStr, ourWriter, theirWriter, theirWriterStr, err := inheritableSocketpairEmulation() + ourReader, theirWriter, err := os.Pipe() + if err != nil { + log.Printf("Unable to create pipe: %v", err) + return + } + theirReader, ourWriter, err := os.Pipe() if err != nil { - log.Printf("Unable to create two inheritable RPC pipes: %v", err) + log.Printf("Unable to create pipe: %v", err) return } - ourEvents, theirEvents, theirEventStr, err := inheritableEvents() + theirEvents, ourEvents, err := os.Pipe() if err != nil { - log.Printf("Unable to create one inheritable events pipe: %v", err) + log.Printf("Unable to create pipe: %v", err) return } IPCServerListen(ourReader, ourWriter, ourEvents, elevatedToken) - theirLogMapping, theirLogMappingHandle, err := ringlogger.Global.ExportInheritableMappingHandleStr() + theirLogMapping, err := ringlogger.Global.ExportInheritableMappingHandle() if err != nil { log.Printf("Unable to export inheritable mapping handle for logging: %v", err) return } log.Printf("Starting UI process for user ‘%s@%s’ for session %d", username, domain, session) - attr := &os.ProcAttr{ - Sys: &syscall.SysProcAttr{ - Token: syscall.Token(elevatedToken), - }, - Files: []*os.File{devNull, devNull, devNull}, - Dir: userProfileDirectory, - } procsLock.Lock() - var proc *os.Process + var proc *uiProcess if alive := aliveSessions[session]; alive { - proc, err = os.StartProcess(path, []string{path, "/ui", theirReaderStr, theirWriterStr, theirEventStr, theirLogMapping}, attr) + proc, err = launchUIProcess(path, []string{ + path, + "/ui", + strconv.FormatUint(uint64(theirReader.Fd()), 10), + strconv.FormatUint(uint64(theirWriter.Fd()), 10), + strconv.FormatUint(uint64(theirEvents.Fd()), 10), + strconv.FormatUint(uint64(theirLogMapping), 10), + }, userProfileDirectory, []windows.Handle{ + windows.Handle(theirReader.Fd()), + windows.Handle(theirWriter.Fd()), + windows.Handle(theirEvents.Fd()), + theirLogMapping, + }, runToken) } else { err = errors.New("Session has logged out") } @@ -202,8 +216,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest theirReader.Close() theirWriter.Close() theirEvents.Close() - windows.Close(theirLogMappingHandle) - runtime.UnlockOSThread() + windows.CloseHandle(theirLogMapping) if err != nil { ourReader.Close() ourWriter.Close() @@ -217,9 +230,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest procsLock.Unlock() sessionIsDead := false - processStatus, err := proc.Wait() - if err == nil { - exitCode := processStatus.Sys().(syscall.WaitStatus).ExitCode + if exitCode, err := proc.Wait(); err == nil { log.Printf("Exited UI process for user '%s@%s' for session %d with status %x", username, domain, session, exitCode) const STATUS_DLL_INIT_FAILED_LOGOFF = 0xC000026B sessionIsDead = exitCode == STATUS_DLL_INIT_FAILED_LOGOFF @@ -243,14 +254,13 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest goStartProcess := func(session uint32) { procsGroup.Add(1) go func() { - defer printPanic() startProcess(session) procsGroup.Done() }() } - time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces) go checkForUpdates() + go driver.UninstallLegacyWintun() // We uninstall opportunistically here, so that we don't have to carry around the uninstaller code forever. var sessionsPointer *windows.WTS_SESSION_INFO var count uint32 @@ -259,12 +269,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest serviceError = services.ErrorEnumerateSessions return } - sessions := *(*[]windows.WTS_SESSION_INFO)(unsafe.Pointer(&struct { - addr *windows.WTS_SESSION_INFO - len int - cap int - }{sessionsPointer, int(count), int(count)})) - for _, session := range sessions { + for _, session := range unsafe.Slice(sessionsPointer, count) { if session.State != windows.WTSActive && session.State != windows.WTSDisconnected { continue } diff --git a/manager/tunneltracker.go b/manager/tunneltracker.go index 0f222aac..9003d445 100644 --- a/manager/tunneltracker.go +++ b/manager/tunneltracker.go @@ -1,17 +1,20 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package manager import ( + "errors" "fmt" "log" "runtime" "sync" + "sync/atomic" "syscall" "time" + "unsafe" "golang.org/x/sys/windows" "golang.org/x/sys/windows/svc" @@ -21,50 +24,10 @@ import ( "golang.zx2c4.com/wireguard/windows/services" ) -func trackExistingTunnels() error { - m, err := serviceManager() - if err != nil { - return err - } - names, err := conf.ListConfigNames() - if err != nil { - return err - } - for _, name := range names { - serviceName, err := services.ServiceNameOfTunnel(name) - if err != nil { - continue - } - service, err := m.OpenService(serviceName) - if err != nil { - continue - } - go trackTunnelService(name, service) - } - return nil -} - -var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *windows.SERVICE_NOTIFY) uintptr { - return 0 -}) - -var trackedTunnels = make(map[string]TunnelState) -var trackedTunnelsLock = sync.Mutex{} - -func svcStateToTunState(s svc.State) TunnelState { - switch s { - case svc.StartPending: - return TunnelStarting - case svc.Running: - return TunnelStarted - case svc.StopPending: - return TunnelStopping - case svc.Stopped: - return TunnelStopped - default: - return TunnelUnknown - } -} +var ( + trackedTunnels = make(map[string]TunnelState) + trackedTunnelsLock = sync.Mutex{} +) func trackedTunnelsGlobalState() (state TunnelState) { state = TunnelStopped @@ -82,50 +45,95 @@ func trackedTunnelsGlobalState() (state TunnelState) { return } -func trackTunnelService(tunnelName string, service *mgr.Service) { - defer func() { - service.Close() - log.Printf("[%s] Tunnel service tracker finished", tunnelName) - }() +var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *windows.SERVICE_NOTIFY) uintptr { + return 0 +}) - trackedTunnelsLock.Lock() - if _, found := trackedTunnels[tunnelName]; found { - trackedTunnelsLock.Unlock() - return +type serviceSubscriptionState struct { + service *mgr.Service + cb func(status uint32) bool + done sync.WaitGroup + once uint32 +} + +var serviceSubscriptionCallbackPtr = windows.NewCallback(func(notification uint32, context uintptr) uintptr { + state := (*serviceSubscriptionState)(unsafe.Pointer(context)) + if atomic.LoadUint32(&state.once) != 0 { + return 0 } - trackedTunnels[tunnelName] = TunnelUnknown - trackedTunnelsLock.Unlock() - defer func() { - trackedTunnelsLock.Lock() - delete(trackedTunnels, tunnelName) - trackedTunnelsLock.Unlock() - }() + if notification == 0 { + status, err := state.service.Query() + if err == nil { + notification = svcStateToNotifyState(uint32(status.State)) + } + } + if state.cb(notification) && atomic.CompareAndSwapUint32(&state.once, 0, 1) { + state.done.Done() + } + return 0 +}) - const serviceNotifications = windows.SERVICE_NOTIFY_RUNNING | windows.SERVICE_NOTIFY_START_PENDING | windows.SERVICE_NOTIFY_STOP_PENDING | windows.SERVICE_NOTIFY_STOPPED | windows.SERVICE_NOTIFY_DELETE_PENDING - notifier := &windows.SERVICE_NOTIFY{ - Version: windows.SERVICE_NOTIFY_STATUS_CHANGE, - NotifyCallback: serviceTrackerCallbackPtr, +func svcStateToNotifyState(s uint32) uint32 { + switch s { + case windows.SERVICE_STOPPED: + return windows.SERVICE_NOTIFY_STOPPED + case windows.SERVICE_START_PENDING: + return windows.SERVICE_NOTIFY_START_PENDING + case windows.SERVICE_STOP_PENDING: + return windows.SERVICE_NOTIFY_STOP_PENDING + case windows.SERVICE_RUNNING: + return windows.SERVICE_NOTIFY_RUNNING + case windows.SERVICE_CONTINUE_PENDING: + return windows.SERVICE_NOTIFY_CONTINUE_PENDING + case windows.SERVICE_PAUSE_PENDING: + return windows.SERVICE_NOTIFY_PAUSE_PENDING + case windows.SERVICE_PAUSED: + return windows.SERVICE_NOTIFY_PAUSED + case windows.SERVICE_NO_CHANGE: + return 0 + default: + return 0 } +} - checkForDisabled := func() (shouldReturn bool) { - config, err := service.Config() - if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || config.StartType == windows.SERVICE_DISABLED { - log.Printf("[%s] Found disabled service via timeout, so deleting", tunnelName) - service.Delete() - trackedTunnelsLock.Lock() - trackedTunnels[tunnelName] = TunnelStopped - trackedTunnelsLock.Unlock() - IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil) - return true +func notifyStateToTunState(s uint32) TunnelState { + if s&(windows.SERVICE_NOTIFY_STOPPED|windows.SERVICE_NOTIFY_DELETED) != 0 { + return TunnelStopped + } else if s&(windows.SERVICE_NOTIFY_DELETE_PENDING|windows.SERVICE_NOTIFY_STOP_PENDING) != 0 { + return TunnelStopping + } else if s&windows.SERVICE_NOTIFY_RUNNING != 0 { + return TunnelStarted + } else if s&windows.SERVICE_NOTIFY_START_PENDING != 0 { + return TunnelStarting + } else { + return TunnelUnknown + } +} + +func trackService(service *mgr.Service, callback func(status uint32) bool) error { + var subscription uintptr + state := &serviceSubscriptionState{service: service, cb: callback} + state.done.Add(1) + err := windows.SubscribeServiceChangeNotifications(service.Handle, windows.SC_EVENT_STATUS_CHANGE, serviceSubscriptionCallbackPtr, uintptr(unsafe.Pointer(state)), &subscription) + if err == nil { + defer windows.UnsubscribeServiceChangeNotifications(subscription) + status, err := service.Query() + if err == nil { + if callback(svcStateToNotifyState(uint32(status.State))) { + return nil + } } - return false + state.done.Wait() + runtime.KeepAlive(state.cb) + return nil } - if checkForDisabled() { - return + if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) { + return err } - runtime.LockOSThread() + // TODO: Below this line is Windows 7 compatibility code, which hopefully we can delete at some point. + runtime.LockOSThread() // This line would be fitting but is intentionally commented out: // // defer runtime.UnlockOSThread() @@ -134,7 +142,11 @@ func trackTunnelService(tunnelName string, service *mgr.Service) { // with the thread local context, which in turn appears to corrupt Go's own usage of TLS, // leading to crashes sometime later (usually in runtime_unlock()) when the thread is recycled. - lastState := TunnelUnknown + const serviceNotifications = windows.SERVICE_NOTIFY_RUNNING | windows.SERVICE_NOTIFY_START_PENDING | windows.SERVICE_NOTIFY_STOP_PENDING | windows.SERVICE_NOTIFY_STOPPED | windows.SERVICE_NOTIFY_DELETE_PENDING + notifier := &windows.SERVICE_NOTIFY{ + Version: windows.SERVICE_NOTIFY_STATUS_CHANGE, + NotifyCallback: serviceTrackerCallbackPtr, + } for { err := windows.NotifyServiceStatusChange(service.Handle, serviceNotifications, notifier) switch err { @@ -142,42 +154,94 @@ func trackTunnelService(tunnelName string, service *mgr.Service) { for { if windows.SleepEx(uint32(time.Second*3/time.Millisecond), true) == windows.WAIT_IO_COMPLETION { break - } else if checkForDisabled() { - return + } else if callback(0) { + return nil } } case windows.ERROR_SERVICE_MARKED_FOR_DELETE: - trackedTunnelsLock.Lock() - trackedTunnels[tunnelName] = TunnelStopped - trackedTunnelsLock.Unlock() - IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil) - return + // Should be SERVICE_NOTIFY_DELETE_PENDING, but actually, we must release the handle and return here; otherwise it never deletes. + if callback(windows.SERVICE_NOTIFY_DELETED) { + return nil + } case windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING: continue default: + return err + } + if callback(svcStateToNotifyState(notifier.ServiceStatus.CurrentState)) { + return nil + } + } +} + +func trackTunnelService(tunnelName string, service *mgr.Service) { + trackedTunnelsLock.Lock() + if _, found := trackedTunnels[tunnelName]; found { + trackedTunnelsLock.Unlock() + service.Close() + return + } + + defer func() { + service.Close() + log.Printf("[%s] Tunnel service tracker finished", tunnelName) + }() + trackedTunnels[tunnelName] = TunnelUnknown + trackedTunnelsLock.Unlock() + defer func() { + trackedTunnelsLock.Lock() + delete(trackedTunnels, tunnelName) + trackedTunnelsLock.Unlock() + }() + + for i := 0; i < 20; i++ { + if i > 0 { + time.Sleep(time.Second / 5) + } + if status, err := service.Query(); err != nil || status.State != svc.Stopped { + break + } + } + + checkForDisabled := func() (shouldReturn bool) { + config, err := service.Config() + if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || (err != nil && config.StartType == windows.SERVICE_DISABLED) { + log.Printf("[%s] Found disabled service via timeout, so deleting", tunnelName) + service.Delete() trackedTunnelsLock.Lock() trackedTunnels[tunnelName] = TunnelStopped trackedTunnelsLock.Unlock() - IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring service, so stopping: %v", err)) - service.Control(svc.Stop) - return + IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil) + return true } - - state := svcStateToTunState(svc.State(notifier.ServiceStatus.CurrentState)) + return false + } + if checkForDisabled() { + return + } + lastState := TunnelUnknown + err := trackService(service, func(status uint32) bool { + state := notifyStateToTunState(status) var tunnelError error if state == TunnelStopped { - if notifier.ServiceStatus.Win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) { - maybeErr := services.Error(notifier.ServiceStatus.ServiceSpecificExitCode) - if maybeErr != services.ErrorSuccess { - tunnelError = maybeErr - } - } else { - switch notifier.ServiceStatus.Win32ExitCode { - case uint32(windows.NO_ERROR), uint32(windows.ERROR_SERVICE_NEVER_STARTED): - default: - tunnelError = syscall.Errno(notifier.ServiceStatus.Win32ExitCode) + serviceStatus, err := service.Query() + if err == nil { + if serviceStatus.Win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) { + maybeErr := services.Error(serviceStatus.ServiceSpecificExitCode) + if maybeErr != services.ErrorSuccess { + tunnelError = maybeErr + } + } else { + switch serviceStatus.Win32ExitCode { + case uint32(windows.NO_ERROR), uint32(windows.ERROR_SERVICE_NEVER_STARTED): + default: + tunnelError = syscall.Errno(serviceStatus.Win32ExitCode) + } } } + if tunnelError != nil { + service.Delete() + } } if state != lastState { trackedTunnelsLock.Lock() @@ -186,5 +250,94 @@ func trackTunnelService(tunnelName string, service *mgr.Service) { IPCServerNotifyTunnelChange(tunnelName, state, tunnelError) lastState = state } + if state == TunnelUnknown && checkForDisabled() { + return true + } + return state == TunnelStopped + }) + if err != nil && !checkForDisabled() { + trackedTunnelsLock.Lock() + trackedTunnels[tunnelName] = TunnelStopped + trackedTunnelsLock.Unlock() + IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring service, so stopping: %w", err)) + service.Control(svc.Stop) + } +} + +func trackExistingTunnels() error { + m, err := serviceManager() + if err != nil { + return err + } + names, err := conf.ListConfigNames() + if err != nil { + return err + } + for _, name := range names { + trackedTunnelsLock.Lock() + if _, found := trackedTunnels[name]; found { + trackedTunnelsLock.Unlock() + continue + } + trackedTunnelsLock.Unlock() + serviceName, err := conf.ServiceNameOfTunnel(name) + if err != nil { + continue + } + service, err := m.OpenService(serviceName) + if err != nil { + continue + } + go trackTunnelService(name, service) } + return nil +} + +var servicesSubscriptionWatcherCallbackPtr = windows.NewCallback(func(notification uint32, context uintptr) uintptr { + trackExistingTunnels() + return 0 +}) + +func watchNewTunnelServices() error { + m, err := serviceManager() + if err != nil { + return err + } + var subscription uintptr + err = windows.SubscribeServiceChangeNotifications(m.Handle, windows.SC_EVENT_DATABASE_CHANGE, servicesSubscriptionWatcherCallbackPtr, 0, &subscription) + if err == nil { + // We probably could do: + // defer windows.UnsubscribeServiceChangeNotifications(subscription) + // and then terminate after some point, but instead we just let this go forever; it's process-lived. + return trackExistingTunnels() + } + if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) { + return err + } + + // TODO: Below this line is Windows 7 compatibility code, which hopefully we can delete at some point. + go func() { + runtime.LockOSThread() + notifier := &windows.SERVICE_NOTIFY{ + Version: windows.SERVICE_NOTIFY_STATUS_CHANGE, + NotifyCallback: serviceTrackerCallbackPtr, + } + for { + err := windows.NotifyServiceStatusChange(m.Handle, windows.SERVICE_NOTIFY_CREATED, notifier) + if err == nil { + windows.SleepEx(windows.INFINITE, true) + if notifier.ServiceNames != nil { + windows.LocalFree(windows.Handle(unsafe.Pointer(notifier.ServiceNames))) + notifier.ServiceNames = nil + } + trackExistingTunnels() + } else if err == windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING { + continue + } else { + time.Sleep(time.Second * 3) + trackExistingTunnels() + } + } + }() + return trackExistingTunnels() } diff --git a/manager/uiprocess.go b/manager/uiprocess.go new file mode 100644 index 00000000..b33b1ad3 --- /dev/null +++ b/manager/uiprocess.go @@ -0,0 +1,103 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package manager + +import ( + "errors" + "runtime" + "sync/atomic" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +type uiProcess struct { + handle uintptr +} + +func launchUIProcess(executable string, args []string, workingDirectory string, handles []windows.Handle, token windows.Token) (*uiProcess, error) { + executable16, err := windows.UTF16PtrFromString(executable) + if err != nil { + return nil, err + } + args16, err := windows.UTF16PtrFromString(windows.ComposeCommandLine(args)) + if err != nil { + return nil, err + } + workingDirectory16, err := windows.UTF16PtrFromString(workingDirectory) + if err != nil { + return nil, err + } + var environmentBlock *uint16 + err = windows.CreateEnvironmentBlock(&environmentBlock, token, false) + if err != nil { + return nil, err + } + defer windows.DestroyEnvironmentBlock(environmentBlock) + attributeList, err := windows.NewProcThreadAttributeList(1) + if err != nil { + return nil, err + } + defer attributeList.Delete() + si := &windows.StartupInfoEx{ + StartupInfo: windows.StartupInfo{Cb: uint32(unsafe.Sizeof(windows.StartupInfoEx{}))}, + ProcThreadAttributeList: attributeList.List(), + } + if len(handles) == 0 { + handles = []windows.Handle{0} + } + attributeList.Update(windows.PROC_THREAD_ATTRIBUTE_HANDLE_LIST, unsafe.Pointer(&handles[0]), uintptr(len(handles))*unsafe.Sizeof(handles[0])) + pi := new(windows.ProcessInformation) + err = windows.CreateProcessAsUser(token, executable16, args16, nil, nil, true, windows.CREATE_DEFAULT_ERROR_MODE|windows.CREATE_UNICODE_ENVIRONMENT|windows.EXTENDED_STARTUPINFO_PRESENT, environmentBlock, workingDirectory16, &si.StartupInfo, pi) + if err != nil { + return nil, err + } + windows.CloseHandle(pi.Thread) + uiProc := &uiProcess{handle: uintptr(pi.Process)} + runtime.SetFinalizer(uiProc, (*uiProcess).release) + return uiProc, nil +} + +func (p *uiProcess) release() error { + handle := windows.Handle(atomic.SwapUintptr(&p.handle, uintptr(windows.InvalidHandle))) + if handle == windows.InvalidHandle { + return nil + } + err := windows.CloseHandle(handle) + if err != nil { + return err + } + runtime.SetFinalizer(p, nil) + return nil +} + +func (p *uiProcess) Wait() (uint32, error) { + handle := windows.Handle(atomic.LoadUintptr(&p.handle)) + s, err := windows.WaitForSingleObject(handle, syscall.INFINITE) + switch s { + case windows.WAIT_OBJECT_0: + case windows.WAIT_FAILED: + return 0, err + default: + return 0, errors.New("unexpected result from WaitForSingleObject") + } + var exitCode uint32 + err = windows.GetExitCodeProcess(handle, &exitCode) + if err != nil { + return 0, err + } + p.release() + return exitCode, nil +} + +func (p *uiProcess) Kill() error { + handle := windows.Handle(atomic.LoadUintptr(&p.handle)) + if handle == windows.InvalidHandle { + return nil + } + return windows.TerminateProcess(handle, 1) +} diff --git a/manager/updatestate.go b/manager/updatestate.go index b54cc367..d5a19c8d 100644 --- a/manager/updatestate.go +++ b/manager/updatestate.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package manager @@ -8,11 +8,16 @@ package manager import ( "log" "time" + _ "unsafe" + "golang.zx2c4.com/wireguard/windows/services" "golang.zx2c4.com/wireguard/windows/updater" "golang.zx2c4.com/wireguard/windows/version" ) +//go:linkname fastrandn runtime.fastrandn +func fastrandn(n uint32) uint32 + type UpdateState uint32 const ( @@ -23,35 +28,38 @@ const ( var updateState = UpdateStateUnknown -func checkForUpdates() { - defer printPanic() +func jitterSleep(min, max time.Duration) { + time.Sleep(min + time.Millisecond*time.Duration(fastrandn(uint32((max-min+1)/time.Millisecond)))) +} +func checkForUpdates() { if !version.IsRunningOfficialVersion() { log.Println("Build is not official, so updates are disabled") updateState = UpdateStateUpdatesDisabledUnofficialBuild IPCServerNotifyUpdateFound(updateState) return } - - first := true + if services.StartedAtBoot() { + jitterSleep(time.Minute*2, time.Minute*5) + } + noError, didNotify := true, false for { update, err := updater.CheckForUpdate() - if err == nil && update != nil { + if err == nil && update != nil && !didNotify { log.Println("An update is available") updateState = UpdateStateFoundUpdate IPCServerNotifyUpdateFound(updateState) - return - } - if err != nil { + didNotify = true + } else if err != nil && !didNotify { log.Printf("Update checker: %v", err) - if first { - time.Sleep(time.Minute * 4) - first = false + if noError { + jitterSleep(time.Minute*4, time.Minute*6) + noError = false } else { - time.Sleep(time.Minute * 25) + jitterSleep(time.Minute*25, time.Minute*30) } } else { - time.Sleep(time.Hour) + jitterSleep(time.Hour-time.Minute*3, time.Hour+time.Minute*3) } } } |