diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-10-25 14:42:00 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-10-26 11:54:56 +0200 |
commit | e9c7358b01ea282646d0f09b201072096aae810f (patch) | |
tree | d9e555cddb09544ff5711869d615cb7c9b5a567f | |
parent | services: remove unused pipe path helper (diff) | |
download | wireguard-windows-e9c7358b01ea282646d0f09b201072096aae810f.tar.xz wireguard-windows-e9c7358b01ea282646d0f09b201072096aae810f.zip |
services: use more reliable method of detecting boot-up
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r-- | conf/dnsresolver_windows.go | 6 | ||||
-rw-r--r-- | conf/name.go | 8 | ||||
-rw-r--r-- | manager/install.go | 7 | ||||
-rw-r--r-- | manager/ipc_server.go | 5 | ||||
-rw-r--r-- | manager/service.go | 4 | ||||
-rw-r--r-- | manager/tunneltracker.go | 2 | ||||
-rw-r--r-- | services/boot.go | 47 | ||||
-rw-r--r-- | services/names.go | 19 | ||||
-rw-r--r-- | tunnel/addressconfig.go | 13 | ||||
-rw-r--r-- | tunnel/service.go | 32 |
10 files changed, 88 insertions, 55 deletions
diff --git a/conf/dnsresolver_windows.go b/conf/dnsresolver_windows.go index 85923950..094b1029 100644 --- a/conf/dnsresolver_windows.go +++ b/conf/dnsresolver_windows.go @@ -14,14 +14,14 @@ import ( "unsafe" "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/windows/services" ) //sys internetGetConnectedState(flags *uint32, reserved uint32) (connected bool) = wininet.InternetGetConnectedState func resolveHostname(name string) (resolvedIPString string, err error) { maxTries := 10 - systemJustBooted := windows.DurationSinceBoot() <= time.Minute*10 - if systemJustBooted { + if services.StartedAtBoot() { maxTries *= 4 } for i := 0; i < maxTries; i++ { @@ -37,7 +37,7 @@ func resolveHostname(name string) (resolvedIPString string, err error) { continue } var state uint32 - if err == windows.WSAHOST_NOT_FOUND && systemJustBooted && !internetGetConnectedState(&state, 0) { + if err == windows.WSAHOST_NOT_FOUND && services.StartedAtBoot() && !internetGetConnectedState(&state, 0) { log.Printf("Host not found when resolving %s, but no Internet connection available, sleeping for 4 seconds", name) continue } diff --git a/conf/name.go b/conf/name.go index e3287128..2b42c0e9 100644 --- a/conf/name.go +++ b/conf/name.go @@ -6,6 +6,7 @@ package conf import ( + "errors" "regexp" "strconv" "strings" @@ -114,3 +115,10 @@ func TunnelNameIsLess(a, b string) bool { } return false } + +func ServiceNameOfTunnel(tunnelName string) (string, error) { + if !TunnelNameIsValid(tunnelName) { + return "", errors.New("Tunnel name is not valid") + } + return "WireGuardTunnel$" + tunnelName, nil +} diff --git a/manager/install.go b/manager/install.go index f74fecf0..2ab94bc1 100644 --- a/manager/install.go +++ b/manager/install.go @@ -17,7 +17,6 @@ import ( "golang.org/x/sys/windows/svc/mgr" "golang.zx2c4.com/wireguard/windows/conf" - "golang.zx2c4.com/wireguard/windows/services" ) var cachedServiceManager *mgr.Mgr @@ -130,7 +129,7 @@ func InstallTunnel(configPath string) error { return err } - serviceName, err := services.ServiceNameOfTunnel(name) + serviceName, err := conf.ServiceNameOfTunnel(name) if err != nil { return err } @@ -183,7 +182,7 @@ func UninstallTunnel(name string) error { if err != nil { return err } - serviceName, err := services.ServiceNameOfTunnel(name) + serviceName, err := conf.ServiceNameOfTunnel(name) if err != nil { return err } @@ -211,7 +210,7 @@ func changeTunnelServiceConfigFilePath(name, oldPath, newPath string) { if err != nil { return } - serviceName, err := services.ServiceNameOfTunnel(name) + serviceName, err := conf.ServiceNameOfTunnel(name) if err != nil { return } diff --git a/manager/ipc_server.go b/manager/ipc_server.go index 504b5cce..dccbce28 100644 --- a/manager/ipc_server.go +++ b/manager/ipc_server.go @@ -20,7 +20,6 @@ import ( "golang.org/x/sys/windows/svc" "golang.zx2c4.com/wireguard/windows/conf" - "golang.zx2c4.com/wireguard/windows/services" "golang.zx2c4.com/wireguard/windows/updater" ) @@ -130,7 +129,7 @@ func (s *ManagerService) Stop(tunnelName string) error { } func (s *ManagerService) WaitForStop(tunnelName string) error { - serviceName, err := services.ServiceNameOfTunnel(tunnelName) + serviceName, err := conf.ServiceNameOfTunnel(tunnelName) if err != nil { return err } @@ -161,7 +160,7 @@ func (s *ManagerService) Delete(tunnelName string) error { } func (s *ManagerService) State(tunnelName string) (TunnelState, error) { - serviceName, err := services.ServiceNameOfTunnel(tunnelName) + serviceName, err := conf.ServiceNameOfTunnel(tunnelName) if err != nil { return 0, err } diff --git a/manager/service.go b/manager/service.go index 15ea88f0..db03c3a8 100644 --- a/manager/service.go +++ b/manager/service.go @@ -22,7 +22,6 @@ import ( "golang.zx2c4.com/wireguard/windows/elevate" "golang.zx2c4.com/wireguard/windows/ringlogger" "golang.zx2c4.com/wireguard/windows/services" - "golang.zx2c4.com/wireguard/windows/version" ) type managerService struct{} @@ -54,8 +53,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest return } - log.Println("Starting", version.UserAgent()) - + services.PrintStarting() checkForPitfalls() path, err := os.Executable() diff --git a/manager/tunneltracker.go b/manager/tunneltracker.go index 103388f2..96020635 100644 --- a/manager/tunneltracker.go +++ b/manager/tunneltracker.go @@ -278,7 +278,7 @@ func trackExistingTunnels() error { continue } trackedTunnelsLock.Unlock() - serviceName, err := services.ServiceNameOfTunnel(name) + serviceName, err := conf.ServiceNameOfTunnel(name) if err != nil { continue } diff --git a/services/boot.go b/services/boot.go new file mode 100644 index 00000000..bdcaac2f --- /dev/null +++ b/services/boot.go @@ -0,0 +1,47 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + */ + +package services + +import ( + "errors" + "log" + "sync" + "time" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" + "golang.zx2c4.com/wireguard/windows/version" +) + +var ( + startedAtBoot bool + startedAtBootOnce sync.Once +) + +func StartedAtBoot() bool { + startedAtBootOnce.Do(func() { + if isService, err := svc.IsWindowsService(); err == nil && !isService { + return + } + if reason, err := svc.DynamicStartReason(); err == nil { + startedAtBoot = (reason & svc.StartReasonAuto) != 0 || (reason & svc.StartReasonDelayedAuto) != 0 + } else if errors.Is(err, windows.ERROR_PROC_NOT_FOUND) { + // This is an ugly hack for Windows 7, which hopefully we'll be able to remove down the road. + startedAtBoot = windows.DurationSinceBoot() < time.Minute*10 + } else { + log.Printf("Unable to determine service start reason: %v", err) + } + }) + return startedAtBoot +} + +func PrintStarting() { + boot := "" + if StartedAtBoot() { + boot = " at boot" + } + log.Printf("Starting%s %s", boot, version.UserAgent()) +}
\ No newline at end of file diff --git a/services/names.go b/services/names.go deleted file mode 100644 index bf58550f..00000000 --- a/services/names.go +++ /dev/null @@ -1,19 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. - */ - -package services - -import ( - "errors" - - "golang.zx2c4.com/wireguard/windows/conf" -) - -func ServiceNameOfTunnel(tunnelName string) (string, error) { - if !conf.TunnelNameIsValid(tunnelName) { - return "", errors.New("Tunnel name is not valid") - } - return "WireGuardTunnel$" + tunnelName, nil -} diff --git a/tunnel/addressconfig.go b/tunnel/addressconfig.go index 86b82097..f315cd15 100644 --- a/tunnel/addressconfig.go +++ b/tunnel/addressconfig.go @@ -15,6 +15,7 @@ import ( "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/services" "golang.zx2c4.com/wireguard/windows/tunnel/firewall" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) @@ -53,14 +54,14 @@ func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, add } func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, luid winipcfg.LUID) error { - systemJustBooted := windows.DurationSinceBoot() <= time.Minute*10 + retryOnFailure := services.StartedAtBoot() tryTimes := 0 startOver: var err error if tryTimes > 0 { log.Printf("Retrying interface configuration after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err) time.Sleep(time.Second) - systemJustBooted = systemJustBooted && tryTimes < 15 + retryOnFailure = retryOnFailure && tryTimes < 15 } tryTimes++ @@ -135,7 +136,7 @@ startOver: if !conf.Interface.TableOff { err = luid.SetRoutesForFamily(family, deduplicatedRoutes) - if err == windows.ERROR_NOT_FOUND && systemJustBooted { + if err == windows.ERROR_NOT_FOUND && retryOnFailure { goto startOver } else if err != nil { return fmt.Errorf("unable to set routes: %w", err) @@ -147,7 +148,7 @@ startOver: cleanupAddressesOnDisconnectedInterfaces(family, addresses) err = luid.SetIPAddressesForFamily(family, addresses) } - if err == windows.ERROR_NOT_FOUND && systemJustBooted { + if err == windows.ERROR_NOT_FOUND && retryOnFailure { goto startOver } else if err != nil { return fmt.Errorf("unable to set ips: %w", err) @@ -170,14 +171,14 @@ startOver: ipif.Metric = 0 } err = ipif.Set() - if err == windows.ERROR_NOT_FOUND && systemJustBooted { + if err == windows.ERROR_NOT_FOUND && retryOnFailure { goto startOver } else if err != nil { return fmt.Errorf("unable to set metric and MTU: %w", err) } err = luid.SetDNS(family, conf.Interface.DNS, conf.Interface.DNSSearch) - if err == windows.ERROR_NOT_FOUND && systemJustBooted { + if err == windows.ERROR_NOT_FOUND && retryOnFailure { goto startOver } else if err != nil { return fmt.Errorf("unable to set DNS: %w", err) diff --git a/tunnel/service.go b/tunnel/service.go index 24df3dfe..4a4f5e40 100644 --- a/tunnel/service.go +++ b/tunnel/service.go @@ -16,14 +16,12 @@ import ( "golang.org/x/sys/windows" "golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc/mgr" - "golang.zx2c4.com/wireguard/windows/conf" "golang.zx2c4.com/wireguard/windows/driver" "golang.zx2c4.com/wireguard/windows/elevate" "golang.zx2c4.com/wireguard/windows/ringlogger" "golang.zx2c4.com/wireguard/windows/services" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - "golang.zx2c4.com/wireguard/windows/version" ) type tunnelService struct { @@ -119,20 +117,22 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, log.SetPrefix(fmt.Sprintf("[%s] ", config.Name)) - log.Println("Starting", version.UserAgent()) + services.PrintStarting() - if m, err := mgr.Connect(); err == nil { - if lockStatus, err := m.LockStatus(); err == nil && lockStatus.IsLocked { - /* If we don't do this, then the driver installation will block forever, because - * installing a network adapter starts the driver service too. Apparently at boot time, - * Windows 8.1 locks the SCM for each service start, creating a deadlock if we don't - * announce that we're running before starting additional services. - */ - log.Printf("SCM locked for %v by %s, marking service as started", lockStatus.Age, lockStatus.Owner) - serviceState = svc.Running - changes <- svc.Status{State: serviceState} + if services.StartedAtBoot() { + if m, err := mgr.Connect(); err == nil { + if lockStatus, err := m.LockStatus(); err == nil && lockStatus.IsLocked { + /* If we don't do this, then the driver installation will block forever, because + * installing a network adapter starts the driver service too. Apparently at boot time, + * Windows 8.1 locks the SCM for each service start, creating a deadlock if we don't + * announce that we're running before starting additional services. + */ + log.Printf("SCM locked for %v by %s, marking service as started", lockStatus.Age, lockStatus.Owner) + serviceState = svc.Running + changes <- svc.Status{State: serviceState} + } + m.Disconnect() } - m.Disconnect() } log.Println("Watching network interfaces") @@ -156,7 +156,7 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err) } adapter, err = driver.CreateAdapter(config.Name, "WireGuard", deterministicGUID(config)) - if err == nil || windows.DurationSinceBoot() > time.Minute*10 { + if err == nil || !services.StartedAtBoot() { break } } @@ -250,7 +250,7 @@ func Run(confPath string) error { if err != nil { return err } - serviceName, err := services.ServiceNameOfTunnel(name) + serviceName, err := conf.ServiceNameOfTunnel(name) if err != nil { return err } |