diff options
Diffstat (limited to 'manager')
-rw-r--r-- | manager/install.go | 51 | ||||
-rw-r--r-- | manager/interfacecleanup.go | 56 | ||||
-rw-r--r-- | manager/ipc_client.go | 17 | ||||
-rw-r--r-- | manager/ipc_driver.go | 61 | ||||
-rw-r--r-- | manager/ipc_server.go | 128 | ||||
-rw-r--r-- | manager/ipc_uapi.go | 71 | ||||
-rw-r--r-- | manager/legacystore.go | 129 | ||||
-rw-r--r-- | manager/service.go | 65 | ||||
-rw-r--r-- | manager/tunneltracker.go | 130 | ||||
-rw-r--r-- | manager/uiprocess.go | 103 | ||||
-rw-r--r-- | manager/updatestate.go | 34 |
11 files changed, 424 insertions, 421 deletions
diff --git a/manager/install.go b/manager/install.go index bd7d23c7..44a744cf 100644 --- a/manager/install.go +++ b/manager/install.go @@ -1,13 +1,15 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package manager import ( "errors" + "log" "os" + "strings" "time" "golang.org/x/sys/windows" @@ -15,7 +17,6 @@ import ( "golang.org/x/sys/windows/svc/mgr" "golang.zx2c4.com/wireguard/windows/conf" - "golang.zx2c4.com/wireguard/windows/services" ) var cachedServiceManager *mgr.Mgr @@ -128,7 +129,7 @@ func InstallTunnel(configPath string) error { return err } - serviceName, err := services.ServiceNameOfTunnel(name) + serviceName, err := conf.ServiceNameOfTunnel(name) if err != nil { return err } @@ -181,7 +182,7 @@ func UninstallTunnel(name string) error { if err != nil { return err } - serviceName, err := services.ServiceNameOfTunnel(name) + serviceName, err := conf.ServiceNameOfTunnel(name) if err != nil { return err } @@ -197,3 +198,45 @@ func UninstallTunnel(name string) error { } return err2 } + +func changeTunnelServiceConfigFilePath(name, oldPath, newPath string) { + var err error + defer func() { + if err != nil { + log.Printf("Unable to change tunnel service command line argument from %#q to %#q: %v", oldPath, newPath, err) + } + }() + m, err := serviceManager() + if err != nil { + return + } + serviceName, err := conf.ServiceNameOfTunnel(name) + if err != nil { + return + } + service, err := m.OpenService(serviceName) + if err == windows.ERROR_SERVICE_DOES_NOT_EXIST { + err = nil + return + } else if err != nil { + return + } + defer service.Close() + config, err := service.Config() + if err != nil { + return + } + exePath, err := os.Executable() + if err != nil { + return + } + args, err := windows.DecomposeCommandLine(config.BinaryPathName) + if err != nil || len(args) != 3 || + !strings.EqualFold(args[0], exePath) || args[1] != "/tunnelservice" || !strings.EqualFold(args[2], oldPath) { + err = nil + return + } + args[2] = newPath + config.BinaryPathName = windows.ComposeCommandLine(args) + err = service.UpdateConfig(config) +} diff --git a/manager/interfacecleanup.go b/manager/interfacecleanup.go deleted file mode 100644 index c270f4ab..00000000 --- a/manager/interfacecleanup.go +++ /dev/null @@ -1,56 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. - */ - -package manager - -import ( - "log" - - "golang.org/x/sys/windows" - "golang.org/x/sys/windows/svc" - "golang.org/x/sys/windows/svc/mgr" - "golang.zx2c4.com/wireguard/tun/wintun" - - "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/windows/services" -) - -func cleanupStaleWintunInterfaces() { - m, err := mgr.Connect() - if err != nil { - return - } - defer m.Disconnect() - - tun.WintunPool.DeleteMatchingAdapters(func(wintun *wintun.Adapter) bool { - interfaceName, err := wintun.Name() - if err != nil { - log.Printf("Removing Wintun interface because determining interface name failed: %v", err) - return true - } - serviceName, err := services.ServiceNameOfTunnel(interfaceName) - if err != nil { - log.Printf("Removing Wintun interface ‘%s’ because determining tunnel service name failed: %v", interfaceName, err) - return true - } - service, err := m.OpenService(serviceName) - if err == windows.ERROR_SERVICE_DOES_NOT_EXIST { - log.Printf("Removing Wintun interface ‘%s’ because no service for it exists", interfaceName) - return true - } else if err != nil { - return false - } - defer service.Close() - status, err := service.Query() - if err != nil { - return false - } - if status.State == svc.Stopped { - log.Printf("Removing Wintun interface ‘%s’ because its service is stopped", interfaceName) - return true - } - return false - }, false) -} diff --git a/manager/ipc_client.go b/manager/ipc_client.go index 2f78a47e..8c9c4c04 100644 --- a/manager/ipc_client.go +++ b/manager/ipc_client.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package manager @@ -64,7 +64,7 @@ var ( ) type TunnelChangeCallback struct { - cb func(tunnel *Tunnel, state TunnelState, globalState TunnelState, err error) + cb func(tunnel *Tunnel, state, globalState TunnelState, err error) } var tunnelChangeCallbacks = make(map[*TunnelChangeCallback]bool) @@ -93,7 +93,7 @@ type UpdateProgressCallback struct { var updateProgressCallbacks = make(map[*UpdateProgressCallback]bool) -func InitializeIPCClient(reader *os.File, writer *os.File, events *os.File) { +func InitializeIPCClient(reader, writer, events *os.File) { rpcDecoder = gob.NewDecoder(reader) rpcEncoder = gob.NewEncoder(writer) go func() { @@ -431,43 +431,52 @@ func IPCClientUpdate() error { return rpcEncoder.Encode(UpdateMethodType) } -func IPCClientRegisterTunnelChange(cb func(tunnel *Tunnel, state TunnelState, globalState TunnelState, err error)) *TunnelChangeCallback { +func IPCClientRegisterTunnelChange(cb func(tunnel *Tunnel, state, globalState TunnelState, err error)) *TunnelChangeCallback { s := &TunnelChangeCallback{cb} tunnelChangeCallbacks[s] = true return s } + func (cb *TunnelChangeCallback) Unregister() { delete(tunnelChangeCallbacks, cb) } + func IPCClientRegisterTunnelsChange(cb func()) *TunnelsChangeCallback { s := &TunnelsChangeCallback{cb} tunnelsChangeCallbacks[s] = true return s } + func (cb *TunnelsChangeCallback) Unregister() { delete(tunnelsChangeCallbacks, cb) } + func IPCClientRegisterManagerStopping(cb func()) *ManagerStoppingCallback { s := &ManagerStoppingCallback{cb} managerStoppingCallbacks[s] = true return s } + func (cb *ManagerStoppingCallback) Unregister() { delete(managerStoppingCallbacks, cb) } + func IPCClientRegisterUpdateFound(cb func(updateState UpdateState)) *UpdateFoundCallback { s := &UpdateFoundCallback{cb} updateFoundCallbacks[s] = true return s } + func (cb *UpdateFoundCallback) Unregister() { delete(updateFoundCallbacks, cb) } + func IPCClientRegisterUpdateProgress(cb func(dp updater.DownloadProgress)) *UpdateProgressCallback { s := &UpdateProgressCallback{cb} updateProgressCallbacks[s] = true return s } + func (cb *UpdateProgressCallback) Unregister() { delete(updateProgressCallbacks, cb) } diff --git a/manager/ipc_driver.go b/manager/ipc_driver.go new file mode 100644 index 00000000..6cb43c38 --- /dev/null +++ b/manager/ipc_driver.go @@ -0,0 +1,61 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package manager + +import ( + "sync" + + "golang.zx2c4.com/wireguard/windows/driver" +) + +type lockedDriverAdapter struct { + *driver.Adapter + sync.Mutex +} + +var ( + driverAdapters = make(map[string]*lockedDriverAdapter) + driverAdaptersLock sync.RWMutex +) + +func findDriverAdapter(tunnelName string) (*lockedDriverAdapter, error) { + driverAdaptersLock.RLock() + driverAdapter, ok := driverAdapters[tunnelName] + if ok { + driverAdapter.Lock() + driverAdaptersLock.RUnlock() + return driverAdapter, nil + } + driverAdaptersLock.RUnlock() + driverAdaptersLock.Lock() + defer driverAdaptersLock.Unlock() + driverAdapter, ok = driverAdapters[tunnelName] + if ok { + driverAdapter.Lock() + return driverAdapter, nil + } + driverAdapter = &lockedDriverAdapter{} + var err error + driverAdapter.Adapter, err = driver.OpenAdapter(tunnelName) + if err != nil { + return nil, err + } + driverAdapters[tunnelName] = driverAdapter + driverAdapter.Lock() + return driverAdapter, nil +} + +func releaseDriverAdapter(tunnelName string) { + driverAdaptersLock.Lock() + defer driverAdaptersLock.Unlock() + driverAdapter, ok := driverAdapters[tunnelName] + if !ok { + return + } + driverAdapter.Lock() + delete(driverAdapters, tunnelName) + driverAdapter.Unlock() +} diff --git a/manager/ipc_server.go b/manager/ipc_server.go index 8ba050f9..e21ffaf0 100644 --- a/manager/ipc_server.go +++ b/manager/ipc_server.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package manager @@ -20,14 +20,15 @@ import ( "golang.org/x/sys/windows/svc" "golang.zx2c4.com/wireguard/windows/conf" - "golang.zx2c4.com/wireguard/windows/services" "golang.zx2c4.com/wireguard/windows/updater" ) -var managerServices = make(map[*ManagerService]bool) -var managerServicesLock sync.RWMutex -var haveQuit uint32 -var quitManagersChan = make(chan struct{}, 1) +var ( + managerServices = make(map[*ManagerService]bool) + managerServicesLock sync.RWMutex + haveQuit uint32 + quitManagersChan = make(chan struct{}, 1) +) type ManagerService struct { events *os.File @@ -51,33 +52,18 @@ func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) if err != nil { return nil, err } - pipe, err := connectTunnelServicePipe(tunnelName) - if err != nil { - return nil, err - } - pipe.SetDeadline(time.Now().Add(time.Second * 2)) - _, err = pipe.Write([]byte("get=1\n\n")) - if err == windows.ERROR_NO_DATA { - log.Println("IPC pipe closed unexpectedly, so reopening") - pipe.Unlock() - disconnectTunnelServicePipe(tunnelName) - pipe, err = connectTunnelServicePipe(tunnelName) - if err != nil { - return nil, err - } - pipe.SetDeadline(time.Now().Add(time.Second * 2)) - _, err = pipe.Write([]byte("get=1\n\n")) - } + driverAdapter, err := findDriverAdapter(tunnelName) if err != nil { - pipe.Unlock() - disconnectTunnelServicePipe(tunnelName) return nil, err } - conf, err := conf.FromUAPI(pipe, storedConfig) - pipe.Unlock() + runtimeConfig, err := driverAdapter.Configuration() if err != nil { + driverAdapter.Unlock() + releaseDriverAdapter(tunnelName) return nil, err } + conf := conf.FromDriverConfiguration(runtimeConfig, storedConfig) + driverAdapter.Unlock() if s.elevatedToken == 0 { conf.Redact() } @@ -85,44 +71,47 @@ func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) } func (s *ManagerService) Start(tunnelName string) error { - // TODO: Rather than being lazy and gating this behind a knob (yuck!), we should instead keep track of the routes - // of each tunnel, and only deactivate in the case of a tunnel with identical routes being added. - if !conf.AdminBool("MultipleSimultaneousTunnels") { - trackedTunnelsLock.Lock() - tt := make([]string, 0, len(trackedTunnels)) - var inTransition string - for t, state := range trackedTunnels { - tt = append(tt, t) - if len(t) > 0 && (state == TunnelStarting || state == TunnelUnknown) { - inTransition = t - break - } - } - trackedTunnelsLock.Unlock() - if len(inTransition) != 0 { - return fmt.Errorf("Please allow the tunnel ‘%s’ to finish activating", inTransition) - } - go func() { - for _, t := range tt { - s.Stop(t) - } - for _, t := range tt { - state, err := s.State(t) - if err == nil && (state == TunnelStarted || state == TunnelStarting) { - log.Printf("[%s] Trying again to stop zombie tunnel", t) - s.Stop(t) - time.Sleep(time.Millisecond * 100) - } - } - }() - } - time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces) - - // After that process is started -- it's somewhat asynchronous -- we install the new one. c, err := conf.LoadFromName(tunnelName) if err != nil { return err } + + // Figure out which tunnels have intersecting addresses/routes and stop those. + trackedTunnelsLock.Lock() + tt := make([]string, 0, len(trackedTunnels)) + var inTransition string + for t, state := range trackedTunnels { + c2, err := conf.LoadFromName(t) + if err != nil || !c.IntersectsWith(c2) { + // If we can't get the config, assume it doesn't intersect. + continue + } + tt = append(tt, t) + if len(t) > 0 && (state == TunnelStarting || state == TunnelUnknown) { + inTransition = t + break + } + } + trackedTunnelsLock.Unlock() + if len(inTransition) != 0 { + return fmt.Errorf("Please allow the tunnel ‘%s’ to finish activating", inTransition) + } + + // Stop those intersecting tunnels asynchronously. + go func() { + for _, t := range tt { + s.Stop(t) + } + for _, t := range tt { + state, err := s.State(t) + if err == nil && (state == TunnelStarted || state == TunnelStarting) { + log.Printf("[%s] Trying again to stop zombie tunnel", t) + s.Stop(t) + time.Sleep(time.Millisecond * 100) + } + } + }() + // After the stop process has begun, but before it's finished, we install the new one. path, err := c.Path() if err != nil { return err @@ -131,8 +120,6 @@ func (s *ManagerService) Start(tunnelName string) error { } func (s *ManagerService) Stop(tunnelName string) error { - time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces) - err := UninstallTunnel(tunnelName) if err == windows.ERROR_SERVICE_DOES_NOT_EXIST { _, notExistsError := conf.LoadFromName(tunnelName) @@ -144,7 +131,7 @@ func (s *ManagerService) Stop(tunnelName string) error { } func (s *ManagerService) WaitForStop(tunnelName string) error { - serviceName, err := services.ServiceNameOfTunnel(tunnelName) + serviceName, err := conf.ServiceNameOfTunnel(tunnelName) if err != nil { return err } @@ -175,7 +162,7 @@ func (s *ManagerService) Delete(tunnelName string) error { } func (s *ManagerService) State(tunnelName string) (TunnelState, error) { - serviceName, err := services.ServiceNameOfTunnel(tunnelName) + serviceName, err := conf.ServiceNameOfTunnel(tunnelName) if err != nil { return 0, err } @@ -230,7 +217,7 @@ func (s *ManagerService) Tunnels() ([]Tunnel, error) { } tunnels := make([]Tunnel, len(names)) for i := 0; i < len(tunnels); i++ { - (tunnels)[i].Name = names[i] + tunnels[i].Name = names[i] } return tunnels, nil // TODO: account for running ones that aren't in the configuration store somehow @@ -267,9 +254,6 @@ func (s *ManagerService) Quit(stopTunnelsOnQuit bool) (alreadyQuit bool, err err } func (s *ManagerService) UpdateState() UpdateState { - if s.elevatedToken == 0 { - return UpdateStateUnknown - } return updateState } @@ -457,7 +441,7 @@ func (s *ManagerService) ServeConn(reader io.Reader, writer io.Writer) { } } -func IPCServerListen(reader *os.File, writer *os.File, events *os.File, elevatedToken windows.Token) { +func IPCServerListen(reader, writer, events *os.File, elevatedToken windows.Token) { service := &ManagerService{ events: events, elevatedToken: elevatedToken, @@ -477,7 +461,7 @@ func IPCServerListen(reader *os.File, writer *os.File, events *os.File, elevated }() } -func notifyAll(notificationType NotificationType, adminOnly bool, ifaces ...interface{}) { +func notifyAll(notificationType NotificationType, adminOnly bool, ifaces ...any) { if len(managerServices) == 0 { return } @@ -528,7 +512,7 @@ func IPCServerNotifyTunnelsChange() { } func IPCServerNotifyUpdateFound(state UpdateState) { - notifyAll(UpdateFoundNotificationType, true, state) + notifyAll(UpdateFoundNotificationType, false, state) } func IPCServerNotifyUpdateProgress(dp updater.DownloadProgress) { diff --git a/manager/ipc_uapi.go b/manager/ipc_uapi.go deleted file mode 100644 index 85477125..00000000 --- a/manager/ipc_uapi.go +++ /dev/null @@ -1,71 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. - */ - -package manager - -import ( - "net" - "sync" - - "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/ipc/winpipe" - - "golang.zx2c4.com/wireguard/windows/services" -) - -type connectedTunnel struct { - net.Conn - sync.Mutex -} - -var connectedTunnelServicePipes = make(map[string]*connectedTunnel) -var connectedTunnelServicePipesLock sync.RWMutex - -func connectTunnelServicePipe(tunnelName string) (*connectedTunnel, error) { - connectedTunnelServicePipesLock.RLock() - pipe, ok := connectedTunnelServicePipes[tunnelName] - if ok { - pipe.Lock() - connectedTunnelServicePipesLock.RUnlock() - return pipe, nil - } - connectedTunnelServicePipesLock.RUnlock() - connectedTunnelServicePipesLock.Lock() - defer connectedTunnelServicePipesLock.Unlock() - pipe, ok = connectedTunnelServicePipes[tunnelName] - if ok { - pipe.Lock() - return pipe, nil - } - pipePath, err := services.PipePathOfTunnel(tunnelName) - if err != nil { - return nil, err - } - localSystem, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid) - if err != nil { - return nil, err - } - pipe = &connectedTunnel{} - pipe.Conn, err = winpipe.Dial(pipePath, nil, &winpipe.DialConfig{ExpectedOwner: localSystem}) - if err != nil { - return nil, err - } - connectedTunnelServicePipes[tunnelName] = pipe - pipe.Lock() - return pipe, nil -} - -func disconnectTunnelServicePipe(tunnelName string) { - connectedTunnelServicePipesLock.Lock() - defer connectedTunnelServicePipesLock.Unlock() - pipe, ok := connectedTunnelServicePipes[tunnelName] - if !ok { - return - } - pipe.Lock() - pipe.Close() - delete(connectedTunnelServicePipes, tunnelName) - pipe.Unlock() -} diff --git a/manager/legacystore.go b/manager/legacystore.go deleted file mode 100644 index 4125cb86..00000000 --- a/manager/legacystore.go +++ /dev/null @@ -1,129 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. - */ - -package manager - -import ( - "fmt" - "log" - "os" - "path/filepath" - "regexp" - "strings" - - "golang.org/x/sys/windows" - "golang.org/x/sys/windows/registry" - "golang.org/x/sys/windows/svc/mgr" - - "golang.zx2c4.com/wireguard/windows/conf" -) - -func moveConfigsFromLegacyStore() { - oldRoot, err := windows.KnownFolderPath(windows.FOLDERID_LocalAppData, windows.KF_FLAG_DEFAULT) - if err != nil { - return - } - oldC := filepath.Join(oldRoot, "WireGuard", "Configurations") - files, err := os.ReadDir(oldC) - if err != nil { - return - } - pendingDeletion := make(map[string]bool) - if key, err := registry.OpenKey(registry.LOCAL_MACHINE, `SYSTEM\CurrentControlSet\Control\Session Manager`, registry.READ); err == nil { - if ntPaths, _, err := key.GetStringsValue("PendingFileRenameOperations"); err == nil { - for _, ntPath := range ntPaths { - pendingDeletion[strings.ToLower(strings.TrimPrefix(ntPath, `\??\`))] = true - } - } - key.Close() - } - migratedConfigs := make(map[string]string) - for i := range files { - if files[i].IsDir() { - continue - } - fileName := files[i].Name() - oldPath := filepath.Join(oldC, fileName) - if pendingDeletion[strings.ToLower(oldPath)] { - continue - } - config, err := conf.LoadFromPath(oldPath) - if err != nil { - continue - } - newPath, err := config.Path() - if err != nil { - continue - } - err = config.Save(false) - if err != nil { - continue - } - oldPath16, err := windows.UTF16PtrFromString(oldPath) - if err == nil { - windows.MoveFileEx(oldPath16, nil, windows.MOVEFILE_DELAY_UNTIL_REBOOT) - } - migratedConfigs[strings.ToLower(oldPath)] = newPath - log.Printf("Migrated configuration from ‘%s’ to ‘%s’", oldPath, newPath) - } - oldC16, err := windows.UTF16PtrFromString(oldC) - if err == nil { - windows.MoveFileEx(oldC16, nil, windows.MOVEFILE_DELAY_UNTIL_REBOOT) - } - oldLog16, err := windows.UTF16PtrFromString(filepath.Join(oldRoot, "WireGuard", "log.bin")) - if err == nil { - windows.MoveFileEx(oldLog16, nil, windows.MOVEFILE_DELAY_UNTIL_REBOOT) - } - oldRoot16, err := windows.UTF16PtrFromString(filepath.Join(oldRoot, "WireGuard")) - if err == nil { - windows.MoveFileEx(oldRoot16, nil, windows.MOVEFILE_DELAY_UNTIL_REBOOT) - } - if len(migratedConfigs) == 0 { - return - } - m, err := mgr.Connect() - if err != nil { - return - } - defer m.Disconnect() - services, err := m.ListServices() - if err != nil { - return - } - matcher, err := regexp.Compile(" /tunnelservice \"?([^\"]+)\"?$") - if err != nil { - return - } - for _, svcName := range services { - if !strings.HasPrefix(svcName, "WireGuardTunnel$") { - continue - } - svc, err := m.OpenService(svcName) - if err != nil { - continue - } - config, err := svc.Config() - if err != nil { - continue - } - matches := matcher.FindStringSubmatchIndex(config.BinaryPathName) - if len(matches) != 4 { - svc.Close() - continue - } - newName, found := migratedConfigs[strings.ToLower(config.BinaryPathName[matches[2]:])] - if !found { - svc.Close() - continue - } - config.BinaryPathName = config.BinaryPathName[:matches[0]] + fmt.Sprintf(" /tunnelservice \"%s\"", newName) - err = svc.UpdateConfig(config) - svc.Close() - if err != nil { - continue - } - log.Printf("Migrated service command line arguments for ‘%s’", svcName) - } -} diff --git a/manager/service.go b/manager/service.go index 99187eab..47e20d45 100644 --- a/manager/service.go +++ b/manager/service.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package manager @@ -12,18 +12,17 @@ import ( "runtime" "strconv" "sync" - "syscall" "time" "unsafe" "golang.org/x/sys/windows" "golang.org/x/sys/windows/svc" + "golang.zx2c4.com/wireguard/windows/driver" "golang.zx2c4.com/wireguard/windows/conf" "golang.zx2c4.com/wireguard/windows/elevate" "golang.zx2c4.com/wireguard/windows/ringlogger" "golang.zx2c4.com/wireguard/windows/services" - "golang.zx2c4.com/wireguard/windows/version" ) type managerService struct{} @@ -43,38 +42,36 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest changes <- svc.Status{State: svc.StopPending} }() - err = ringlogger.InitGlobalLogger("MGR") + var logFile string + logFile, err = conf.LogFile(true) if err != nil { serviceError = services.ErrorRingloggerOpen return } - - log.Println("Starting", version.UserAgent()) - - path, err := os.Executable() + err = ringlogger.InitGlobalLogger(logFile, "MGR") if err != nil { - serviceError = services.ErrorDetermineExecutablePath + serviceError = services.ErrorRingloggerOpen return } - devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0) + services.PrintStarting() + + path, err := os.Executable() if err != nil { - serviceError = services.ErrorOpenNULFile + serviceError = services.ErrorDetermineExecutablePath return } - moveConfigsFromLegacyStore() - - err = trackExistingTunnels() + err = watchNewTunnelServices() if err != nil { serviceError = services.ErrorTrackTunnels return } - conf.RegisterStoreChangeCallback(conf.MigrateUnencryptedConfigs) + conf.RegisterStoreChangeCallback(func() { conf.MigrateUnencryptedConfigs(changeTunnelServiceConfigFilePath) }) conf.RegisterStoreChangeCallback(IPCServerNotifyTunnelsChange) - procs := make(map[uint32]*os.Process) + procs := make(map[uint32]*uiProcess) aliveSessions := make(map[uint32]bool) procsLock := sync.Mutex{} stoppingManager := false @@ -196,29 +193,22 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest } log.Printf("Starting UI process for user ‘%s@%s’ for session %d", username, domain, session) - attr := &os.ProcAttr{ - Sys: &syscall.SysProcAttr{ - Token: syscall.Token(runToken), - AdditionalInheritedHandles: []syscall.Handle{ - syscall.Handle(theirReader.Fd()), - syscall.Handle(theirWriter.Fd()), - syscall.Handle(theirEvents.Fd()), - syscall.Handle(theirLogMapping)}, - }, - Files: []*os.File{devNull, devNull, devNull}, - Dir: userProfileDirectory, - } procsLock.Lock() - var proc *os.Process + var proc *uiProcess if alive := aliveSessions[session]; alive { - proc, err = os.StartProcess(path, []string{ + proc, err = launchUIProcess(path, []string{ path, "/ui", strconv.FormatUint(uint64(theirReader.Fd()), 10), strconv.FormatUint(uint64(theirWriter.Fd()), 10), strconv.FormatUint(uint64(theirEvents.Fd()), 10), strconv.FormatUint(uint64(theirLogMapping), 10), - }, attr) + }, userProfileDirectory, []windows.Handle{ + windows.Handle(theirReader.Fd()), + windows.Handle(theirWriter.Fd()), + windows.Handle(theirEvents.Fd()), + theirLogMapping, + }, runToken) } else { err = errors.New("Session has logged out") } @@ -240,9 +230,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest procsLock.Unlock() sessionIsDead := false - processStatus, err := proc.Wait() - if err == nil { - exitCode := processStatus.Sys().(syscall.WaitStatus).ExitCode + if exitCode, err := proc.Wait(); err == nil { log.Printf("Exited UI process for user '%s@%s' for session %d with status %x", username, domain, session, exitCode) const STATUS_DLL_INIT_FAILED_LOGOFF = 0xC000026B sessionIsDead = exitCode == STATUS_DLL_INIT_FAILED_LOGOFF @@ -271,8 +259,8 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest }() } - time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces) go checkForUpdates() + go driver.UninstallLegacyWintun() // We uninstall opportunistically here, so that we don't have to carry around the uninstaller code forever. var sessionsPointer *windows.WTS_SESSION_INFO var count uint32 @@ -281,12 +269,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest serviceError = services.ErrorEnumerateSessions return } - sessions := *(*[]windows.WTS_SESSION_INFO)(unsafe.Pointer(&struct { - addr *windows.WTS_SESSION_INFO - len int - cap int - }{sessionsPointer, int(count), int(count)})) - for _, session := range sessions { + for _, session := range unsafe.Slice(sessionsPointer, count) { if session.State != windows.WTSActive && session.State != windows.WTSDisconnected { continue } diff --git a/manager/tunneltracker.go b/manager/tunneltracker.go index b32450a3..9003d445 100644 --- a/manager/tunneltracker.go +++ b/manager/tunneltracker.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package manager @@ -24,31 +24,10 @@ import ( "golang.zx2c4.com/wireguard/windows/services" ) -func trackExistingTunnels() error { - m, err := serviceManager() - if err != nil { - return err - } - names, err := conf.ListConfigNames() - if err != nil { - return err - } - for _, name := range names { - serviceName, err := services.ServiceNameOfTunnel(name) - if err != nil { - continue - } - service, err := m.OpenService(serviceName) - if err != nil { - continue - } - go trackTunnelService(name, service) - } - return nil -} - -var trackedTunnels = make(map[string]TunnelState) -var trackedTunnelsLock = sync.Mutex{} +var ( + trackedTunnels = make(map[string]TunnelState) + trackedTunnelsLock = sync.Mutex{} +) func trackedTunnelsGlobalState() (state TunnelState) { state = TunnelStopped @@ -196,16 +175,17 @@ func trackService(service *mgr.Service, callback func(status uint32) bool) error } func trackTunnelService(tunnelName string, service *mgr.Service) { - defer func() { - service.Close() - log.Printf("[%s] Tunnel service tracker finished", tunnelName) - }() - trackedTunnelsLock.Lock() if _, found := trackedTunnels[tunnelName]; found { trackedTunnelsLock.Unlock() + service.Close() return } + + defer func() { + service.Close() + log.Printf("[%s] Tunnel service tracker finished", tunnelName) + }() trackedTunnels[tunnelName] = TunnelUnknown trackedTunnelsLock.Unlock() defer func() { @@ -214,6 +194,15 @@ func trackTunnelService(tunnelName string, service *mgr.Service) { trackedTunnelsLock.Unlock() }() + for i := 0; i < 20; i++ { + if i > 0 { + time.Sleep(time.Second / 5) + } + if status, err := service.Query(); err != nil || status.State != svc.Stopped { + break + } + } + checkForDisabled := func() (shouldReturn bool) { config, err := service.Config() if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || (err != nil && config.StartType == windows.SERVICE_DISABLED) { @@ -273,5 +262,82 @@ func trackTunnelService(tunnelName string, service *mgr.Service) { IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring service, so stopping: %w", err)) service.Control(svc.Stop) } - disconnectTunnelServicePipe(tunnelName) +} + +func trackExistingTunnels() error { + m, err := serviceManager() + if err != nil { + return err + } + names, err := conf.ListConfigNames() + if err != nil { + return err + } + for _, name := range names { + trackedTunnelsLock.Lock() + if _, found := trackedTunnels[name]; found { + trackedTunnelsLock.Unlock() + continue + } + trackedTunnelsLock.Unlock() + serviceName, err := conf.ServiceNameOfTunnel(name) + if err != nil { + continue + } + service, err := m.OpenService(serviceName) + if err != nil { + continue + } + go trackTunnelService(name, service) + } + return nil +} + +var servicesSubscriptionWatcherCallbackPtr = windows.NewCallback(func(notification uint32, context uintptr) uintptr { + trackExistingTunnels() + return 0 +}) + +func watchNewTunnelServices() error { + m, err := serviceManager() + if err != nil { + return err + } + var subscription uintptr + err = windows.SubscribeServiceChangeNotifications(m.Handle, windows.SC_EVENT_DATABASE_CHANGE, servicesSubscriptionWatcherCallbackPtr, 0, &subscription) + if err == nil { + // We probably could do: + // defer windows.UnsubscribeServiceChangeNotifications(subscription) + // and then terminate after some point, but instead we just let this go forever; it's process-lived. + return trackExistingTunnels() + } + if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) { + return err + } + + // TODO: Below this line is Windows 7 compatibility code, which hopefully we can delete at some point. + go func() { + runtime.LockOSThread() + notifier := &windows.SERVICE_NOTIFY{ + Version: windows.SERVICE_NOTIFY_STATUS_CHANGE, + NotifyCallback: serviceTrackerCallbackPtr, + } + for { + err := windows.NotifyServiceStatusChange(m.Handle, windows.SERVICE_NOTIFY_CREATED, notifier) + if err == nil { + windows.SleepEx(windows.INFINITE, true) + if notifier.ServiceNames != nil { + windows.LocalFree(windows.Handle(unsafe.Pointer(notifier.ServiceNames))) + notifier.ServiceNames = nil + } + trackExistingTunnels() + } else if err == windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING { + continue + } else { + time.Sleep(time.Second * 3) + trackExistingTunnels() + } + } + }() + return trackExistingTunnels() } diff --git a/manager/uiprocess.go b/manager/uiprocess.go new file mode 100644 index 00000000..b33b1ad3 --- /dev/null +++ b/manager/uiprocess.go @@ -0,0 +1,103 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package manager + +import ( + "errors" + "runtime" + "sync/atomic" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +type uiProcess struct { + handle uintptr +} + +func launchUIProcess(executable string, args []string, workingDirectory string, handles []windows.Handle, token windows.Token) (*uiProcess, error) { + executable16, err := windows.UTF16PtrFromString(executable) + if err != nil { + return nil, err + } + args16, err := windows.UTF16PtrFromString(windows.ComposeCommandLine(args)) + if err != nil { + return nil, err + } + workingDirectory16, err := windows.UTF16PtrFromString(workingDirectory) + if err != nil { + return nil, err + } + var environmentBlock *uint16 + err = windows.CreateEnvironmentBlock(&environmentBlock, token, false) + if err != nil { + return nil, err + } + defer windows.DestroyEnvironmentBlock(environmentBlock) + attributeList, err := windows.NewProcThreadAttributeList(1) + if err != nil { + return nil, err + } + defer attributeList.Delete() + si := &windows.StartupInfoEx{ + StartupInfo: windows.StartupInfo{Cb: uint32(unsafe.Sizeof(windows.StartupInfoEx{}))}, + ProcThreadAttributeList: attributeList.List(), + } + if len(handles) == 0 { + handles = []windows.Handle{0} + } + attributeList.Update(windows.PROC_THREAD_ATTRIBUTE_HANDLE_LIST, unsafe.Pointer(&handles[0]), uintptr(len(handles))*unsafe.Sizeof(handles[0])) + pi := new(windows.ProcessInformation) + err = windows.CreateProcessAsUser(token, executable16, args16, nil, nil, true, windows.CREATE_DEFAULT_ERROR_MODE|windows.CREATE_UNICODE_ENVIRONMENT|windows.EXTENDED_STARTUPINFO_PRESENT, environmentBlock, workingDirectory16, &si.StartupInfo, pi) + if err != nil { + return nil, err + } + windows.CloseHandle(pi.Thread) + uiProc := &uiProcess{handle: uintptr(pi.Process)} + runtime.SetFinalizer(uiProc, (*uiProcess).release) + return uiProc, nil +} + +func (p *uiProcess) release() error { + handle := windows.Handle(atomic.SwapUintptr(&p.handle, uintptr(windows.InvalidHandle))) + if handle == windows.InvalidHandle { + return nil + } + err := windows.CloseHandle(handle) + if err != nil { + return err + } + runtime.SetFinalizer(p, nil) + return nil +} + +func (p *uiProcess) Wait() (uint32, error) { + handle := windows.Handle(atomic.LoadUintptr(&p.handle)) + s, err := windows.WaitForSingleObject(handle, syscall.INFINITE) + switch s { + case windows.WAIT_OBJECT_0: + case windows.WAIT_FAILED: + return 0, err + default: + return 0, errors.New("unexpected result from WaitForSingleObject") + } + var exitCode uint32 + err = windows.GetExitCodeProcess(handle, &exitCode) + if err != nil { + return 0, err + } + p.release() + return exitCode, nil +} + +func (p *uiProcess) Kill() error { + handle := windows.Handle(atomic.LoadUintptr(&p.handle)) + if handle == windows.InvalidHandle { + return nil + } + return windows.TerminateProcess(handle, 1) +} diff --git a/manager/updatestate.go b/manager/updatestate.go index 069e9b8a..d5a19c8d 100644 --- a/manager/updatestate.go +++ b/manager/updatestate.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package manager @@ -8,11 +8,16 @@ package manager import ( "log" "time" + _ "unsafe" + "golang.zx2c4.com/wireguard/windows/services" "golang.zx2c4.com/wireguard/windows/updater" "golang.zx2c4.com/wireguard/windows/version" ) +//go:linkname fastrandn runtime.fastrandn +func fastrandn(n uint32) uint32 + type UpdateState uint32 const ( @@ -23,6 +28,10 @@ const ( var updateState = UpdateStateUnknown +func jitterSleep(min, max time.Duration) { + time.Sleep(min + time.Millisecond*time.Duration(fastrandn(uint32((max-min+1)/time.Millisecond)))) +} + func checkForUpdates() { if !version.IsRunningOfficialVersion() { log.Println("Build is not official, so updates are disabled") @@ -30,26 +39,27 @@ func checkForUpdates() { IPCServerNotifyUpdateFound(updateState) return } - - first := true + if services.StartedAtBoot() { + jitterSleep(time.Minute*2, time.Minute*5) + } + noError, didNotify := true, false for { update, err := updater.CheckForUpdate() - if err == nil && update != nil { + if err == nil && update != nil && !didNotify { log.Println("An update is available") updateState = UpdateStateFoundUpdate IPCServerNotifyUpdateFound(updateState) - return - } - if err != nil { + didNotify = true + } else if err != nil && !didNotify { log.Printf("Update checker: %v", err) - if first { - time.Sleep(time.Minute * 4) - first = false + if noError { + jitterSleep(time.Minute*4, time.Minute*6) + noError = false } else { - time.Sleep(time.Minute * 25) + jitterSleep(time.Minute*25, time.Minute*30) } } else { - time.Sleep(time.Hour) + jitterSleep(time.Hour-time.Minute*3, time.Hour+time.Minute*3) } } } |