aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/manager/ipc_server.go
diff options
context:
space:
mode:
Diffstat (limited to 'manager/ipc_server.go')
-rw-r--r--manager/ipc_server.go303
1 files changed, 217 insertions, 86 deletions
diff --git a/manager/ipc_server.go b/manager/ipc_server.go
index 0a3bceae..3afa3651 100644
--- a/manager/ipc_server.go
+++ b/manager/ipc_server.go
@@ -9,9 +9,9 @@ import (
"bytes"
"encoding/gob"
"fmt"
+ "io"
"io/ioutil"
"log"
- "net/rpc"
"os"
"sync"
"sync/atomic"
@@ -37,52 +37,42 @@ type ManagerService struct {
elevatedToken windows.Token
}
-func (s *ManagerService) StoredConfig(tunnelName string, config *conf.Config) error {
- c, err := conf.LoadFromName(tunnelName)
- if err != nil {
- return err
- }
- *config = *c
- return nil
+func (s *ManagerService) StoredConfig(tunnelName string) (*conf.Config, error) {
+ return conf.LoadFromName(tunnelName)
}
-func (s *ManagerService) RuntimeConfig(tunnelName string, config *conf.Config) error {
+func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) {
storedConfig, err := conf.LoadFromName(tunnelName)
if err != nil {
- return err
+ return nil, err
}
pipePath, err := services.PipePathOfTunnel(storedConfig.Name)
if err != nil {
- return err
+ return nil, err
}
localSystem, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid)
if err != nil {
- return err
+ return nil, err
}
pipe, err := winpipe.DialPipe(pipePath, nil, localSystem)
if err != nil {
- return err
+ return nil, err
}
defer pipe.Close()
pipe.SetWriteDeadline(time.Now().Add(time.Second * 2))
_, err = pipe.Write([]byte("get=1\n\n"))
if err != nil {
- return err
+ return nil, err
}
pipe.SetReadDeadline(time.Now().Add(time.Second * 2))
resp, err := ioutil.ReadAll(pipe)
if err != nil {
- return err
+ return nil, err
}
- runtimeConfig, err := conf.FromUAPI(string(resp), storedConfig)
- if err != nil {
- return err
- }
- *config = *runtimeConfig
- return nil
+ return conf.FromUAPI(string(resp), storedConfig)
}
-func (s *ManagerService) Start(tunnelName string, unused *uintptr) error {
+func (s *ManagerService) Start(tunnelName string) error {
// For now, enforce only one tunnel at a time. Later we'll remove this silly restriction.
trackedTunnelsLock.Lock()
tt := make([]string, 0, len(trackedTunnels))
@@ -100,14 +90,13 @@ func (s *ManagerService) Start(tunnelName string, unused *uintptr) error {
}
go func() {
for _, t := range tt {
- s.Stop(t, unused)
+ s.Stop(t)
}
for _, t := range tt {
- var state TunnelState
- var unused uintptr
- if s.State(t, &state) == nil && (state == TunnelStarted || state == TunnelStarting) {
+ state, err := s.State(t)
+ if err == nil && (state == TunnelStarted || state == TunnelStarting) {
log.Printf("[%s] Trying again to stop zombie tunnel", t)
- s.Stop(t, &unused)
+ s.Stop(t)
time.Sleep(time.Millisecond * 100)
}
}
@@ -126,7 +115,7 @@ func (s *ManagerService) Start(tunnelName string, unused *uintptr) error {
return InstallTunnel(path)
}
-func (s *ManagerService) Stop(tunnelName string, _ *uintptr) error {
+func (s *ManagerService) Stop(tunnelName string) error {
time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces)
err := UninstallTunnel(tunnelName)
@@ -139,7 +128,7 @@ func (s *ManagerService) Stop(tunnelName string, _ *uintptr) error {
return err
}
-func (s *ManagerService) WaitForStop(tunnelName string, _ *uintptr) error {
+func (s *ManagerService) WaitForStop(tunnelName string) error {
serviceName, err := services.ServiceNameOfTunnel(tunnelName)
if err != nil {
return err
@@ -159,84 +148,77 @@ func (s *ManagerService) WaitForStop(tunnelName string, _ *uintptr) error {
}
}
-func (s *ManagerService) Delete(tunnelName string, _ *uintptr) error {
- err := s.Stop(tunnelName, nil)
+func (s *ManagerService) Delete(tunnelName string) error {
+ err := s.Stop(tunnelName)
if err != nil {
return err
}
return conf.DeleteName(tunnelName)
}
-func (s *ManagerService) State(tunnelName string, state *TunnelState) error {
+func (s *ManagerService) State(tunnelName string) (TunnelState, error) {
serviceName, err := services.ServiceNameOfTunnel(tunnelName)
if err != nil {
- return err
+ return 0, err
}
m, err := serviceManager()
if err != nil {
- return err
+ return 0, err
}
service, err := m.OpenService(serviceName)
if err != nil {
- *state = TunnelStopped
- return nil
+ return TunnelStopped, nil
}
defer service.Close()
status, err := service.Query()
if err != nil {
- *state = TunnelUnknown
- return err
+ return TunnelUnknown, nil
}
switch status.State {
case svc.Stopped:
- *state = TunnelStopped
+ return TunnelStopped, nil
case svc.StopPending:
- *state = TunnelStopping
+ return TunnelStopping, nil
case svc.Running:
- *state = TunnelStarted
+ return TunnelStarted, nil
case svc.StartPending:
- *state = TunnelStarting
+ return TunnelStarting, nil
default:
- *state = TunnelUnknown
+ return TunnelUnknown, nil
}
- return nil
}
-func (s *ManagerService) GlobalState(_ uintptr, state *TunnelState) error {
- *state = trackedTunnelsGlobalState()
- return nil
+func (s *ManagerService) GlobalState() TunnelState {
+ return trackedTunnelsGlobalState()
}
-func (s *ManagerService) Create(tunnelConfig conf.Config, tunnel *Tunnel) error {
+func (s *ManagerService) Create(tunnelConfig *conf.Config) (*Tunnel, error) {
err := tunnelConfig.Save()
if err != nil {
- return err
+ return nil, err
}
- *tunnel = Tunnel{tunnelConfig.Name}
- return nil
+ return &Tunnel{tunnelConfig.Name}, nil
// TODO: handle already existing situation
// TODO: handle already running and existing situation
}
-func (s *ManagerService) Tunnels(_ uintptr, tunnels *[]Tunnel) error {
+func (s *ManagerService) Tunnels() ([]Tunnel, error) {
names, err := conf.ListConfigNames()
if err != nil {
- return err
+ return nil, err
}
- *tunnels = make([]Tunnel, len(names))
- for i := 0; i < len(*tunnels); i++ {
- (*tunnels)[i].Name = names[i]
+ tunnels := make([]Tunnel, len(names))
+ for i := 0; i < len(tunnels); i++ {
+ (tunnels)[i].Name = names[i]
}
- return nil
+ return tunnels, nil
// TODO: account for running ones that aren't in the configuration store somehow
}
-func (s *ManagerService) Quit(stopTunnelsOnQuit bool, alreadyQuit *bool) error {
+func (s *ManagerService) Quit(stopTunnelsOnQuit bool) (alreadyQuit bool, err error) {
if !atomic.CompareAndSwapUint32(&haveQuit, 0, 1) {
- *alreadyQuit = true
- return nil
+ return true, nil
}
- *alreadyQuit = false
// Work around potential race condition of delivering messages to the wrong process by removing from notifications.
managerServicesLock.Lock()
@@ -246,7 +228,7 @@ func (s *ManagerService) Quit(stopTunnelsOnQuit bool, alreadyQuit *bool) error {
if stopTunnelsOnQuit {
names, err := conf.ListConfigNames()
if err != nil {
- return err
+ return false, err
}
for _, name := range names {
UninstallTunnel(name)
@@ -254,15 +236,14 @@ func (s *ManagerService) Quit(stopTunnelsOnQuit bool, alreadyQuit *bool) error {
}
quitManagersChan <- struct{}{}
- return nil
+ return false, nil
}
-func (s *ManagerService) UpdateState(_ uintptr, state *UpdateState) error {
- *state = updateState
- return nil
+func (s *ManagerService) UpdateState() UpdateState {
+ return updateState
}
-func (s *ManagerService) Update(_ uintptr, _ *uintptr) error {
+func (s *ManagerService) Update() {
progress := updater.DownloadVerifyAndExecute(uintptr(s.elevatedToken))
go func() {
for {
@@ -273,32 +254,183 @@ func (s *ManagerService) Update(_ uintptr, _ *uintptr) error {
}
}
}()
- return nil
}
-func IPCServerListen(reader *os.File, writer *os.File, events *os.File, elevatedToken windows.Token) error {
+func (s *ManagerService) ServeConn(reader io.Reader, writer io.Writer) {
+ decoder := gob.NewDecoder(reader)
+ encoder := gob.NewEncoder(writer)
+ for {
+ var methodType MethodType
+ err := decoder.Decode(&methodType)
+ if err != nil {
+ return
+ }
+ switch methodType {
+ case StoredConfigMethodType:
+ var tunnelName string
+ err := decoder.Decode(&tunnelName)
+ if err != nil {
+ return
+ }
+ config, retErr := s.StoredConfig(tunnelName)
+ err = encoder.Encode(*config)
+ if err != nil {
+ return
+ }
+ err = encoder.Encode(errToString(retErr))
+ if err != nil {
+ return
+ }
+ case RuntimeConfigMethodType:
+ var tunnelName string
+ err := decoder.Decode(&tunnelName)
+ if err != nil {
+ return
+ }
+ config, retErr := s.RuntimeConfig(tunnelName)
+ err = encoder.Encode(*config)
+ if err != nil {
+ return
+ }
+ err = encoder.Encode(errToString(retErr))
+ if err != nil {
+ return
+ }
+ case StartMethodType:
+ var tunnelName string
+ err := decoder.Decode(&tunnelName)
+ if err != nil {
+ return
+ }
+ retErr := s.Start(tunnelName)
+ err = encoder.Encode(errToString(retErr))
+ if err != nil {
+ return
+ }
+ case StopMethodType:
+ var tunnelName string
+ err := decoder.Decode(&tunnelName)
+ if err != nil {
+ return
+ }
+ retErr := s.Stop(tunnelName)
+ err = encoder.Encode(errToString(retErr))
+ if err != nil {
+ return
+ }
+ case WaitForStopMethodType:
+ var tunnelName string
+ err := decoder.Decode(&tunnelName)
+ if err != nil {
+ return
+ }
+ retErr := s.WaitForStop(tunnelName)
+ err = encoder.Encode(errToString(retErr))
+ if err != nil {
+ return
+ }
+ case DeleteMethodType:
+ var tunnelName string
+ err := decoder.Decode(&tunnelName)
+ if err != nil {
+ return
+ }
+ retErr := s.Delete(tunnelName)
+ err = encoder.Encode(errToString(retErr))
+ if err != nil {
+ return
+ }
+ case StateMethodType:
+ var tunnelName string
+ err := decoder.Decode(&tunnelName)
+ if err != nil {
+ return
+ }
+ state, retErr := s.State(tunnelName)
+ err = encoder.Encode(state)
+ if err != nil {
+ return
+ }
+ err = encoder.Encode(errToString(retErr))
+ if err != nil {
+ return
+ }
+ case GlobalStateMethodType:
+ state := s.GlobalState()
+ err = encoder.Encode(state)
+ if err != nil {
+ return
+ }
+ case CreateMethodType:
+ var config conf.Config
+ err := decoder.Decode(&config)
+ if err != nil {
+ return
+ }
+ tunnel, retErr := s.Create(&config)
+ err = encoder.Encode(tunnel)
+ if err != nil {
+ return
+ }
+ err = encoder.Encode(errToString(retErr))
+ if err != nil {
+ return
+ }
+ case TunnelsMethodType:
+ tunnels, retErr := s.Tunnels()
+ err = encoder.Encode(tunnels)
+ if err != nil {
+ return
+ }
+ err = encoder.Encode(errToString(retErr))
+ if err != nil {
+ return
+ }
+ case QuitMethodType:
+ var stopTunnelsOnQuit bool
+ err := decoder.Decode(&stopTunnelsOnQuit)
+ if err != nil {
+ return
+ }
+ alreadyQuit, retErr := s.Quit(stopTunnelsOnQuit)
+ err = encoder.Encode(alreadyQuit)
+ if err != nil {
+ return
+ }
+ err = encoder.Encode(errToString(retErr))
+ if err != nil {
+ return
+ }
+ case UpdateStateMethodType:
+ updateState := s.UpdateState()
+ err = encoder.Encode(updateState)
+ if err != nil {
+ return
+ }
+ case UpdateMethodType:
+ s.Update()
+ default:
+ return
+ }
+ }
+}
+
+func IPCServerListen(reader *os.File, writer *os.File, events *os.File, elevatedToken windows.Token) {
service := &ManagerService{
events: events,
elevatedToken: elevatedToken,
}
- server := rpc.NewServer()
- err := server.Register(service)
- if err != nil {
- return err
- }
-
go func() {
managerServicesLock.Lock()
managerServices[service] = true
managerServicesLock.Unlock()
- server.ServeConn(&pipeRWC{reader, writer})
+ service.ServeConn(reader, writer)
managerServicesLock.Lock()
delete(managerServices, service)
managerServicesLock.Unlock()
}()
- return nil
}
func notifyAll(notificationType NotificationType, ifaces ...interface{}) {
@@ -327,12 +459,15 @@ func notifyAll(notificationType NotificationType, ifaces ...interface{}) {
managerServicesLock.RUnlock()
}
-func IPCServerNotifyTunnelChange(name string, state TunnelState, err error) {
+func errToString(err error) string {
if err == nil {
- notifyAll(TunnelChangeNotificationType, name, state, trackedTunnelsGlobalState(), "")
- } else {
- notifyAll(TunnelChangeNotificationType, name, state, trackedTunnelsGlobalState(), err.Error())
+ return ""
}
+ return err.Error()
+}
+
+func IPCServerNotifyTunnelChange(name string, state TunnelState, err error) {
+ notifyAll(TunnelChangeNotificationType, name, state, trackedTunnelsGlobalState(), errToString(err))
}
func IPCServerNotifyTunnelsChange() {
@@ -344,11 +479,7 @@ func IPCServerNotifyUpdateFound(state UpdateState) {
}
func IPCServerNotifyUpdateProgress(dp updater.DownloadProgress) {
- if dp.Error == nil {
- notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, "", dp.Complete)
- } else {
- notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, dp.Error.Error(), dp.Complete)
- }
+ notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, errToString(dp.Error), dp.Complete)
}
func IPCServerNotifyManagerStopping() {