/* SPDX-License-Identifier: MIT * * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. */ package manager import ( "bytes" "encoding/gob" "fmt" "io" "log" "os" "sync" "sync/atomic" "time" "golang.org/x/sys/windows" "golang.org/x/sys/windows/svc" "golang.zx2c4.com/wireguard/windows/conf" "golang.zx2c4.com/wireguard/windows/updater" ) var ( managerServices = make(map[*ManagerService]bool) managerServicesLock sync.RWMutex haveQuit uint32 quitManagersChan = make(chan struct{}, 1) ) type ManagerService struct { events *os.File eventLock sync.Mutex elevatedToken windows.Token } func (s *ManagerService) StoredConfig(tunnelName string) (*conf.Config, error) { conf, err := conf.LoadFromName(tunnelName) if err != nil { return nil, err } if s.elevatedToken == 0 { conf.Redact() } return conf, nil } func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) { storedConfig, err := conf.LoadFromName(tunnelName) if err != nil { return nil, err } driverAdapter, err := findDriverAdapter(tunnelName) if err != nil { return nil, err } runtimeConfig, err := driverAdapter.Configuration() if err != nil { driverAdapter.Unlock() releaseDriverAdapter(tunnelName) return nil, err } conf := conf.FromDriverConfiguration(runtimeConfig, storedConfig) driverAdapter.Unlock() if s.elevatedToken == 0 { conf.Redact() } return conf, nil } func (s *ManagerService) Start(tunnelName string) error { c, err := conf.LoadFromName(tunnelName) if err != nil { return err } // Figure out which tunnels have intersecting addresses/routes and stop those. trackedTunnelsLock.Lock() tt := make([]string, 0, len(trackedTunnels)) var inTransition string for t, state := range trackedTunnels { c2, err := conf.LoadFromName(t) if err != nil || !c.IntersectsWith(c2) { // If we can't get the config, assume it doesn't intersect. continue } tt = append(tt, t) if len(t) > 0 && (state == TunnelStarting || state == TunnelUnknown) { inTransition = t break } } trackedTunnelsLock.Unlock() if len(inTransition) != 0 { return fmt.Errorf("Please allow the tunnel ā€˜%sā€™ to finish activating", inTransition) } // Stop those intersecting tunnels asynchronously. go func() { for _, t := range tt { s.Stop(t) } for _, t := range tt { state, err := s.State(t) if err == nil && (state == TunnelStarted || state == TunnelStarting) { log.Printf("[%s] Trying again to stop zombie tunnel", t) s.Stop(t) time.Sleep(time.Millisecond * 100) } } }() // After the stop process has begun, but before it's finished, we install the new one. path, err := c.Path() if err != nil { return err } return InstallTunnel(path) } func (s *ManagerService) Stop(tunnelName string) error { err := UninstallTunnel(tunnelName) if err == windows.ERROR_SERVICE_DOES_NOT_EXIST { _, notExistsError := conf.LoadFromName(tunnelName) if notExistsError == nil { return nil } } return err } func (s *ManagerService) WaitForStop(tunnelName string) error { serviceName, err := conf.ServiceNameOfTunnel(tunnelName) if err != nil { return err } m, err := serviceManager() if err != nil { return err } for { service, err := m.OpenService(serviceName) if err == nil || err == windows.ERROR_SERVICE_MARKED_FOR_DELETE { service.Close() time.Sleep(time.Second / 3) } else { return nil } } } func (s *ManagerService) Delete(tunnelName string) error { if s.elevatedToken == 0 { return windows.ERROR_ACCESS_DENIED } err := s.Stop(tunnelName) if err != nil { return err } return conf.DeleteName(tunnelName) } func (s *ManagerService) State(tunnelName string) (TunnelState, error) { serviceName, err := conf.ServiceNameOfTunnel(tunnelName) if err != nil { return 0, err } m, err := serviceManager() if err != nil { return 0, err } service, err := m.OpenService(serviceName) if err != nil { return TunnelStopped, nil } defer service.Close() status, err := service.Query() if err != nil { return TunnelUnknown, nil } switch status.State { case svc.Stopped: return TunnelStopped, nil case svc.StopPending: return TunnelStopping, nil case svc.Running: return TunnelStarted, nil case svc.StartPending: return TunnelStarting, nil default: return TunnelUnknown, nil } } func (s *ManagerService) GlobalState() TunnelState { return trackedTunnelsGlobalState() } func (s *ManagerService) Create(tunnelConfig *conf.Config) (*Tunnel, error) { if s.elevatedToken == 0 { return nil, windows.ERROR_ACCESS_DENIED } err := tunnelConfig.Save(true) if err != nil { return nil, err } return &Tunnel{tunnelConfig.Name}, nil // TODO: handle already existing situation // TODO: handle already running and existing situation } func (s *ManagerService) Tunnels() ([]Tunnel, error) { names, err := conf.ListConfigNames() if err != nil { return nil, err } tunnels := make([]Tunnel, len(names)) for i := 0; i < len(tunnels); i++ { tunnels[i].Name = names[i] } return tunnels, nil // TODO: account for running ones that aren't in the configuration store somehow } func (s *ManagerService) Quit(stopTunnelsOnQuit bool) (alreadyQuit bool, err error) { if s.elevatedToken == 0 { return false, windows.ERROR_ACCESS_DENIED } if !atomic.CompareAndSwapUint32(&haveQuit, 0, 1) { return true, nil } // Work around potential race condition of delivering messages to the wrong process by removing from notifications. managerServicesLock.Lock() s.eventLock.Lock() s.events = nil s.eventLock.Unlock() delete(managerServices, s) managerServicesLock.Unlock() if stopTunnelsOnQuit { names, err := conf.ListConfigNames() if err != nil { return false, err } for _, name := range names { UninstallTunnel(name) } } quitManagersChan <- struct{}{} return false, nil } func (s *ManagerService) UpdateState() UpdateState { return updateState } func (s *ManagerService) Update() { if s.elevatedToken == 0 { return } progress := updater.DownloadVerifyAndExecute(uintptr(s.elevatedToken)) go func() { for { dp := <-progress IPCServerNotifyUpdateProgress(dp) if dp.Complete || dp.Error != nil { return } } }() } func (s *ManagerService) ServeConn(reader io.Reader, writer io.Writer) { decoder := gob.NewDecoder(reader) encoder := gob.NewEncoder(writer) for { var methodType MethodType err := decoder.Decode(&methodType) if err != nil { return } switch methodType { case StoredConfigMethodType: var tunnelName string err := decoder.Decode(&tunnelName) if err != nil { return } config, retErr := s.StoredConfig(tunnelName) if config == nil { config = &conf.Config{} } err = encoder.Encode(*config) if err != nil { return } err = encoder.Encode(errToString(retErr)) if err != nil { return } case RuntimeConfigMethodType: var tunnelName string err := decoder.Decode(&tunnelName) if err != nil { return } config, retErr := s.RuntimeConfig(tunnelName) if config == nil { config = &conf.Config{} } err = encoder.Encode(*config) if err != nil { return } err = encoder.Encode(errToString(retErr)) if err != nil { return } case StartMethodType: var tunnelName string err := decoder.Decode(&tunnelName) if err != nil { return } retErr := s.Start(tunnelName) err = encoder.Encode(errToString(retErr)) if err != nil { return } case StopMethodType: var tunnelName string err := decoder.Decode(&tunnelName) if err != nil { return } retErr := s.Stop(tunnelName) err = encoder.Encode(errToString(retErr)) if err != nil { return } case WaitForStopMethodType: var tunnelName string err := decoder.Decode(&tunnelName) if err != nil { return } retErr := s.WaitForStop(tunnelName) err = encoder.Encode(errToString(retErr)) if err != nil { return } case DeleteMethodType: var tunnelName string err := decoder.Decode(&tunnelName) if err != nil { return } retErr := s.Delete(tunnelName) err = encoder.Encode(errToString(retErr)) if err != nil { return } case StateMethodType: var tunnelName string err := decoder.Decode(&tunnelName) if err != nil { return } state, retErr := s.State(tunnelName) err = encoder.Encode(state) if err != nil { return } err = encoder.Encode(errToString(retErr)) if err != nil { return } case GlobalStateMethodType: state := s.GlobalState() err = encoder.Encode(state) if err != nil { return } case CreateMethodType: var config conf.Config err := decoder.Decode(&config) if err != nil { return } tunnel, retErr := s.Create(&config) if tunnel == nil { tunnel = &Tunnel{} } err = encoder.Encode(tunnel) if err != nil { return } err = encoder.Encode(errToString(retErr)) if err != nil { return } case TunnelsMethodType: tunnels, retErr := s.Tunnels() err = encoder.Encode(tunnels) if err != nil { return } err = encoder.Encode(errToString(retErr)) if err != nil { return } case QuitMethodType: var stopTunnelsOnQuit bool err := decoder.Decode(&stopTunnelsOnQuit) if err != nil { return } alreadyQuit, retErr := s.Quit(stopTunnelsOnQuit) err = encoder.Encode(alreadyQuit) if err != nil { return } err = encoder.Encode(errToString(retErr)) if err != nil { return } case UpdateStateMethodType: updateState := s.UpdateState() err = encoder.Encode(updateState) if err != nil { return } case UpdateMethodType: s.Update() default: return } } } func IPCServerListen(reader, writer, events *os.File, elevatedToken windows.Token) { service := &ManagerService{ events: events, elevatedToken: elevatedToken, } go func() { managerServicesLock.Lock() managerServices[service] = true managerServicesLock.Unlock() service.ServeConn(reader, writer) managerServicesLock.Lock() service.eventLock.Lock() service.events = nil service.eventLock.Unlock() delete(managerServices, service) managerServicesLock.Unlock() }() } func notifyAll(notificationType NotificationType, adminOnly bool, ifaces ...interface{}) { if len(managerServices) == 0 { return } var buf bytes.Buffer encoder := gob.NewEncoder(&buf) err := encoder.Encode(notificationType) if err != nil { return } for _, iface := range ifaces { err = encoder.Encode(iface) if err != nil { return } } managerServicesLock.RLock() for m := range managerServices { if m.elevatedToken == 0 && adminOnly { continue } go func(m *ManagerService) { m.eventLock.Lock() defer m.eventLock.Unlock() if m.events != nil { m.events.SetWriteDeadline(time.Now().Add(time.Second)) m.events.Write(buf.Bytes()) } }(m) } managerServicesLock.RUnlock() } func errToString(err error) string { if err == nil { return "" } return err.Error() } func IPCServerNotifyTunnelChange(name string, state TunnelState, err error) { notifyAll(TunnelChangeNotificationType, false, name, state, trackedTunnelsGlobalState(), errToString(err)) } func IPCServerNotifyTunnelsChange() { notifyAll(TunnelsChangeNotificationType, false) } func IPCServerNotifyUpdateFound(state UpdateState) { notifyAll(UpdateFoundNotificationType, false, state) } func IPCServerNotifyUpdateProgress(dp updater.DownloadProgress) { notifyAll(UpdateProgressNotificationType, true, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, errToString(dp.Error), dp.Complete) } func IPCServerNotifyManagerStopping() { notifyAll(ManagerStoppingNotificationType, false) time.Sleep(time.Millisecond * 200) }