diff options
Diffstat (limited to 'manager/ipc_server.go')
-rw-r--r-- | manager/ipc_server.go | 303 |
1 files changed, 217 insertions, 86 deletions
diff --git a/manager/ipc_server.go b/manager/ipc_server.go index 0a3bceae..3afa3651 100644 --- a/manager/ipc_server.go +++ b/manager/ipc_server.go @@ -9,9 +9,9 @@ import ( "bytes" "encoding/gob" "fmt" + "io" "io/ioutil" "log" - "net/rpc" "os" "sync" "sync/atomic" @@ -37,52 +37,42 @@ type ManagerService struct { elevatedToken windows.Token } -func (s *ManagerService) StoredConfig(tunnelName string, config *conf.Config) error { - c, err := conf.LoadFromName(tunnelName) - if err != nil { - return err - } - *config = *c - return nil +func (s *ManagerService) StoredConfig(tunnelName string) (*conf.Config, error) { + return conf.LoadFromName(tunnelName) } -func (s *ManagerService) RuntimeConfig(tunnelName string, config *conf.Config) error { +func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) { storedConfig, err := conf.LoadFromName(tunnelName) if err != nil { - return err + return nil, err } pipePath, err := services.PipePathOfTunnel(storedConfig.Name) if err != nil { - return err + return nil, err } localSystem, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid) if err != nil { - return err + return nil, err } pipe, err := winpipe.DialPipe(pipePath, nil, localSystem) if err != nil { - return err + return nil, err } defer pipe.Close() pipe.SetWriteDeadline(time.Now().Add(time.Second * 2)) _, err = pipe.Write([]byte("get=1\n\n")) if err != nil { - return err + return nil, err } pipe.SetReadDeadline(time.Now().Add(time.Second * 2)) resp, err := ioutil.ReadAll(pipe) if err != nil { - return err + return nil, err } - runtimeConfig, err := conf.FromUAPI(string(resp), storedConfig) - if err != nil { - return err - } - *config = *runtimeConfig - return nil + return conf.FromUAPI(string(resp), storedConfig) } -func (s *ManagerService) Start(tunnelName string, unused *uintptr) error { +func (s *ManagerService) Start(tunnelName string) error { // For now, enforce only one tunnel at a time. Later we'll remove this silly restriction. trackedTunnelsLock.Lock() tt := make([]string, 0, len(trackedTunnels)) @@ -100,14 +90,13 @@ func (s *ManagerService) Start(tunnelName string, unused *uintptr) error { } go func() { for _, t := range tt { - s.Stop(t, unused) + s.Stop(t) } for _, t := range tt { - var state TunnelState - var unused uintptr - if s.State(t, &state) == nil && (state == TunnelStarted || state == TunnelStarting) { + state, err := s.State(t) + if err == nil && (state == TunnelStarted || state == TunnelStarting) { log.Printf("[%s] Trying again to stop zombie tunnel", t) - s.Stop(t, &unused) + s.Stop(t) time.Sleep(time.Millisecond * 100) } } @@ -126,7 +115,7 @@ func (s *ManagerService) Start(tunnelName string, unused *uintptr) error { return InstallTunnel(path) } -func (s *ManagerService) Stop(tunnelName string, _ *uintptr) error { +func (s *ManagerService) Stop(tunnelName string) error { time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces) err := UninstallTunnel(tunnelName) @@ -139,7 +128,7 @@ func (s *ManagerService) Stop(tunnelName string, _ *uintptr) error { return err } -func (s *ManagerService) WaitForStop(tunnelName string, _ *uintptr) error { +func (s *ManagerService) WaitForStop(tunnelName string) error { serviceName, err := services.ServiceNameOfTunnel(tunnelName) if err != nil { return err @@ -159,84 +148,77 @@ func (s *ManagerService) WaitForStop(tunnelName string, _ *uintptr) error { } } -func (s *ManagerService) Delete(tunnelName string, _ *uintptr) error { - err := s.Stop(tunnelName, nil) +func (s *ManagerService) Delete(tunnelName string) error { + err := s.Stop(tunnelName) if err != nil { return err } return conf.DeleteName(tunnelName) } -func (s *ManagerService) State(tunnelName string, state *TunnelState) error { +func (s *ManagerService) State(tunnelName string) (TunnelState, error) { serviceName, err := services.ServiceNameOfTunnel(tunnelName) if err != nil { - return err + return 0, err } m, err := serviceManager() if err != nil { - return err + return 0, err } service, err := m.OpenService(serviceName) if err != nil { - *state = TunnelStopped - return nil + return TunnelStopped, nil } defer service.Close() status, err := service.Query() if err != nil { - *state = TunnelUnknown - return err + return TunnelUnknown, nil } switch status.State { case svc.Stopped: - *state = TunnelStopped + return TunnelStopped, nil case svc.StopPending: - *state = TunnelStopping + return TunnelStopping, nil case svc.Running: - *state = TunnelStarted + return TunnelStarted, nil case svc.StartPending: - *state = TunnelStarting + return TunnelStarting, nil default: - *state = TunnelUnknown + return TunnelUnknown, nil } - return nil } -func (s *ManagerService) GlobalState(_ uintptr, state *TunnelState) error { - *state = trackedTunnelsGlobalState() - return nil +func (s *ManagerService) GlobalState() TunnelState { + return trackedTunnelsGlobalState() } -func (s *ManagerService) Create(tunnelConfig conf.Config, tunnel *Tunnel) error { +func (s *ManagerService) Create(tunnelConfig *conf.Config) (*Tunnel, error) { err := tunnelConfig.Save() if err != nil { - return err + return nil, err } - *tunnel = Tunnel{tunnelConfig.Name} - return nil + return &Tunnel{tunnelConfig.Name}, nil // TODO: handle already existing situation // TODO: handle already running and existing situation } -func (s *ManagerService) Tunnels(_ uintptr, tunnels *[]Tunnel) error { +func (s *ManagerService) Tunnels() ([]Tunnel, error) { names, err := conf.ListConfigNames() if err != nil { - return err + return nil, err } - *tunnels = make([]Tunnel, len(names)) - for i := 0; i < len(*tunnels); i++ { - (*tunnels)[i].Name = names[i] + tunnels := make([]Tunnel, len(names)) + for i := 0; i < len(tunnels); i++ { + (tunnels)[i].Name = names[i] } - return nil + return tunnels, nil // TODO: account for running ones that aren't in the configuration store somehow } -func (s *ManagerService) Quit(stopTunnelsOnQuit bool, alreadyQuit *bool) error { +func (s *ManagerService) Quit(stopTunnelsOnQuit bool) (alreadyQuit bool, err error) { if !atomic.CompareAndSwapUint32(&haveQuit, 0, 1) { - *alreadyQuit = true - return nil + return true, nil } - *alreadyQuit = false // Work around potential race condition of delivering messages to the wrong process by removing from notifications. managerServicesLock.Lock() @@ -246,7 +228,7 @@ func (s *ManagerService) Quit(stopTunnelsOnQuit bool, alreadyQuit *bool) error { if stopTunnelsOnQuit { names, err := conf.ListConfigNames() if err != nil { - return err + return false, err } for _, name := range names { UninstallTunnel(name) @@ -254,15 +236,14 @@ func (s *ManagerService) Quit(stopTunnelsOnQuit bool, alreadyQuit *bool) error { } quitManagersChan <- struct{}{} - return nil + return false, nil } -func (s *ManagerService) UpdateState(_ uintptr, state *UpdateState) error { - *state = updateState - return nil +func (s *ManagerService) UpdateState() UpdateState { + return updateState } -func (s *ManagerService) Update(_ uintptr, _ *uintptr) error { +func (s *ManagerService) Update() { progress := updater.DownloadVerifyAndExecute(uintptr(s.elevatedToken)) go func() { for { @@ -273,32 +254,183 @@ func (s *ManagerService) Update(_ uintptr, _ *uintptr) error { } } }() - return nil } -func IPCServerListen(reader *os.File, writer *os.File, events *os.File, elevatedToken windows.Token) error { +func (s *ManagerService) ServeConn(reader io.Reader, writer io.Writer) { + decoder := gob.NewDecoder(reader) + encoder := gob.NewEncoder(writer) + for { + var methodType MethodType + err := decoder.Decode(&methodType) + if err != nil { + return + } + switch methodType { + case StoredConfigMethodType: + var tunnelName string + err := decoder.Decode(&tunnelName) + if err != nil { + return + } + config, retErr := s.StoredConfig(tunnelName) + err = encoder.Encode(*config) + if err != nil { + return + } + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case RuntimeConfigMethodType: + var tunnelName string + err := decoder.Decode(&tunnelName) + if err != nil { + return + } + config, retErr := s.RuntimeConfig(tunnelName) + err = encoder.Encode(*config) + if err != nil { + return + } + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case StartMethodType: + var tunnelName string + err := decoder.Decode(&tunnelName) + if err != nil { + return + } + retErr := s.Start(tunnelName) + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case StopMethodType: + var tunnelName string + err := decoder.Decode(&tunnelName) + if err != nil { + return + } + retErr := s.Stop(tunnelName) + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case WaitForStopMethodType: + var tunnelName string + err := decoder.Decode(&tunnelName) + if err != nil { + return + } + retErr := s.WaitForStop(tunnelName) + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case DeleteMethodType: + var tunnelName string + err := decoder.Decode(&tunnelName) + if err != nil { + return + } + retErr := s.Delete(tunnelName) + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case StateMethodType: + var tunnelName string + err := decoder.Decode(&tunnelName) + if err != nil { + return + } + state, retErr := s.State(tunnelName) + err = encoder.Encode(state) + if err != nil { + return + } + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case GlobalStateMethodType: + state := s.GlobalState() + err = encoder.Encode(state) + if err != nil { + return + } + case CreateMethodType: + var config conf.Config + err := decoder.Decode(&config) + if err != nil { + return + } + tunnel, retErr := s.Create(&config) + err = encoder.Encode(tunnel) + if err != nil { + return + } + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case TunnelsMethodType: + tunnels, retErr := s.Tunnels() + err = encoder.Encode(tunnels) + if err != nil { + return + } + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case QuitMethodType: + var stopTunnelsOnQuit bool + err := decoder.Decode(&stopTunnelsOnQuit) + if err != nil { + return + } + alreadyQuit, retErr := s.Quit(stopTunnelsOnQuit) + err = encoder.Encode(alreadyQuit) + if err != nil { + return + } + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case UpdateStateMethodType: + updateState := s.UpdateState() + err = encoder.Encode(updateState) + if err != nil { + return + } + case UpdateMethodType: + s.Update() + default: + return + } + } +} + +func IPCServerListen(reader *os.File, writer *os.File, events *os.File, elevatedToken windows.Token) { service := &ManagerService{ events: events, elevatedToken: elevatedToken, } - server := rpc.NewServer() - err := server.Register(service) - if err != nil { - return err - } - go func() { managerServicesLock.Lock() managerServices[service] = true managerServicesLock.Unlock() - server.ServeConn(&pipeRWC{reader, writer}) + service.ServeConn(reader, writer) managerServicesLock.Lock() delete(managerServices, service) managerServicesLock.Unlock() }() - return nil } func notifyAll(notificationType NotificationType, ifaces ...interface{}) { @@ -327,12 +459,15 @@ func notifyAll(notificationType NotificationType, ifaces ...interface{}) { managerServicesLock.RUnlock() } -func IPCServerNotifyTunnelChange(name string, state TunnelState, err error) { +func errToString(err error) string { if err == nil { - notifyAll(TunnelChangeNotificationType, name, state, trackedTunnelsGlobalState(), "") - } else { - notifyAll(TunnelChangeNotificationType, name, state, trackedTunnelsGlobalState(), err.Error()) + return "" } + return err.Error() +} + +func IPCServerNotifyTunnelChange(name string, state TunnelState, err error) { + notifyAll(TunnelChangeNotificationType, name, state, trackedTunnelsGlobalState(), errToString(err)) } func IPCServerNotifyTunnelsChange() { @@ -344,11 +479,7 @@ func IPCServerNotifyUpdateFound(state UpdateState) { } func IPCServerNotifyUpdateProgress(dp updater.DownloadProgress) { - if dp.Error == nil { - notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, "", dp.Complete) - } else { - notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, dp.Error.Error(), dp.Complete) - } + notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, errToString(dp.Error), dp.Complete) } func IPCServerNotifyManagerStopping() { |