diff options
Diffstat (limited to 'service')
-rw-r--r-- | service/errors.go | 3 | ||||
-rw-r--r-- | service/ipc_client.go | 78 | ||||
-rw-r--r-- | service/ipc_server.go | 58 | ||||
-rw-r--r-- | service/securityapi.go | 157 | ||||
-rw-r--r-- | service/service_manager.go | 182 | ||||
-rw-r--r-- | service/updatestate.go | 56 |
6 files changed, 386 insertions, 148 deletions
diff --git a/service/errors.go b/service/errors.go index ecc0283a..eb31bb4e 100644 --- a/service/errors.go +++ b/service/errors.go @@ -26,7 +26,6 @@ const ( ErrorBindSocketsToDefaultRoutes ErrorSetNetConfig ErrorDetermineExecutablePath - ErrorFindAdministratorsSID ErrorCreateSecurityDescriptor ErrorOpenNULFile ErrorTrackTunnels @@ -60,8 +59,6 @@ func (e Error) Error() string { return "Unable to bind sockets to default route" case ErrorSetNetConfig: return "Unable to set interface addresses, routes, dns, and/or adapter settings" - case ErrorFindAdministratorsSID: - return "Unable to find Administrators SID" case ErrorCreateSecurityDescriptor: return "Unable to determine security descriptor" case ErrorOpenNULFile: diff --git a/service/ipc_client.go b/service/ipc_client.go index 41e71f22..adfd456c 100644 --- a/service/ipc_client.go +++ b/service/ipc_client.go @@ -9,6 +9,7 @@ import ( "encoding/gob" "errors" "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/updater" "net/rpc" "os" ) @@ -33,6 +34,8 @@ const ( TunnelChangeNotificationType NotificationType = iota TunnelsChangeNotificationType ManagerStoppingNotificationType + UpdateFoundNotificationType + UpdateProgressNotificationType ) var rpcClient *rpc.Client @@ -55,6 +58,18 @@ type ManagerStoppingCallback struct { var managerStoppingCallbacks = make(map[*ManagerStoppingCallback]bool) +type UpdateFoundCallback struct { + cb func(updateState UpdateState) +} + +var updateFoundCallbacks = make(map[*UpdateFoundCallback]bool) + +type UpdateProgressCallback struct { + cb func(dp updater.DownloadProgress) +} + +var updateProgressCallbacks = make(map[*UpdateProgressCallback]bool) + func InitializeIPCClient(reader *os.File, writer *os.File, events *os.File) { rpcClient = rpc.NewClient(&pipeRWC{reader, writer}) go func() { @@ -106,6 +121,44 @@ func InitializeIPCClient(reader *os.File, writer *os.File, events *os.File) { for cb := range managerStoppingCallbacks { cb.cb() } + case UpdateFoundNotificationType: + var state UpdateState + err = decoder.Decode(&state) + if err != nil { + continue + } + for cb := range updateFoundCallbacks { + cb.cb(state) + } + case UpdateProgressNotificationType: + var dp updater.DownloadProgress + err = decoder.Decode(&dp.Activity) + if err != nil { + continue + } + err = decoder.Decode(&dp.BytesDownloaded) + if err != nil { + continue + } + err = decoder.Decode(&dp.BytesTotal) + if err != nil { + continue + } + var errStr string + err = decoder.Decode(&errStr) + if err != nil { + continue + } + if len(errStr) > 0 { + dp.Error = errors.New(errStr) + } + err = decoder.Decode(&dp.Complete) + if err != nil { + continue + } + for cb := range updateProgressCallbacks { + cb.cb(dp) + } } } }() @@ -176,6 +229,15 @@ func IPCClientQuit(stopTunnelsOnQuit bool) (bool, error) { return alreadyQuit, rpcClient.Call("ManagerService.Quit", stopTunnelsOnQuit, &alreadyQuit) } +func IPCClientUpdateState() (UpdateState, error) { + var state UpdateState + return state, rpcClient.Call("ManagerService.UpdateState", uintptr(0), &state) +} + +func IPCClientUpdate() error { + return rpcClient.Call("ManagerService.Update", uintptr(0), nil) +} + func IPCClientRegisterTunnelChange(cb func(tunnel *Tunnel, state TunnelState, globalState TunnelState, err error)) *TunnelChangeCallback { s := &TunnelChangeCallback{cb} tunnelChangeCallbacks[s] = true @@ -200,3 +262,19 @@ func IPCClientRegisterManagerStopping(cb func()) *ManagerStoppingCallback { func (cb *ManagerStoppingCallback) Unregister() { delete(managerStoppingCallbacks, cb) } +func IPCClientRegisterUpdateFound(cb func(updateState UpdateState)) *UpdateFoundCallback { + s := &UpdateFoundCallback{cb} + updateFoundCallbacks[s] = true + return s +} +func (cb *UpdateFoundCallback) Unregister() { + delete(updateFoundCallbacks, cb) +} +func IPCClientRegisterUpdateProgress(cb func(dp updater.DownloadProgress)) *UpdateProgressCallback { + s := &UpdateProgressCallback{cb} + updateProgressCallbacks[s] = true + return s +} +func (cb *UpdateProgressCallback) Unregister() { + delete(updateProgressCallbacks, cb) +} diff --git a/service/ipc_server.go b/service/ipc_server.go index 9ccea8ef..becad3ee 100644 --- a/service/ipc_server.go +++ b/service/ipc_server.go @@ -10,8 +10,10 @@ import ( "encoding/gob" "fmt" "github.com/Microsoft/go-winio" + "golang.org/x/sys/windows" "golang.org/x/sys/windows/svc" "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/updater" "io/ioutil" "log" "net/rpc" @@ -27,8 +29,14 @@ var managerServicesLock sync.RWMutex var haveQuit uint32 var quitManagersChan = make(chan struct{}, 1) +type UserTokenInfo struct { + elevatedToken windows.Token + elevatedEnvironment []string +} + type ManagerService struct { - events *os.File + events *os.File + userTokenInfo *UserTokenInfo } func (s *ManagerService) StoredConfig(tunnelName string, config *conf.Config) error { @@ -115,7 +123,7 @@ func (s *ManagerService) Start(tunnelName string, unused *uintptr) error { return InstallTunnel(path) } -func (s *ManagerService) Stop(tunnelName string, unused *uintptr) error { +func (s *ManagerService) Stop(tunnelName string, _ *uintptr) error { err := UninstallTunnel(tunnelName) if err == syscall.Errno(serviceDOES_NOT_EXIST) { _, notExistsError := conf.LoadFromName(tunnelName) @@ -126,7 +134,7 @@ func (s *ManagerService) Stop(tunnelName string, unused *uintptr) error { return err } -func (s *ManagerService) WaitForStop(tunnelName string, unused *uintptr) error { +func (s *ManagerService) WaitForStop(tunnelName string, _ *uintptr) error { serviceName, err := ServiceNameOfTunnel(tunnelName) if err != nil { return err @@ -146,7 +154,7 @@ func (s *ManagerService) WaitForStop(tunnelName string, unused *uintptr) error { } } -func (s *ManagerService) Delete(tunnelName string, unused *uintptr) error { +func (s *ManagerService) Delete(tunnelName string, _ *uintptr) error { err := s.Stop(tunnelName, nil) if err != nil { return err @@ -189,7 +197,7 @@ func (s *ManagerService) State(tunnelName string, state *TunnelState) error { return nil } -func (s *ManagerService) GlobalState(unused uintptr, state *TunnelState) error { +func (s *ManagerService) GlobalState(_ uintptr, state *TunnelState) error { *state = trackedTunnelsGlobalState() return nil } @@ -205,7 +213,7 @@ func (s *ManagerService) Create(tunnelConfig conf.Config, tunnel *Tunnel) error //TODO: handle already running and existing situation } -func (s *ManagerService) Tunnels(unused uintptr, tunnels *[]Tunnel) error { +func (s *ManagerService) Tunnels(_ uintptr, tunnels *[]Tunnel) error { names, err := conf.ListConfigNames() if err != nil { return err @@ -244,8 +252,30 @@ func (s *ManagerService) Quit(stopTunnelsOnQuit bool, alreadyQuit *bool) error { return nil } -func IPCServerListen(reader *os.File, writer *os.File, events *os.File) error { - service := &ManagerService{events: events} +func (s *ManagerService) UpdateState(_ uintptr, state *UpdateState) error { + *state = updateState + return nil +} + +func (s *ManagerService) Update(_ uintptr, _ *uintptr) error { + progress := updater.DownloadVerifyAndExecute(uintptr(s.userTokenInfo.elevatedToken), s.userTokenInfo.elevatedEnvironment) + go func() { + for { + dp := <-progress + IPCServerNotifyUpdateProgress(dp) + if dp.Complete || dp.Error != nil { + return + } + } + }() + return nil +} + +func IPCServerListen(reader *os.File, writer *os.File, events *os.File, userTokenInfo *UserTokenInfo) error { + service := &ManagerService{ + events: events, + userTokenInfo: userTokenInfo, + } server := rpc.NewServer() err := server.Register(service) @@ -304,6 +334,18 @@ func IPCServerNotifyTunnelsChange() { notifyAll(TunnelsChangeNotificationType) } +func IPCServerNotifyUpdateFound(state UpdateState) { + notifyAll(UpdateFoundNotificationType, state) +} + +func IPCServerNotifyUpdateProgress(dp updater.DownloadProgress) { + if dp.Error == nil { + notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, "", dp.Complete) + } else { + notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, dp.Error.Error(), dp.Complete) + } +} + func IPCServerNotifyManagerStopping() { notifyAll(ManagerStoppingNotificationType) time.Sleep(time.Millisecond * 200) diff --git a/service/securityapi.go b/service/securityapi.go new file mode 100644 index 00000000..6c5f7844 --- /dev/null +++ b/service/securityapi.go @@ -0,0 +1,157 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package service + +import ( + "errors" + "golang.org/x/sys/windows" + "syscall" + "unicode/utf16" + "unsafe" +) + +const ( + wtsSessionLogon uint32 = 5 + wtsSessionLogoff uint32 = 6 +) + +type wtsState int + +const ( + wtsActive wtsState = iota + wtsConnected + wtsConnectQuery + wtsShadow + wtsDisconnected + wtsIdle + wtsListen + wtsReset + wtsDown + wtsInit +) + +type wtsSessionNotification struct { + size uint32 + sessionID uint32 +} + +type wtsSessionInfo struct { + sessionID uint32 + windowStationName *uint16 + state wtsState +} + +//sys wtsQueryUserToken(session uint32, token *windows.Token) (err error) = wtsapi32.WTSQueryUserToken +//sys wtsEnumerateSessions(handle windows.Handle, reserved uint32, version uint32, sessions **wtsSessionInfo, count *uint32) (err error) = wtsapi32.WTSEnumerateSessionsW +//sys wtsFreeMemory(ptr uintptr) = wtsapi32.WTSFreeMemory + +const ( + SE_KERNEL_OBJECT = 6 + + DACL_SECURITY_INFORMATION = 4 + ATTRIBUTE_SECURITY_INFORMATION = 16 +) + +//sys getSecurityInfo(handle windows.Handle, objectType uint32, si uint32, sidOwner *windows.SID, sidGroup *windows.SID, dacl *uintptr, sacl *uintptr, securityDescriptor *uintptr) (err error) [failretval!=0] = advapi32.GetSecurityInfo +//sys getSecurityDescriptorLength(securityDescriptor uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength + +//sys createEnvironmentBlock(block *uintptr, token windows.Token, inheritExisting bool) (err error) = userenv.CreateEnvironmentBlock +//sys destroyEnvironmentBlock(block uintptr) (err error) = userenv.DestroyEnvironmentBlock + +func userEnviron(token windows.Token) (env []string, err error) { + var block uintptr + err = createEnvironmentBlock(&block, token, false) + if err != nil { + return + } + offset := uintptr(0) + for { + entry := (*[(1 << 30) - 1]uint16)(unsafe.Pointer(block + offset))[:] + for i, v := range entry { + if v == 0 { + entry = entry[:i] + break + } + } + if len(entry) == 0 { + break + } + env = append(env, string(utf16.Decode(entry))) + offset += 2 * (uintptr(len(entry)) + 1) + } + destroyEnvironmentBlock(block) + return +} + +func tokenIsElevated(token windows.Token) bool { + var isElevated uint32 + var outLen uint32 + err := windows.GetTokenInformation(token, windows.TokenElevation, (*byte)(unsafe.Pointer(&isElevated)), uint32(unsafe.Sizeof(isElevated)), &outLen) + if err != nil { + return false + } + return outLen == uint32(unsafe.Sizeof(isElevated)) && isElevated != 0 + +} + +func getElevatedToken(token windows.Token) (windows.Token, error) { + if tokenIsElevated(token) { + return token, nil + } + var linkedToken windows.Token + var outLen uint32 + err := windows.GetTokenInformation(token, windows.TokenLinkedToken, (*byte)(unsafe.Pointer(&linkedToken)), uint32(unsafe.Sizeof(linkedToken)), &outLen) + if err != nil { + return windows.Token(0), err + } + if tokenIsElevated(linkedToken) { + return linkedToken, nil + } + linkedToken.Close() + return windows.Token(0), errors.New("the linked token is not elevated") +} + +func tokenIsMemberOfBuiltInAdministrator(token windows.Token) bool { + //TODO: SECURITY CRITICIAL! + //TODO: Isn't it better to use an impersonation token or userToken.IsMember instead? + adminSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid) + if err != nil { + return false + } + gs, err := token.GetTokenGroups() + if err != nil { + return false + } + p := unsafe.Pointer(&gs.Groups[0]) + groups := (*[(1 << 28) - 1]windows.SIDAndAttributes)(p)[:gs.GroupCount] + isAdmin := false + for _, g := range groups { + if windows.EqualSid(g.Sid, adminSid) { + isAdmin = true + break + } + } + return isAdmin +} + +func getCurrentSecurityAttributes() (*syscall.SecurityAttributes, error) { + currentProcess, err := windows.GetCurrentProcess() + if err != nil { + return nil, err + } + securityAttributes := &syscall.SecurityAttributes{} + err = getSecurityInfo(currentProcess, SE_KERNEL_OBJECT, DACL_SECURITY_INFORMATION, nil, nil, nil, nil, &securityAttributes.SecurityDescriptor) + if err != nil { + return nil, err + } + windows.LocalFree(windows.Handle(securityAttributes.SecurityDescriptor)) + securityAttributes.Length = getSecurityDescriptorLength(securityAttributes.SecurityDescriptor) + if securityAttributes.Length == 0 { + windows.LocalFree(windows.Handle(securityAttributes.SecurityDescriptor)) + return nil, err + } + return securityAttributes, nil +}
\ No newline at end of file diff --git a/service/service_manager.go b/service/service_manager.go index 715b4257..e8818ae4 100644 --- a/service/service_manager.go +++ b/service/service_manager.go @@ -16,82 +16,9 @@ import ( "runtime/debug" "sync" "syscall" - "unicode/utf16" "unsafe" ) -const ( - wtsSessionLogon uint32 = 5 - wtsSessionLogoff uint32 = 6 -) - -type wtsState int - -const ( - wtsActive wtsState = iota - wtsConnected - wtsConnectQuery - wtsShadow - wtsDisconnected - wtsIdle - wtsListen - wtsReset - wtsDown - wtsInit -) - -type wtsSessionNotification struct { - size uint32 - sessionID uint32 -} - -type wtsSessionInfo struct { - sessionID uint32 - windowStationName *uint16 - state wtsState -} - -//sys wtsQueryUserToken(session uint32, token *windows.Token) (err error) = wtsapi32.WTSQueryUserToken -//sys wtsEnumerateSessions(handle windows.Handle, reserved uint32, version uint32, sessions **wtsSessionInfo, count *uint32) (err error) = wtsapi32.WTSEnumerateSessionsW -//sys wtsFreeMemory(ptr uintptr) = wtsapi32.WTSFreeMemory - -const ( - SE_KERNEL_OBJECT = 6 - DACL_SECURITY_INFORMATION = 4 - ATTRIBUTE_SECURITY_INFORMATION = 16 -) - -//sys getSecurityInfo(handle windows.Handle, objectType uint32, si uint32, sidOwner *windows.SID, sidGroup *windows.SID, dacl *uintptr, sacl *uintptr, securityDescriptor *uintptr) (err error) [failretval!=0] = advapi32.GetSecurityInfo -//sys getSecurityDescriptorLength(securityDescriptor uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength - -//sys createEnvironmentBlock(block *uintptr, token windows.Token, inheritExisting bool) (err error) = userenv.CreateEnvironmentBlock -//sys destroyEnvironmentBlock(block uintptr) (err error) = userenv.DestroyEnvironmentBlock - -func userEnviron(token windows.Token) (env []string, err error) { - var block uintptr - err = createEnvironmentBlock(&block, token, false) - if err != nil { - return - } - offset := uintptr(0) - for { - entry := (*[(1 << 30) - 1]uint16)(unsafe.Pointer(block + offset))[:] - for i, v := range entry { - if v == 0 { - entry = entry[:i] - break - } - } - if len(entry) == 0 { - break - } - env = append(env, string(utf16.Decode(entry))) - offset += 2 * (uintptr(len(entry)) + 1) - } - destroyEnvironmentBlock(block) - return -} - type managerService struct{} func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) { @@ -126,29 +53,12 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest serviceError = ErrorDetermineExecutablePath return } - - adminSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid) - if err != nil { - serviceError = ErrorFindAdministratorsSID - return - } - - currentProcess, err := windows.GetCurrentProcess() - if err != nil { - panic(err) - } - var securityAttributes syscall.SecurityAttributes - err = getSecurityInfo(currentProcess, SE_KERNEL_OBJECT, DACL_SECURITY_INFORMATION, nil, nil, nil, nil, &securityAttributes.SecurityDescriptor) + securityAttributes, err := getCurrentSecurityAttributes() if err != nil { serviceError = ErrorCreateSecurityDescriptor return } defer windows.LocalFree(windows.Handle(securityAttributes.SecurityDescriptor)) - securityAttributes.Length = getSecurityDescriptorLength(securityAttributes.SecurityDescriptor) - if securityAttributes.Length == 0 { - serviceError = ErrorCreateSecurityDescriptor - return - } devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0) if err != nil { @@ -172,49 +82,51 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest startProcess = func(session uint32) { defer runtime.UnlockOSThread() + + var userToken windows.Token + err := wtsQueryUserToken(session, &userToken) + if err != nil { + return + } + defer userToken.Close() + if !tokenIsMemberOfBuiltInAdministrator(userToken) { + return + } + user, err := userToken.GetTokenUser() + if err != nil { + log.Printf("Unable to lookup user from token: %v", err) + return + } + username, domain, accType, err := user.User.Sid.LookupAccount("") + if err != nil { + log.Printf("Unable to lookup username from sid: %v", err) + return + } + if accType != windows.SidTypeUser { + return + } + env, err := userEnviron(userToken) + if err != nil { + log.Printf("Unable to determine user environment: %v", err) + return + } + userTokenInfo := &UserTokenInfo{} + userTokenInfo.elevatedToken, err = getElevatedToken(userToken) + if err != nil { + log.Printf("Unable to elevate token: %v", err) + } + if userTokenInfo.elevatedToken != userToken { + defer userTokenInfo.elevatedToken.Close() + } + userTokenInfo.elevatedEnvironment, err = userEnviron(userTokenInfo.elevatedToken) + if err != nil { + log.Printf("Unable to determine elevated environment: %v", err) + return + } for { if stoppingManager { return } - var userToken windows.Token - err := wtsQueryUserToken(session, &userToken) - if err != nil { - return - } - - //TODO: SECURITY CRITICIAL! - //TODO: Isn't it better to use an impersonation token and userToken.IsMember instead? - gs, err := userToken.GetTokenGroups() - if err != nil { - log.Printf("Unable to lookup user groups from token: %v", err) - return - } - p := unsafe.Pointer(&gs.Groups[0]) - groups := (*[(1 << 28) - 1]windows.SIDAndAttributes)(p)[:gs.GroupCount] - isAdmin := false - for _, g := range groups { - if windows.EqualSid(g.Sid, adminSid) { - isAdmin = true - break - } - } - if !isAdmin { - return - } - - user, err := userToken.GetTokenUser() - if err != nil { - log.Printf("Unable to lookup user from token: %v", err) - return - } - username, domain, accType, err := user.User.Sid.LookupAccount("") - if err != nil { - log.Printf("Unable to lookup username from sid: %v", err) - return - } - if accType != windows.SidTypeUser { - return - } //TODO: we lock the OS thread so that these inheritable handles don't escape into other processes that // might be running in parallel Go routines. But the Go runtime is strange and who knows what's really @@ -226,7 +138,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest return } ourEvents, theirEvents, theirEventStr, err := inheritableEvents() - err = IPCServerListen(ourReader, ourWriter, ourEvents) + err = IPCServerListen(ourReader, ourWriter, ourEvents, userTokenInfo) if err != nil { log.Printf("Unable to listen on IPC pipes: %v", err) return @@ -237,12 +149,6 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest return } - env, err := userEnviron(userToken) - if err != nil { - log.Printf("Unable to determine user environment: %v", err) - return - } - log.Printf("Starting UI process for user: '%s@%s'", username, domain) attr := &os.ProcAttr{ Sys: &syscall.SysProcAttr{ @@ -286,6 +192,8 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest } } + go checkForUpdates() + var sessionsPointer *wtsSessionInfo var count uint32 err = wtsEnumerateSessions(0, 0, 1, &sessionsPointer, &count) diff --git a/service/updatestate.go b/service/updatestate.go new file mode 100644 index 00000000..c046edb2 --- /dev/null +++ b/service/updatestate.go @@ -0,0 +1,56 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package service + +import ( + "golang.zx2c4.com/wireguard/windows/updater" + "golang.zx2c4.com/wireguard/windows/version" + "log" + "time" +) + +type UpdateState uint32 + +const ( + UpdateStateUnknown UpdateState = iota + UpdateStateFoundUpdate + UpdateStateUpdatesDisabledUnofficialBuild +) + +var updateState = UpdateStateUnknown + +func checkForUpdates() { + if !version.IsRunningOfficialVersion() { + log.Println("Build is not official, so updates are disabled") + updateState = UpdateStateUpdatesDisabledUnofficialBuild + IPCServerNotifyUpdateFound(updateState) + return + } + + time.Sleep(time.Second * 10) + + first := true + for { + update, err := updater.CheckForUpdate() + if err == nil && update != nil { + log.Println("An update is available") + updateState = UpdateStateFoundUpdate + IPCServerNotifyUpdateFound(updateState) + return + } + if err != nil { + log.Printf("Update checker: %v", err) + if first { + time.Sleep(time.Minute * 4) + first = false + } else { + time.Sleep(time.Minute * 25) + } + } else { + time.Sleep(time.Hour) + } + } +} |