From e493f911269a2dabab7b05ec28726cdaeffb660e Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Mon, 20 May 2019 14:18:01 +0200 Subject: service: split into tunnel and manager --- manager/ipc_server.go | 348 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 348 insertions(+) create mode 100644 manager/ipc_server.go (limited to 'manager/ipc_server.go') diff --git a/manager/ipc_server.go b/manager/ipc_server.go new file mode 100644 index 00000000..0accb4d3 --- /dev/null +++ b/manager/ipc_server.go @@ -0,0 +1,348 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package manager + +import ( + "bytes" + "encoding/gob" + "fmt" + "io/ioutil" + "log" + "net/rpc" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/Microsoft/go-winio" + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" + + "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/updater" +) + +var managerServices = make(map[*ManagerService]bool) +var managerServicesLock sync.RWMutex +var haveQuit uint32 +var quitManagersChan = make(chan struct{}, 1) + +type ManagerService struct { + events *os.File + elevatedToken windows.Token +} + +func (s *ManagerService) StoredConfig(tunnelName string, config *conf.Config) error { + c, err := conf.LoadFromName(tunnelName) + if err != nil { + return err + } + *config = *c + return nil +} + +func (s *ManagerService) RuntimeConfig(tunnelName string, config *conf.Config) error { + storedConfig, err := conf.LoadFromName(tunnelName) + if err != nil { + return err + } + pipePath, err := PipePathOfTunnel(storedConfig.Name) + if err != nil { + return err + } + pipe, err := winio.DialPipe(pipePath, nil) + if err != nil { + return err + } + pipe.SetWriteDeadline(time.Now().Add(time.Second * 2)) + _, err = pipe.Write([]byte("get=1\n\n")) + if err != nil { + return err + } + pipe.SetReadDeadline(time.Now().Add(time.Second * 2)) + resp, err := ioutil.ReadAll(pipe) + if err != nil { + return err + } + pipe.Close() + runtimeConfig, err := conf.FromUAPI(string(resp), storedConfig) + if err != nil { + return err + } + *config = *runtimeConfig + return nil +} + +func (s *ManagerService) Start(tunnelName string, unused *uintptr) error { + // For now, enforce only one tunnel at a time. Later we'll remove this silly restriction. + trackedTunnelsLock.Lock() + tt := make([]string, 0, len(trackedTunnels)) + var inTransition string + for t, state := range trackedTunnels { + tt = append(tt, t) + if len(t) > 0 && (state == TunnelStarting || state == TunnelUnknown) { + inTransition = t + break + } + } + trackedTunnelsLock.Unlock() + if len(inTransition) != 0 { + return fmt.Errorf("Please allow the tunnel ā€˜%sā€™ to finish activating", inTransition) + } + go func() { + for _, t := range tt { + s.Stop(t, unused) + } + for _, t := range tt { + var state TunnelState + var unused uintptr + if s.State(t, &state) == nil && (state == TunnelStarted || state == TunnelStarting) { + log.Printf("[%s] Trying again to stop zombie tunnel", t) + s.Stop(t, &unused) + time.Sleep(time.Millisecond * 100) + } + } + }() + + // After that process is started -- it's somewhat asynchronous -- we install the new one. + c, err := conf.LoadFromName(tunnelName) + if err != nil { + return err + } + path, err := c.Path() + if err != nil { + return err + } + return InstallTunnel(path) +} + +func (s *ManagerService) Stop(tunnelName string, _ *uintptr) error { + err := UninstallTunnel(tunnelName) + if err == windows.ERROR_SERVICE_DOES_NOT_EXIST { + _, notExistsError := conf.LoadFromName(tunnelName) + if notExistsError == nil { + return nil + } + } + return err +} + +func (s *ManagerService) WaitForStop(tunnelName string, _ *uintptr) error { + serviceName, err := ServiceNameOfTunnel(tunnelName) + if err != nil { + return err + } + m, err := serviceManager() + if err != nil { + return err + } + for { + service, err := m.OpenService(serviceName) + if err == nil || err == windows.ERROR_SERVICE_MARKED_FOR_DELETE { + service.Close() + time.Sleep(time.Second / 3) + } else { + return nil + } + } +} + +func (s *ManagerService) Delete(tunnelName string, _ *uintptr) error { + err := s.Stop(tunnelName, nil) + if err != nil { + return err + } + return conf.DeleteName(tunnelName) +} + +func (s *ManagerService) State(tunnelName string, state *TunnelState) error { + serviceName, err := ServiceNameOfTunnel(tunnelName) + if err != nil { + return err + } + m, err := serviceManager() + if err != nil { + return err + } + service, err := m.OpenService(serviceName) + if err != nil { + *state = TunnelStopped + return nil + } + defer service.Close() + status, err := service.Query() + if err != nil { + *state = TunnelUnknown + return err + } + switch status.State { + case svc.Stopped: + *state = TunnelStopped + case svc.StopPending: + *state = TunnelStopping + case svc.Running: + *state = TunnelStarted + case svc.StartPending: + *state = TunnelStarting + default: + *state = TunnelUnknown + } + return nil +} + +func (s *ManagerService) GlobalState(_ uintptr, state *TunnelState) error { + *state = trackedTunnelsGlobalState() + return nil +} + +func (s *ManagerService) Create(tunnelConfig conf.Config, tunnel *Tunnel) error { + err := tunnelConfig.Save() + if err != nil { + return err + } + *tunnel = Tunnel{tunnelConfig.Name} + return nil + //TODO: handle already existing situation + //TODO: handle already running and existing situation +} + +func (s *ManagerService) Tunnels(_ uintptr, tunnels *[]Tunnel) error { + names, err := conf.ListConfigNames() + if err != nil { + return err + } + *tunnels = make([]Tunnel, len(names)) + for i := 0; i < len(*tunnels); i++ { + (*tunnels)[i].Name = names[i] + } + return nil + //TODO: account for running ones that aren't in the configuration store somehow +} + +func (s *ManagerService) Quit(stopTunnelsOnQuit bool, alreadyQuit *bool) error { + if !atomic.CompareAndSwapUint32(&haveQuit, 0, 1) { + *alreadyQuit = true + return nil + } + *alreadyQuit = false + + // Work around potential race condition of delivering messages to the wrong process by removing from notifications. + managerServicesLock.Lock() + delete(managerServices, s) + managerServicesLock.Unlock() + + if stopTunnelsOnQuit { + names, err := conf.ListConfigNames() + if err != nil { + return err + } + for _, name := range names { + UninstallTunnel(name) + } + } + + quitManagersChan <- struct{}{} + return nil +} + +func (s *ManagerService) UpdateState(_ uintptr, state *UpdateState) error { + *state = updateState + return nil +} + +func (s *ManagerService) Update(_ uintptr, _ *uintptr) error { + progress := updater.DownloadVerifyAndExecute(uintptr(s.elevatedToken)) + go func() { + for { + dp := <-progress + IPCServerNotifyUpdateProgress(dp) + if dp.Complete || dp.Error != nil { + return + } + } + }() + return nil +} + +func IPCServerListen(reader *os.File, writer *os.File, events *os.File, elevatedToken windows.Token) error { + service := &ManagerService{ + events: events, + elevatedToken: elevatedToken, + } + + server := rpc.NewServer() + err := server.Register(service) + if err != nil { + return err + } + + go func() { + managerServicesLock.Lock() + managerServices[service] = true + managerServicesLock.Unlock() + server.ServeConn(&pipeRWC{reader, writer}) + managerServicesLock.Lock() + delete(managerServices, service) + managerServicesLock.Unlock() + + }() + return nil +} + +func notifyAll(notificationType NotificationType, ifaces ...interface{}) { + if len(managerServices) == 0 { + return + } + + var buf bytes.Buffer + encoder := gob.NewEncoder(&buf) + err := encoder.Encode(notificationType) + if err != nil { + return + } + for _, iface := range ifaces { + err = encoder.Encode(iface) + if err != nil { + return + } + } + + managerServicesLock.RLock() + for m := range managerServices { + m.events.SetWriteDeadline(time.Now().Add(time.Second)) + m.events.Write(buf.Bytes()) + } + managerServicesLock.RUnlock() +} + +func IPCServerNotifyTunnelChange(name string, state TunnelState, err error) { + if err == nil { + notifyAll(TunnelChangeNotificationType, name, state, trackedTunnelsGlobalState(), "") + } else { + notifyAll(TunnelChangeNotificationType, name, state, trackedTunnelsGlobalState(), err.Error()) + } +} + +func IPCServerNotifyTunnelsChange() { + notifyAll(TunnelsChangeNotificationType) +} + +func IPCServerNotifyUpdateFound(state UpdateState) { + notifyAll(UpdateFoundNotificationType, state) +} + +func IPCServerNotifyUpdateProgress(dp updater.DownloadProgress) { + if dp.Error == nil { + notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, "", dp.Complete) + } else { + notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, dp.Error.Error(), dp.Complete) + } +} + +func IPCServerNotifyManagerStopping() { + notifyAll(ManagerStoppingNotificationType) + time.Sleep(time.Millisecond * 200) +} -- cgit v1.2.3-59-g8ed1b