From 0712ec69c70d5065447f7ffdbd907e3d7ae50ae9 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Thu, 28 Feb 2019 03:58:43 +0100 Subject: ipc: implement event system with pipes Also use Go 1.12's Sysconn Signed-off-by: Jason A. Donenfeld --- main.go | 10 +++++-- service/ipc_client.go | 67 ++++++++++++++++++++++++++++++++++++++++----- service/ipc_event.go | 30 -------------------- service/ipc_pipe.go | 32 +++++++++++++++++----- service/ipc_server.go | 59 ++++++++++++++++++--------------------- service/mksyscall.go | 2 +- service/service_manager.go | 7 +++-- service/zsyscall_windows.go | 36 +++--------------------- ui/ui.go | 9 +++--- 9 files changed, 132 insertions(+), 120 deletions(-) delete mode 100644 service/ipc_event.go diff --git a/main.go b/main.go index d377bb58..2620b0ce 100644 --- a/main.go +++ b/main.go @@ -25,7 +25,7 @@ var flags = [...]string{ "/managerservice", "/tunnelservice CONFIG_PATH", "/tunneldebug CONFIG_PATH", - "/ui CMD_READ_HANDLE CMD_WRITE_HANDLE", + "/ui CMD_READ_HANDLE CMD_WRITE_HANDLE CMD_EVENT_HANDLE", } //sys messageBoxEx(hwnd windows.Handle, text *uint16, title *uint16, typ uint, languageId uint16) = user32.MessageBoxExW @@ -146,7 +146,7 @@ func main() { } return case "/ui": - if len(os.Args) != 4 { + if len(os.Args) != 5 { usage() } readPipe, err := pipeFromHandleArgument(os.Args[2]) @@ -157,7 +157,11 @@ func main() { if err != nil { fatal(err) } - service.InitializeIPCClient(readPipe, writePipe) + eventPipe, err := pipeFromHandleArgument(os.Args[4]) + if err != nil { + fatal(err) + } + service.InitializeIPCClient(readPipe, writePipe, eventPipe) ui.RunUI() return } diff --git a/service/ipc_client.go b/service/ipc_client.go index 25575014..c3d08897 100644 --- a/service/ipc_client.go +++ b/service/ipc_client.go @@ -6,7 +6,7 @@ package service import ( - "golang.org/x/sys/windows" + "encoding/gob" "golang.zx2c4.com/wireguard/windows/conf" "net/rpc" "os" @@ -27,10 +27,54 @@ const ( TunnelDeleting ) +type NotificationType int + +const ( + TunnelChangeNotificationType NotificationType = iota + TunnelsChangeNotificationType +) + var rpcClient *rpc.Client -func InitializeIPCClient(reader *os.File, writer *os.File) { +type tunnelChangeCallback struct { + cb func(tunnel string) +} + +var tunnelChangeCallbacks = make(map[*tunnelChangeCallback]bool) + +type tunnelsChangeCallback struct { + cb func() +} + +var tunnelsChangeCallbacks = make(map[*tunnelsChangeCallback]bool) + +func InitializeIPCClient(reader *os.File, writer *os.File, events *os.File) { rpcClient = rpc.NewClient(&pipeRWC{reader, writer}) + go func() { + decoder := gob.NewDecoder(events) + for { + var notificationType NotificationType + err := decoder.Decode(¬ificationType) + if err != nil { + return + } + switch notificationType { + case TunnelChangeNotificationType: + var tunnel string + err := decoder.Decode(&tunnel) + if err != nil || len(tunnel) == 0 { + continue + } + for cb := range tunnelChangeCallbacks { + cb.cb(tunnel) + } + case TunnelsChangeNotificationType: + for cb := range tunnelsChangeCallbacks { + cb.cb() + } + } + } + }() } func (t *Tunnel) StoredConfig() (c conf.Config, err error) { @@ -78,10 +122,19 @@ func IPCClientQuit(stopTunnelsOnQuit bool) (bool, error) { return alreadyQuit, rpcClient.Call("ManagerService.Quit", stopTunnelsOnQuit, &alreadyQuit) } -func IPCClientRegisterAsNotificationThread() error { - return rpcClient.Call("ManagerService.RegisterAsNotificationThread", windows.GetCurrentThreadId(), nil) +func IPCClientRegisterTunnelChange(cb func(tunnel string)) *tunnelChangeCallback { + s := &tunnelChangeCallback{cb} + tunnelChangeCallbacks[s] = true + return s } - -func IPCClientUnregisterAsNotificationThread() error { - return rpcClient.Call("ManagerService.UnregisterAsNotificationThread", windows.GetCurrentThreadId(), nil) +func IPCClientUnregisterTunnelChange(cb *tunnelChangeCallback) { + delete(tunnelChangeCallbacks, cb) +} +func IPCClientRegisterTunnelsChange(cb func()) *tunnelsChangeCallback { + s := &tunnelsChangeCallback{cb} + tunnelsChangeCallbacks[s] = true + return s +} +func IPCClientUnregisterTunnelsChange(cb *tunnelsChangeCallback) { + delete(tunnelsChangeCallbacks, cb) } diff --git a/service/ipc_event.go b/service/ipc_event.go deleted file mode 100644 index f56f289d..00000000 --- a/service/ipc_event.go +++ /dev/null @@ -1,30 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package service - -import "golang.org/x/sys/windows" - -//sys registerWindowMessage(name *uint16) (message uint, err error) = user32.RegisterWindowMessageW - -var ( - tunnelsChangedMessage uint - tunnelChangedMessage uint -) -func IPCRegisterEventMessages() error { - m, err := registerWindowMessage(windows.StringToUTF16Ptr("WireGuard Manager Event - Tunnels Changed")) - if err != nil { - return err - } - tunnelsChangedMessage = m - - m, err = registerWindowMessage(windows.StringToUTF16Ptr("WireGuard Manager Event - Tunnel Changed")) - if err != nil { - return err - } - tunnelChangedMessage = m - - return nil -} diff --git a/service/ipc_pipe.go b/service/ipc_pipe.go index ee63f2d4..00f54bf7 100644 --- a/service/ipc_pipe.go +++ b/service/ipc_pipe.go @@ -33,26 +33,44 @@ func (p *pipeRWC) Close() error { return err2 } -func inheritableSocketpairEmulation() (ourReader *os.File, theirReader *os.File, theirReaderStr string, ourWriter *os.File, theirWriter *os.File, theirWriterStr string, err error) { - ourReader, theirWriter, err = os.Pipe() +func makeInheritableAndGetStr(f *os.File) (str string, err error) { + sc, err := f.SyscallConn() if err != nil { return } - err = windows.SetHandleInformation(windows.Handle(theirWriter.Fd()), windows.HANDLE_FLAG_INHERIT, windows.HANDLE_FLAG_INHERIT) + err2 := sc.Control(func(fd uintptr) { + err = windows.SetHandleInformation(windows.Handle(fd), windows.HANDLE_FLAG_INHERIT, windows.HANDLE_FLAG_INHERIT) + str = strconv.FormatUint(uint64(fd), 10) + }) + if err2 != nil { + err = err2 + } + return +} + +func inheritableEvents() (ourEvents *os.File, theirEvents *os.File, theirEventStr string, err error) { + theirEvents, ourEvents, err = os.Pipe() if err != nil { return } - theirWriterStr = strconv.FormatUint(uint64(theirWriter.Fd()), 10) + theirEventStr, err = makeInheritableAndGetStr(theirEvents) + return +} - theirReader, ourWriter, err = os.Pipe() +func inheritableSocketpairEmulation() (ourReader *os.File, theirReader *os.File, theirReaderStr string, ourWriter *os.File, theirWriter *os.File, theirWriterStr string, err error) { + ourReader, theirWriter, err = os.Pipe() if err != nil { return } - err = windows.SetHandleInformation(windows.Handle(theirReader.Fd()), windows.HANDLE_FLAG_INHERIT, windows.HANDLE_FLAG_INHERIT) + theirWriterStr, err = makeInheritableAndGetStr(theirWriter) if err != nil { return } - theirReaderStr = strconv.FormatUint(uint64(theirReader.Fd()), 10) + theirReader, ourWriter, err = os.Pipe() + if err != nil { + return + } + theirReaderStr, err = makeInheritableAndGetStr(theirReader) return } diff --git a/service/ipc_server.go b/service/ipc_server.go index a2a4c9ee..73c2916e 100644 --- a/service/ipc_server.go +++ b/service/ipc_server.go @@ -6,13 +6,15 @@ package service import ( + "bytes" + "encoding/gob" "errors" - "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/conf" "net/rpc" "os" "sync" "sync/atomic" + "time" ) var managerServices = make(map[*ManagerService]bool) @@ -21,8 +23,7 @@ var haveQuit uint32 var quitManagersChan = make(chan struct{}, 1) type ManagerService struct { - notifierHandles map[windows.Handle]bool - notifierHandlesLock sync.RWMutex + events *os.File } func (s *ManagerService) StoredConfig(tunnelName string, config *conf.Config) error { @@ -127,22 +128,8 @@ func (s *ManagerService) Quit(stopTunnelsOnQuit bool, alreadyQuit *bool) error { return nil } -func (s *ManagerService) RegisterAsNotificationThread(handle windows.Handle, unused *uintptr) error { - s.notifierHandlesLock.Lock() - s.notifierHandles[handle] = true - s.notifierHandlesLock.Unlock() - return nil -} - -func (s *ManagerService) UnregisterAsNotificationThread(handle windows.Handle, unused *uintptr) error { - s.notifierHandlesLock.Lock() - delete(s.notifierHandles, handle) - s.notifierHandlesLock.Unlock() - return nil -} - -func IPCServerListen(reader *os.File, writer *os.File) error { - service := &ManagerService{notifierHandles: make(map[windows.Handle]bool)} +func IPCServerListen(reader *os.File, writer *os.File, events *os.File) error { + service := &ManagerService{events: events} server := rpc.NewServer() err := server.Register(service) @@ -163,28 +150,34 @@ func IPCServerListen(reader *os.File, writer *os.File) error { return nil } -//sys postMessage(hwnd windows.Handle, msg uint, wparam uintptr, lparam uintptr) (err error) = user32.PostMessageW +func notifyAll(notificationType NotificationType, iface interface{}) { + var buf bytes.Buffer + encoder := gob.NewEncoder(&buf) + err := encoder.Encode(notificationType) + if err != nil { + return + } + if iface != nil { + err = encoder.Encode(iface) + if err != nil { + return + } + } -func notifyAll(f func(handle windows.Handle)) { managerServicesLock.RLock() - for m, _ := range managerServices { - m.notifierHandlesLock.RLock() - for handle, _ := range m.notifierHandles { - f(handle) - } - m.notifierHandlesLock.RUnlock() + for m := range managerServices { + go func() { + m.events.SetWriteDeadline(time.Now().Add(time.Second)) + m.events.Write(buf.Bytes()) + }() } managerServicesLock.RUnlock() } func IPCServerNotifyTunnelChange(name string) { - notifyAll(func(handle windows.Handle) { - //TODO: postthreadmessage - }) + notifyAll(TunnelChangeNotificationType, name) } func IPCServerNotifyTunnelsChange() { - notifyAll(func(handle windows.Handle) { - postMessage(handle, tunnelsChangedMessage, 0, 0) - }) + notifyAll(TunnelsChangeNotificationType, nil) } diff --git a/service/mksyscall.go b/service/mksyscall.go index f80b9d1a..e96640aa 100644 --- a/service/mksyscall.go +++ b/service/mksyscall.go @@ -5,4 +5,4 @@ package service -//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go service_manager.go ipc_server.go ipc_event.go +//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go service_manager.go diff --git a/service/service_manager.go b/service/service_manager.go index 9f529bd1..eb28e833 100644 --- a/service/service_manager.go +++ b/service/service_manager.go @@ -180,7 +180,8 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest elog.Error(1, "Unable to create two inheritable pipes: "+err.Error()) return } - err = IPCServerListen(ourReader, ourWriter) + ourEvents, theirEvents, theirEventStr, err := inheritableEvents() + err = IPCServerListen(ourReader, ourWriter, ourEvents) if err != nil { elog.Error(1, "Unable to listen on IPC pipes: "+err.Error()) return @@ -193,9 +194,10 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest }, Files: []*os.File{devNull, devNull, devNull}, } - proc, err := os.StartProcess(path, []string{path, "/ui", theirReaderStr, theirWriterStr}, attr) + proc, err := os.StartProcess(path, []string{path, "/ui", theirReaderStr, theirWriterStr, theirEventStr}, attr) theirReader.Close() theirWriter.Close() + theirEvents.Close() if err != nil { elog.Error(1, "Unable to start manager UI process: "+err.Error()) return @@ -210,6 +212,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest procsLock.Unlock() ourReader.Close() ourWriter.Close() + ourEvents.Close() } } diff --git a/service/zsyscall_windows.go b/service/zsyscall_windows.go index 79d4ccf5..649e3581 100644 --- a/service/zsyscall_windows.go +++ b/service/zsyscall_windows.go @@ -39,14 +39,11 @@ func errnoErr(e syscall.Errno) error { var ( modwtsapi32 = windows.NewLazySystemDLL("wtsapi32.dll") modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") - moduser32 = windows.NewLazySystemDLL("user32.dll") - procWTSQueryUserToken = modwtsapi32.NewProc("WTSQueryUserToken") - procWTSEnumerateSessionsW = modwtsapi32.NewProc("WTSEnumerateSessionsW") - procWTSFreeMemory = modwtsapi32.NewProc("WTSFreeMemory") - procCreateWellKnownSid = modadvapi32.NewProc("CreateWellKnownSid") - procPostMessageW = moduser32.NewProc("PostMessageW") - procRegisterWindowMessageW = moduser32.NewProc("RegisterWindowMessageW") + procWTSQueryUserToken = modwtsapi32.NewProc("WTSQueryUserToken") + procWTSEnumerateSessionsW = modwtsapi32.NewProc("WTSEnumerateSessionsW") + procWTSFreeMemory = modwtsapi32.NewProc("WTSFreeMemory") + procCreateWellKnownSid = modadvapi32.NewProc("CreateWellKnownSid") ) func wtfQueryUserToken(session uint32, token *windows.Token) (err error) { @@ -89,28 +86,3 @@ func createWellKnownSid(sidType wellKnownSidType, domainSid *windows.SID, sid *w } return } - -func postMessage(hwnd windows.Handle, msg uint, wparam uintptr, lparam uintptr) (err error) { - r1, _, e1 := syscall.Syscall6(procPostMessageW.Addr(), 4, uintptr(hwnd), uintptr(msg), uintptr(wparam), uintptr(lparam), 0, 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func registerWindowMessage(name *uint16) (message uint, err error) { - r0, _, e1 := syscall.Syscall(procRegisterWindowMessageW.Addr(), 1, uintptr(unsafe.Pointer(name)), 0, 0) - message = uint(r0) - if message == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} diff --git a/ui/ui.go b/ui/ui.go index 39d2ba23..abb528b3 100644 --- a/ui/ui.go +++ b/ui/ui.go @@ -132,10 +132,9 @@ func RunUI() { } }) - err := service.IPCClientRegisterAsNotificationThread() - if err != nil { - walk.MsgBox(mw, "Unable to register for notifications", err.Error(), walk.MsgBoxIconError) - os.Exit(1) - } + service.IPCClientRegisterTunnelChange(func(tunnel string) { + walk.MsgBox(mw, "Tunnel Changed", "The tunnel that changed is: "+tunnel, walk.MsgBoxIconInformation) + }) + mw.Run() } -- cgit v1.2.3-59-g8ed1b