From eb267fcc71239adadf484c4e2eb1be7bca280df7 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Wed, 18 Sep 2019 22:44:46 -0600 Subject: manager: switch to vanilla gob from rpc to remove reflection bloat Signed-off-by: Jason A. Donenfeld --- manager/ipc_client.go | 256 ++++++++++++++++++++++++++++++++++++------ manager/ipc_pipe.go | 22 ---- manager/ipc_server.go | 303 ++++++++++++++++++++++++++++++++++++-------------- manager/service.go | 6 +- 4 files changed, 444 insertions(+), 143 deletions(-) (limited to 'manager') diff --git a/manager/ipc_client.go b/manager/ipc_client.go index a23493f0..c8b2f852 100644 --- a/manager/ipc_client.go +++ b/manager/ipc_client.go @@ -8,8 +8,8 @@ package manager import ( "encoding/gob" "errors" - "net/rpc" "os" + "sync" "golang.zx2c4.com/wireguard/windows/conf" "golang.zx2c4.com/wireguard/windows/updater" @@ -39,7 +39,29 @@ const ( UpdateProgressNotificationType ) -var rpcClient *rpc.Client +type MethodType int + +const ( + StoredConfigMethodType MethodType = iota + RuntimeConfigMethodType + StartMethodType + StopMethodType + WaitForStopMethodType + DeleteMethodType + StateMethodType + GlobalStateMethodType + CreateMethodType + TunnelsMethodType + QuitMethodType + UpdateStateMethodType + UpdateMethodType +) + +var ( + rpcEncoder *gob.Encoder + rpcDecoder *gob.Decoder + rpcMutex sync.Mutex +) type TunnelChangeCallback struct { cb func(tunnel *Tunnel, state TunnelState, globalState TunnelState, err error) @@ -72,7 +94,8 @@ type UpdateProgressCallback struct { var updateProgressCallbacks = make(map[*UpdateProgressCallback]bool) func InitializeIPCClient(reader *os.File, writer *os.File, events *os.File) { - rpcClient = rpc.NewClient(&pipeRWC{reader, writer}) + rpcDecoder = gob.NewDecoder(reader) + rpcEncoder = gob.NewEncoder(writer) go func() { decoder := gob.NewDecoder(events) for { @@ -165,22 +188,88 @@ func InitializeIPCClient(reader *os.File, writer *os.File, events *os.File) { }() } +func rpcDecodeError() error { + var str string + err := rpcDecoder.Decode(&str) + if err != nil { + return err + } + if len(str) == 0 { + return nil + } + return errors.New(str) +} + func (t *Tunnel) StoredConfig() (c conf.Config, err error) { - err = rpcClient.Call("ManagerService.StoredConfig", t.Name, &c) + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(StoredConfigMethodType) + if err != nil { + return + } + err = rpcEncoder.Encode(t.Name) + if err != nil { + return + } + err = rpcDecoder.Decode(&c) + if err != nil { + return + } + err = rpcDecodeError() return } func (t *Tunnel) RuntimeConfig() (c conf.Config, err error) { - err = rpcClient.Call("ManagerService.RuntimeConfig", t.Name, &c) + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(RuntimeConfigMethodType) + if err != nil { + return + } + err = rpcEncoder.Encode(t.Name) + if err != nil { + return + } + err = rpcDecoder.Decode(&c) + if err != nil { + return + } + err = rpcDecodeError() return } -func (t *Tunnel) Start() error { - return rpcClient.Call("ManagerService.Start", t.Name, nil) +func (t *Tunnel) Start() (err error) { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(StartMethodType) + if err != nil { + return + } + err = rpcEncoder.Encode(t.Name) + if err != nil { + return + } + err = rpcDecodeError() + return } -func (t *Tunnel) Stop() error { - return rpcClient.Call("ManagerService.Stop", t.Name, nil) +func (t *Tunnel) Stop() (err error) { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(StopMethodType) + if err != nil { + return + } + err = rpcEncoder.Encode(t.Name) + if err != nil { + return + } + err = rpcDecodeError() + return } func (t *Tunnel) Toggle() (oldState TunnelState, err error) { @@ -197,46 +286,149 @@ func (t *Tunnel) Toggle() (oldState TunnelState, err error) { return } -func (t *Tunnel) WaitForStop() error { - return rpcClient.Call("ManagerService.WaitForStop", t.Name, nil) +func (t *Tunnel) WaitForStop() (err error) { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(WaitForStopMethodType) + if err != nil { + return + } + err = rpcEncoder.Encode(t.Name) + if err != nil { + return + } + err = rpcDecodeError() + return } -func (t *Tunnel) Delete() error { - return rpcClient.Call("ManagerService.Delete", t.Name, nil) +func (t *Tunnel) Delete() (err error) { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(DeleteMethodType) + if err != nil { + return + } + err = rpcEncoder.Encode(t.Name) + if err != nil { + return + } + err = rpcDecodeError() + return } -func (t *Tunnel) State() (TunnelState, error) { - var state TunnelState - return state, rpcClient.Call("ManagerService.State", t.Name, &state) +func (t *Tunnel) State() (tunnelState TunnelState, err error) { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(StateMethodType) + if err != nil { + return + } + err = rpcEncoder.Encode(t.Name) + if err != nil { + return + } + err = rpcDecoder.Decode(&tunnelState) + if err != nil { + return + } + err = rpcDecodeError() + return } -func IPCClientNewTunnel(conf *conf.Config) (Tunnel, error) { - var tunnel Tunnel - return tunnel, rpcClient.Call("ManagerService.Create", *conf, &tunnel) +func IPCClientGlobalState() (tunnelState TunnelState, err error) { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(GlobalStateMethodType) + if err != nil { + return + } + err = rpcDecoder.Decode(&tunnelState) + if err != nil { + return + } + return } -func IPCClientTunnels() ([]Tunnel, error) { - var tunnels []Tunnel - return tunnels, rpcClient.Call("ManagerService.Tunnels", uintptr(0), &tunnels) +func IPCClientNewTunnel(conf *conf.Config) (tunnel Tunnel, err error) { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(CreateMethodType) + if err != nil { + return + } + err = rpcEncoder.Encode(*conf) + if err != nil { + return + } + err = rpcDecoder.Decode(&tunnel) + if err != nil { + return + } + err = rpcDecodeError() + return } -func IPCClientGlobalState() (TunnelState, error) { - var state TunnelState - return state, rpcClient.Call("ManagerService.GlobalState", uintptr(0), &state) +func IPCClientTunnels() (tunnels []Tunnel, err error) { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(TunnelsMethodType) + if err != nil { + return + } + err = rpcDecoder.Decode(&tunnels) + if err != nil { + return + } + err = rpcDecodeError() + return } -func IPCClientQuit(stopTunnelsOnQuit bool) (bool, error) { - var alreadyQuit bool - return alreadyQuit, rpcClient.Call("ManagerService.Quit", stopTunnelsOnQuit, &alreadyQuit) +func IPCClientQuit(stopTunnelsOnQuit bool) (alreadyQuit bool, err error) { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(QuitMethodType) + if err != nil { + return + } + err = rpcEncoder.Encode(stopTunnelsOnQuit) + if err != nil { + return + } + err = rpcDecoder.Decode(&alreadyQuit) + if err != nil { + return + } + err = rpcDecodeError() + return } -func IPCClientUpdateState() (UpdateState, error) { - var state UpdateState - return state, rpcClient.Call("ManagerService.UpdateState", uintptr(0), &state) +func IPCClientUpdateState() (updateState UpdateState, err error) { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(UpdateStateMethodType) + if err != nil { + return + } + err = rpcDecoder.Decode(&updateState) + if err != nil { + return + } + return } func IPCClientUpdate() error { - return rpcClient.Call("ManagerService.Update", uintptr(0), nil) + rpcMutex.Lock() + defer rpcMutex.Unlock() + + return rpcEncoder.Encode(UpdateMethodType) } func IPCClientRegisterTunnelChange(cb func(tunnel *Tunnel, state TunnelState, globalState TunnelState, err error)) *TunnelChangeCallback { diff --git a/manager/ipc_pipe.go b/manager/ipc_pipe.go index 657a6275..d4214ac0 100644 --- a/manager/ipc_pipe.go +++ b/manager/ipc_pipe.go @@ -12,28 +12,6 @@ import ( "golang.org/x/sys/windows" ) -type pipeRWC struct { - reader *os.File - writer *os.File -} - -func (p *pipeRWC) Read(b []byte) (int, error) { - return p.reader.Read(b) -} - -func (p *pipeRWC) Write(b []byte) (int, error) { - return p.writer.Write(b) -} - -func (p *pipeRWC) Close() error { - err1 := p.writer.Close() - err2 := p.reader.Close() - if err1 != nil { - return err1 - } - return err2 -} - func makeInheritableAndGetStr(f *os.File) (str string, err error) { sc, err := f.SyscallConn() if err != nil { diff --git a/manager/ipc_server.go b/manager/ipc_server.go index 0a3bceae..3afa3651 100644 --- a/manager/ipc_server.go +++ b/manager/ipc_server.go @@ -9,9 +9,9 @@ import ( "bytes" "encoding/gob" "fmt" + "io" "io/ioutil" "log" - "net/rpc" "os" "sync" "sync/atomic" @@ -37,52 +37,42 @@ type ManagerService struct { elevatedToken windows.Token } -func (s *ManagerService) StoredConfig(tunnelName string, config *conf.Config) error { - c, err := conf.LoadFromName(tunnelName) - if err != nil { - return err - } - *config = *c - return nil +func (s *ManagerService) StoredConfig(tunnelName string) (*conf.Config, error) { + return conf.LoadFromName(tunnelName) } -func (s *ManagerService) RuntimeConfig(tunnelName string, config *conf.Config) error { +func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) { storedConfig, err := conf.LoadFromName(tunnelName) if err != nil { - return err + return nil, err } pipePath, err := services.PipePathOfTunnel(storedConfig.Name) if err != nil { - return err + return nil, err } localSystem, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid) if err != nil { - return err + return nil, err } pipe, err := winpipe.DialPipe(pipePath, nil, localSystem) if err != nil { - return err + return nil, err } defer pipe.Close() pipe.SetWriteDeadline(time.Now().Add(time.Second * 2)) _, err = pipe.Write([]byte("get=1\n\n")) if err != nil { - return err + return nil, err } pipe.SetReadDeadline(time.Now().Add(time.Second * 2)) resp, err := ioutil.ReadAll(pipe) if err != nil { - return err + return nil, err } - runtimeConfig, err := conf.FromUAPI(string(resp), storedConfig) - if err != nil { - return err - } - *config = *runtimeConfig - return nil + return conf.FromUAPI(string(resp), storedConfig) } -func (s *ManagerService) Start(tunnelName string, unused *uintptr) error { +func (s *ManagerService) Start(tunnelName string) error { // For now, enforce only one tunnel at a time. Later we'll remove this silly restriction. trackedTunnelsLock.Lock() tt := make([]string, 0, len(trackedTunnels)) @@ -100,14 +90,13 @@ func (s *ManagerService) Start(tunnelName string, unused *uintptr) error { } go func() { for _, t := range tt { - s.Stop(t, unused) + s.Stop(t) } for _, t := range tt { - var state TunnelState - var unused uintptr - if s.State(t, &state) == nil && (state == TunnelStarted || state == TunnelStarting) { + state, err := s.State(t) + if err == nil && (state == TunnelStarted || state == TunnelStarting) { log.Printf("[%s] Trying again to stop zombie tunnel", t) - s.Stop(t, &unused) + s.Stop(t) time.Sleep(time.Millisecond * 100) } } @@ -126,7 +115,7 @@ func (s *ManagerService) Start(tunnelName string, unused *uintptr) error { return InstallTunnel(path) } -func (s *ManagerService) Stop(tunnelName string, _ *uintptr) error { +func (s *ManagerService) Stop(tunnelName string) error { time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces) err := UninstallTunnel(tunnelName) @@ -139,7 +128,7 @@ func (s *ManagerService) Stop(tunnelName string, _ *uintptr) error { return err } -func (s *ManagerService) WaitForStop(tunnelName string, _ *uintptr) error { +func (s *ManagerService) WaitForStop(tunnelName string) error { serviceName, err := services.ServiceNameOfTunnel(tunnelName) if err != nil { return err @@ -159,84 +148,77 @@ func (s *ManagerService) WaitForStop(tunnelName string, _ *uintptr) error { } } -func (s *ManagerService) Delete(tunnelName string, _ *uintptr) error { - err := s.Stop(tunnelName, nil) +func (s *ManagerService) Delete(tunnelName string) error { + err := s.Stop(tunnelName) if err != nil { return err } return conf.DeleteName(tunnelName) } -func (s *ManagerService) State(tunnelName string, state *TunnelState) error { +func (s *ManagerService) State(tunnelName string) (TunnelState, error) { serviceName, err := services.ServiceNameOfTunnel(tunnelName) if err != nil { - return err + return 0, err } m, err := serviceManager() if err != nil { - return err + return 0, err } service, err := m.OpenService(serviceName) if err != nil { - *state = TunnelStopped - return nil + return TunnelStopped, nil } defer service.Close() status, err := service.Query() if err != nil { - *state = TunnelUnknown - return err + return TunnelUnknown, nil } switch status.State { case svc.Stopped: - *state = TunnelStopped + return TunnelStopped, nil case svc.StopPending: - *state = TunnelStopping + return TunnelStopping, nil case svc.Running: - *state = TunnelStarted + return TunnelStarted, nil case svc.StartPending: - *state = TunnelStarting + return TunnelStarting, nil default: - *state = TunnelUnknown + return TunnelUnknown, nil } - return nil } -func (s *ManagerService) GlobalState(_ uintptr, state *TunnelState) error { - *state = trackedTunnelsGlobalState() - return nil +func (s *ManagerService) GlobalState() TunnelState { + return trackedTunnelsGlobalState() } -func (s *ManagerService) Create(tunnelConfig conf.Config, tunnel *Tunnel) error { +func (s *ManagerService) Create(tunnelConfig *conf.Config) (*Tunnel, error) { err := tunnelConfig.Save() if err != nil { - return err + return nil, err } - *tunnel = Tunnel{tunnelConfig.Name} - return nil + return &Tunnel{tunnelConfig.Name}, nil // TODO: handle already existing situation // TODO: handle already running and existing situation } -func (s *ManagerService) Tunnels(_ uintptr, tunnels *[]Tunnel) error { +func (s *ManagerService) Tunnels() ([]Tunnel, error) { names, err := conf.ListConfigNames() if err != nil { - return err + return nil, err } - *tunnels = make([]Tunnel, len(names)) - for i := 0; i < len(*tunnels); i++ { - (*tunnels)[i].Name = names[i] + tunnels := make([]Tunnel, len(names)) + for i := 0; i < len(tunnels); i++ { + (tunnels)[i].Name = names[i] } - return nil + return tunnels, nil // TODO: account for running ones that aren't in the configuration store somehow } -func (s *ManagerService) Quit(stopTunnelsOnQuit bool, alreadyQuit *bool) error { +func (s *ManagerService) Quit(stopTunnelsOnQuit bool) (alreadyQuit bool, err error) { if !atomic.CompareAndSwapUint32(&haveQuit, 0, 1) { - *alreadyQuit = true - return nil + return true, nil } - *alreadyQuit = false // Work around potential race condition of delivering messages to the wrong process by removing from notifications. managerServicesLock.Lock() @@ -246,7 +228,7 @@ func (s *ManagerService) Quit(stopTunnelsOnQuit bool, alreadyQuit *bool) error { if stopTunnelsOnQuit { names, err := conf.ListConfigNames() if err != nil { - return err + return false, err } for _, name := range names { UninstallTunnel(name) @@ -254,15 +236,14 @@ func (s *ManagerService) Quit(stopTunnelsOnQuit bool, alreadyQuit *bool) error { } quitManagersChan <- struct{}{} - return nil + return false, nil } -func (s *ManagerService) UpdateState(_ uintptr, state *UpdateState) error { - *state = updateState - return nil +func (s *ManagerService) UpdateState() UpdateState { + return updateState } -func (s *ManagerService) Update(_ uintptr, _ *uintptr) error { +func (s *ManagerService) Update() { progress := updater.DownloadVerifyAndExecute(uintptr(s.elevatedToken)) go func() { for { @@ -273,32 +254,183 @@ func (s *ManagerService) Update(_ uintptr, _ *uintptr) error { } } }() - return nil } -func IPCServerListen(reader *os.File, writer *os.File, events *os.File, elevatedToken windows.Token) error { +func (s *ManagerService) ServeConn(reader io.Reader, writer io.Writer) { + decoder := gob.NewDecoder(reader) + encoder := gob.NewEncoder(writer) + for { + var methodType MethodType + err := decoder.Decode(&methodType) + if err != nil { + return + } + switch methodType { + case StoredConfigMethodType: + var tunnelName string + err := decoder.Decode(&tunnelName) + if err != nil { + return + } + config, retErr := s.StoredConfig(tunnelName) + err = encoder.Encode(*config) + if err != nil { + return + } + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case RuntimeConfigMethodType: + var tunnelName string + err := decoder.Decode(&tunnelName) + if err != nil { + return + } + config, retErr := s.RuntimeConfig(tunnelName) + err = encoder.Encode(*config) + if err != nil { + return + } + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case StartMethodType: + var tunnelName string + err := decoder.Decode(&tunnelName) + if err != nil { + return + } + retErr := s.Start(tunnelName) + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case StopMethodType: + var tunnelName string + err := decoder.Decode(&tunnelName) + if err != nil { + return + } + retErr := s.Stop(tunnelName) + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case WaitForStopMethodType: + var tunnelName string + err := decoder.Decode(&tunnelName) + if err != nil { + return + } + retErr := s.WaitForStop(tunnelName) + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case DeleteMethodType: + var tunnelName string + err := decoder.Decode(&tunnelName) + if err != nil { + return + } + retErr := s.Delete(tunnelName) + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case StateMethodType: + var tunnelName string + err := decoder.Decode(&tunnelName) + if err != nil { + return + } + state, retErr := s.State(tunnelName) + err = encoder.Encode(state) + if err != nil { + return + } + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case GlobalStateMethodType: + state := s.GlobalState() + err = encoder.Encode(state) + if err != nil { + return + } + case CreateMethodType: + var config conf.Config + err := decoder.Decode(&config) + if err != nil { + return + } + tunnel, retErr := s.Create(&config) + err = encoder.Encode(tunnel) + if err != nil { + return + } + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case TunnelsMethodType: + tunnels, retErr := s.Tunnels() + err = encoder.Encode(tunnels) + if err != nil { + return + } + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case QuitMethodType: + var stopTunnelsOnQuit bool + err := decoder.Decode(&stopTunnelsOnQuit) + if err != nil { + return + } + alreadyQuit, retErr := s.Quit(stopTunnelsOnQuit) + err = encoder.Encode(alreadyQuit) + if err != nil { + return + } + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case UpdateStateMethodType: + updateState := s.UpdateState() + err = encoder.Encode(updateState) + if err != nil { + return + } + case UpdateMethodType: + s.Update() + default: + return + } + } +} + +func IPCServerListen(reader *os.File, writer *os.File, events *os.File, elevatedToken windows.Token) { service := &ManagerService{ events: events, elevatedToken: elevatedToken, } - server := rpc.NewServer() - err := server.Register(service) - if err != nil { - return err - } - go func() { managerServicesLock.Lock() managerServices[service] = true managerServicesLock.Unlock() - server.ServeConn(&pipeRWC{reader, writer}) + service.ServeConn(reader, writer) managerServicesLock.Lock() delete(managerServices, service) managerServicesLock.Unlock() }() - return nil } func notifyAll(notificationType NotificationType, ifaces ...interface{}) { @@ -327,12 +459,15 @@ func notifyAll(notificationType NotificationType, ifaces ...interface{}) { managerServicesLock.RUnlock() } -func IPCServerNotifyTunnelChange(name string, state TunnelState, err error) { +func errToString(err error) string { if err == nil { - notifyAll(TunnelChangeNotificationType, name, state, trackedTunnelsGlobalState(), "") - } else { - notifyAll(TunnelChangeNotificationType, name, state, trackedTunnelsGlobalState(), err.Error()) + return "" } + return err.Error() +} + +func IPCServerNotifyTunnelChange(name string, state TunnelState, err error) { + notifyAll(TunnelChangeNotificationType, name, state, trackedTunnelsGlobalState(), errToString(err)) } func IPCServerNotifyTunnelsChange() { @@ -344,11 +479,7 @@ func IPCServerNotifyUpdateFound(state UpdateState) { } func IPCServerNotifyUpdateProgress(dp updater.DownloadProgress) { - if dp.Error == nil { - notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, "", dp.Complete) - } else { - notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, dp.Error.Error(), dp.Complete) - } + notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, errToString(dp.Error), dp.Complete) } func IPCServerNotifyManagerStopping() { diff --git a/manager/service.go b/manager/service.go index 47ba51bc..e493e7cb 100644 --- a/manager/service.go +++ b/manager/service.go @@ -167,15 +167,15 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest runtime.LockOSThread() ourReader, theirReader, theirReaderStr, ourWriter, theirWriter, theirWriterStr, err := inheritableSocketpairEmulation() if err != nil { - log.Printf("Unable to create two inheritable pipes: %v", err) + log.Printf("Unable to create two inheritable RPC pipes: %v", err) return } ourEvents, theirEvents, theirEventStr, err := inheritableEvents() - err = IPCServerListen(ourReader, ourWriter, ourEvents, elevatedToken) if err != nil { - log.Printf("Unable to listen on IPC pipes: %v", err) + log.Printf("Unable to create one inheritable events pipe: %v", err) return } + IPCServerListen(ourReader, ourWriter, ourEvents, elevatedToken) theirLogMapping, theirLogMappingHandle, err := ringlogger.Global.ExportInheritableMappingHandleStr() if err != nil { log.Printf("Unable to export inheritable mapping handle for logging: %v", err) -- cgit v1.2.3-59-g8ed1b