From e493f911269a2dabab7b05ec28726cdaeffb660e Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Mon, 20 May 2019 14:18:01 +0200 Subject: service: split into tunnel and manager --- manager/install.go | 205 ++++++++++++++++++++++++++++ manager/ipc_client.go | 281 ++++++++++++++++++++++++++++++++++++++ manager/ipc_pipe.go | 77 +++++++++++ manager/ipc_server.go | 348 +++++++++++++++++++++++++++++++++++++++++++++++ manager/names.go | 26 ++++ manager/service.go | 331 ++++++++++++++++++++++++++++++++++++++++++++ manager/tunneltracker.go | 182 +++++++++++++++++++++++++ manager/updatestate.go | 57 ++++++++ 8 files changed, 1507 insertions(+) create mode 100644 manager/install.go create mode 100644 manager/ipc_client.go create mode 100644 manager/ipc_pipe.go create mode 100644 manager/ipc_server.go create mode 100644 manager/names.go create mode 100644 manager/service.go create mode 100644 manager/tunneltracker.go create mode 100644 manager/updatestate.go (limited to 'manager') diff --git a/manager/install.go b/manager/install.go new file mode 100644 index 00000000..4a570297 --- /dev/null +++ b/manager/install.go @@ -0,0 +1,205 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package manager + +import ( + "errors" + "os" + "time" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/mgr" + + "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/tunnel" +) + +var cachedServiceManager *mgr.Mgr + +func serviceManager() (*mgr.Mgr, error) { + if cachedServiceManager != nil { + return cachedServiceManager, nil + } + m, err := mgr.Connect() + if err != nil { + return nil, err + } + cachedServiceManager = m + return cachedServiceManager, nil +} + +func InstallManager() error { + m, err := serviceManager() + if err != nil { + return err + } + path, err := os.Executable() + if err != nil { + return nil + } + + //TODO: Do we want to bail if executable isn't being run from the right location? + + serviceName := "WireGuardManager" + service, err := m.OpenService(serviceName) + if err == nil { + status, err := service.Query() + if err != nil { + service.Close() + return err + } + if status.State != svc.Stopped { + service.Close() + return errors.New("Manager already installed and running") + } + err = service.Delete() + service.Close() + if err != nil { + return err + } + for { + service, err = m.OpenService(serviceName) + if err != nil { + break + } + service.Close() + time.Sleep(time.Second / 3) + } + } + + config := mgr.Config{ + ServiceType: windows.SERVICE_WIN32_OWN_PROCESS, + StartType: mgr.StartAutomatic, + ErrorControl: mgr.ErrorNormal, + DisplayName: "WireGuard Manager", + } + + service, err = m.CreateService(serviceName, path, config, "/managerservice") + if err != nil { + return err + } + service.Start() + return service.Close() +} + +func UninstallManager() error { + m, err := serviceManager() + if err != nil { + return err + } + serviceName := "WireGuardManager" + service, err := m.OpenService(serviceName) + if err != nil { + return err + } + service.Control(svc.Stop) + err = service.Delete() + err2 := service.Close() + if err != nil { + return err + } + return err2 +} + +func RunManager() error { + return svc.Run("WireGuardManager", &managerService{}) +} + +func InstallTunnel(configPath string) error { + m, err := serviceManager() + if err != nil { + return err + } + path, err := os.Executable() + if err != nil { + return nil + } + + name, err := conf.NameFromPath(configPath) + if err != nil { + return err + } + + serviceName, err := ServiceNameOfTunnel(name) + if err != nil { + return err + } + service, err := m.OpenService(serviceName) + if err == nil { + status, err := service.Query() + if err != nil && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE { + service.Close() + return err + } + if status.State != svc.Stopped && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE { + service.Close() + return errors.New("Tunnel already installed and running") + } + err = service.Delete() + service.Close() + if err != nil && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE { + return err + } + for { + service, err = m.OpenService(serviceName) + if err != nil && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE { + break + } + service.Close() + time.Sleep(time.Second / 3) + } + } + + config := mgr.Config{ + ServiceType: windows.SERVICE_WIN32_OWN_PROCESS, + StartType: mgr.StartAutomatic, + ErrorControl: mgr.ErrorNormal, + DisplayName: "WireGuard Tunnel: " + name, + } + + service, err = m.CreateService(serviceName, path, config, "/tunnelservice", configPath) + if err != nil { + return err + } + err = service.Start() + go trackTunnelService(name, service) // Pass off reference to handle. + return err +} + +func UninstallTunnel(name string) error { + m, err := serviceManager() + if err != nil { + return err + } + serviceName, err := ServiceNameOfTunnel(name) + if err != nil { + return err + } + service, err := m.OpenService(serviceName) + if err != nil { + return err + } + service.Control(svc.Stop) + err = service.Delete() + err2 := service.Close() + if err != nil && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE { + return err + } + return err2 +} + +func RunTunnel(confPath string) error { + name, err := conf.NameFromPath(confPath) + if err != nil { + return err + } + serviceName, err := ServiceNameOfTunnel(name) + if err != nil { + return err + } + return svc.Run(serviceName, &tunnel.Service{confPath}) +} diff --git a/manager/ipc_client.go b/manager/ipc_client.go new file mode 100644 index 00000000..a23493f0 --- /dev/null +++ b/manager/ipc_client.go @@ -0,0 +1,281 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package manager + +import ( + "encoding/gob" + "errors" + "net/rpc" + "os" + + "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/updater" +) + +type Tunnel struct { + Name string +} + +type TunnelState int + +const ( + TunnelUnknown TunnelState = iota + TunnelStarted + TunnelStopped + TunnelStarting + TunnelStopping +) + +type NotificationType int + +const ( + TunnelChangeNotificationType NotificationType = iota + TunnelsChangeNotificationType + ManagerStoppingNotificationType + UpdateFoundNotificationType + UpdateProgressNotificationType +) + +var rpcClient *rpc.Client + +type TunnelChangeCallback struct { + cb func(tunnel *Tunnel, state TunnelState, globalState TunnelState, err error) +} + +var tunnelChangeCallbacks = make(map[*TunnelChangeCallback]bool) + +type TunnelsChangeCallback struct { + cb func() +} + +var tunnelsChangeCallbacks = make(map[*TunnelsChangeCallback]bool) + +type ManagerStoppingCallback struct { + cb func() +} + +var managerStoppingCallbacks = make(map[*ManagerStoppingCallback]bool) + +type UpdateFoundCallback struct { + cb func(updateState UpdateState) +} + +var updateFoundCallbacks = make(map[*UpdateFoundCallback]bool) + +type UpdateProgressCallback struct { + cb func(dp updater.DownloadProgress) +} + +var updateProgressCallbacks = make(map[*UpdateProgressCallback]bool) + +func InitializeIPCClient(reader *os.File, writer *os.File, events *os.File) { + rpcClient = rpc.NewClient(&pipeRWC{reader, writer}) + go func() { + decoder := gob.NewDecoder(events) + for { + var notificationType NotificationType + err := decoder.Decode(¬ificationType) + if err != nil { + return + } + switch notificationType { + case TunnelChangeNotificationType: + var tunnel string + err := decoder.Decode(&tunnel) + if err != nil || len(tunnel) == 0 { + continue + } + var state TunnelState + err = decoder.Decode(&state) + if err != nil { + continue + } + var globalState TunnelState + err = decoder.Decode(&globalState) + if err != nil { + continue + } + var errStr string + err = decoder.Decode(&errStr) + if err != nil { + continue + } + var retErr error + if len(errStr) > 0 { + retErr = errors.New(errStr) + } + if state == TunnelUnknown { + continue + } + t := &Tunnel{tunnel} + for cb := range tunnelChangeCallbacks { + cb.cb(t, state, globalState, retErr) + } + case TunnelsChangeNotificationType: + for cb := range tunnelsChangeCallbacks { + cb.cb() + } + case ManagerStoppingNotificationType: + for cb := range managerStoppingCallbacks { + cb.cb() + } + case UpdateFoundNotificationType: + var state UpdateState + err = decoder.Decode(&state) + if err != nil { + continue + } + for cb := range updateFoundCallbacks { + cb.cb(state) + } + case UpdateProgressNotificationType: + var dp updater.DownloadProgress + err = decoder.Decode(&dp.Activity) + if err != nil { + continue + } + err = decoder.Decode(&dp.BytesDownloaded) + if err != nil { + continue + } + err = decoder.Decode(&dp.BytesTotal) + if err != nil { + continue + } + var errStr string + err = decoder.Decode(&errStr) + if err != nil { + continue + } + if len(errStr) > 0 { + dp.Error = errors.New(errStr) + } + err = decoder.Decode(&dp.Complete) + if err != nil { + continue + } + for cb := range updateProgressCallbacks { + cb.cb(dp) + } + } + } + }() +} + +func (t *Tunnel) StoredConfig() (c conf.Config, err error) { + err = rpcClient.Call("ManagerService.StoredConfig", t.Name, &c) + return +} + +func (t *Tunnel) RuntimeConfig() (c conf.Config, err error) { + err = rpcClient.Call("ManagerService.RuntimeConfig", t.Name, &c) + return +} + +func (t *Tunnel) Start() error { + return rpcClient.Call("ManagerService.Start", t.Name, nil) +} + +func (t *Tunnel) Stop() error { + return rpcClient.Call("ManagerService.Stop", t.Name, nil) +} + +func (t *Tunnel) Toggle() (oldState TunnelState, err error) { + oldState, err = t.State() + if err != nil { + oldState = TunnelUnknown + return + } + if oldState == TunnelStarted { + err = t.Stop() + } else if oldState == TunnelStopped { + err = t.Start() + } + return +} + +func (t *Tunnel) WaitForStop() error { + return rpcClient.Call("ManagerService.WaitForStop", t.Name, nil) +} + +func (t *Tunnel) Delete() error { + return rpcClient.Call("ManagerService.Delete", t.Name, nil) +} + +func (t *Tunnel) State() (TunnelState, error) { + var state TunnelState + return state, rpcClient.Call("ManagerService.State", t.Name, &state) +} + +func IPCClientNewTunnel(conf *conf.Config) (Tunnel, error) { + var tunnel Tunnel + return tunnel, rpcClient.Call("ManagerService.Create", *conf, &tunnel) +} + +func IPCClientTunnels() ([]Tunnel, error) { + var tunnels []Tunnel + return tunnels, rpcClient.Call("ManagerService.Tunnels", uintptr(0), &tunnels) +} + +func IPCClientGlobalState() (TunnelState, error) { + var state TunnelState + return state, rpcClient.Call("ManagerService.GlobalState", uintptr(0), &state) +} + +func IPCClientQuit(stopTunnelsOnQuit bool) (bool, error) { + var alreadyQuit bool + return alreadyQuit, rpcClient.Call("ManagerService.Quit", stopTunnelsOnQuit, &alreadyQuit) +} + +func IPCClientUpdateState() (UpdateState, error) { + var state UpdateState + return state, rpcClient.Call("ManagerService.UpdateState", uintptr(0), &state) +} + +func IPCClientUpdate() error { + return rpcClient.Call("ManagerService.Update", uintptr(0), nil) +} + +func IPCClientRegisterTunnelChange(cb func(tunnel *Tunnel, state TunnelState, globalState TunnelState, err error)) *TunnelChangeCallback { + s := &TunnelChangeCallback{cb} + tunnelChangeCallbacks[s] = true + return s +} +func (cb *TunnelChangeCallback) Unregister() { + delete(tunnelChangeCallbacks, cb) +} +func IPCClientRegisterTunnelsChange(cb func()) *TunnelsChangeCallback { + s := &TunnelsChangeCallback{cb} + tunnelsChangeCallbacks[s] = true + return s +} +func (cb *TunnelsChangeCallback) Unregister() { + delete(tunnelsChangeCallbacks, cb) +} +func IPCClientRegisterManagerStopping(cb func()) *ManagerStoppingCallback { + s := &ManagerStoppingCallback{cb} + managerStoppingCallbacks[s] = true + return s +} +func (cb *ManagerStoppingCallback) Unregister() { + delete(managerStoppingCallbacks, cb) +} +func IPCClientRegisterUpdateFound(cb func(updateState UpdateState)) *UpdateFoundCallback { + s := &UpdateFoundCallback{cb} + updateFoundCallbacks[s] = true + return s +} +func (cb *UpdateFoundCallback) Unregister() { + delete(updateFoundCallbacks, cb) +} +func IPCClientRegisterUpdateProgress(cb func(dp updater.DownloadProgress)) *UpdateProgressCallback { + s := &UpdateProgressCallback{cb} + updateProgressCallbacks[s] = true + return s +} +func (cb *UpdateProgressCallback) Unregister() { + delete(updateProgressCallbacks, cb) +} diff --git a/manager/ipc_pipe.go b/manager/ipc_pipe.go new file mode 100644 index 00000000..657a6275 --- /dev/null +++ b/manager/ipc_pipe.go @@ -0,0 +1,77 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package manager + +import ( + "os" + "strconv" + + "golang.org/x/sys/windows" +) + +type pipeRWC struct { + reader *os.File + writer *os.File +} + +func (p *pipeRWC) Read(b []byte) (int, error) { + return p.reader.Read(b) +} + +func (p *pipeRWC) Write(b []byte) (int, error) { + return p.writer.Write(b) +} + +func (p *pipeRWC) Close() error { + err1 := p.writer.Close() + err2 := p.reader.Close() + if err1 != nil { + return err1 + } + return err2 +} + +func makeInheritableAndGetStr(f *os.File) (str string, err error) { + sc, err := f.SyscallConn() + if err != nil { + return + } + err2 := sc.Control(func(fd uintptr) { + err = windows.SetHandleInformation(windows.Handle(fd), windows.HANDLE_FLAG_INHERIT, windows.HANDLE_FLAG_INHERIT) + str = strconv.FormatUint(uint64(fd), 10) + }) + if err2 != nil { + err = err2 + } + return +} + +func inheritableEvents() (ourEvents *os.File, theirEvents *os.File, theirEventStr string, err error) { + theirEvents, ourEvents, err = os.Pipe() + if err != nil { + return + } + theirEventStr, err = makeInheritableAndGetStr(theirEvents) + return +} + +func inheritableSocketpairEmulation() (ourReader *os.File, theirReader *os.File, theirReaderStr string, ourWriter *os.File, theirWriter *os.File, theirWriterStr string, err error) { + ourReader, theirWriter, err = os.Pipe() + if err != nil { + return + } + theirWriterStr, err = makeInheritableAndGetStr(theirWriter) + if err != nil { + return + } + + theirReader, ourWriter, err = os.Pipe() + if err != nil { + return + } + theirReaderStr, err = makeInheritableAndGetStr(theirReader) + return +} diff --git a/manager/ipc_server.go b/manager/ipc_server.go new file mode 100644 index 00000000..0accb4d3 --- /dev/null +++ b/manager/ipc_server.go @@ -0,0 +1,348 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package manager + +import ( + "bytes" + "encoding/gob" + "fmt" + "io/ioutil" + "log" + "net/rpc" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/Microsoft/go-winio" + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" + + "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/updater" +) + +var managerServices = make(map[*ManagerService]bool) +var managerServicesLock sync.RWMutex +var haveQuit uint32 +var quitManagersChan = make(chan struct{}, 1) + +type ManagerService struct { + events *os.File + elevatedToken windows.Token +} + +func (s *ManagerService) StoredConfig(tunnelName string, config *conf.Config) error { + c, err := conf.LoadFromName(tunnelName) + if err != nil { + return err + } + *config = *c + return nil +} + +func (s *ManagerService) RuntimeConfig(tunnelName string, config *conf.Config) error { + storedConfig, err := conf.LoadFromName(tunnelName) + if err != nil { + return err + } + pipePath, err := PipePathOfTunnel(storedConfig.Name) + if err != nil { + return err + } + pipe, err := winio.DialPipe(pipePath, nil) + if err != nil { + return err + } + pipe.SetWriteDeadline(time.Now().Add(time.Second * 2)) + _, err = pipe.Write([]byte("get=1\n\n")) + if err != nil { + return err + } + pipe.SetReadDeadline(time.Now().Add(time.Second * 2)) + resp, err := ioutil.ReadAll(pipe) + if err != nil { + return err + } + pipe.Close() + runtimeConfig, err := conf.FromUAPI(string(resp), storedConfig) + if err != nil { + return err + } + *config = *runtimeConfig + return nil +} + +func (s *ManagerService) Start(tunnelName string, unused *uintptr) error { + // For now, enforce only one tunnel at a time. Later we'll remove this silly restriction. + trackedTunnelsLock.Lock() + tt := make([]string, 0, len(trackedTunnels)) + var inTransition string + for t, state := range trackedTunnels { + tt = append(tt, t) + if len(t) > 0 && (state == TunnelStarting || state == TunnelUnknown) { + inTransition = t + break + } + } + trackedTunnelsLock.Unlock() + if len(inTransition) != 0 { + return fmt.Errorf("Please allow the tunnel ā€˜%sā€™ to finish activating", inTransition) + } + go func() { + for _, t := range tt { + s.Stop(t, unused) + } + for _, t := range tt { + var state TunnelState + var unused uintptr + if s.State(t, &state) == nil && (state == TunnelStarted || state == TunnelStarting) { + log.Printf("[%s] Trying again to stop zombie tunnel", t) + s.Stop(t, &unused) + time.Sleep(time.Millisecond * 100) + } + } + }() + + // After that process is started -- it's somewhat asynchronous -- we install the new one. + c, err := conf.LoadFromName(tunnelName) + if err != nil { + return err + } + path, err := c.Path() + if err != nil { + return err + } + return InstallTunnel(path) +} + +func (s *ManagerService) Stop(tunnelName string, _ *uintptr) error { + err := UninstallTunnel(tunnelName) + if err == windows.ERROR_SERVICE_DOES_NOT_EXIST { + _, notExistsError := conf.LoadFromName(tunnelName) + if notExistsError == nil { + return nil + } + } + return err +} + +func (s *ManagerService) WaitForStop(tunnelName string, _ *uintptr) error { + serviceName, err := ServiceNameOfTunnel(tunnelName) + if err != nil { + return err + } + m, err := serviceManager() + if err != nil { + return err + } + for { + service, err := m.OpenService(serviceName) + if err == nil || err == windows.ERROR_SERVICE_MARKED_FOR_DELETE { + service.Close() + time.Sleep(time.Second / 3) + } else { + return nil + } + } +} + +func (s *ManagerService) Delete(tunnelName string, _ *uintptr) error { + err := s.Stop(tunnelName, nil) + if err != nil { + return err + } + return conf.DeleteName(tunnelName) +} + +func (s *ManagerService) State(tunnelName string, state *TunnelState) error { + serviceName, err := ServiceNameOfTunnel(tunnelName) + if err != nil { + return err + } + m, err := serviceManager() + if err != nil { + return err + } + service, err := m.OpenService(serviceName) + if err != nil { + *state = TunnelStopped + return nil + } + defer service.Close() + status, err := service.Query() + if err != nil { + *state = TunnelUnknown + return err + } + switch status.State { + case svc.Stopped: + *state = TunnelStopped + case svc.StopPending: + *state = TunnelStopping + case svc.Running: + *state = TunnelStarted + case svc.StartPending: + *state = TunnelStarting + default: + *state = TunnelUnknown + } + return nil +} + +func (s *ManagerService) GlobalState(_ uintptr, state *TunnelState) error { + *state = trackedTunnelsGlobalState() + return nil +} + +func (s *ManagerService) Create(tunnelConfig conf.Config, tunnel *Tunnel) error { + err := tunnelConfig.Save() + if err != nil { + return err + } + *tunnel = Tunnel{tunnelConfig.Name} + return nil + //TODO: handle already existing situation + //TODO: handle already running and existing situation +} + +func (s *ManagerService) Tunnels(_ uintptr, tunnels *[]Tunnel) error { + names, err := conf.ListConfigNames() + if err != nil { + return err + } + *tunnels = make([]Tunnel, len(names)) + for i := 0; i < len(*tunnels); i++ { + (*tunnels)[i].Name = names[i] + } + return nil + //TODO: account for running ones that aren't in the configuration store somehow +} + +func (s *ManagerService) Quit(stopTunnelsOnQuit bool, alreadyQuit *bool) error { + if !atomic.CompareAndSwapUint32(&haveQuit, 0, 1) { + *alreadyQuit = true + return nil + } + *alreadyQuit = false + + // Work around potential race condition of delivering messages to the wrong process by removing from notifications. + managerServicesLock.Lock() + delete(managerServices, s) + managerServicesLock.Unlock() + + if stopTunnelsOnQuit { + names, err := conf.ListConfigNames() + if err != nil { + return err + } + for _, name := range names { + UninstallTunnel(name) + } + } + + quitManagersChan <- struct{}{} + return nil +} + +func (s *ManagerService) UpdateState(_ uintptr, state *UpdateState) error { + *state = updateState + return nil +} + +func (s *ManagerService) Update(_ uintptr, _ *uintptr) error { + progress := updater.DownloadVerifyAndExecute(uintptr(s.elevatedToken)) + go func() { + for { + dp := <-progress + IPCServerNotifyUpdateProgress(dp) + if dp.Complete || dp.Error != nil { + return + } + } + }() + return nil +} + +func IPCServerListen(reader *os.File, writer *os.File, events *os.File, elevatedToken windows.Token) error { + service := &ManagerService{ + events: events, + elevatedToken: elevatedToken, + } + + server := rpc.NewServer() + err := server.Register(service) + if err != nil { + return err + } + + go func() { + managerServicesLock.Lock() + managerServices[service] = true + managerServicesLock.Unlock() + server.ServeConn(&pipeRWC{reader, writer}) + managerServicesLock.Lock() + delete(managerServices, service) + managerServicesLock.Unlock() + + }() + return nil +} + +func notifyAll(notificationType NotificationType, ifaces ...interface{}) { + if len(managerServices) == 0 { + return + } + + var buf bytes.Buffer + encoder := gob.NewEncoder(&buf) + err := encoder.Encode(notificationType) + if err != nil { + return + } + for _, iface := range ifaces { + err = encoder.Encode(iface) + if err != nil { + return + } + } + + managerServicesLock.RLock() + for m := range managerServices { + m.events.SetWriteDeadline(time.Now().Add(time.Second)) + m.events.Write(buf.Bytes()) + } + managerServicesLock.RUnlock() +} + +func IPCServerNotifyTunnelChange(name string, state TunnelState, err error) { + if err == nil { + notifyAll(TunnelChangeNotificationType, name, state, trackedTunnelsGlobalState(), "") + } else { + notifyAll(TunnelChangeNotificationType, name, state, trackedTunnelsGlobalState(), err.Error()) + } +} + +func IPCServerNotifyTunnelsChange() { + notifyAll(TunnelsChangeNotificationType) +} + +func IPCServerNotifyUpdateFound(state UpdateState) { + notifyAll(UpdateFoundNotificationType, state) +} + +func IPCServerNotifyUpdateProgress(dp updater.DownloadProgress) { + if dp.Error == nil { + notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, "", dp.Complete) + } else { + notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, dp.Error.Error(), dp.Complete) + } +} + +func IPCServerNotifyManagerStopping() { + notifyAll(ManagerStoppingNotificationType) + time.Sleep(time.Millisecond * 200) +} diff --git a/manager/names.go b/manager/names.go new file mode 100644 index 00000000..bebf0cae --- /dev/null +++ b/manager/names.go @@ -0,0 +1,26 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package manager + +import ( + "errors" + + "golang.zx2c4.com/wireguard/windows/conf" +) + +func ServiceNameOfTunnel(tunnelName string) (string, error) { + if !conf.TunnelNameIsValid(tunnelName) { + return "", errors.New("Tunnel name is not valid") + } + return "WireGuardTunnel$" + tunnelName, nil +} + +func PipePathOfTunnel(tunnelName string) (string, error) { + if !conf.TunnelNameIsValid(tunnelName) { + return "", errors.New("Tunnel name is not valid") + } + return "\\\\.\\pipe\\WireGuard\\" + tunnelName, nil +} diff --git a/manager/service.go b/manager/service.go new file mode 100644 index 00000000..ba7208d8 --- /dev/null +++ b/manager/service.go @@ -0,0 +1,331 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package manager + +import ( + "errors" + "fmt" + "log" + "os" + "runtime" + "runtime/debug" + "strings" + "sync" + "syscall" + "time" + "unsafe" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" + + "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/ringlogger" + "golang.zx2c4.com/wireguard/windows/services" + "golang.zx2c4.com/wireguard/windows/version" +) + +type managerService struct{} + +func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) { + changes <- svc.Status{State: svc.StartPending} + + var err error + serviceError := services.ErrorSuccess + + defer func() { + svcSpecificEC, exitCode = services.DetermineErrorCode(err, serviceError) + logErr := services.CombineErrors(err, serviceError) + if logErr != nil { + log.Print(logErr) + } + changes <- svc.Status{State: svc.StopPending} + }() + + err = ringlogger.InitGlobalLogger("MGR") + if err != nil { + serviceError = services.ErrorRingloggerOpen + return + } + defer func() { + if x := recover(); x != nil { + for _, line := range append([]string{fmt.Sprint(x)}, strings.Split(string(debug.Stack()), "\n")...) { + if len(strings.TrimSpace(line)) > 0 { + log.Println(line) + } + } + panic(x) + } + }() + + log.Println("Starting", version.UserAgent()) + + path, err := os.Executable() + if err != nil { + serviceError = services.ErrorDetermineExecutablePath + return + } + + devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0) + if err != nil { + serviceError = services.ErrorOpenNULFile + return + } + + err = trackExistingTunnels() + if err != nil { + serviceError = services.ErrorTrackTunnels + return + } + + conf.RegisterStoreChangeCallback(func() { conf.MigrateUnencryptedConfigs() }) // Ignore return value for now, but could be useful later. + conf.RegisterStoreChangeCallback(IPCServerNotifyTunnelsChange) + + procs := make(map[uint32]*os.Process) + aliveSessions := make(map[uint32]bool) + procsLock := sync.Mutex{} + var startProcess func(session uint32) + stoppingManager := false + + startProcess = func(session uint32) { + defer func() { + runtime.UnlockOSThread() + procsLock.Lock() + delete(aliveSessions, session) + procsLock.Unlock() + }() + + var userToken windows.Token + err := windows.WTSQueryUserToken(session, &userToken) + if err != nil { + return + } + if !services.TokenIsMemberOfBuiltInAdministrator(userToken) { + userToken.Close() + return + } + user, err := userToken.GetTokenUser() + if err != nil { + log.Printf("Unable to lookup user from token: %v", err) + userToken.Close() + return + } + username, domain, accType, err := user.User.Sid.LookupAccount("") + if err != nil { + log.Printf("Unable to lookup username from sid: %v", err) + userToken.Close() + return + } + if accType != windows.SidTypeUser { + userToken.Close() + return + } + var elevatedToken windows.Token + if userToken.IsElevated() { + elevatedToken = userToken + } else { + elevatedToken, err = userToken.GetLinkedToken() + userToken.Close() + if err != nil { + log.Printf("Unable to elevate token: %v", err) + return + } + if !elevatedToken.IsElevated() { + elevatedToken.Close() + log.Println("Linked token is not elevated") + return + } + } + defer elevatedToken.Close() + userToken = 0 + first := true + for { + if stoppingManager { + return + } + + procsLock.Lock() + if alive := aliveSessions[session]; !alive { + procsLock.Unlock() + return + } + procsLock.Unlock() + + if !first { + time.Sleep(time.Second) + } else { + first = false + } + + //TODO: we lock the OS thread so that these inheritable handles don't escape into other processes that + // might be running in parallel Go routines. But the Go runtime is strange and who knows what's really + // happening with these or what is inherited. We need to do some analysis to be certain of what's going on. + runtime.LockOSThread() + ourReader, theirReader, theirReaderStr, ourWriter, theirWriter, theirWriterStr, err := inheritableSocketpairEmulation() + if err != nil { + log.Printf("Unable to create two inheritable pipes: %v", err) + return + } + ourEvents, theirEvents, theirEventStr, err := inheritableEvents() + err = IPCServerListen(ourReader, ourWriter, ourEvents, elevatedToken) + if err != nil { + log.Printf("Unable to listen on IPC pipes: %v", err) + return + } + theirLogMapping, theirLogMappingHandle, err := ringlogger.Global.ExportInheritableMappingHandleStr() + if err != nil { + log.Printf("Unable to export inheritable mapping handle for logging: %v", err) + return + } + + log.Printf("Starting UI process for user '%s@%s' for session %d", username, domain, session) + attr := &os.ProcAttr{ + Sys: &syscall.SysProcAttr{ + Token: syscall.Token(elevatedToken), + }, + Files: []*os.File{devNull, devNull, devNull}, + } + procsLock.Lock() + var proc *os.Process + if alive := aliveSessions[session]; alive { + proc, err = os.StartProcess(path, []string{path, "/ui", theirReaderStr, theirWriterStr, theirEventStr, theirLogMapping}, attr) + } else { + err = errors.New("Session has logged out") + } + procsLock.Unlock() + theirReader.Close() + theirWriter.Close() + theirEvents.Close() + windows.Close(theirLogMappingHandle) + runtime.UnlockOSThread() + if err != nil { + ourReader.Close() + ourWriter.Close() + ourEvents.Close() + log.Printf("Unable to start manager UI process for user '%s@%s' for session %d: %v", username, domain, session, err) + return + } + + procsLock.Lock() + procs[session] = proc + procsLock.Unlock() + + sessionIsDead := false + processStatus, err := proc.Wait() + if err == nil { + exitCode := processStatus.Sys().(syscall.WaitStatus).ExitCode + log.Printf("Exited UI process for user '%s@%s' for session %d with status %x", username, domain, session, exitCode) + const STATUS_DLL_INIT_FAILED_LOGOFF = 0xC000026B + sessionIsDead = exitCode == STATUS_DLL_INIT_FAILED_LOGOFF + } else { + log.Printf("Unable to wait for UI process for user '%s@%s' for session %d: %v", username, domain, session, err) + } + + procsLock.Lock() + delete(procs, session) + procsLock.Unlock() + ourReader.Close() + ourWriter.Close() + ourEvents.Close() + + if sessionIsDead { + return + } + } + } + + go checkForUpdates() + + var sessionsPointer *windows.WTS_SESSION_INFO + var count uint32 + err = windows.WTSEnumerateSessions(0, 0, 1, &sessionsPointer, &count) + if err != nil { + serviceError = services.ErrorEnumerateSessions + return + } + sessions := *(*[]windows.WTS_SESSION_INFO)(unsafe.Pointer(&struct { + addr *windows.WTS_SESSION_INFO + len int + cap int + }{sessionsPointer, int(count), int(count)})) + for _, session := range sessions { + if session.State != windows.WTSActive && session.State != windows.WTSDisconnected { + continue + } + procsLock.Lock() + if alive := aliveSessions[session.SessionID]; !alive { + aliveSessions[session.SessionID] = true + if _, ok := procs[session.SessionID]; !ok { + go startProcess(session.SessionID) + } + } + procsLock.Unlock() + } + windows.WTSFreeMemory(uintptr(unsafe.Pointer(sessionsPointer))) + + changes <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptSessionChange} + + uninstall := false +loop: + for { + select { + case <-quitManagersChan: + uninstall = true + break loop + case c := <-r: + switch c.Cmd { + case svc.Stop: + break loop + case svc.Interrogate: + changes <- c.CurrentStatus + case svc.SessionChange: + if c.EventType != windows.WTS_SESSION_LOGON && c.EventType != windows.WTS_SESSION_LOGOFF { + continue + } + sessionNotification := (*windows.WTSSESSION_NOTIFICATION)(unsafe.Pointer(c.EventData)) + if uintptr(sessionNotification.Size) != unsafe.Sizeof(*sessionNotification) { + log.Printf("Unexpected size of WTSSESSION_NOTIFICATION: %d", sessionNotification.Size) + continue + } + if c.EventType == windows.WTS_SESSION_LOGOFF { + procsLock.Lock() + delete(aliveSessions, sessionNotification.SessionID) + if proc, ok := procs[sessionNotification.SessionID]; ok { + proc.Kill() + } + procsLock.Unlock() + } else if c.EventType == windows.WTS_SESSION_LOGON { + procsLock.Lock() + if alive := aliveSessions[sessionNotification.SessionID]; !alive { + aliveSessions[sessionNotification.SessionID] = true + if _, ok := procs[sessionNotification.SessionID]; !ok { + go startProcess(sessionNotification.SessionID) + } + } + procsLock.Unlock() + } + + default: + log.Printf("Unexpected service control request #%d", c) + } + } + } + + changes <- svc.Status{State: svc.StopPending} + procsLock.Lock() + stoppingManager = true + IPCServerNotifyManagerStopping() + for _, proc := range procs { + proc.Kill() + } + procsLock.Unlock() + if uninstall { + err = UninstallManager() + if err != nil { + log.Printf("Unable to uninstaller manager when quitting: %v", err) + } + } + return +} diff --git a/manager/tunneltracker.go b/manager/tunneltracker.go new file mode 100644 index 00000000..1cde98e2 --- /dev/null +++ b/manager/tunneltracker.go @@ -0,0 +1,182 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package manager + +import ( + "fmt" + "log" + "runtime" + "sync" + "syscall" + "time" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/mgr" + + "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/services" +) + +func trackExistingTunnels() error { + m, err := serviceManager() + if err != nil { + return err + } + names, err := conf.ListConfigNames() + if err != nil { + return err + } + for _, name := range names { + serviceName, err := ServiceNameOfTunnel(name) + if err != nil { + continue + } + service, err := m.OpenService(serviceName) + if err != nil { + continue + } + go trackTunnelService(name, service) + } + return nil +} + +var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *windows.SERVICE_NOTIFY) uintptr { + return 0 +}) + +var trackedTunnels = make(map[string]TunnelState) +var trackedTunnelsLock = sync.Mutex{} + +func svcStateToTunState(s svc.State) TunnelState { + switch s { + case svc.StartPending: + return TunnelStarting + case svc.Running: + return TunnelStarted + case svc.StopPending: + return TunnelStopping + case svc.Stopped: + return TunnelStopped + default: + return TunnelUnknown + } +} + +func trackedTunnelsGlobalState() (state TunnelState) { + state = TunnelStopped + trackedTunnelsLock.Lock() + defer trackedTunnelsLock.Unlock() + for _, s := range trackedTunnels { + if s == TunnelStarting { + return TunnelStarting + } else if s == TunnelStopping { + return TunnelStopping + } else if s == TunnelStarted || s == TunnelUnknown { + state = TunnelStarted + } + } + return +} + +func trackTunnelService(tunnelName string, service *mgr.Service) { + defer func() { + service.Close() + log.Printf("[%s] Tunnel managerService tracker finished", tunnelName) + }() + + trackedTunnelsLock.Lock() + if _, found := trackedTunnels[tunnelName]; found { + trackedTunnelsLock.Unlock() + return + } + trackedTunnels[tunnelName] = TunnelUnknown + trackedTunnelsLock.Unlock() + defer func() { + trackedTunnelsLock.Lock() + delete(trackedTunnels, tunnelName) + trackedTunnelsLock.Unlock() + }() + + const serviceNotifications = windows.SERVICE_NOTIFY_RUNNING | windows.SERVICE_NOTIFY_START_PENDING | windows.SERVICE_NOTIFY_STOP_PENDING | windows.SERVICE_NOTIFY_STOPPED | windows.SERVICE_NOTIFY_DELETE_PENDING + notifier := &windows.SERVICE_NOTIFY{ + Version: windows.SERVICE_NOTIFY_STATUS_CHANGE, + NotifyCallback: serviceTrackerCallbackPtr, + } + + checkForDisabled := func() (shouldReturn bool) { + config, err := service.Config() + if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || config.StartType == windows.SERVICE_DISABLED { + log.Printf("[%s] Found disabled service via timeout, so deleting", tunnelName) + service.Delete() + trackedTunnelsLock.Lock() + trackedTunnels[tunnelName] = TunnelStopped + trackedTunnelsLock.Unlock() + IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil) + return true + } + return false + } + if checkForDisabled() { + return + } + + runtime.LockOSThread() + defer runtime.UnlockOSThread() + lastState := TunnelUnknown + for { + err := windows.NotifyServiceStatusChange(service.Handle, serviceNotifications, notifier) + switch err { + case nil: + for { + if windows.SleepEx(uint32(time.Second*3/time.Millisecond), true) == windows.WAIT_IO_COMPLETION { + break + } else if checkForDisabled() { + return + } + } + case windows.ERROR_SERVICE_MARKED_FOR_DELETE: + trackedTunnelsLock.Lock() + trackedTunnels[tunnelName] = TunnelStopped + trackedTunnelsLock.Unlock() + IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil) + return + case windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING: + continue + default: + trackedTunnelsLock.Lock() + trackedTunnels[tunnelName] = TunnelStopped + trackedTunnelsLock.Unlock() + IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring managerService, so stopping: %v", err)) + service.Control(svc.Stop) + return + } + + state := svcStateToTunState(svc.State(notifier.ServiceStatus.CurrentState)) + var tunnelError error + if state == TunnelStopped { + if notifier.ServiceStatus.Win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) { + maybeErr := services.Error(notifier.ServiceStatus.ServiceSpecificExitCode) + if maybeErr != services.ErrorSuccess { + tunnelError = maybeErr + } + } else { + switch notifier.ServiceStatus.Win32ExitCode { + case uint32(windows.NO_ERROR), uint32(windows.ERROR_SERVICE_NEVER_STARTED): + default: + tunnelError = syscall.Errno(notifier.ServiceStatus.Win32ExitCode) + } + } + } + if state != lastState { + trackedTunnelsLock.Lock() + trackedTunnels[tunnelName] = state + trackedTunnelsLock.Unlock() + IPCServerNotifyTunnelChange(tunnelName, state, tunnelError) + lastState = state + } + } +} diff --git a/manager/updatestate.go b/manager/updatestate.go new file mode 100644 index 00000000..2e82baf8 --- /dev/null +++ b/manager/updatestate.go @@ -0,0 +1,57 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package manager + +import ( + "log" + "time" + + "golang.zx2c4.com/wireguard/windows/updater" + "golang.zx2c4.com/wireguard/windows/version" +) + +type UpdateState uint32 + +const ( + UpdateStateUnknown UpdateState = iota + UpdateStateFoundUpdate + UpdateStateUpdatesDisabledUnofficialBuild +) + +var updateState = UpdateStateUnknown + +func checkForUpdates() { + if !version.IsRunningOfficialVersion() { + log.Println("Build is not official, so updates are disabled") + updateState = UpdateStateUpdatesDisabledUnofficialBuild + IPCServerNotifyUpdateFound(updateState) + return + } + + time.Sleep(time.Second * 10) + + first := true + for { + update, err := updater.CheckForUpdate() + if err == nil && update != nil { + log.Println("An update is available") + updateState = UpdateStateFoundUpdate + IPCServerNotifyUpdateFound(updateState) + return + } + if err != nil { + log.Printf("Update checker: %v", err) + if first { + time.Sleep(time.Minute * 4) + first = false + } else { + time.Sleep(time.Minute * 25) + } + } else { + time.Sleep(time.Hour) + } + } +} -- cgit v1.2.3-59-g8ed1b