aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/manager
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2019-05-20 14:18:01 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2019-05-20 14:18:01 +0200
commite493f911269a2dabab7b05ec28726cdaeffb660e (patch)
treedb88ec568dfc508da863e67164de909448c66742 /manager
parentservice: move route monitor and account for changing index (diff)
downloadwireguard-windows-e493f911269a2dabab7b05ec28726cdaeffb660e.tar.xz
wireguard-windows-e493f911269a2dabab7b05ec28726cdaeffb660e.zip
service: split into tunnel and manager
Diffstat (limited to 'manager')
-rw-r--r--manager/install.go205
-rw-r--r--manager/ipc_client.go281
-rw-r--r--manager/ipc_pipe.go77
-rw-r--r--manager/ipc_server.go348
-rw-r--r--manager/names.go26
-rw-r--r--manager/service.go331
-rw-r--r--manager/tunneltracker.go182
-rw-r--r--manager/updatestate.go57
8 files changed, 1507 insertions, 0 deletions
diff --git a/manager/install.go b/manager/install.go
new file mode 100644
index 00000000..4a570297
--- /dev/null
+++ b/manager/install.go
@@ -0,0 +1,205 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package manager
+
+import (
+ "errors"
+ "os"
+ "time"
+
+ "golang.org/x/sys/windows"
+ "golang.org/x/sys/windows/svc"
+ "golang.org/x/sys/windows/svc/mgr"
+
+ "golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/tunnel"
+)
+
+var cachedServiceManager *mgr.Mgr
+
+func serviceManager() (*mgr.Mgr, error) {
+ if cachedServiceManager != nil {
+ return cachedServiceManager, nil
+ }
+ m, err := mgr.Connect()
+ if err != nil {
+ return nil, err
+ }
+ cachedServiceManager = m
+ return cachedServiceManager, nil
+}
+
+func InstallManager() error {
+ m, err := serviceManager()
+ if err != nil {
+ return err
+ }
+ path, err := os.Executable()
+ if err != nil {
+ return nil
+ }
+
+ //TODO: Do we want to bail if executable isn't being run from the right location?
+
+ serviceName := "WireGuardManager"
+ service, err := m.OpenService(serviceName)
+ if err == nil {
+ status, err := service.Query()
+ if err != nil {
+ service.Close()
+ return err
+ }
+ if status.State != svc.Stopped {
+ service.Close()
+ return errors.New("Manager already installed and running")
+ }
+ err = service.Delete()
+ service.Close()
+ if err != nil {
+ return err
+ }
+ for {
+ service, err = m.OpenService(serviceName)
+ if err != nil {
+ break
+ }
+ service.Close()
+ time.Sleep(time.Second / 3)
+ }
+ }
+
+ config := mgr.Config{
+ ServiceType: windows.SERVICE_WIN32_OWN_PROCESS,
+ StartType: mgr.StartAutomatic,
+ ErrorControl: mgr.ErrorNormal,
+ DisplayName: "WireGuard Manager",
+ }
+
+ service, err = m.CreateService(serviceName, path, config, "/managerservice")
+ if err != nil {
+ return err
+ }
+ service.Start()
+ return service.Close()
+}
+
+func UninstallManager() error {
+ m, err := serviceManager()
+ if err != nil {
+ return err
+ }
+ serviceName := "WireGuardManager"
+ service, err := m.OpenService(serviceName)
+ if err != nil {
+ return err
+ }
+ service.Control(svc.Stop)
+ err = service.Delete()
+ err2 := service.Close()
+ if err != nil {
+ return err
+ }
+ return err2
+}
+
+func RunManager() error {
+ return svc.Run("WireGuardManager", &managerService{})
+}
+
+func InstallTunnel(configPath string) error {
+ m, err := serviceManager()
+ if err != nil {
+ return err
+ }
+ path, err := os.Executable()
+ if err != nil {
+ return nil
+ }
+
+ name, err := conf.NameFromPath(configPath)
+ if err != nil {
+ return err
+ }
+
+ serviceName, err := ServiceNameOfTunnel(name)
+ if err != nil {
+ return err
+ }
+ service, err := m.OpenService(serviceName)
+ if err == nil {
+ status, err := service.Query()
+ if err != nil && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE {
+ service.Close()
+ return err
+ }
+ if status.State != svc.Stopped && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE {
+ service.Close()
+ return errors.New("Tunnel already installed and running")
+ }
+ err = service.Delete()
+ service.Close()
+ if err != nil && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE {
+ return err
+ }
+ for {
+ service, err = m.OpenService(serviceName)
+ if err != nil && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE {
+ break
+ }
+ service.Close()
+ time.Sleep(time.Second / 3)
+ }
+ }
+
+ config := mgr.Config{
+ ServiceType: windows.SERVICE_WIN32_OWN_PROCESS,
+ StartType: mgr.StartAutomatic,
+ ErrorControl: mgr.ErrorNormal,
+ DisplayName: "WireGuard Tunnel: " + name,
+ }
+
+ service, err = m.CreateService(serviceName, path, config, "/tunnelservice", configPath)
+ if err != nil {
+ return err
+ }
+ err = service.Start()
+ go trackTunnelService(name, service) // Pass off reference to handle.
+ return err
+}
+
+func UninstallTunnel(name string) error {
+ m, err := serviceManager()
+ if err != nil {
+ return err
+ }
+ serviceName, err := ServiceNameOfTunnel(name)
+ if err != nil {
+ return err
+ }
+ service, err := m.OpenService(serviceName)
+ if err != nil {
+ return err
+ }
+ service.Control(svc.Stop)
+ err = service.Delete()
+ err2 := service.Close()
+ if err != nil && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE {
+ return err
+ }
+ return err2
+}
+
+func RunTunnel(confPath string) error {
+ name, err := conf.NameFromPath(confPath)
+ if err != nil {
+ return err
+ }
+ serviceName, err := ServiceNameOfTunnel(name)
+ if err != nil {
+ return err
+ }
+ return svc.Run(serviceName, &tunnel.Service{confPath})
+}
diff --git a/manager/ipc_client.go b/manager/ipc_client.go
new file mode 100644
index 00000000..a23493f0
--- /dev/null
+++ b/manager/ipc_client.go
@@ -0,0 +1,281 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package manager
+
+import (
+ "encoding/gob"
+ "errors"
+ "net/rpc"
+ "os"
+
+ "golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/updater"
+)
+
+type Tunnel struct {
+ Name string
+}
+
+type TunnelState int
+
+const (
+ TunnelUnknown TunnelState = iota
+ TunnelStarted
+ TunnelStopped
+ TunnelStarting
+ TunnelStopping
+)
+
+type NotificationType int
+
+const (
+ TunnelChangeNotificationType NotificationType = iota
+ TunnelsChangeNotificationType
+ ManagerStoppingNotificationType
+ UpdateFoundNotificationType
+ UpdateProgressNotificationType
+)
+
+var rpcClient *rpc.Client
+
+type TunnelChangeCallback struct {
+ cb func(tunnel *Tunnel, state TunnelState, globalState TunnelState, err error)
+}
+
+var tunnelChangeCallbacks = make(map[*TunnelChangeCallback]bool)
+
+type TunnelsChangeCallback struct {
+ cb func()
+}
+
+var tunnelsChangeCallbacks = make(map[*TunnelsChangeCallback]bool)
+
+type ManagerStoppingCallback struct {
+ cb func()
+}
+
+var managerStoppingCallbacks = make(map[*ManagerStoppingCallback]bool)
+
+type UpdateFoundCallback struct {
+ cb func(updateState UpdateState)
+}
+
+var updateFoundCallbacks = make(map[*UpdateFoundCallback]bool)
+
+type UpdateProgressCallback struct {
+ cb func(dp updater.DownloadProgress)
+}
+
+var updateProgressCallbacks = make(map[*UpdateProgressCallback]bool)
+
+func InitializeIPCClient(reader *os.File, writer *os.File, events *os.File) {
+ rpcClient = rpc.NewClient(&pipeRWC{reader, writer})
+ go func() {
+ decoder := gob.NewDecoder(events)
+ for {
+ var notificationType NotificationType
+ err := decoder.Decode(&notificationType)
+ if err != nil {
+ return
+ }
+ switch notificationType {
+ case TunnelChangeNotificationType:
+ var tunnel string
+ err := decoder.Decode(&tunnel)
+ if err != nil || len(tunnel) == 0 {
+ continue
+ }
+ var state TunnelState
+ err = decoder.Decode(&state)
+ if err != nil {
+ continue
+ }
+ var globalState TunnelState
+ err = decoder.Decode(&globalState)
+ if err != nil {
+ continue
+ }
+ var errStr string
+ err = decoder.Decode(&errStr)
+ if err != nil {
+ continue
+ }
+ var retErr error
+ if len(errStr) > 0 {
+ retErr = errors.New(errStr)
+ }
+ if state == TunnelUnknown {
+ continue
+ }
+ t := &Tunnel{tunnel}
+ for cb := range tunnelChangeCallbacks {
+ cb.cb(t, state, globalState, retErr)
+ }
+ case TunnelsChangeNotificationType:
+ for cb := range tunnelsChangeCallbacks {
+ cb.cb()
+ }
+ case ManagerStoppingNotificationType:
+ for cb := range managerStoppingCallbacks {
+ cb.cb()
+ }
+ case UpdateFoundNotificationType:
+ var state UpdateState
+ err = decoder.Decode(&state)
+ if err != nil {
+ continue
+ }
+ for cb := range updateFoundCallbacks {
+ cb.cb(state)
+ }
+ case UpdateProgressNotificationType:
+ var dp updater.DownloadProgress
+ err = decoder.Decode(&dp.Activity)
+ if err != nil {
+ continue
+ }
+ err = decoder.Decode(&dp.BytesDownloaded)
+ if err != nil {
+ continue
+ }
+ err = decoder.Decode(&dp.BytesTotal)
+ if err != nil {
+ continue
+ }
+ var errStr string
+ err = decoder.Decode(&errStr)
+ if err != nil {
+ continue
+ }
+ if len(errStr) > 0 {
+ dp.Error = errors.New(errStr)
+ }
+ err = decoder.Decode(&dp.Complete)
+ if err != nil {
+ continue
+ }
+ for cb := range updateProgressCallbacks {
+ cb.cb(dp)
+ }
+ }
+ }
+ }()
+}
+
+func (t *Tunnel) StoredConfig() (c conf.Config, err error) {
+ err = rpcClient.Call("ManagerService.StoredConfig", t.Name, &c)
+ return
+}
+
+func (t *Tunnel) RuntimeConfig() (c conf.Config, err error) {
+ err = rpcClient.Call("ManagerService.RuntimeConfig", t.Name, &c)
+ return
+}
+
+func (t *Tunnel) Start() error {
+ return rpcClient.Call("ManagerService.Start", t.Name, nil)
+}
+
+func (t *Tunnel) Stop() error {
+ return rpcClient.Call("ManagerService.Stop", t.Name, nil)
+}
+
+func (t *Tunnel) Toggle() (oldState TunnelState, err error) {
+ oldState, err = t.State()
+ if err != nil {
+ oldState = TunnelUnknown
+ return
+ }
+ if oldState == TunnelStarted {
+ err = t.Stop()
+ } else if oldState == TunnelStopped {
+ err = t.Start()
+ }
+ return
+}
+
+func (t *Tunnel) WaitForStop() error {
+ return rpcClient.Call("ManagerService.WaitForStop", t.Name, nil)
+}
+
+func (t *Tunnel) Delete() error {
+ return rpcClient.Call("ManagerService.Delete", t.Name, nil)
+}
+
+func (t *Tunnel) State() (TunnelState, error) {
+ var state TunnelState
+ return state, rpcClient.Call("ManagerService.State", t.Name, &state)
+}
+
+func IPCClientNewTunnel(conf *conf.Config) (Tunnel, error) {
+ var tunnel Tunnel
+ return tunnel, rpcClient.Call("ManagerService.Create", *conf, &tunnel)
+}
+
+func IPCClientTunnels() ([]Tunnel, error) {
+ var tunnels []Tunnel
+ return tunnels, rpcClient.Call("ManagerService.Tunnels", uintptr(0), &tunnels)
+}
+
+func IPCClientGlobalState() (TunnelState, error) {
+ var state TunnelState
+ return state, rpcClient.Call("ManagerService.GlobalState", uintptr(0), &state)
+}
+
+func IPCClientQuit(stopTunnelsOnQuit bool) (bool, error) {
+ var alreadyQuit bool
+ return alreadyQuit, rpcClient.Call("ManagerService.Quit", stopTunnelsOnQuit, &alreadyQuit)
+}
+
+func IPCClientUpdateState() (UpdateState, error) {
+ var state UpdateState
+ return state, rpcClient.Call("ManagerService.UpdateState", uintptr(0), &state)
+}
+
+func IPCClientUpdate() error {
+ return rpcClient.Call("ManagerService.Update", uintptr(0), nil)
+}
+
+func IPCClientRegisterTunnelChange(cb func(tunnel *Tunnel, state TunnelState, globalState TunnelState, err error)) *TunnelChangeCallback {
+ s := &TunnelChangeCallback{cb}
+ tunnelChangeCallbacks[s] = true
+ return s
+}
+func (cb *TunnelChangeCallback) Unregister() {
+ delete(tunnelChangeCallbacks, cb)
+}
+func IPCClientRegisterTunnelsChange(cb func()) *TunnelsChangeCallback {
+ s := &TunnelsChangeCallback{cb}
+ tunnelsChangeCallbacks[s] = true
+ return s
+}
+func (cb *TunnelsChangeCallback) Unregister() {
+ delete(tunnelsChangeCallbacks, cb)
+}
+func IPCClientRegisterManagerStopping(cb func()) *ManagerStoppingCallback {
+ s := &ManagerStoppingCallback{cb}
+ managerStoppingCallbacks[s] = true
+ return s
+}
+func (cb *ManagerStoppingCallback) Unregister() {
+ delete(managerStoppingCallbacks, cb)
+}
+func IPCClientRegisterUpdateFound(cb func(updateState UpdateState)) *UpdateFoundCallback {
+ s := &UpdateFoundCallback{cb}
+ updateFoundCallbacks[s] = true
+ return s
+}
+func (cb *UpdateFoundCallback) Unregister() {
+ delete(updateFoundCallbacks, cb)
+}
+func IPCClientRegisterUpdateProgress(cb func(dp updater.DownloadProgress)) *UpdateProgressCallback {
+ s := &UpdateProgressCallback{cb}
+ updateProgressCallbacks[s] = true
+ return s
+}
+func (cb *UpdateProgressCallback) Unregister() {
+ delete(updateProgressCallbacks, cb)
+}
diff --git a/manager/ipc_pipe.go b/manager/ipc_pipe.go
new file mode 100644
index 00000000..657a6275
--- /dev/null
+++ b/manager/ipc_pipe.go
@@ -0,0 +1,77 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package manager
+
+import (
+ "os"
+ "strconv"
+
+ "golang.org/x/sys/windows"
+)
+
+type pipeRWC struct {
+ reader *os.File
+ writer *os.File
+}
+
+func (p *pipeRWC) Read(b []byte) (int, error) {
+ return p.reader.Read(b)
+}
+
+func (p *pipeRWC) Write(b []byte) (int, error) {
+ return p.writer.Write(b)
+}
+
+func (p *pipeRWC) Close() error {
+ err1 := p.writer.Close()
+ err2 := p.reader.Close()
+ if err1 != nil {
+ return err1
+ }
+ return err2
+}
+
+func makeInheritableAndGetStr(f *os.File) (str string, err error) {
+ sc, err := f.SyscallConn()
+ if err != nil {
+ return
+ }
+ err2 := sc.Control(func(fd uintptr) {
+ err = windows.SetHandleInformation(windows.Handle(fd), windows.HANDLE_FLAG_INHERIT, windows.HANDLE_FLAG_INHERIT)
+ str = strconv.FormatUint(uint64(fd), 10)
+ })
+ if err2 != nil {
+ err = err2
+ }
+ return
+}
+
+func inheritableEvents() (ourEvents *os.File, theirEvents *os.File, theirEventStr string, err error) {
+ theirEvents, ourEvents, err = os.Pipe()
+ if err != nil {
+ return
+ }
+ theirEventStr, err = makeInheritableAndGetStr(theirEvents)
+ return
+}
+
+func inheritableSocketpairEmulation() (ourReader *os.File, theirReader *os.File, theirReaderStr string, ourWriter *os.File, theirWriter *os.File, theirWriterStr string, err error) {
+ ourReader, theirWriter, err = os.Pipe()
+ if err != nil {
+ return
+ }
+ theirWriterStr, err = makeInheritableAndGetStr(theirWriter)
+ if err != nil {
+ return
+ }
+
+ theirReader, ourWriter, err = os.Pipe()
+ if err != nil {
+ return
+ }
+ theirReaderStr, err = makeInheritableAndGetStr(theirReader)
+ return
+}
diff --git a/manager/ipc_server.go b/manager/ipc_server.go
new file mode 100644
index 00000000..0accb4d3
--- /dev/null
+++ b/manager/ipc_server.go
@@ -0,0 +1,348 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package manager
+
+import (
+ "bytes"
+ "encoding/gob"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "net/rpc"
+ "os"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/Microsoft/go-winio"
+ "golang.org/x/sys/windows"
+ "golang.org/x/sys/windows/svc"
+
+ "golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/updater"
+)
+
+var managerServices = make(map[*ManagerService]bool)
+var managerServicesLock sync.RWMutex
+var haveQuit uint32
+var quitManagersChan = make(chan struct{}, 1)
+
+type ManagerService struct {
+ events *os.File
+ elevatedToken windows.Token
+}
+
+func (s *ManagerService) StoredConfig(tunnelName string, config *conf.Config) error {
+ c, err := conf.LoadFromName(tunnelName)
+ if err != nil {
+ return err
+ }
+ *config = *c
+ return nil
+}
+
+func (s *ManagerService) RuntimeConfig(tunnelName string, config *conf.Config) error {
+ storedConfig, err := conf.LoadFromName(tunnelName)
+ if err != nil {
+ return err
+ }
+ pipePath, err := PipePathOfTunnel(storedConfig.Name)
+ if err != nil {
+ return err
+ }
+ pipe, err := winio.DialPipe(pipePath, nil)
+ if err != nil {
+ return err
+ }
+ pipe.SetWriteDeadline(time.Now().Add(time.Second * 2))
+ _, err = pipe.Write([]byte("get=1\n\n"))
+ if err != nil {
+ return err
+ }
+ pipe.SetReadDeadline(time.Now().Add(time.Second * 2))
+ resp, err := ioutil.ReadAll(pipe)
+ if err != nil {
+ return err
+ }
+ pipe.Close()
+ runtimeConfig, err := conf.FromUAPI(string(resp), storedConfig)
+ if err != nil {
+ return err
+ }
+ *config = *runtimeConfig
+ return nil
+}
+
+func (s *ManagerService) Start(tunnelName string, unused *uintptr) error {
+ // For now, enforce only one tunnel at a time. Later we'll remove this silly restriction.
+ trackedTunnelsLock.Lock()
+ tt := make([]string, 0, len(trackedTunnels))
+ var inTransition string
+ for t, state := range trackedTunnels {
+ tt = append(tt, t)
+ if len(t) > 0 && (state == TunnelStarting || state == TunnelUnknown) {
+ inTransition = t
+ break
+ }
+ }
+ trackedTunnelsLock.Unlock()
+ if len(inTransition) != 0 {
+ return fmt.Errorf("Please allow the tunnel ā€˜%sā€™ to finish activating", inTransition)
+ }
+ go func() {
+ for _, t := range tt {
+ s.Stop(t, unused)
+ }
+ for _, t := range tt {
+ var state TunnelState
+ var unused uintptr
+ if s.State(t, &state) == nil && (state == TunnelStarted || state == TunnelStarting) {
+ log.Printf("[%s] Trying again to stop zombie tunnel", t)
+ s.Stop(t, &unused)
+ time.Sleep(time.Millisecond * 100)
+ }
+ }
+ }()
+
+ // After that process is started -- it's somewhat asynchronous -- we install the new one.
+ c, err := conf.LoadFromName(tunnelName)
+ if err != nil {
+ return err
+ }
+ path, err := c.Path()
+ if err != nil {
+ return err
+ }
+ return InstallTunnel(path)
+}
+
+func (s *ManagerService) Stop(tunnelName string, _ *uintptr) error {
+ err := UninstallTunnel(tunnelName)
+ if err == windows.ERROR_SERVICE_DOES_NOT_EXIST {
+ _, notExistsError := conf.LoadFromName(tunnelName)
+ if notExistsError == nil {
+ return nil
+ }
+ }
+ return err
+}
+
+func (s *ManagerService) WaitForStop(tunnelName string, _ *uintptr) error {
+ serviceName, err := ServiceNameOfTunnel(tunnelName)
+ if err != nil {
+ return err
+ }
+ m, err := serviceManager()
+ if err != nil {
+ return err
+ }
+ for {
+ service, err := m.OpenService(serviceName)
+ if err == nil || err == windows.ERROR_SERVICE_MARKED_FOR_DELETE {
+ service.Close()
+ time.Sleep(time.Second / 3)
+ } else {
+ return nil
+ }
+ }
+}
+
+func (s *ManagerService) Delete(tunnelName string, _ *uintptr) error {
+ err := s.Stop(tunnelName, nil)
+ if err != nil {
+ return err
+ }
+ return conf.DeleteName(tunnelName)
+}
+
+func (s *ManagerService) State(tunnelName string, state *TunnelState) error {
+ serviceName, err := ServiceNameOfTunnel(tunnelName)
+ if err != nil {
+ return err
+ }
+ m, err := serviceManager()
+ if err != nil {
+ return err
+ }
+ service, err := m.OpenService(serviceName)
+ if err != nil {
+ *state = TunnelStopped
+ return nil
+ }
+ defer service.Close()
+ status, err := service.Query()
+ if err != nil {
+ *state = TunnelUnknown
+ return err
+ }
+ switch status.State {
+ case svc.Stopped:
+ *state = TunnelStopped
+ case svc.StopPending:
+ *state = TunnelStopping
+ case svc.Running:
+ *state = TunnelStarted
+ case svc.StartPending:
+ *state = TunnelStarting
+ default:
+ *state = TunnelUnknown
+ }
+ return nil
+}
+
+func (s *ManagerService) GlobalState(_ uintptr, state *TunnelState) error {
+ *state = trackedTunnelsGlobalState()
+ return nil
+}
+
+func (s *ManagerService) Create(tunnelConfig conf.Config, tunnel *Tunnel) error {
+ err := tunnelConfig.Save()
+ if err != nil {
+ return err
+ }
+ *tunnel = Tunnel{tunnelConfig.Name}
+ return nil
+ //TODO: handle already existing situation
+ //TODO: handle already running and existing situation
+}
+
+func (s *ManagerService) Tunnels(_ uintptr, tunnels *[]Tunnel) error {
+ names, err := conf.ListConfigNames()
+ if err != nil {
+ return err
+ }
+ *tunnels = make([]Tunnel, len(names))
+ for i := 0; i < len(*tunnels); i++ {
+ (*tunnels)[i].Name = names[i]
+ }
+ return nil
+ //TODO: account for running ones that aren't in the configuration store somehow
+}
+
+func (s *ManagerService) Quit(stopTunnelsOnQuit bool, alreadyQuit *bool) error {
+ if !atomic.CompareAndSwapUint32(&haveQuit, 0, 1) {
+ *alreadyQuit = true
+ return nil
+ }
+ *alreadyQuit = false
+
+ // Work around potential race condition of delivering messages to the wrong process by removing from notifications.
+ managerServicesLock.Lock()
+ delete(managerServices, s)
+ managerServicesLock.Unlock()
+
+ if stopTunnelsOnQuit {
+ names, err := conf.ListConfigNames()
+ if err != nil {
+ return err
+ }
+ for _, name := range names {
+ UninstallTunnel(name)
+ }
+ }
+
+ quitManagersChan <- struct{}{}
+ return nil
+}
+
+func (s *ManagerService) UpdateState(_ uintptr, state *UpdateState) error {
+ *state = updateState
+ return nil
+}
+
+func (s *ManagerService) Update(_ uintptr, _ *uintptr) error {
+ progress := updater.DownloadVerifyAndExecute(uintptr(s.elevatedToken))
+ go func() {
+ for {
+ dp := <-progress
+ IPCServerNotifyUpdateProgress(dp)
+ if dp.Complete || dp.Error != nil {
+ return
+ }
+ }
+ }()
+ return nil
+}
+
+func IPCServerListen(reader *os.File, writer *os.File, events *os.File, elevatedToken windows.Token) error {
+ service := &ManagerService{
+ events: events,
+ elevatedToken: elevatedToken,
+ }
+
+ server := rpc.NewServer()
+ err := server.Register(service)
+ if err != nil {
+ return err
+ }
+
+ go func() {
+ managerServicesLock.Lock()
+ managerServices[service] = true
+ managerServicesLock.Unlock()
+ server.ServeConn(&pipeRWC{reader, writer})
+ managerServicesLock.Lock()
+ delete(managerServices, service)
+ managerServicesLock.Unlock()
+
+ }()
+ return nil
+}
+
+func notifyAll(notificationType NotificationType, ifaces ...interface{}) {
+ if len(managerServices) == 0 {
+ return
+ }
+
+ var buf bytes.Buffer
+ encoder := gob.NewEncoder(&buf)
+ err := encoder.Encode(notificationType)
+ if err != nil {
+ return
+ }
+ for _, iface := range ifaces {
+ err = encoder.Encode(iface)
+ if err != nil {
+ return
+ }
+ }
+
+ managerServicesLock.RLock()
+ for m := range managerServices {
+ m.events.SetWriteDeadline(time.Now().Add(time.Second))
+ m.events.Write(buf.Bytes())
+ }
+ managerServicesLock.RUnlock()
+}
+
+func IPCServerNotifyTunnelChange(name string, state TunnelState, err error) {
+ if err == nil {
+ notifyAll(TunnelChangeNotificationType, name, state, trackedTunnelsGlobalState(), "")
+ } else {
+ notifyAll(TunnelChangeNotificationType, name, state, trackedTunnelsGlobalState(), err.Error())
+ }
+}
+
+func IPCServerNotifyTunnelsChange() {
+ notifyAll(TunnelsChangeNotificationType)
+}
+
+func IPCServerNotifyUpdateFound(state UpdateState) {
+ notifyAll(UpdateFoundNotificationType, state)
+}
+
+func IPCServerNotifyUpdateProgress(dp updater.DownloadProgress) {
+ if dp.Error == nil {
+ notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, "", dp.Complete)
+ } else {
+ notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, dp.Error.Error(), dp.Complete)
+ }
+}
+
+func IPCServerNotifyManagerStopping() {
+ notifyAll(ManagerStoppingNotificationType)
+ time.Sleep(time.Millisecond * 200)
+}
diff --git a/manager/names.go b/manager/names.go
new file mode 100644
index 00000000..bebf0cae
--- /dev/null
+++ b/manager/names.go
@@ -0,0 +1,26 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package manager
+
+import (
+ "errors"
+
+ "golang.zx2c4.com/wireguard/windows/conf"
+)
+
+func ServiceNameOfTunnel(tunnelName string) (string, error) {
+ if !conf.TunnelNameIsValid(tunnelName) {
+ return "", errors.New("Tunnel name is not valid")
+ }
+ return "WireGuardTunnel$" + tunnelName, nil
+}
+
+func PipePathOfTunnel(tunnelName string) (string, error) {
+ if !conf.TunnelNameIsValid(tunnelName) {
+ return "", errors.New("Tunnel name is not valid")
+ }
+ return "\\\\.\\pipe\\WireGuard\\" + tunnelName, nil
+}
diff --git a/manager/service.go b/manager/service.go
new file mode 100644
index 00000000..ba7208d8
--- /dev/null
+++ b/manager/service.go
@@ -0,0 +1,331 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package manager
+
+import (
+ "errors"
+ "fmt"
+ "log"
+ "os"
+ "runtime"
+ "runtime/debug"
+ "strings"
+ "sync"
+ "syscall"
+ "time"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+ "golang.org/x/sys/windows/svc"
+
+ "golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/ringlogger"
+ "golang.zx2c4.com/wireguard/windows/services"
+ "golang.zx2c4.com/wireguard/windows/version"
+)
+
+type managerService struct{}
+
+func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) {
+ changes <- svc.Status{State: svc.StartPending}
+
+ var err error
+ serviceError := services.ErrorSuccess
+
+ defer func() {
+ svcSpecificEC, exitCode = services.DetermineErrorCode(err, serviceError)
+ logErr := services.CombineErrors(err, serviceError)
+ if logErr != nil {
+ log.Print(logErr)
+ }
+ changes <- svc.Status{State: svc.StopPending}
+ }()
+
+ err = ringlogger.InitGlobalLogger("MGR")
+ if err != nil {
+ serviceError = services.ErrorRingloggerOpen
+ return
+ }
+ defer func() {
+ if x := recover(); x != nil {
+ for _, line := range append([]string{fmt.Sprint(x)}, strings.Split(string(debug.Stack()), "\n")...) {
+ if len(strings.TrimSpace(line)) > 0 {
+ log.Println(line)
+ }
+ }
+ panic(x)
+ }
+ }()
+
+ log.Println("Starting", version.UserAgent())
+
+ path, err := os.Executable()
+ if err != nil {
+ serviceError = services.ErrorDetermineExecutablePath
+ return
+ }
+
+ devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0)
+ if err != nil {
+ serviceError = services.ErrorOpenNULFile
+ return
+ }
+
+ err = trackExistingTunnels()
+ if err != nil {
+ serviceError = services.ErrorTrackTunnels
+ return
+ }
+
+ conf.RegisterStoreChangeCallback(func() { conf.MigrateUnencryptedConfigs() }) // Ignore return value for now, but could be useful later.
+ conf.RegisterStoreChangeCallback(IPCServerNotifyTunnelsChange)
+
+ procs := make(map[uint32]*os.Process)
+ aliveSessions := make(map[uint32]bool)
+ procsLock := sync.Mutex{}
+ var startProcess func(session uint32)
+ stoppingManager := false
+
+ startProcess = func(session uint32) {
+ defer func() {
+ runtime.UnlockOSThread()
+ procsLock.Lock()
+ delete(aliveSessions, session)
+ procsLock.Unlock()
+ }()
+
+ var userToken windows.Token
+ err := windows.WTSQueryUserToken(session, &userToken)
+ if err != nil {
+ return
+ }
+ if !services.TokenIsMemberOfBuiltInAdministrator(userToken) {
+ userToken.Close()
+ return
+ }
+ user, err := userToken.GetTokenUser()
+ if err != nil {
+ log.Printf("Unable to lookup user from token: %v", err)
+ userToken.Close()
+ return
+ }
+ username, domain, accType, err := user.User.Sid.LookupAccount("")
+ if err != nil {
+ log.Printf("Unable to lookup username from sid: %v", err)
+ userToken.Close()
+ return
+ }
+ if accType != windows.SidTypeUser {
+ userToken.Close()
+ return
+ }
+ var elevatedToken windows.Token
+ if userToken.IsElevated() {
+ elevatedToken = userToken
+ } else {
+ elevatedToken, err = userToken.GetLinkedToken()
+ userToken.Close()
+ if err != nil {
+ log.Printf("Unable to elevate token: %v", err)
+ return
+ }
+ if !elevatedToken.IsElevated() {
+ elevatedToken.Close()
+ log.Println("Linked token is not elevated")
+ return
+ }
+ }
+ defer elevatedToken.Close()
+ userToken = 0
+ first := true
+ for {
+ if stoppingManager {
+ return
+ }
+
+ procsLock.Lock()
+ if alive := aliveSessions[session]; !alive {
+ procsLock.Unlock()
+ return
+ }
+ procsLock.Unlock()
+
+ if !first {
+ time.Sleep(time.Second)
+ } else {
+ first = false
+ }
+
+ //TODO: we lock the OS thread so that these inheritable handles don't escape into other processes that
+ // might be running in parallel Go routines. But the Go runtime is strange and who knows what's really
+ // happening with these or what is inherited. We need to do some analysis to be certain of what's going on.
+ runtime.LockOSThread()
+ ourReader, theirReader, theirReaderStr, ourWriter, theirWriter, theirWriterStr, err := inheritableSocketpairEmulation()
+ if err != nil {
+ log.Printf("Unable to create two inheritable pipes: %v", err)
+ return
+ }
+ ourEvents, theirEvents, theirEventStr, err := inheritableEvents()
+ err = IPCServerListen(ourReader, ourWriter, ourEvents, elevatedToken)
+ if err != nil {
+ log.Printf("Unable to listen on IPC pipes: %v", err)
+ return
+ }
+ theirLogMapping, theirLogMappingHandle, err := ringlogger.Global.ExportInheritableMappingHandleStr()
+ if err != nil {
+ log.Printf("Unable to export inheritable mapping handle for logging: %v", err)
+ return
+ }
+
+ log.Printf("Starting UI process for user '%s@%s' for session %d", username, domain, session)
+ attr := &os.ProcAttr{
+ Sys: &syscall.SysProcAttr{
+ Token: syscall.Token(elevatedToken),
+ },
+ Files: []*os.File{devNull, devNull, devNull},
+ }
+ procsLock.Lock()
+ var proc *os.Process
+ if alive := aliveSessions[session]; alive {
+ proc, err = os.StartProcess(path, []string{path, "/ui", theirReaderStr, theirWriterStr, theirEventStr, theirLogMapping}, attr)
+ } else {
+ err = errors.New("Session has logged out")
+ }
+ procsLock.Unlock()
+ theirReader.Close()
+ theirWriter.Close()
+ theirEvents.Close()
+ windows.Close(theirLogMappingHandle)
+ runtime.UnlockOSThread()
+ if err != nil {
+ ourReader.Close()
+ ourWriter.Close()
+ ourEvents.Close()
+ log.Printf("Unable to start manager UI process for user '%s@%s' for session %d: %v", username, domain, session, err)
+ return
+ }
+
+ procsLock.Lock()
+ procs[session] = proc
+ procsLock.Unlock()
+
+ sessionIsDead := false
+ processStatus, err := proc.Wait()
+ if err == nil {
+ exitCode := processStatus.Sys().(syscall.WaitStatus).ExitCode
+ log.Printf("Exited UI process for user '%s@%s' for session %d with status %x", username, domain, session, exitCode)
+ const STATUS_DLL_INIT_FAILED_LOGOFF = 0xC000026B
+ sessionIsDead = exitCode == STATUS_DLL_INIT_FAILED_LOGOFF
+ } else {
+ log.Printf("Unable to wait for UI process for user '%s@%s' for session %d: %v", username, domain, session, err)
+ }
+
+ procsLock.Lock()
+ delete(procs, session)
+ procsLock.Unlock()
+ ourReader.Close()
+ ourWriter.Close()
+ ourEvents.Close()
+
+ if sessionIsDead {
+ return
+ }
+ }
+ }
+
+ go checkForUpdates()
+
+ var sessionsPointer *windows.WTS_SESSION_INFO
+ var count uint32
+ err = windows.WTSEnumerateSessions(0, 0, 1, &sessionsPointer, &count)
+ if err != nil {
+ serviceError = services.ErrorEnumerateSessions
+ return
+ }
+ sessions := *(*[]windows.WTS_SESSION_INFO)(unsafe.Pointer(&struct {
+ addr *windows.WTS_SESSION_INFO
+ len int
+ cap int
+ }{sessionsPointer, int(count), int(count)}))
+ for _, session := range sessions {
+ if session.State != windows.WTSActive && session.State != windows.WTSDisconnected {
+ continue
+ }
+ procsLock.Lock()
+ if alive := aliveSessions[session.SessionID]; !alive {
+ aliveSessions[session.SessionID] = true
+ if _, ok := procs[session.SessionID]; !ok {
+ go startProcess(session.SessionID)
+ }
+ }
+ procsLock.Unlock()
+ }
+ windows.WTSFreeMemory(uintptr(unsafe.Pointer(sessionsPointer)))
+
+ changes <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptSessionChange}
+
+ uninstall := false
+loop:
+ for {
+ select {
+ case <-quitManagersChan:
+ uninstall = true
+ break loop
+ case c := <-r:
+ switch c.Cmd {
+ case svc.Stop:
+ break loop
+ case svc.Interrogate:
+ changes <- c.CurrentStatus
+ case svc.SessionChange:
+ if c.EventType != windows.WTS_SESSION_LOGON && c.EventType != windows.WTS_SESSION_LOGOFF {
+ continue
+ }
+ sessionNotification := (*windows.WTSSESSION_NOTIFICATION)(unsafe.Pointer(c.EventData))
+ if uintptr(sessionNotification.Size) != unsafe.Sizeof(*sessionNotification) {
+ log.Printf("Unexpected size of WTSSESSION_NOTIFICATION: %d", sessionNotification.Size)
+ continue
+ }
+ if c.EventType == windows.WTS_SESSION_LOGOFF {
+ procsLock.Lock()
+ delete(aliveSessions, sessionNotification.SessionID)
+ if proc, ok := procs[sessionNotification.SessionID]; ok {
+ proc.Kill()
+ }
+ procsLock.Unlock()
+ } else if c.EventType == windows.WTS_SESSION_LOGON {
+ procsLock.Lock()
+ if alive := aliveSessions[sessionNotification.SessionID]; !alive {
+ aliveSessions[sessionNotification.SessionID] = true
+ if _, ok := procs[sessionNotification.SessionID]; !ok {
+ go startProcess(sessionNotification.SessionID)
+ }
+ }
+ procsLock.Unlock()
+ }
+
+ default:
+ log.Printf("Unexpected service control request #%d", c)
+ }
+ }
+ }
+
+ changes <- svc.Status{State: svc.StopPending}
+ procsLock.Lock()
+ stoppingManager = true
+ IPCServerNotifyManagerStopping()
+ for _, proc := range procs {
+ proc.Kill()
+ }
+ procsLock.Unlock()
+ if uninstall {
+ err = UninstallManager()
+ if err != nil {
+ log.Printf("Unable to uninstaller manager when quitting: %v", err)
+ }
+ }
+ return
+}
diff --git a/manager/tunneltracker.go b/manager/tunneltracker.go
new file mode 100644
index 00000000..1cde98e2
--- /dev/null
+++ b/manager/tunneltracker.go
@@ -0,0 +1,182 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package manager
+
+import (
+ "fmt"
+ "log"
+ "runtime"
+ "sync"
+ "syscall"
+ "time"
+
+ "golang.org/x/sys/windows"
+ "golang.org/x/sys/windows/svc"
+ "golang.org/x/sys/windows/svc/mgr"
+
+ "golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/services"
+)
+
+func trackExistingTunnels() error {
+ m, err := serviceManager()
+ if err != nil {
+ return err
+ }
+ names, err := conf.ListConfigNames()
+ if err != nil {
+ return err
+ }
+ for _, name := range names {
+ serviceName, err := ServiceNameOfTunnel(name)
+ if err != nil {
+ continue
+ }
+ service, err := m.OpenService(serviceName)
+ if err != nil {
+ continue
+ }
+ go trackTunnelService(name, service)
+ }
+ return nil
+}
+
+var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *windows.SERVICE_NOTIFY) uintptr {
+ return 0
+})
+
+var trackedTunnels = make(map[string]TunnelState)
+var trackedTunnelsLock = sync.Mutex{}
+
+func svcStateToTunState(s svc.State) TunnelState {
+ switch s {
+ case svc.StartPending:
+ return TunnelStarting
+ case svc.Running:
+ return TunnelStarted
+ case svc.StopPending:
+ return TunnelStopping
+ case svc.Stopped:
+ return TunnelStopped
+ default:
+ return TunnelUnknown
+ }
+}
+
+func trackedTunnelsGlobalState() (state TunnelState) {
+ state = TunnelStopped
+ trackedTunnelsLock.Lock()
+ defer trackedTunnelsLock.Unlock()
+ for _, s := range trackedTunnels {
+ if s == TunnelStarting {
+ return TunnelStarting
+ } else if s == TunnelStopping {
+ return TunnelStopping
+ } else if s == TunnelStarted || s == TunnelUnknown {
+ state = TunnelStarted
+ }
+ }
+ return
+}
+
+func trackTunnelService(tunnelName string, service *mgr.Service) {
+ defer func() {
+ service.Close()
+ log.Printf("[%s] Tunnel managerService tracker finished", tunnelName)
+ }()
+
+ trackedTunnelsLock.Lock()
+ if _, found := trackedTunnels[tunnelName]; found {
+ trackedTunnelsLock.Unlock()
+ return
+ }
+ trackedTunnels[tunnelName] = TunnelUnknown
+ trackedTunnelsLock.Unlock()
+ defer func() {
+ trackedTunnelsLock.Lock()
+ delete(trackedTunnels, tunnelName)
+ trackedTunnelsLock.Unlock()
+ }()
+
+ const serviceNotifications = windows.SERVICE_NOTIFY_RUNNING | windows.SERVICE_NOTIFY_START_PENDING | windows.SERVICE_NOTIFY_STOP_PENDING | windows.SERVICE_NOTIFY_STOPPED | windows.SERVICE_NOTIFY_DELETE_PENDING
+ notifier := &windows.SERVICE_NOTIFY{
+ Version: windows.SERVICE_NOTIFY_STATUS_CHANGE,
+ NotifyCallback: serviceTrackerCallbackPtr,
+ }
+
+ checkForDisabled := func() (shouldReturn bool) {
+ config, err := service.Config()
+ if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || config.StartType == windows.SERVICE_DISABLED {
+ log.Printf("[%s] Found disabled service via timeout, so deleting", tunnelName)
+ service.Delete()
+ trackedTunnelsLock.Lock()
+ trackedTunnels[tunnelName] = TunnelStopped
+ trackedTunnelsLock.Unlock()
+ IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil)
+ return true
+ }
+ return false
+ }
+ if checkForDisabled() {
+ return
+ }
+
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+ lastState := TunnelUnknown
+ for {
+ err := windows.NotifyServiceStatusChange(service.Handle, serviceNotifications, notifier)
+ switch err {
+ case nil:
+ for {
+ if windows.SleepEx(uint32(time.Second*3/time.Millisecond), true) == windows.WAIT_IO_COMPLETION {
+ break
+ } else if checkForDisabled() {
+ return
+ }
+ }
+ case windows.ERROR_SERVICE_MARKED_FOR_DELETE:
+ trackedTunnelsLock.Lock()
+ trackedTunnels[tunnelName] = TunnelStopped
+ trackedTunnelsLock.Unlock()
+ IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil)
+ return
+ case windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING:
+ continue
+ default:
+ trackedTunnelsLock.Lock()
+ trackedTunnels[tunnelName] = TunnelStopped
+ trackedTunnelsLock.Unlock()
+ IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring managerService, so stopping: %v", err))
+ service.Control(svc.Stop)
+ return
+ }
+
+ state := svcStateToTunState(svc.State(notifier.ServiceStatus.CurrentState))
+ var tunnelError error
+ if state == TunnelStopped {
+ if notifier.ServiceStatus.Win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) {
+ maybeErr := services.Error(notifier.ServiceStatus.ServiceSpecificExitCode)
+ if maybeErr != services.ErrorSuccess {
+ tunnelError = maybeErr
+ }
+ } else {
+ switch notifier.ServiceStatus.Win32ExitCode {
+ case uint32(windows.NO_ERROR), uint32(windows.ERROR_SERVICE_NEVER_STARTED):
+ default:
+ tunnelError = syscall.Errno(notifier.ServiceStatus.Win32ExitCode)
+ }
+ }
+ }
+ if state != lastState {
+ trackedTunnelsLock.Lock()
+ trackedTunnels[tunnelName] = state
+ trackedTunnelsLock.Unlock()
+ IPCServerNotifyTunnelChange(tunnelName, state, tunnelError)
+ lastState = state
+ }
+ }
+}
diff --git a/manager/updatestate.go b/manager/updatestate.go
new file mode 100644
index 00000000..2e82baf8
--- /dev/null
+++ b/manager/updatestate.go
@@ -0,0 +1,57 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package manager
+
+import (
+ "log"
+ "time"
+
+ "golang.zx2c4.com/wireguard/windows/updater"
+ "golang.zx2c4.com/wireguard/windows/version"
+)
+
+type UpdateState uint32
+
+const (
+ UpdateStateUnknown UpdateState = iota
+ UpdateStateFoundUpdate
+ UpdateStateUpdatesDisabledUnofficialBuild
+)
+
+var updateState = UpdateStateUnknown
+
+func checkForUpdates() {
+ if !version.IsRunningOfficialVersion() {
+ log.Println("Build is not official, so updates are disabled")
+ updateState = UpdateStateUpdatesDisabledUnofficialBuild
+ IPCServerNotifyUpdateFound(updateState)
+ return
+ }
+
+ time.Sleep(time.Second * 10)
+
+ first := true
+ for {
+ update, err := updater.CheckForUpdate()
+ if err == nil && update != nil {
+ log.Println("An update is available")
+ updateState = UpdateStateFoundUpdate
+ IPCServerNotifyUpdateFound(updateState)
+ return
+ }
+ if err != nil {
+ log.Printf("Update checker: %v", err)
+ if first {
+ time.Sleep(time.Minute * 4)
+ first = false
+ } else {
+ time.Sleep(time.Minute * 25)
+ }
+ } else {
+ time.Sleep(time.Hour)
+ }
+ }
+}