aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/service
diff options
context:
space:
mode:
Diffstat (limited to 'service')
-rw-r--r--service/ipc_client.go67
-rw-r--r--service/ipc_event.go30
-rw-r--r--service/ipc_pipe.go32
-rw-r--r--service/ipc_server.go59
-rw-r--r--service/mksyscall.go2
-rw-r--r--service/service_manager.go7
-rw-r--r--service/zsyscall_windows.go36
7 files changed, 121 insertions, 112 deletions
diff --git a/service/ipc_client.go b/service/ipc_client.go
index 25575014..c3d08897 100644
--- a/service/ipc_client.go
+++ b/service/ipc_client.go
@@ -6,7 +6,7 @@
package service
import (
- "golang.org/x/sys/windows"
+ "encoding/gob"
"golang.zx2c4.com/wireguard/windows/conf"
"net/rpc"
"os"
@@ -27,10 +27,54 @@ const (
TunnelDeleting
)
+type NotificationType int
+
+const (
+ TunnelChangeNotificationType NotificationType = iota
+ TunnelsChangeNotificationType
+)
+
var rpcClient *rpc.Client
-func InitializeIPCClient(reader *os.File, writer *os.File) {
+type tunnelChangeCallback struct {
+ cb func(tunnel string)
+}
+
+var tunnelChangeCallbacks = make(map[*tunnelChangeCallback]bool)
+
+type tunnelsChangeCallback struct {
+ cb func()
+}
+
+var tunnelsChangeCallbacks = make(map[*tunnelsChangeCallback]bool)
+
+func InitializeIPCClient(reader *os.File, writer *os.File, events *os.File) {
rpcClient = rpc.NewClient(&pipeRWC{reader, writer})
+ go func() {
+ decoder := gob.NewDecoder(events)
+ for {
+ var notificationType NotificationType
+ err := decoder.Decode(&notificationType)
+ if err != nil {
+ return
+ }
+ switch notificationType {
+ case TunnelChangeNotificationType:
+ var tunnel string
+ err := decoder.Decode(&tunnel)
+ if err != nil || len(tunnel) == 0 {
+ continue
+ }
+ for cb := range tunnelChangeCallbacks {
+ cb.cb(tunnel)
+ }
+ case TunnelsChangeNotificationType:
+ for cb := range tunnelsChangeCallbacks {
+ cb.cb()
+ }
+ }
+ }
+ }()
}
func (t *Tunnel) StoredConfig() (c conf.Config, err error) {
@@ -78,10 +122,19 @@ func IPCClientQuit(stopTunnelsOnQuit bool) (bool, error) {
return alreadyQuit, rpcClient.Call("ManagerService.Quit", stopTunnelsOnQuit, &alreadyQuit)
}
-func IPCClientRegisterAsNotificationThread() error {
- return rpcClient.Call("ManagerService.RegisterAsNotificationThread", windows.GetCurrentThreadId(), nil)
+func IPCClientRegisterTunnelChange(cb func(tunnel string)) *tunnelChangeCallback {
+ s := &tunnelChangeCallback{cb}
+ tunnelChangeCallbacks[s] = true
+ return s
}
-
-func IPCClientUnregisterAsNotificationThread() error {
- return rpcClient.Call("ManagerService.UnregisterAsNotificationThread", windows.GetCurrentThreadId(), nil)
+func IPCClientUnregisterTunnelChange(cb *tunnelChangeCallback) {
+ delete(tunnelChangeCallbacks, cb)
+}
+func IPCClientRegisterTunnelsChange(cb func()) *tunnelsChangeCallback {
+ s := &tunnelsChangeCallback{cb}
+ tunnelsChangeCallbacks[s] = true
+ return s
+}
+func IPCClientUnregisterTunnelsChange(cb *tunnelsChangeCallback) {
+ delete(tunnelsChangeCallbacks, cb)
}
diff --git a/service/ipc_event.go b/service/ipc_event.go
deleted file mode 100644
index f56f289d..00000000
--- a/service/ipc_event.go
+++ /dev/null
@@ -1,30 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package service
-
-import "golang.org/x/sys/windows"
-
-//sys registerWindowMessage(name *uint16) (message uint, err error) = user32.RegisterWindowMessageW
-
-var (
- tunnelsChangedMessage uint
- tunnelChangedMessage uint
-)
-func IPCRegisterEventMessages() error {
- m, err := registerWindowMessage(windows.StringToUTF16Ptr("WireGuard Manager Event - Tunnels Changed"))
- if err != nil {
- return err
- }
- tunnelsChangedMessage = m
-
- m, err = registerWindowMessage(windows.StringToUTF16Ptr("WireGuard Manager Event - Tunnel Changed"))
- if err != nil {
- return err
- }
- tunnelChangedMessage = m
-
- return nil
-}
diff --git a/service/ipc_pipe.go b/service/ipc_pipe.go
index ee63f2d4..00f54bf7 100644
--- a/service/ipc_pipe.go
+++ b/service/ipc_pipe.go
@@ -33,26 +33,44 @@ func (p *pipeRWC) Close() error {
return err2
}
-func inheritableSocketpairEmulation() (ourReader *os.File, theirReader *os.File, theirReaderStr string, ourWriter *os.File, theirWriter *os.File, theirWriterStr string, err error) {
- ourReader, theirWriter, err = os.Pipe()
+func makeInheritableAndGetStr(f *os.File) (str string, err error) {
+ sc, err := f.SyscallConn()
if err != nil {
return
}
- err = windows.SetHandleInformation(windows.Handle(theirWriter.Fd()), windows.HANDLE_FLAG_INHERIT, windows.HANDLE_FLAG_INHERIT)
+ err2 := sc.Control(func(fd uintptr) {
+ err = windows.SetHandleInformation(windows.Handle(fd), windows.HANDLE_FLAG_INHERIT, windows.HANDLE_FLAG_INHERIT)
+ str = strconv.FormatUint(uint64(fd), 10)
+ })
+ if err2 != nil {
+ err = err2
+ }
+ return
+}
+
+func inheritableEvents() (ourEvents *os.File, theirEvents *os.File, theirEventStr string, err error) {
+ theirEvents, ourEvents, err = os.Pipe()
if err != nil {
return
}
- theirWriterStr = strconv.FormatUint(uint64(theirWriter.Fd()), 10)
+ theirEventStr, err = makeInheritableAndGetStr(theirEvents)
+ return
+}
- theirReader, ourWriter, err = os.Pipe()
+func inheritableSocketpairEmulation() (ourReader *os.File, theirReader *os.File, theirReaderStr string, ourWriter *os.File, theirWriter *os.File, theirWriterStr string, err error) {
+ ourReader, theirWriter, err = os.Pipe()
if err != nil {
return
}
- err = windows.SetHandleInformation(windows.Handle(theirReader.Fd()), windows.HANDLE_FLAG_INHERIT, windows.HANDLE_FLAG_INHERIT)
+ theirWriterStr, err = makeInheritableAndGetStr(theirWriter)
if err != nil {
return
}
- theirReaderStr = strconv.FormatUint(uint64(theirReader.Fd()), 10)
+ theirReader, ourWriter, err = os.Pipe()
+ if err != nil {
+ return
+ }
+ theirReaderStr, err = makeInheritableAndGetStr(theirReader)
return
}
diff --git a/service/ipc_server.go b/service/ipc_server.go
index a2a4c9ee..73c2916e 100644
--- a/service/ipc_server.go
+++ b/service/ipc_server.go
@@ -6,13 +6,15 @@
package service
import (
+ "bytes"
+ "encoding/gob"
"errors"
- "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/conf"
"net/rpc"
"os"
"sync"
"sync/atomic"
+ "time"
)
var managerServices = make(map[*ManagerService]bool)
@@ -21,8 +23,7 @@ var haveQuit uint32
var quitManagersChan = make(chan struct{}, 1)
type ManagerService struct {
- notifierHandles map[windows.Handle]bool
- notifierHandlesLock sync.RWMutex
+ events *os.File
}
func (s *ManagerService) StoredConfig(tunnelName string, config *conf.Config) error {
@@ -127,22 +128,8 @@ func (s *ManagerService) Quit(stopTunnelsOnQuit bool, alreadyQuit *bool) error {
return nil
}
-func (s *ManagerService) RegisterAsNotificationThread(handle windows.Handle, unused *uintptr) error {
- s.notifierHandlesLock.Lock()
- s.notifierHandles[handle] = true
- s.notifierHandlesLock.Unlock()
- return nil
-}
-
-func (s *ManagerService) UnregisterAsNotificationThread(handle windows.Handle, unused *uintptr) error {
- s.notifierHandlesLock.Lock()
- delete(s.notifierHandles, handle)
- s.notifierHandlesLock.Unlock()
- return nil
-}
-
-func IPCServerListen(reader *os.File, writer *os.File) error {
- service := &ManagerService{notifierHandles: make(map[windows.Handle]bool)}
+func IPCServerListen(reader *os.File, writer *os.File, events *os.File) error {
+ service := &ManagerService{events: events}
server := rpc.NewServer()
err := server.Register(service)
@@ -163,28 +150,34 @@ func IPCServerListen(reader *os.File, writer *os.File) error {
return nil
}
-//sys postMessage(hwnd windows.Handle, msg uint, wparam uintptr, lparam uintptr) (err error) = user32.PostMessageW
+func notifyAll(notificationType NotificationType, iface interface{}) {
+ var buf bytes.Buffer
+ encoder := gob.NewEncoder(&buf)
+ err := encoder.Encode(notificationType)
+ if err != nil {
+ return
+ }
+ if iface != nil {
+ err = encoder.Encode(iface)
+ if err != nil {
+ return
+ }
+ }
-func notifyAll(f func(handle windows.Handle)) {
managerServicesLock.RLock()
- for m, _ := range managerServices {
- m.notifierHandlesLock.RLock()
- for handle, _ := range m.notifierHandles {
- f(handle)
- }
- m.notifierHandlesLock.RUnlock()
+ for m := range managerServices {
+ go func() {
+ m.events.SetWriteDeadline(time.Now().Add(time.Second))
+ m.events.Write(buf.Bytes())
+ }()
}
managerServicesLock.RUnlock()
}
func IPCServerNotifyTunnelChange(name string) {
- notifyAll(func(handle windows.Handle) {
- //TODO: postthreadmessage
- })
+ notifyAll(TunnelChangeNotificationType, name)
}
func IPCServerNotifyTunnelsChange() {
- notifyAll(func(handle windows.Handle) {
- postMessage(handle, tunnelsChangedMessage, 0, 0)
- })
+ notifyAll(TunnelsChangeNotificationType, nil)
}
diff --git a/service/mksyscall.go b/service/mksyscall.go
index f80b9d1a..e96640aa 100644
--- a/service/mksyscall.go
+++ b/service/mksyscall.go
@@ -5,4 +5,4 @@
package service
-//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go service_manager.go ipc_server.go ipc_event.go
+//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go service_manager.go
diff --git a/service/service_manager.go b/service/service_manager.go
index 9f529bd1..eb28e833 100644
--- a/service/service_manager.go
+++ b/service/service_manager.go
@@ -180,7 +180,8 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
elog.Error(1, "Unable to create two inheritable pipes: "+err.Error())
return
}
- err = IPCServerListen(ourReader, ourWriter)
+ ourEvents, theirEvents, theirEventStr, err := inheritableEvents()
+ err = IPCServerListen(ourReader, ourWriter, ourEvents)
if err != nil {
elog.Error(1, "Unable to listen on IPC pipes: "+err.Error())
return
@@ -193,9 +194,10 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
},
Files: []*os.File{devNull, devNull, devNull},
}
- proc, err := os.StartProcess(path, []string{path, "/ui", theirReaderStr, theirWriterStr}, attr)
+ proc, err := os.StartProcess(path, []string{path, "/ui", theirReaderStr, theirWriterStr, theirEventStr}, attr)
theirReader.Close()
theirWriter.Close()
+ theirEvents.Close()
if err != nil {
elog.Error(1, "Unable to start manager UI process: "+err.Error())
return
@@ -210,6 +212,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
procsLock.Unlock()
ourReader.Close()
ourWriter.Close()
+ ourEvents.Close()
}
}
diff --git a/service/zsyscall_windows.go b/service/zsyscall_windows.go
index 79d4ccf5..649e3581 100644
--- a/service/zsyscall_windows.go
+++ b/service/zsyscall_windows.go
@@ -39,14 +39,11 @@ func errnoErr(e syscall.Errno) error {
var (
modwtsapi32 = windows.NewLazySystemDLL("wtsapi32.dll")
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
- moduser32 = windows.NewLazySystemDLL("user32.dll")
- procWTSQueryUserToken = modwtsapi32.NewProc("WTSQueryUserToken")
- procWTSEnumerateSessionsW = modwtsapi32.NewProc("WTSEnumerateSessionsW")
- procWTSFreeMemory = modwtsapi32.NewProc("WTSFreeMemory")
- procCreateWellKnownSid = modadvapi32.NewProc("CreateWellKnownSid")
- procPostMessageW = moduser32.NewProc("PostMessageW")
- procRegisterWindowMessageW = moduser32.NewProc("RegisterWindowMessageW")
+ procWTSQueryUserToken = modwtsapi32.NewProc("WTSQueryUserToken")
+ procWTSEnumerateSessionsW = modwtsapi32.NewProc("WTSEnumerateSessionsW")
+ procWTSFreeMemory = modwtsapi32.NewProc("WTSFreeMemory")
+ procCreateWellKnownSid = modadvapi32.NewProc("CreateWellKnownSid")
)
func wtfQueryUserToken(session uint32, token *windows.Token) (err error) {
@@ -89,28 +86,3 @@ func createWellKnownSid(sidType wellKnownSidType, domainSid *windows.SID, sid *w
}
return
}
-
-func postMessage(hwnd windows.Handle, msg uint, wparam uintptr, lparam uintptr) (err error) {
- r1, _, e1 := syscall.Syscall6(procPostMessageW.Addr(), 4, uintptr(hwnd), uintptr(msg), uintptr(wparam), uintptr(lparam), 0, 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func registerWindowMessage(name *uint16) (message uint, err error) {
- r0, _, e1 := syscall.Syscall(procRegisterWindowMessageW.Addr(), 1, uintptr(unsafe.Pointer(name)), 0, 0)
- message = uint(r0)
- if message == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}