aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/manager/ipc_server.go
diff options
context:
space:
mode:
Diffstat (limited to 'manager/ipc_server.go')
-rw-r--r--manager/ipc_server.go134
1 files changed, 82 insertions, 52 deletions
diff --git a/manager/ipc_server.go b/manager/ipc_server.go
index 1367c2e9..e21ffaf0 100644
--- a/manager/ipc_server.go
+++ b/manager/ipc_server.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package manager
@@ -10,7 +10,6 @@ import (
"encoding/gob"
"fmt"
"io"
- "io/ioutil"
"log"
"os"
"sync"
@@ -20,64 +19,73 @@ import (
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
- "golang.zx2c4.com/wireguard/ipc/winpipe"
-
"golang.zx2c4.com/wireguard/windows/conf"
- "golang.zx2c4.com/wireguard/windows/services"
"golang.zx2c4.com/wireguard/windows/updater"
)
-var managerServices = make(map[*ManagerService]bool)
-var managerServicesLock sync.RWMutex
-var haveQuit uint32
-var quitManagersChan = make(chan struct{}, 1)
+var (
+ managerServices = make(map[*ManagerService]bool)
+ managerServicesLock sync.RWMutex
+ haveQuit uint32
+ quitManagersChan = make(chan struct{}, 1)
+)
type ManagerService struct {
events *os.File
+ eventLock sync.Mutex
elevatedToken windows.Token
}
func (s *ManagerService) StoredConfig(tunnelName string) (*conf.Config, error) {
- return conf.LoadFromName(tunnelName)
-}
-
-func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) {
- storedConfig, err := conf.LoadFromName(tunnelName)
+ conf, err := conf.LoadFromName(tunnelName)
if err != nil {
return nil, err
}
- pipePath, err := services.PipePathOfTunnel(storedConfig.Name)
- if err != nil {
- return nil, err
+ if s.elevatedToken == 0 {
+ conf.Redact()
}
- localSystem, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid)
+ return conf, nil
+}
+
+func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) {
+ storedConfig, err := conf.LoadFromName(tunnelName)
if err != nil {
return nil, err
}
- pipe, err := winpipe.DialPipe(pipePath, nil, localSystem)
+ driverAdapter, err := findDriverAdapter(tunnelName)
if err != nil {
return nil, err
}
- defer pipe.Close()
- pipe.SetWriteDeadline(time.Now().Add(time.Second * 2))
- _, err = pipe.Write([]byte("get=1\n\n"))
+ runtimeConfig, err := driverAdapter.Configuration()
if err != nil {
+ driverAdapter.Unlock()
+ releaseDriverAdapter(tunnelName)
return nil, err
}
- pipe.SetReadDeadline(time.Now().Add(time.Second * 2))
- resp, err := ioutil.ReadAll(pipe)
- if err != nil {
- return nil, err
+ conf := conf.FromDriverConfiguration(runtimeConfig, storedConfig)
+ driverAdapter.Unlock()
+ if s.elevatedToken == 0 {
+ conf.Redact()
}
- return conf.FromUAPI(string(resp), storedConfig)
+ return conf, nil
}
func (s *ManagerService) Start(tunnelName string) error {
- // For now, enforce only one tunnel at a time. Later we'll remove this silly restriction.
+ c, err := conf.LoadFromName(tunnelName)
+ if err != nil {
+ return err
+ }
+
+ // Figure out which tunnels have intersecting addresses/routes and stop those.
trackedTunnelsLock.Lock()
tt := make([]string, 0, len(trackedTunnels))
var inTransition string
for t, state := range trackedTunnels {
+ c2, err := conf.LoadFromName(t)
+ if err != nil || !c.IntersectsWith(c2) {
+ // If we can't get the config, assume it doesn't intersect.
+ continue
+ }
tt = append(tt, t)
if len(t) > 0 && (state == TunnelStarting || state == TunnelUnknown) {
inTransition = t
@@ -88,6 +96,8 @@ func (s *ManagerService) Start(tunnelName string) error {
if len(inTransition) != 0 {
return fmt.Errorf("Please allow the tunnel ā€˜%sā€™ to finish activating", inTransition)
}
+
+ // Stop those intersecting tunnels asynchronously.
go func() {
for _, t := range tt {
s.Stop(t)
@@ -101,13 +111,7 @@ func (s *ManagerService) Start(tunnelName string) error {
}
}
}()
- time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces)
-
- // After that process is started -- it's somewhat asynchronous -- we install the new one.
- c, err := conf.LoadFromName(tunnelName)
- if err != nil {
- return err
- }
+ // After the stop process has begun, but before it's finished, we install the new one.
path, err := c.Path()
if err != nil {
return err
@@ -116,8 +120,6 @@ func (s *ManagerService) Start(tunnelName string) error {
}
func (s *ManagerService) Stop(tunnelName string) error {
- time.AfterFunc(time.Second*10, cleanupStaleWintunInterfaces)
-
err := UninstallTunnel(tunnelName)
if err == windows.ERROR_SERVICE_DOES_NOT_EXIST {
_, notExistsError := conf.LoadFromName(tunnelName)
@@ -129,7 +131,7 @@ func (s *ManagerService) Stop(tunnelName string) error {
}
func (s *ManagerService) WaitForStop(tunnelName string) error {
- serviceName, err := services.ServiceNameOfTunnel(tunnelName)
+ serviceName, err := conf.ServiceNameOfTunnel(tunnelName)
if err != nil {
return err
}
@@ -149,6 +151,9 @@ func (s *ManagerService) WaitForStop(tunnelName string) error {
}
func (s *ManagerService) Delete(tunnelName string) error {
+ if s.elevatedToken == 0 {
+ return windows.ERROR_ACCESS_DENIED
+ }
err := s.Stop(tunnelName)
if err != nil {
return err
@@ -157,7 +162,7 @@ func (s *ManagerService) Delete(tunnelName string) error {
}
func (s *ManagerService) State(tunnelName string) (TunnelState, error) {
- serviceName, err := services.ServiceNameOfTunnel(tunnelName)
+ serviceName, err := conf.ServiceNameOfTunnel(tunnelName)
if err != nil {
return 0, err
}
@@ -193,7 +198,10 @@ func (s *ManagerService) GlobalState() TunnelState {
}
func (s *ManagerService) Create(tunnelConfig *conf.Config) (*Tunnel, error) {
- err := tunnelConfig.Save()
+ if s.elevatedToken == 0 {
+ return nil, windows.ERROR_ACCESS_DENIED
+ }
+ err := tunnelConfig.Save(true)
if err != nil {
return nil, err
}
@@ -209,19 +217,25 @@ func (s *ManagerService) Tunnels() ([]Tunnel, error) {
}
tunnels := make([]Tunnel, len(names))
for i := 0; i < len(tunnels); i++ {
- (tunnels)[i].Name = names[i]
+ tunnels[i].Name = names[i]
}
return tunnels, nil
// TODO: account for running ones that aren't in the configuration store somehow
}
func (s *ManagerService) Quit(stopTunnelsOnQuit bool) (alreadyQuit bool, err error) {
+ if s.elevatedToken == 0 {
+ return false, windows.ERROR_ACCESS_DENIED
+ }
if !atomic.CompareAndSwapUint32(&haveQuit, 0, 1) {
return true, nil
}
// Work around potential race condition of delivering messages to the wrong process by removing from notifications.
managerServicesLock.Lock()
+ s.eventLock.Lock()
+ s.events = nil
+ s.eventLock.Unlock()
delete(managerServices, s)
managerServicesLock.Unlock()
@@ -244,6 +258,9 @@ func (s *ManagerService) UpdateState() UpdateState {
}
func (s *ManagerService) Update() {
+ if s.elevatedToken == 0 {
+ return
+ }
progress := updater.DownloadVerifyAndExecute(uintptr(s.elevatedToken))
go func() {
for {
@@ -374,6 +391,9 @@ func (s *ManagerService) ServeConn(reader io.Reader, writer io.Writer) {
return
}
tunnel, retErr := s.Create(&config)
+ if tunnel == nil {
+ tunnel = &Tunnel{}
+ }
err = encoder.Encode(tunnel)
if err != nil {
return
@@ -421,26 +441,27 @@ func (s *ManagerService) ServeConn(reader io.Reader, writer io.Writer) {
}
}
-func IPCServerListen(reader *os.File, writer *os.File, events *os.File, elevatedToken windows.Token) {
+func IPCServerListen(reader, writer, events *os.File, elevatedToken windows.Token) {
service := &ManagerService{
events: events,
elevatedToken: elevatedToken,
}
go func() {
- defer printPanic()
managerServicesLock.Lock()
managerServices[service] = true
managerServicesLock.Unlock()
service.ServeConn(reader, writer)
managerServicesLock.Lock()
+ service.eventLock.Lock()
+ service.events = nil
+ service.eventLock.Unlock()
delete(managerServices, service)
managerServicesLock.Unlock()
-
}()
}
-func notifyAll(notificationType NotificationType, ifaces ...interface{}) {
+func notifyAll(notificationType NotificationType, adminOnly bool, ifaces ...any) {
if len(managerServices) == 0 {
return
}
@@ -460,8 +481,17 @@ func notifyAll(notificationType NotificationType, ifaces ...interface{}) {
managerServicesLock.RLock()
for m := range managerServices {
- m.events.SetWriteDeadline(time.Now().Add(time.Second))
- m.events.Write(buf.Bytes())
+ if m.elevatedToken == 0 && adminOnly {
+ continue
+ }
+ go func(m *ManagerService) {
+ m.eventLock.Lock()
+ defer m.eventLock.Unlock()
+ if m.events != nil {
+ m.events.SetWriteDeadline(time.Now().Add(time.Second))
+ m.events.Write(buf.Bytes())
+ }
+ }(m)
}
managerServicesLock.RUnlock()
}
@@ -474,22 +504,22 @@ func errToString(err error) string {
}
func IPCServerNotifyTunnelChange(name string, state TunnelState, err error) {
- notifyAll(TunnelChangeNotificationType, name, state, trackedTunnelsGlobalState(), errToString(err))
+ notifyAll(TunnelChangeNotificationType, false, name, state, trackedTunnelsGlobalState(), errToString(err))
}
func IPCServerNotifyTunnelsChange() {
- notifyAll(TunnelsChangeNotificationType)
+ notifyAll(TunnelsChangeNotificationType, false)
}
func IPCServerNotifyUpdateFound(state UpdateState) {
- notifyAll(UpdateFoundNotificationType, state)
+ notifyAll(UpdateFoundNotificationType, false, state)
}
func IPCServerNotifyUpdateProgress(dp updater.DownloadProgress) {
- notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, errToString(dp.Error), dp.Complete)
+ notifyAll(UpdateProgressNotificationType, true, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, errToString(dp.Error), dp.Complete)
}
func IPCServerNotifyManagerStopping() {
- notifyAll(ManagerStoppingNotificationType)
+ notifyAll(ManagerStoppingNotificationType, false)
time.Sleep(time.Millisecond * 200)
}