diff options
Diffstat (limited to 'tunnel/service.go')
-rw-r--r-- | tunnel/service.go | 175 |
1 files changed, 123 insertions, 52 deletions
diff --git a/tunnel/service.go b/tunnel/service.go index 63cd243f..a595994c 100644 --- a/tunnel/service.go +++ b/tunnel/service.go @@ -21,11 +21,12 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/driver" "golang.zx2c4.com/wireguard/windows/elevate" "golang.zx2c4.com/wireguard/windows/ringlogger" "golang.zx2c4.com/wireguard/windows/services" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "golang.zx2c4.com/wireguard/windows/version" ) @@ -40,6 +41,9 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, var uapi net.Listener var watcher *interfaceWatcher var nativeTun *tun.NativeTun + var wintun tun.Device + var adapter *driver.Adapter + var luid winipcfg.LUID var config *conf.Config var err error serviceError := services.ErrorSuccess @@ -127,10 +131,10 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, if m, err := mgr.Connect(); err == nil { if lockStatus, err := m.LockStatus(); err == nil && lockStatus.IsLocked { - /* If we don't do this, then the Wintun installation will block forever, because - * installing a Wintun device starts a service too. Apparently at boot time, Windows - * 8.1 locks the SCM for each service start, creating a deadlock if we don't announce - * that we're running before starting additional services. + /* If we don't do this, then the driver installation will block forever, because + * installing a network adapter starts the driver service too. Apparently at boot time, + * Windows 8.1 locks the SCM for each service start, creating a deadlock if we don't + * announce that we're running before starting additional services. */ log.Printf("SCM locked for %v by %s, marking service as started", lockStatus.Age, lockStatus.Owner) changes <- svc.Status{State: svc.Running} @@ -146,34 +150,81 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, } log.Println("Resolving DNS names") - uapiConf, err := config.ToUAPI() + err = config.ResolveEndpoints() if err != nil { serviceError = services.ErrorDNSLookup return } - log.Println("Creating Wintun interface") - var wintun tun.Device - for i := 0; i < 5; i++ { - if i > 0 { - time.Sleep(time.Second) - log.Printf("Retrying Wintun creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err) + log.Println("Creating network adapter") + if conf.AdminBool("ExperimentalKernelDriver") { + // Does an adapter with this name already exist? + adapter, err = driver.DefaultPool.OpenAdapter(config.Name) + if err == nil { + // If so, we delete it, in case it has weird residual configuration. + _, err = adapter.Delete() + if err != nil { + err = fmt.Errorf("Error deleting already existing adapter: %w", err) + serviceError = services.ErrorCreateNetworkAdapter + return + } } - wintun, err = tun.CreateTUNWithRequestedGUID(config.Name, deterministicGUID(config), 0) - if err == nil || windows.DurationSinceBoot() > time.Minute*4 { - break + for i := 0; i < 5; i++ { + if i > 0 { + time.Sleep(time.Second) + log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err) + } + var rebootRequired bool + adapter, rebootRequired, err = driver.DefaultPool.CreateAdapter(config.Name, deterministicGUID(config)) + if err == nil || windows.DurationSinceBoot() > time.Minute*4 { + if rebootRequired { + log.Println("Windows indicated a reboot is required.") + } + break + } + } + if err != nil { + err = fmt.Errorf("Error creating adapter: %w", err) + serviceError = services.ErrorCreateNetworkAdapter + return + } + defer adapter.Delete() + luid = adapter.LUID() + driverVersion, err := driver.RunningVersion() + if err != nil { + log.Printf("Warning: unable to determine driver version: %v", err) + } else { + log.Printf("Using WireGuardNT/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff) + } + err = adapter.SetLogging(driver.AdapterLogOn) + if err != nil { + err = fmt.Errorf("Error enabling adapter logging: %w", err) + serviceError = services.ErrorCreateNetworkAdapter + return } - } - if err != nil { - serviceError = services.ErrorCreateWintun - return - } - nativeTun = wintun.(*tun.NativeTun) - wintunVersion, err := nativeTun.RunningVersion() - if err != nil { - log.Printf("Warning: unable to determine Wintun version: %v", err) } else { - log.Printf("Using Wintun/%d.%d", (wintunVersion>>16)&0xffff, wintunVersion&0xffff) + for i := 0; i < 5; i++ { + if i > 0 { + time.Sleep(time.Second) + log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err) + } + wintun, err = tun.CreateTUNWithRequestedGUID(config.Name, deterministicGUID(config), 0) + if err == nil || windows.DurationSinceBoot() > time.Minute*4 { + break + } + } + if err != nil { + serviceError = services.ErrorCreateNetworkAdapter + return + } + nativeTun = wintun.(*tun.NativeTun) + luid = winipcfg.LUID(nativeTun.LUID()) + driverVersion, err := nativeTun.RunningVersion() + if err != nil { + log.Printf("Warning: unable to determine driver version: %v", err) + } else { + log.Printf("Using Wintun/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff) + } } err = runScriptCommand(config.Interface.PreUp, config.Name) @@ -182,7 +233,7 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, return } - err = enableFirewall(config, nativeTun) + err = enableFirewall(config, luid) if err != nil { serviceError = services.ErrorFirewall return @@ -195,37 +246,51 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, return } - log.Println("Creating interface instance") - bind := conn.NewDefaultBind() - dev = device.NewDevice(wintun, bind, &device.Logger{log.Printf, log.Printf}) + if nativeTun != nil { + log.Println("Creating interface instance") + bind := conn.NewDefaultBind() + dev = device.NewDevice(wintun, bind, &device.Logger{log.Printf, log.Printf}) - log.Println("Setting interface configuration") - uapi, err = ipc.UAPIListen(config.Name) - if err != nil { - serviceError = services.ErrorUAPIListen - return - } - err = dev.IpcSet(uapiConf) - if err != nil { - serviceError = services.ErrorDeviceSetConfig - return - } + log.Println("Setting interface configuration") + uapi, err = ipc.UAPIListen(config.Name) + if err != nil { + serviceError = services.ErrorUAPIListen + return + } + err = dev.IpcSet(config.ToUAPI()) + if err != nil { + serviceError = services.ErrorDeviceSetConfig + return + } - log.Println("Bringing peers up") - dev.Up() + log.Println("Bringing peers up") + dev.Up() - watcher.Configure(bind.(conn.BindSocketToInterface), config, nativeTun) + var clamper mtuClamper + clamper = nativeTun + watcher.Configure(bind.(conn.BindSocketToInterface), clamper, config, luid) - log.Println("Listening for UAPI requests") - go func() { - for { - conn, err := uapi.Accept() - if err != nil { - continue + log.Println("Listening for UAPI requests") + go func() { + for { + conn, err := uapi.Accept() + if err != nil { + continue + } + go dev.IpcHandle(conn) } - go dev.IpcHandle(conn) + }() + } else { + err = adapter.SetConfiguration(config.ToDriverConfiguration()) + if err != nil { + serviceError = services.ErrorDeviceSetConfig } - }() + err = adapter.SetAdapterState(driver.AdapterStateUp) + if err != nil { + serviceError = services.ErrorDeviceBringUp + } + watcher.Configure(nil, nil, config, luid) + } err = runScriptCommand(config.Interface.PostUp, config.Name) if err != nil { @@ -236,6 +301,12 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, changes <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptShutdown} log.Println("Startup complete") + var devWaitChan chan struct{} + if dev != nil { + devWaitChan = dev.Wait() + } else { + devWaitChan = make(chan struct{}) + } for { select { case c := <-r: @@ -247,7 +318,7 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, default: log.Printf("Unexpected service control request #%d\n", c) } - case <-dev.Wait(): + case <-devWaitChan: return case e := <-watcher.errors: serviceError, err = e.serviceError, e.err |