diff options
Diffstat (limited to 'manager/ipc_server.go')
-rw-r--r-- | manager/ipc_server.go | 128 |
1 files changed, 56 insertions, 72 deletions
diff --git a/manager/ipc_server.go b/manager/ipc_server.go index 8ba050f9..e21ffaf0 100644 --- a/manager/ipc_server.go +++ b/manager/ipc_server.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package manager @@ -20,14 +20,15 @@ import ( "golang.org/x/sys/windows/svc" "golang.zx2c4.com/wireguard/windows/conf" - "golang.zx2c4.com/wireguard/windows/services" "golang.zx2c4.com/wireguard/windows/updater" ) -var managerServices = make(map[*ManagerService]bool) -var managerServicesLock sync.RWMutex -var haveQuit uint32 -var quitManagersChan = make(chan struct{}, 1) +var ( + managerServices = make(map[*ManagerService]bool) + managerServicesLock sync.RWMutex + haveQuit uint32 + quitManagersChan = make(chan struct{}, 1) +) type ManagerService struct { events *os.File @@ -51,33 +52,18 @@ func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) if err != nil { return nil, err } - pipe, err := connectTunnelServicePipe(tunnelName) - if err != nil { - return nil, err - } - pipe.SetDeadline(time.Now().Add(time.Second * 2)) - _, err = pipe.Write([]byte("get=1\n\n")) - if err == windows.ERROR_NO_DATA { - log.Println("IPC pipe closed unexpectedly, so reopening") - pipe.Unlock() - disconnectTunnelServicePipe(tunnelName) - pipe, err = connectTunnelServicePipe(tunnelName) - if err != nil { - return nil, err - } - pipe.SetDeadline(time.Now().Add(time.Second * 2)) - _, err = pipe.Write([]byte("get=1\n\n")) - } + driverAdapter, err := findDriverAdapter(tunnelName) if err != nil { - pipe.Unlock() - disconnectTunnelServicePipe(tunnelName) return nil, err } - conf, err := conf.FromUAPI(pipe, storedConfig) - pipe.Unlock() + runtimeConfig, err := driverAdapter.Configuration() if err != nil { + driverAdapter.Unlock() + releaseDriverAdapter(tunnelName) return nil, err } + conf := conf.FromDriverConfiguration(runtimeConfig, storedConfig) + driverAdapter.Unlock() if s.elevatedToken == 0 { conf.Redact() } @@ -85,44 +71,47 @@ func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) } func (s *ManagerService) Start(tunnelName string) error { - // TODO: Rather than being lazy and gating this behind a knob (yuck!), we should instead keep track of the routes - // of each tunnel, and only deactivate in the case of a tunnel with identical routes being added. - if !conf.AdminBool("MultipleSimultaneousTunnels") { - trackedTunnelsLock.Lock() - tt := make([]string, 0, len(trackedTunnels)) - var inTransition string - for t, state := range trackedTunnels { - tt = append(tt, t) - if len(t) > 0 && (state == TunnelStarting || state == TunnelUnknown) { - inTransition = t - break - } - } - trackedTunnelsLock.Unlock() - if len(inTransition) != 0 { - return fmt.Errorf("Please allow the tunnel ā%sā to finish activating", inTransition) - } - go func() { - for _, t := range tt { - s.Stop(t) - } - for _, t := range tt { - state, err := s.State(t) - if err == nil && (state == TunnelStarted || state == TunnelStarting) { - log.Printf("[%s] Trying again to stop zombie tunnel", t) - s.Stop(t) - time.Sleep(time.Millisecond * 100) - } - } - }() - } - time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces) - - // After that process is started -- it's somewhat asynchronous -- we install the new one. c, err := conf.LoadFromName(tunnelName) if err != nil { return err } + + // Figure out which tunnels have intersecting addresses/routes and stop those. + trackedTunnelsLock.Lock() + tt := make([]string, 0, len(trackedTunnels)) + var inTransition string + for t, state := range trackedTunnels { + c2, err := conf.LoadFromName(t) + if err != nil || !c.IntersectsWith(c2) { + // If we can't get the config, assume it doesn't intersect. + continue + } + tt = append(tt, t) + if len(t) > 0 && (state == TunnelStarting || state == TunnelUnknown) { + inTransition = t + break + } + } + trackedTunnelsLock.Unlock() + if len(inTransition) != 0 { + return fmt.Errorf("Please allow the tunnel ā%sā to finish activating", inTransition) + } + + // Stop those intersecting tunnels asynchronously. + go func() { + for _, t := range tt { + s.Stop(t) + } + for _, t := range tt { + state, err := s.State(t) + if err == nil && (state == TunnelStarted || state == TunnelStarting) { + log.Printf("[%s] Trying again to stop zombie tunnel", t) + s.Stop(t) + time.Sleep(time.Millisecond * 100) + } + } + }() + // After the stop process has begun, but before it's finished, we install the new one. path, err := c.Path() if err != nil { return err @@ -131,8 +120,6 @@ func (s *ManagerService) Start(tunnelName string) error { } func (s *ManagerService) Stop(tunnelName string) error { - time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces) - err := UninstallTunnel(tunnelName) if err == windows.ERROR_SERVICE_DOES_NOT_EXIST { _, notExistsError := conf.LoadFromName(tunnelName) @@ -144,7 +131,7 @@ func (s *ManagerService) Stop(tunnelName string) error { } func (s *ManagerService) WaitForStop(tunnelName string) error { - serviceName, err := services.ServiceNameOfTunnel(tunnelName) + serviceName, err := conf.ServiceNameOfTunnel(tunnelName) if err != nil { return err } @@ -175,7 +162,7 @@ func (s *ManagerService) Delete(tunnelName string) error { } func (s *ManagerService) State(tunnelName string) (TunnelState, error) { - serviceName, err := services.ServiceNameOfTunnel(tunnelName) + serviceName, err := conf.ServiceNameOfTunnel(tunnelName) if err != nil { return 0, err } @@ -230,7 +217,7 @@ func (s *ManagerService) Tunnels() ([]Tunnel, error) { } tunnels := make([]Tunnel, len(names)) for i := 0; i < len(tunnels); i++ { - (tunnels)[i].Name = names[i] + tunnels[i].Name = names[i] } return tunnels, nil // TODO: account for running ones that aren't in the configuration store somehow @@ -267,9 +254,6 @@ func (s *ManagerService) Quit(stopTunnelsOnQuit bool) (alreadyQuit bool, err err } func (s *ManagerService) UpdateState() UpdateState { - if s.elevatedToken == 0 { - return UpdateStateUnknown - } return updateState } @@ -457,7 +441,7 @@ func (s *ManagerService) ServeConn(reader io.Reader, writer io.Writer) { } } -func IPCServerListen(reader *os.File, writer *os.File, events *os.File, elevatedToken windows.Token) { +func IPCServerListen(reader, writer, events *os.File, elevatedToken windows.Token) { service := &ManagerService{ events: events, elevatedToken: elevatedToken, @@ -477,7 +461,7 @@ func IPCServerListen(reader *os.File, writer *os.File, events *os.File, elevated }() } -func notifyAll(notificationType NotificationType, adminOnly bool, ifaces ...interface{}) { +func notifyAll(notificationType NotificationType, adminOnly bool, ifaces ...any) { if len(managerServices) == 0 { return } @@ -528,7 +512,7 @@ func IPCServerNotifyTunnelsChange() { } func IPCServerNotifyUpdateFound(state UpdateState) { - notifyAll(UpdateFoundNotificationType, true, state) + notifyAll(UpdateFoundNotificationType, false, state) } func IPCServerNotifyUpdateProgress(dp updater.DownloadProgress) { |