diff options
Diffstat (limited to 'tunnel/service.go')
-rw-r--r-- | tunnel/service.go | 229 |
1 files changed, 69 insertions, 160 deletions
diff --git a/tunnel/service.go b/tunnel/service.go index 9d5631ed..a56ed1f3 100644 --- a/tunnel/service.go +++ b/tunnel/service.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package tunnel @@ -9,7 +9,6 @@ import ( "bytes" "fmt" "log" - "net" "os" "runtime" "time" @@ -17,17 +16,12 @@ import ( "golang.org/x/sys/windows" "golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc/mgr" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/ipc" - "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/windows/conf" "golang.zx2c4.com/wireguard/windows/driver" "golang.zx2c4.com/wireguard/windows/elevate" "golang.zx2c4.com/wireguard/windows/ringlogger" "golang.zx2c4.com/wireguard/windows/services" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - "golang.zx2c4.com/wireguard/windows/version" ) type tunnelService struct { @@ -35,13 +29,10 @@ type tunnelService struct { } func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) { - changes <- svc.Status{State: svc.StartPending} + serviceState := svc.StartPending + changes <- svc.Status{State: serviceState} - var dev *device.Device - var uapi net.Listener var watcher *interfaceWatcher - var nativeTun *tun.NativeTun - var wintun tun.Device var adapter *driver.Adapter var luid winipcfg.LUID var config *conf.Config @@ -54,7 +45,8 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, if logErr != nil { log.Println(logErr) } - changes <- svc.Status{State: svc.StopPending} + serviceState = svc.StopPending + changes <- svc.Status{State: serviceState} stopIt := make(chan bool, 1) go func() { @@ -88,22 +80,16 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, } }() - if logErr == nil && (dev != nil || adapter != nil) && config != nil { + if logErr == nil && adapter != nil && config != nil { logErr = runScriptCommand(config.Interface.PreDown, config.Name) } if watcher != nil { watcher.Destroy() } - if uapi != nil { - uapi.Close() - } - if dev != nil { - dev.Close() - } if adapter != nil { - adapter.Delete() + adapter.Close() } - if logErr == nil && (dev != nil || adapter != nil) && config != nil { + if logErr == nil && adapter != nil && config != nil { _ = runScriptCommand(config.Interface.PostDown, config.Name) } stopIt <- true @@ -128,29 +114,29 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, return } config.DeduplicateNetworkEntries() - err = CopyConfigOwnerToIPCSecurityDescriptor(service.Path) - if err != nil { - serviceError = services.ErrorLoadConfiguration - return - } log.SetPrefix(fmt.Sprintf("[%s] ", config.Name)) - log.Println("Starting", version.UserAgent()) + services.PrintStarting() - if m, err := mgr.Connect(); err == nil { - if lockStatus, err := m.LockStatus(); err == nil && lockStatus.IsLocked { - /* If we don't do this, then the driver installation will block forever, because - * installing a network adapter starts the driver service too. Apparently at boot time, - * Windows 8.1 locks the SCM for each service start, creating a deadlock if we don't - * announce that we're running before starting additional services. - */ - log.Printf("SCM locked for %v by %s, marking service as started", lockStatus.Age, lockStatus.Owner) - changes <- svc.Status{State: svc.Running} + if services.StartedAtBoot() { + if m, err := mgr.Connect(); err == nil { + if lockStatus, err := m.LockStatus(); err == nil && lockStatus.IsLocked { + /* If we don't do this, then the driver installation will block forever, because + * installing a network adapter starts the driver service too. Apparently at boot time, + * Windows 8.1 locks the SCM for each service start, creating a deadlock if we don't + * announce that we're running before starting additional services. + */ + log.Printf("SCM locked for %v by %s, marking service as started", lockStatus.Age, lockStatus.Owner) + serviceState = svc.Running + changes <- svc.Status{State: serviceState} + } + m.Disconnect() } - m.Disconnect() } + evaluateStaticPitfalls() + log.Println("Watching network interfaces") watcher, err = watchInterface() if err != nil { @@ -166,73 +152,33 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, } log.Println("Creating network adapter") - if UseFixedGUIDInsteadOfDeterministic || !conf.AdminBool("UseUserspaceImplementation") { - // Does an adapter with this name already exist? - adapter, err = driver.DefaultPool.OpenAdapter(config.Name) - if err == nil { - // If so, we delete it, in case it has weird residual configuration. - _, err = adapter.Delete() - if err != nil { - err = fmt.Errorf("Error deleting already existing adapter: %w", err) - serviceError = services.ErrorCreateNetworkAdapter - return - } - } - for i := 0; i < 5; i++ { - if i > 0 { - time.Sleep(time.Second) - log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err) - } - var rebootRequired bool - adapter, rebootRequired, err = driver.DefaultPool.CreateAdapter(config.Name, deterministicGUID(config)) - if err == nil || windows.DurationSinceBoot() > time.Minute*10 { - if rebootRequired { - log.Println("Windows indicated a reboot is required.") - } - break - } - } - if err != nil { - err = fmt.Errorf("Error creating adapter: %w", err) - serviceError = services.ErrorCreateNetworkAdapter - return - } - luid = adapter.LUID() - driverVersion, err := driver.RunningVersion() - if err != nil { - log.Printf("Warning: unable to determine driver version: %v", err) - } else { - log.Printf("Using WireGuardNT/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff) + for i := 0; i < 15; i++ { + if i > 0 { + time.Sleep(time.Second) + log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err) } - err = adapter.SetLogging(driver.AdapterLogOn) - if err != nil { - err = fmt.Errorf("Error enabling adapter logging: %w", err) - serviceError = services.ErrorCreateNetworkAdapter - return + adapter, err = driver.CreateAdapter(config.Name, "WireGuard", deterministicGUID(config)) + if err == nil || !services.StartedAtBoot() { + break } + } + if err != nil { + err = fmt.Errorf("Error creating adapter: %w", err) + serviceError = services.ErrorCreateNetworkAdapter + return + } + luid = adapter.LUID() + driverVersion, err := driver.RunningVersion() + if err != nil { + log.Printf("Warning: unable to determine driver version: %v", err) } else { - for i := 0; i < 5; i++ { - if i > 0 { - time.Sleep(time.Second) - log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err) - } - wintun, err = tun.CreateTUNWithRequestedGUID(config.Name, deterministicGUID(config), 0) - if err == nil || windows.DurationSinceBoot() > time.Minute*10 { - break - } - } - if err != nil { - serviceError = services.ErrorCreateNetworkAdapter - return - } - nativeTun = wintun.(*tun.NativeTun) - luid = winipcfg.LUID(nativeTun.LUID()) - driverVersion, err := nativeTun.RunningVersion() - if err != nil { - log.Printf("Warning: unable to determine driver version: %v", err) - } else { - log.Printf("Using Wintun/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff) - } + log.Printf("Using WireGuardNT/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff) + } + err = adapter.SetLogging(driver.AdapterLogOn) + if err != nil { + err = fmt.Errorf("Error enabling adapter logging: %w", err) + serviceError = services.ErrorCreateNetworkAdapter + return } err = runScriptCommand(config.Interface.PreUp, config.Name) @@ -254,54 +200,18 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, return } - if nativeTun != nil { - log.Println("Creating interface instance") - bind := conn.NewDefaultBind() - dev = device.NewDevice(wintun, bind, &device.Logger{log.Printf, log.Printf}) - - log.Println("Setting interface configuration") - uapi, err = ipc.UAPIListen(config.Name) - if err != nil { - serviceError = services.ErrorUAPIListen - return - } - err = dev.IpcSet(config.ToUAPI()) - if err != nil { - serviceError = services.ErrorDeviceSetConfig - return - } - - log.Println("Bringing peers up") - dev.Up() - - var clamper mtuClamper - clamper = nativeTun - watcher.Configure(bind.(conn.BindSocketToInterface), clamper, nil, config, luid) - - log.Println("Listening for UAPI requests") - go func() { - for { - conn, err := uapi.Accept() - if err != nil { - continue - } - go dev.IpcHandle(conn) - } - }() - } else { - log.Println("Setting interface configuration") - err = adapter.SetConfiguration(config.ToDriverConfiguration()) - if err != nil { - serviceError = services.ErrorDeviceSetConfig - return - } - err = adapter.SetAdapterState(driver.AdapterStateUp) - if err != nil { - serviceError = services.ErrorDeviceBringUp - return - } - watcher.Configure(nil, nil, adapter, config, luid) + log.Println("Setting interface configuration") + err = adapter.SetConfiguration(config.ToDriverConfiguration()) + if err != nil { + serviceError = services.ErrorDeviceSetConfig + return } + err = adapter.SetAdapterState(driver.AdapterStateUp) + if err != nil { + serviceError = services.ErrorDeviceBringUp + return + } + watcher.Configure(adapter, config, luid) err = runScriptCommand(config.Interface.PostUp, config.Name) if err != nil { @@ -309,15 +219,9 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, return } - changes <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptShutdown} - log.Println("Startup complete") + changes <- svc.Status{State: serviceState, Accepts: svc.AcceptStop | svc.AcceptShutdown} - var devWaitChan chan struct{} - if dev != nil { - devWaitChan = dev.Wait() - } else { - devWaitChan = make(chan struct{}) - } + var started bool for { select { case c := <-r: @@ -329,8 +233,13 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, default: log.Printf("Unexpected service control request #%d\n", c) } - case <-devWaitChan: - return + case <-watcher.started: + if !started { + serviceState = svc.Running + changes <- svc.Status{State: serviceState, Accepts: svc.AcceptStop | svc.AcceptShutdown} + log.Println("Startup complete") + started = true + } case e := <-watcher.errors: serviceError, err = e.serviceError, e.err return @@ -343,7 +252,7 @@ func Run(confPath string) error { if err != nil { return err } - serviceName, err := services.ServiceNameOfTunnel(name) + serviceName, err := conf.ServiceNameOfTunnel(name) if err != nil { return err } |