aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2019-05-06 09:46:10 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2019-05-06 15:55:02 +0200
commitc3488b9382b08a7fde24955f6342403576847d6a (patch)
treedd08e02b053b0cf27a75ed2e428701b1799e3bf1
parentui: do tray click action when popup clicked (diff)
downloadwireguard-windows-c3488b9382b08a7fde24955f6342403576847d6a.tar.xz
wireguard-windows-c3488b9382b08a7fde24955f6342403576847d6a.zip
updater: move into manager
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to '')
-rw-r--r--attacksurface.md2
-rw-r--r--service/errors.go3
-rw-r--r--service/ipc_client.go78
-rw-r--r--service/ipc_server.go58
-rw-r--r--service/securityapi.go157
-rw-r--r--service/service_manager.go182
-rw-r--r--service/updatestate.go56
-rw-r--r--ui/ui.go45
-rw-r--r--ui/updatepage.go118
-rw-r--r--updater/downloader.go57
-rw-r--r--updater/msirunner_linux.go8
-rw-r--r--updater/msirunner_windows.go71
-rw-r--r--updater/updater_test.go2
-rw-r--r--version/debugging_linux.go14
14 files changed, 562 insertions, 289 deletions
diff --git a/attacksurface.md b/attacksurface.md
index f8b16bf1..1e65ff32 100644
--- a/attacksurface.md
+++ b/attacksurface.md
@@ -50,4 +50,4 @@ $ signify -S -e -s release.sec -m list
$ upload ./list.sec
```
-The MSIs in that list are only the latest ones available, and filenames fit the form `wireguard-${arch}-${version}.msi`. The updater downloads this list over TLS and verifies the signify Ed25519 signature of it. If it validates, then it finds the first MSI in it for its architecture that has a greater version. It then downloads this MSI from a predefined URL, and verifies the BLAKE2b-256 signature. If it validates, then it calls `WinTrustVerify(WINTRUST_ACTION_GENERIC_VERIFY_V2, WTD_REVOKE_WHOLECHAIN)` on the MSI. If it validates, then it executes the installer with `msiexec.exe /qb- /i`.
+The MSIs in that list are only the latest ones available, and filenames fit the form `wireguard-${arch}-${version}.msi`. The updater downloads this list over TLS and verifies the signify Ed25519 signature of it. If it validates, then it finds the first MSI in it for its architecture that has a greater version. It then downloads this MSI from a predefined URL, and verifies the BLAKE2b-256 signature. If it validates, then it calls `WinTrustVerify(WINTRUST_ACTION_GENERIC_VERIFY_V2, WTD_REVOKE_WHOLECHAIN)` on the MSI. If it validates, then it executes the installer with `msiexec.exe /qb!- /i`, using the elevated token linked to the IPC UI session that requested the update.
diff --git a/service/errors.go b/service/errors.go
index ecc0283a..eb31bb4e 100644
--- a/service/errors.go
+++ b/service/errors.go
@@ -26,7 +26,6 @@ const (
ErrorBindSocketsToDefaultRoutes
ErrorSetNetConfig
ErrorDetermineExecutablePath
- ErrorFindAdministratorsSID
ErrorCreateSecurityDescriptor
ErrorOpenNULFile
ErrorTrackTunnels
@@ -60,8 +59,6 @@ func (e Error) Error() string {
return "Unable to bind sockets to default route"
case ErrorSetNetConfig:
return "Unable to set interface addresses, routes, dns, and/or adapter settings"
- case ErrorFindAdministratorsSID:
- return "Unable to find Administrators SID"
case ErrorCreateSecurityDescriptor:
return "Unable to determine security descriptor"
case ErrorOpenNULFile:
diff --git a/service/ipc_client.go b/service/ipc_client.go
index 41e71f22..adfd456c 100644
--- a/service/ipc_client.go
+++ b/service/ipc_client.go
@@ -9,6 +9,7 @@ import (
"encoding/gob"
"errors"
"golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/updater"
"net/rpc"
"os"
)
@@ -33,6 +34,8 @@ const (
TunnelChangeNotificationType NotificationType = iota
TunnelsChangeNotificationType
ManagerStoppingNotificationType
+ UpdateFoundNotificationType
+ UpdateProgressNotificationType
)
var rpcClient *rpc.Client
@@ -55,6 +58,18 @@ type ManagerStoppingCallback struct {
var managerStoppingCallbacks = make(map[*ManagerStoppingCallback]bool)
+type UpdateFoundCallback struct {
+ cb func(updateState UpdateState)
+}
+
+var updateFoundCallbacks = make(map[*UpdateFoundCallback]bool)
+
+type UpdateProgressCallback struct {
+ cb func(dp updater.DownloadProgress)
+}
+
+var updateProgressCallbacks = make(map[*UpdateProgressCallback]bool)
+
func InitializeIPCClient(reader *os.File, writer *os.File, events *os.File) {
rpcClient = rpc.NewClient(&pipeRWC{reader, writer})
go func() {
@@ -106,6 +121,44 @@ func InitializeIPCClient(reader *os.File, writer *os.File, events *os.File) {
for cb := range managerStoppingCallbacks {
cb.cb()
}
+ case UpdateFoundNotificationType:
+ var state UpdateState
+ err = decoder.Decode(&state)
+ if err != nil {
+ continue
+ }
+ for cb := range updateFoundCallbacks {
+ cb.cb(state)
+ }
+ case UpdateProgressNotificationType:
+ var dp updater.DownloadProgress
+ err = decoder.Decode(&dp.Activity)
+ if err != nil {
+ continue
+ }
+ err = decoder.Decode(&dp.BytesDownloaded)
+ if err != nil {
+ continue
+ }
+ err = decoder.Decode(&dp.BytesTotal)
+ if err != nil {
+ continue
+ }
+ var errStr string
+ err = decoder.Decode(&errStr)
+ if err != nil {
+ continue
+ }
+ if len(errStr) > 0 {
+ dp.Error = errors.New(errStr)
+ }
+ err = decoder.Decode(&dp.Complete)
+ if err != nil {
+ continue
+ }
+ for cb := range updateProgressCallbacks {
+ cb.cb(dp)
+ }
}
}
}()
@@ -176,6 +229,15 @@ func IPCClientQuit(stopTunnelsOnQuit bool) (bool, error) {
return alreadyQuit, rpcClient.Call("ManagerService.Quit", stopTunnelsOnQuit, &alreadyQuit)
}
+func IPCClientUpdateState() (UpdateState, error) {
+ var state UpdateState
+ return state, rpcClient.Call("ManagerService.UpdateState", uintptr(0), &state)
+}
+
+func IPCClientUpdate() error {
+ return rpcClient.Call("ManagerService.Update", uintptr(0), nil)
+}
+
func IPCClientRegisterTunnelChange(cb func(tunnel *Tunnel, state TunnelState, globalState TunnelState, err error)) *TunnelChangeCallback {
s := &TunnelChangeCallback{cb}
tunnelChangeCallbacks[s] = true
@@ -200,3 +262,19 @@ func IPCClientRegisterManagerStopping(cb func()) *ManagerStoppingCallback {
func (cb *ManagerStoppingCallback) Unregister() {
delete(managerStoppingCallbacks, cb)
}
+func IPCClientRegisterUpdateFound(cb func(updateState UpdateState)) *UpdateFoundCallback {
+ s := &UpdateFoundCallback{cb}
+ updateFoundCallbacks[s] = true
+ return s
+}
+func (cb *UpdateFoundCallback) Unregister() {
+ delete(updateFoundCallbacks, cb)
+}
+func IPCClientRegisterUpdateProgress(cb func(dp updater.DownloadProgress)) *UpdateProgressCallback {
+ s := &UpdateProgressCallback{cb}
+ updateProgressCallbacks[s] = true
+ return s
+}
+func (cb *UpdateProgressCallback) Unregister() {
+ delete(updateProgressCallbacks, cb)
+}
diff --git a/service/ipc_server.go b/service/ipc_server.go
index 9ccea8ef..becad3ee 100644
--- a/service/ipc_server.go
+++ b/service/ipc_server.go
@@ -10,8 +10,10 @@ import (
"encoding/gob"
"fmt"
"github.com/Microsoft/go-winio"
+ "golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
"golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/updater"
"io/ioutil"
"log"
"net/rpc"
@@ -27,8 +29,14 @@ var managerServicesLock sync.RWMutex
var haveQuit uint32
var quitManagersChan = make(chan struct{}, 1)
+type UserTokenInfo struct {
+ elevatedToken windows.Token
+ elevatedEnvironment []string
+}
+
type ManagerService struct {
- events *os.File
+ events *os.File
+ userTokenInfo *UserTokenInfo
}
func (s *ManagerService) StoredConfig(tunnelName string, config *conf.Config) error {
@@ -115,7 +123,7 @@ func (s *ManagerService) Start(tunnelName string, unused *uintptr) error {
return InstallTunnel(path)
}
-func (s *ManagerService) Stop(tunnelName string, unused *uintptr) error {
+func (s *ManagerService) Stop(tunnelName string, _ *uintptr) error {
err := UninstallTunnel(tunnelName)
if err == syscall.Errno(serviceDOES_NOT_EXIST) {
_, notExistsError := conf.LoadFromName(tunnelName)
@@ -126,7 +134,7 @@ func (s *ManagerService) Stop(tunnelName string, unused *uintptr) error {
return err
}
-func (s *ManagerService) WaitForStop(tunnelName string, unused *uintptr) error {
+func (s *ManagerService) WaitForStop(tunnelName string, _ *uintptr) error {
serviceName, err := ServiceNameOfTunnel(tunnelName)
if err != nil {
return err
@@ -146,7 +154,7 @@ func (s *ManagerService) WaitForStop(tunnelName string, unused *uintptr) error {
}
}
-func (s *ManagerService) Delete(tunnelName string, unused *uintptr) error {
+func (s *ManagerService) Delete(tunnelName string, _ *uintptr) error {
err := s.Stop(tunnelName, nil)
if err != nil {
return err
@@ -189,7 +197,7 @@ func (s *ManagerService) State(tunnelName string, state *TunnelState) error {
return nil
}
-func (s *ManagerService) GlobalState(unused uintptr, state *TunnelState) error {
+func (s *ManagerService) GlobalState(_ uintptr, state *TunnelState) error {
*state = trackedTunnelsGlobalState()
return nil
}
@@ -205,7 +213,7 @@ func (s *ManagerService) Create(tunnelConfig conf.Config, tunnel *Tunnel) error
//TODO: handle already running and existing situation
}
-func (s *ManagerService) Tunnels(unused uintptr, tunnels *[]Tunnel) error {
+func (s *ManagerService) Tunnels(_ uintptr, tunnels *[]Tunnel) error {
names, err := conf.ListConfigNames()
if err != nil {
return err
@@ -244,8 +252,30 @@ func (s *ManagerService) Quit(stopTunnelsOnQuit bool, alreadyQuit *bool) error {
return nil
}
-func IPCServerListen(reader *os.File, writer *os.File, events *os.File) error {
- service := &ManagerService{events: events}
+func (s *ManagerService) UpdateState(_ uintptr, state *UpdateState) error {
+ *state = updateState
+ return nil
+}
+
+func (s *ManagerService) Update(_ uintptr, _ *uintptr) error {
+ progress := updater.DownloadVerifyAndExecute(uintptr(s.userTokenInfo.elevatedToken), s.userTokenInfo.elevatedEnvironment)
+ go func() {
+ for {
+ dp := <-progress
+ IPCServerNotifyUpdateProgress(dp)
+ if dp.Complete || dp.Error != nil {
+ return
+ }
+ }
+ }()
+ return nil
+}
+
+func IPCServerListen(reader *os.File, writer *os.File, events *os.File, userTokenInfo *UserTokenInfo) error {
+ service := &ManagerService{
+ events: events,
+ userTokenInfo: userTokenInfo,
+ }
server := rpc.NewServer()
err := server.Register(service)
@@ -304,6 +334,18 @@ func IPCServerNotifyTunnelsChange() {
notifyAll(TunnelsChangeNotificationType)
}
+func IPCServerNotifyUpdateFound(state UpdateState) {
+ notifyAll(UpdateFoundNotificationType, state)
+}
+
+func IPCServerNotifyUpdateProgress(dp updater.DownloadProgress) {
+ if dp.Error == nil {
+ notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, "", dp.Complete)
+ } else {
+ notifyAll(UpdateProgressNotificationType, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, dp.Error.Error(), dp.Complete)
+ }
+}
+
func IPCServerNotifyManagerStopping() {
notifyAll(ManagerStoppingNotificationType)
time.Sleep(time.Millisecond * 200)
diff --git a/service/securityapi.go b/service/securityapi.go
new file mode 100644
index 00000000..6c5f7844
--- /dev/null
+++ b/service/securityapi.go
@@ -0,0 +1,157 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package service
+
+import (
+ "errors"
+ "golang.org/x/sys/windows"
+ "syscall"
+ "unicode/utf16"
+ "unsafe"
+)
+
+const (
+ wtsSessionLogon uint32 = 5
+ wtsSessionLogoff uint32 = 6
+)
+
+type wtsState int
+
+const (
+ wtsActive wtsState = iota
+ wtsConnected
+ wtsConnectQuery
+ wtsShadow
+ wtsDisconnected
+ wtsIdle
+ wtsListen
+ wtsReset
+ wtsDown
+ wtsInit
+)
+
+type wtsSessionNotification struct {
+ size uint32
+ sessionID uint32
+}
+
+type wtsSessionInfo struct {
+ sessionID uint32
+ windowStationName *uint16
+ state wtsState
+}
+
+//sys wtsQueryUserToken(session uint32, token *windows.Token) (err error) = wtsapi32.WTSQueryUserToken
+//sys wtsEnumerateSessions(handle windows.Handle, reserved uint32, version uint32, sessions **wtsSessionInfo, count *uint32) (err error) = wtsapi32.WTSEnumerateSessionsW
+//sys wtsFreeMemory(ptr uintptr) = wtsapi32.WTSFreeMemory
+
+const (
+ SE_KERNEL_OBJECT = 6
+
+ DACL_SECURITY_INFORMATION = 4
+ ATTRIBUTE_SECURITY_INFORMATION = 16
+)
+
+//sys getSecurityInfo(handle windows.Handle, objectType uint32, si uint32, sidOwner *windows.SID, sidGroup *windows.SID, dacl *uintptr, sacl *uintptr, securityDescriptor *uintptr) (err error) [failretval!=0] = advapi32.GetSecurityInfo
+//sys getSecurityDescriptorLength(securityDescriptor uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
+
+//sys createEnvironmentBlock(block *uintptr, token windows.Token, inheritExisting bool) (err error) = userenv.CreateEnvironmentBlock
+//sys destroyEnvironmentBlock(block uintptr) (err error) = userenv.DestroyEnvironmentBlock
+
+func userEnviron(token windows.Token) (env []string, err error) {
+ var block uintptr
+ err = createEnvironmentBlock(&block, token, false)
+ if err != nil {
+ return
+ }
+ offset := uintptr(0)
+ for {
+ entry := (*[(1 << 30) - 1]uint16)(unsafe.Pointer(block + offset))[:]
+ for i, v := range entry {
+ if v == 0 {
+ entry = entry[:i]
+ break
+ }
+ }
+ if len(entry) == 0 {
+ break
+ }
+ env = append(env, string(utf16.Decode(entry)))
+ offset += 2 * (uintptr(len(entry)) + 1)
+ }
+ destroyEnvironmentBlock(block)
+ return
+}
+
+func tokenIsElevated(token windows.Token) bool {
+ var isElevated uint32
+ var outLen uint32
+ err := windows.GetTokenInformation(token, windows.TokenElevation, (*byte)(unsafe.Pointer(&isElevated)), uint32(unsafe.Sizeof(isElevated)), &outLen)
+ if err != nil {
+ return false
+ }
+ return outLen == uint32(unsafe.Sizeof(isElevated)) && isElevated != 0
+
+}
+
+func getElevatedToken(token windows.Token) (windows.Token, error) {
+ if tokenIsElevated(token) {
+ return token, nil
+ }
+ var linkedToken windows.Token
+ var outLen uint32
+ err := windows.GetTokenInformation(token, windows.TokenLinkedToken, (*byte)(unsafe.Pointer(&linkedToken)), uint32(unsafe.Sizeof(linkedToken)), &outLen)
+ if err != nil {
+ return windows.Token(0), err
+ }
+ if tokenIsElevated(linkedToken) {
+ return linkedToken, nil
+ }
+ linkedToken.Close()
+ return windows.Token(0), errors.New("the linked token is not elevated")
+}
+
+func tokenIsMemberOfBuiltInAdministrator(token windows.Token) bool {
+ //TODO: SECURITY CRITICIAL!
+ //TODO: Isn't it better to use an impersonation token or userToken.IsMember instead?
+ adminSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid)
+ if err != nil {
+ return false
+ }
+ gs, err := token.GetTokenGroups()
+ if err != nil {
+ return false
+ }
+ p := unsafe.Pointer(&gs.Groups[0])
+ groups := (*[(1 << 28) - 1]windows.SIDAndAttributes)(p)[:gs.GroupCount]
+ isAdmin := false
+ for _, g := range groups {
+ if windows.EqualSid(g.Sid, adminSid) {
+ isAdmin = true
+ break
+ }
+ }
+ return isAdmin
+}
+
+func getCurrentSecurityAttributes() (*syscall.SecurityAttributes, error) {
+ currentProcess, err := windows.GetCurrentProcess()
+ if err != nil {
+ return nil, err
+ }
+ securityAttributes := &syscall.SecurityAttributes{}
+ err = getSecurityInfo(currentProcess, SE_KERNEL_OBJECT, DACL_SECURITY_INFORMATION, nil, nil, nil, nil, &securityAttributes.SecurityDescriptor)
+ if err != nil {
+ return nil, err
+ }
+ windows.LocalFree(windows.Handle(securityAttributes.SecurityDescriptor))
+ securityAttributes.Length = getSecurityDescriptorLength(securityAttributes.SecurityDescriptor)
+ if securityAttributes.Length == 0 {
+ windows.LocalFree(windows.Handle(securityAttributes.SecurityDescriptor))
+ return nil, err
+ }
+ return securityAttributes, nil
+} \ No newline at end of file
diff --git a/service/service_manager.go b/service/service_manager.go
index 715b4257..e8818ae4 100644
--- a/service/service_manager.go
+++ b/service/service_manager.go
@@ -16,82 +16,9 @@ import (
"runtime/debug"
"sync"
"syscall"
- "unicode/utf16"
"unsafe"
)
-const (
- wtsSessionLogon uint32 = 5
- wtsSessionLogoff uint32 = 6
-)
-
-type wtsState int
-
-const (
- wtsActive wtsState = iota
- wtsConnected
- wtsConnectQuery
- wtsShadow
- wtsDisconnected
- wtsIdle
- wtsListen
- wtsReset
- wtsDown
- wtsInit
-)
-
-type wtsSessionNotification struct {
- size uint32
- sessionID uint32
-}
-
-type wtsSessionInfo struct {
- sessionID uint32
- windowStationName *uint16
- state wtsState
-}
-
-//sys wtsQueryUserToken(session uint32, token *windows.Token) (err error) = wtsapi32.WTSQueryUserToken
-//sys wtsEnumerateSessions(handle windows.Handle, reserved uint32, version uint32, sessions **wtsSessionInfo, count *uint32) (err error) = wtsapi32.WTSEnumerateSessionsW
-//sys wtsFreeMemory(ptr uintptr) = wtsapi32.WTSFreeMemory
-
-const (
- SE_KERNEL_OBJECT = 6
- DACL_SECURITY_INFORMATION = 4
- ATTRIBUTE_SECURITY_INFORMATION = 16
-)
-
-//sys getSecurityInfo(handle windows.Handle, objectType uint32, si uint32, sidOwner *windows.SID, sidGroup *windows.SID, dacl *uintptr, sacl *uintptr, securityDescriptor *uintptr) (err error) [failretval!=0] = advapi32.GetSecurityInfo
-//sys getSecurityDescriptorLength(securityDescriptor uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
-
-//sys createEnvironmentBlock(block *uintptr, token windows.Token, inheritExisting bool) (err error) = userenv.CreateEnvironmentBlock
-//sys destroyEnvironmentBlock(block uintptr) (err error) = userenv.DestroyEnvironmentBlock
-
-func userEnviron(token windows.Token) (env []string, err error) {
- var block uintptr
- err = createEnvironmentBlock(&block, token, false)
- if err != nil {
- return
- }
- offset := uintptr(0)
- for {
- entry := (*[(1 << 30) - 1]uint16)(unsafe.Pointer(block + offset))[:]
- for i, v := range entry {
- if v == 0 {
- entry = entry[:i]
- break
- }
- }
- if len(entry) == 0 {
- break
- }
- env = append(env, string(utf16.Decode(entry)))
- offset += 2 * (uintptr(len(entry)) + 1)
- }
- destroyEnvironmentBlock(block)
- return
-}
-
type managerService struct{}
func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) {
@@ -126,29 +53,12 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
serviceError = ErrorDetermineExecutablePath
return
}
-
- adminSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid)
- if err != nil {
- serviceError = ErrorFindAdministratorsSID
- return
- }
-
- currentProcess, err := windows.GetCurrentProcess()
- if err != nil {
- panic(err)
- }
- var securityAttributes syscall.SecurityAttributes
- err = getSecurityInfo(currentProcess, SE_KERNEL_OBJECT, DACL_SECURITY_INFORMATION, nil, nil, nil, nil, &securityAttributes.SecurityDescriptor)
+ securityAttributes, err := getCurrentSecurityAttributes()
if err != nil {
serviceError = ErrorCreateSecurityDescriptor
return
}
defer windows.LocalFree(windows.Handle(securityAttributes.SecurityDescriptor))
- securityAttributes.Length = getSecurityDescriptorLength(securityAttributes.SecurityDescriptor)
- if securityAttributes.Length == 0 {
- serviceError = ErrorCreateSecurityDescriptor
- return
- }
devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0)
if err != nil {
@@ -172,49 +82,51 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
startProcess = func(session uint32) {
defer runtime.UnlockOSThread()
+
+ var userToken windows.Token
+ err := wtsQueryUserToken(session, &userToken)
+ if err != nil {
+ return
+ }
+ defer userToken.Close()
+ if !tokenIsMemberOfBuiltInAdministrator(userToken) {
+ return
+ }
+ user, err := userToken.GetTokenUser()
+ if err != nil {
+ log.Printf("Unable to lookup user from token: %v", err)
+ return
+ }
+ username, domain, accType, err := user.User.Sid.LookupAccount("")
+ if err != nil {
+ log.Printf("Unable to lookup username from sid: %v", err)
+ return
+ }
+ if accType != windows.SidTypeUser {
+ return
+ }
+ env, err := userEnviron(userToken)
+ if err != nil {
+ log.Printf("Unable to determine user environment: %v", err)
+ return
+ }
+ userTokenInfo := &UserTokenInfo{}
+ userTokenInfo.elevatedToken, err = getElevatedToken(userToken)
+ if err != nil {
+ log.Printf("Unable to elevate token: %v", err)
+ }
+ if userTokenInfo.elevatedToken != userToken {
+ defer userTokenInfo.elevatedToken.Close()
+ }
+ userTokenInfo.elevatedEnvironment, err = userEnviron(userTokenInfo.elevatedToken)
+ if err != nil {
+ log.Printf("Unable to determine elevated environment: %v", err)
+ return
+ }
for {
if stoppingManager {
return
}
- var userToken windows.Token
- err := wtsQueryUserToken(session, &userToken)
- if err != nil {
- return
- }
-
- //TODO: SECURITY CRITICIAL!
- //TODO: Isn't it better to use an impersonation token and userToken.IsMember instead?
- gs, err := userToken.GetTokenGroups()
- if err != nil {
- log.Printf("Unable to lookup user groups from token: %v", err)
- return
- }
- p := unsafe.Pointer(&gs.Groups[0])
- groups := (*[(1 << 28) - 1]windows.SIDAndAttributes)(p)[:gs.GroupCount]
- isAdmin := false
- for _, g := range groups {
- if windows.EqualSid(g.Sid, adminSid) {
- isAdmin = true
- break
- }
- }
- if !isAdmin {
- return
- }
-
- user, err := userToken.GetTokenUser()
- if err != nil {
- log.Printf("Unable to lookup user from token: %v", err)
- return
- }
- username, domain, accType, err := user.User.Sid.LookupAccount("")
- if err != nil {
- log.Printf("Unable to lookup username from sid: %v", err)
- return
- }
- if accType != windows.SidTypeUser {
- return
- }
//TODO: we lock the OS thread so that these inheritable handles don't escape into other processes that
// might be running in parallel Go routines. But the Go runtime is strange and who knows what's really
@@ -226,7 +138,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
return
}
ourEvents, theirEvents, theirEventStr, err := inheritableEvents()
- err = IPCServerListen(ourReader, ourWriter, ourEvents)
+ err = IPCServerListen(ourReader, ourWriter, ourEvents, userTokenInfo)
if err != nil {
log.Printf("Unable to listen on IPC pipes: %v", err)
return
@@ -237,12 +149,6 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
return
}
- env, err := userEnviron(userToken)
- if err != nil {
- log.Printf("Unable to determine user environment: %v", err)
- return
- }
-
log.Printf("Starting UI process for user: '%s@%s'", username, domain)
attr := &os.ProcAttr{
Sys: &syscall.SysProcAttr{
@@ -286,6 +192,8 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest
}
}
+ go checkForUpdates()
+
var sessionsPointer *wtsSessionInfo
var count uint32
err = wtsEnumerateSessions(0, 0, 1, &sessionsPointer, &count)
diff --git a/service/updatestate.go b/service/updatestate.go
new file mode 100644
index 00000000..c046edb2
--- /dev/null
+++ b/service/updatestate.go
@@ -0,0 +1,56 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package service
+
+import (
+ "golang.zx2c4.com/wireguard/windows/updater"
+ "golang.zx2c4.com/wireguard/windows/version"
+ "log"
+ "time"
+)
+
+type UpdateState uint32
+
+const (
+ UpdateStateUnknown UpdateState = iota
+ UpdateStateFoundUpdate
+ UpdateStateUpdatesDisabledUnofficialBuild
+)
+
+var updateState = UpdateStateUnknown
+
+func checkForUpdates() {
+ if !version.IsRunningOfficialVersion() {
+ log.Println("Build is not official, so updates are disabled")
+ updateState = UpdateStateUpdatesDisabledUnofficialBuild
+ IPCServerNotifyUpdateFound(updateState)
+ return
+ }
+
+ time.Sleep(time.Second * 10)
+
+ first := true
+ for {
+ update, err := updater.CheckForUpdate()
+ if err == nil && update != nil {
+ log.Println("An update is available")
+ updateState = UpdateStateFoundUpdate
+ IPCServerNotifyUpdateFound(updateState)
+ return
+ }
+ if err != nil {
+ log.Printf("Update checker: %v", err)
+ if first {
+ time.Sleep(time.Minute * 4)
+ first = false
+ } else {
+ time.Sleep(time.Minute * 25)
+ }
+ } else {
+ time.Sleep(time.Hour)
+ }
+ }
+}
diff --git a/ui/ui.go b/ui/ui.go
index f9519282..2e2b56ec 100644
--- a/ui/ui.go
+++ b/ui/ui.go
@@ -9,9 +9,6 @@ import (
"fmt"
"github.com/lxn/walk"
"golang.zx2c4.com/wireguard/windows/service"
- "golang.zx2c4.com/wireguard/windows/updater"
- "golang.zx2c4.com/wireguard/windows/version"
- "log"
"runtime"
"runtime/debug"
"time"
@@ -58,35 +55,25 @@ func RunUI() {
})
})
- go func() {
- if !version.IsRunningOfficialVersion() {
- mtw.Synchronize(func() {
- mtw.SetTitle(mtw.Title() + " (unsigned build)")
- })
+ onUpdateNotification := func(updateState service.UpdateState) {
+ if updateState == service.UpdateStateUnknown {
return
}
-
- first := true
- for {
- update, err := updater.CheckForUpdate()
- if err == nil && update != nil {
- mtw.Synchronize(func() {
- mtw.UpdateFound()
- tray.UpdateFound()
- })
- return
- }
- if err != nil {
- log.Printf("Update checker: %v", err)
- if first {
- time.Sleep(time.Minute * 4)
- first = false
- } else {
- time.Sleep(time.Minute * 25)
- }
- } else {
- time.Sleep(time.Hour)
+ mtw.Synchronize(func() {
+ switch updateState {
+ case service.UpdateStateFoundUpdate:
+ mtw.UpdateFound()
+ tray.UpdateFound()
+ case service.UpdateStateUpdatesDisabledUnofficialBuild:
+ mtw.SetTitle(mtw.Title() + " (unsigned build, no updates)")
}
+ })
+ }
+ service.IPCClientRegisterUpdateFound(onUpdateNotification)
+ go func() {
+ updateState, err := service.IPCClientUpdateState()
+ if err == nil {
+ onUpdateNotification(updateState)
}
}()
diff --git a/ui/updatepage.go b/ui/updatepage.go
index 04cc39ca..b5226017 100644
--- a/ui/updatepage.go
+++ b/ui/updatepage.go
@@ -7,9 +7,9 @@ package ui
import (
"fmt"
- "golang.zx2c4.com/wireguard/windows/updater"
-
"github.com/lxn/walk"
+ "golang.zx2c4.com/wireguard/windows/service"
+ "golang.zx2c4.com/wireguard/windows/updater"
)
type UpdatePage struct {
@@ -52,63 +52,67 @@ func NewUpdatePage() (*UpdatePage, error) {
walk.NewVSpacer(up)
+ switchToUpdatingState := func() {
+ if !bar.Visible() {
+ up.SetSuspended(true)
+ button.SetEnabled(false)
+ button.SetVisible(false)
+ bar.SetVisible(true)
+ bar.SetMarqueeMode(true)
+ up.SetSuspended(false)
+ status.SetText("Status: Waiting for updater service")
+ }
+ }
+
+ switchToReadyState := func() {
+ if bar.Visible() {
+ up.SetSuspended(true)
+ bar.SetVisible(false)
+ bar.SetValue(0)
+ bar.SetRange(0, 1)
+ bar.SetMarqueeMode(false)
+ button.SetVisible(true)
+ button.SetEnabled(true)
+ up.SetSuspended(false)
+ }
+ }
+
button.Clicked().Attach(func() {
- up.SetSuspended(true)
- button.SetEnabled(false)
- button.SetVisible(false)
- bar.SetVisible(true)
- bar.SetMarqueeMode(true)
- up.SetSuspended(false)
- progress := updater.DownloadVerifyAndExecute()
- go func() {
- for {
- dp := <-progress
- retNow := false
- up.Synchronize(func() {
- if dp.Error != nil {
- up.SetSuspended(true)
- bar.SetVisible(false)
- bar.SetValue(0)
- bar.SetRange(0, 1)
- bar.SetMarqueeMode(false)
- button.SetVisible(true)
- button.SetEnabled(true)
- status.SetText(fmt.Sprintf("Error: %v. Please try again.", dp.Error))
- up.SetSuspended(false)
- retNow = true
- return
- }
- if len(dp.Activity) > 0 {
- status.SetText(fmt.Sprintf("Status: %s", dp.Activity))
- }
- if dp.BytesTotal > 0 {
- bar.SetMarqueeMode(false)
- bar.SetRange(0, int(dp.BytesTotal))
- bar.SetValue(int(dp.BytesDownloaded))
- } else {
- bar.SetMarqueeMode(true)
- bar.SetValue(0)
- bar.SetRange(0, 1)
- }
- if dp.Complete {
- up.SetSuspended(true)
- bar.SetVisible(false)
- bar.SetValue(0)
- bar.SetRange(0, 0)
- bar.SetMarqueeMode(false)
- button.SetVisible(true)
- button.SetEnabled(true)
- status.SetText("Status: Complete!")
- up.SetSuspended(false)
- retNow = true
- return
- }
- })
- if retNow {
- return
- }
+ switchToUpdatingState()
+ err := service.IPCClientUpdate()
+ if err != nil {
+ switchToReadyState()
+ status.SetText(fmt.Sprintf("Error: %v. Please try again.", err))
+ }
+ })
+
+ service.IPCClientRegisterUpdateProgress(func(dp updater.DownloadProgress) {
+ up.Synchronize(func() {
+ switchToUpdatingState()
+ if dp.Error != nil {
+ switchToReadyState()
+ status.SetText(fmt.Sprintf("Error: %v. Please try again.", dp.Error))
+ return
+ }
+ if len(dp.Activity) > 0 {
+ status.SetText(fmt.Sprintf("Status: %s", dp.Activity))
}
- }()
+ if dp.BytesTotal > 0 {
+ bar.SetMarqueeMode(false)
+ bar.SetRange(0, int(dp.BytesTotal))
+ bar.SetValue(int(dp.BytesDownloaded))
+ } else {
+ bar.SetMarqueeMode(true)
+ bar.SetValue(0)
+ bar.SetRange(0, 1)
+ }
+ if dp.Complete {
+ switchToReadyState()
+ status.SetText("Status: Complete!")
+ return
+ }
+ })
})
+
return up, nil
}
diff --git a/updater/downloader.go b/updater/downloader.go
index 382d284b..2f83b9b2 100644
--- a/updater/downloader.go
+++ b/updater/downloader.go
@@ -15,7 +15,6 @@ import (
"io"
"net/http"
"os"
- "path"
"sync/atomic"
)
@@ -71,7 +70,7 @@ func CheckForUpdate() (*UpdateFound, error) {
var updateInProgress = uint32(0)
-func DownloadVerifyAndExecute() (progress chan DownloadProgress) {
+func DownloadVerifyAndExecute(userToken uintptr, userEnvironment []string) (progress chan DownloadProgress) {
progress = make(chan DownloadProgress, 128)
progress <- DownloadProgress{Activity: "Initializing"}
@@ -94,33 +93,19 @@ func DownloadVerifyAndExecute() (progress chan DownloadProgress) {
return
}
- progress <- DownloadProgress{Activity: "Creating update file"}
- updateDir, err := msiSaveDirectory()
- if err != nil {
- progress <- DownloadProgress{Error: err}
- return
- }
- // Clean up old updates the brutal way:
- os.RemoveAll(updateDir)
-
- err = os.MkdirAll(updateDir, 0700)
- if err != nil {
- progress <- DownloadProgress{Error: err}
- return
- }
- destinationFilename := path.Join(updateDir, update.name)
- unverifiedDestinationFilename := destinationFilename + ".unverified"
- out, err := os.Create(unverifiedDestinationFilename)
+ progress <- DownloadProgress{Activity: "Creating temporary file"}
+ file, err := msiTempFile()
if err != nil {
progress <- DownloadProgress{Error: err}
return
}
defer func() {
- if out != nil {
- out.Seek(0, io.SeekStart)
- out.Truncate(0)
- out.Close()
- os.Remove(unverifiedDestinationFilename)
+ if file != nil {
+ name := file.Name()
+ file.Seek(0, io.SeekStart)
+ file.Truncate(0)
+ file.Close()
+ os.Remove(name) //TODO: Do we have any sort of TOCTOU here?
}
}()
@@ -149,7 +134,7 @@ func DownloadVerifyAndExecute() (progress chan DownloadProgress) {
return
}
pm := &progressHashWatcher{&dp, progress, hasher}
- _, err = io.Copy(out, io.TeeReader(io.LimitReader(response.Body, 1024*1024*100 /* 100 MiB */), pm))
+ _, err = io.Copy(file, io.TeeReader(io.LimitReader(response.Body, 1024*1024*100 /* 100 MiB */), pm))
if err != nil {
progress <- DownloadProgress{Error: err}
return
@@ -158,25 +143,23 @@ func DownloadVerifyAndExecute() (progress chan DownloadProgress) {
progress <- DownloadProgress{Error: errors.New("The downloaded update has the wrong hash")}
return
}
- out.Close()
- out = nil
+
+ //TODO: it would be nice to rename in place from "file.msi.unverified" to "file.msi", but Windows TOCTOU stuff
+ // is hard, so we'll come back to this later.
+ name := file.Name()
+ file.Close()
+ file = nil
progress <- DownloadProgress{Activity: "Verifying authenticode signature"}
- if !version.VerifyAuthenticode(unverifiedDestinationFilename) {
- os.Remove(unverifiedDestinationFilename)
+ if !version.VerifyAuthenticode(name) {
+ os.Remove(name) //TODO: Do we have any sort of TOCTOU here?
progress <- DownloadProgress{Error: errors.New("The downloaded update does not have an authentic authenticode signature")}
return
}
progress <- DownloadProgress{Activity: "Installing update"}
- err = os.Rename(unverifiedDestinationFilename, destinationFilename)
- if err != nil {
- os.Remove(unverifiedDestinationFilename)
- progress <- DownloadProgress{Error: err}
- return
- }
- err = runMsi(destinationFilename)
- os.Remove(unverifiedDestinationFilename)
+ err = runMsi(name, userToken, userEnvironment)
+ os.Remove(name) //TODO: Do we have any sort of TOCTOU here?
if err != nil {
progress <- DownloadProgress{Error: err}
return
diff --git a/updater/msirunner_linux.go b/updater/msirunner_linux.go
index cbb52cf6..6550025c 100644
--- a/updater/msirunner_linux.go
+++ b/updater/msirunner_linux.go
@@ -7,15 +7,17 @@ package updater
import (
"fmt"
+ "io/ioutil"
+ "os"
"os/exec"
)
// This isn't a Linux program, yes, but having the updater package work across platforms is quite helpful for testing.
-func runMsi(msiPath string) error {
+func runMsi(msiPath string, userToken uintptr, env []string) error {
return exec.Command("qarma", "--info", "--text", fmt.Sprintf("It seems to be working! Were we on Windows, ā€˜%sā€™ would be executed.", msiPath)).Run()
}
-func msiSaveDirectory() (string, error) {
- return "/tmp/wireguard-update-test-msi-directory", nil
+func msiTempFile() (*os.File, error) {
+ return ioutil.TempFile(os.TempDir(), "")
}
diff --git a/updater/msirunner_windows.go b/updater/msirunner_windows.go
index dfa921ee..de3fb58e 100644
--- a/updater/msirunner_windows.go
+++ b/updater/msirunner_windows.go
@@ -6,26 +6,79 @@
package updater
import (
+ "crypto/rand"
+ "encoding/hex"
+ "errors"
+ "github.com/Microsoft/go-winio"
"golang.org/x/sys/windows"
- "golang.zx2c4.com/wireguard/windows/conf"
+ "os"
"os/exec"
"path"
+ "runtime"
+ "syscall"
+ "unsafe"
)
-func runMsi(msiPath string) error {
+func runMsi(msiPath string, userToken uintptr, env []string) error {
system32, err := windows.GetSystemDirectory()
if err != nil {
return err
}
- cmd := exec.Command(path.Join(system32, "msiexec.exe"), "/qb!-", "/i", path.Base(msiPath))
- cmd.Dir = path.Dir(msiPath)
- return cmd.Run()
+ devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0)
+ if err != nil {
+ return err
+ }
+ defer devNull.Close()
+ attr := &os.ProcAttr{
+ Sys: &syscall.SysProcAttr{
+ Token: syscall.Token(userToken),
+ },
+ Files: []*os.File{devNull, devNull, devNull},
+ Env: env,
+ Dir: path.Dir(msiPath),
+ }
+ msiexec := path.Join(system32, "msiexec.exe")
+ proc, err := os.StartProcess(msiexec, []string{msiexec, "/qb!-", "/i", path.Base(msiPath)}, attr)
+ if err != nil {
+ return err
+ }
+ state, err := proc.Wait()
+ if err != nil {
+ return err
+ }
+ if !state.Success() {
+ return &exec.ExitError{ProcessState: state}
+ }
+ return nil
}
-func msiSaveDirectory() (string, error) {
- configRootDir, err := conf.RootDirectory()
+func msiTempFile() (*os.File, error) {
+ var randBytes [32]byte
+ n, err := rand.Read(randBytes[:])
+ if err != nil {
+ return nil, err
+ }
+ if n != int(len(randBytes)) {
+ return nil, errors.New("Unable to generate random bytes")
+ }
+ sd, err := winio.SddlToSecurityDescriptor("O:SYD:PAI(A;;FA;;;SY)(A;;FR;;;BA)")
+ if err != nil {
+ return nil, err
+ }
+ sa := &windows.SecurityAttributes{
+ Length: uint32(len(sd)),
+ SecurityDescriptor: uintptr(unsafe.Pointer(&sd[0])),
+ }
+ //TODO: os.TempDir() returns C:\windows\temp when calling from this context. Supposedly this is mostly secure
+ // against TOCTOU, but who knows! Look into this!
+ name := path.Join(os.TempDir(), hex.EncodeToString(randBytes[:]))
+ name16 := windows.StringToUTF16Ptr(name)
+ //TODO: it would be nice to specify delete_on_close, but msiexec.exe doesn't open its files with read sharing.
+ fileHandle, err := windows.CreateFile(name16, windows.GENERIC_WRITE, windows.FILE_SHARE_READ, sa, windows.CREATE_NEW, windows.FILE_ATTRIBUTE_NORMAL, 0)
+ runtime.KeepAlive(sd)
if err != nil {
- return "", err
+ return nil, err
}
- return path.Join(configRootDir, "Updates"), nil
+ windows.MoveFileEx(name16, nil, windows.MOVEFILE_DELAY_UNTIL_REBOOT)
+ return os.NewFile(uintptr(fileHandle), name), nil
}
diff --git a/updater/updater_test.go b/updater/updater_test.go
index 7bc4df8e..fbd1080d 100644
--- a/updater/updater_test.go
+++ b/updater/updater_test.go
@@ -20,7 +20,7 @@ func TestUpdate(t *testing.T) {
return
}
t.Log("Found update")
- progress := DownloadVerifyAndExecute()
+ progress := DownloadVerifyAndExecute(0, nil)
for {
dp := <-progress
if dp.Error != nil {
diff --git a/version/debugging_linux.go b/version/debugging_linux.go
index df5dbd2f..2a7164ec 100644
--- a/version/debugging_linux.go
+++ b/version/debugging_linux.go
@@ -5,11 +5,13 @@
package version
-// For testing the updater package from linux. Debug stuff only.
+import (
+ "bytes"
+ "fmt"
+ "golang.org/x/sys/unix"
+)
-func IsOfficialPath(path string) bool {
- return true
-}
+// For testing the updater package from linux. Debug stuff only.
func utsToStr(u [65]byte) string {
i := bytes.IndexByte(u[:], 0)
@@ -30,3 +32,7 @@ func OsName() string {
func RunningVersion() string {
return "0.0.0.0"
}
+
+func VerifyAuthenticode(path string) bool {
+ return true
+} \ No newline at end of file