From 919e042c8410a927d48d26a25c9dd7df737ffa43 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Wed, 12 Jun 2019 18:37:49 +0200 Subject: conf: manually migrate from windows.old --- conf/path_windows.go | 62 ++++++++++++++++++++++++++++++++++++++++++------ conf/store.go | 2 ++ conf/zsyscall_windows.go | 29 +++++++++++++++++++--- 3 files changed, 83 insertions(+), 10 deletions(-) (limited to 'conf') diff --git a/conf/path_windows.go b/conf/path_windows.go index 34e189b2..ad2d759c 100644 --- a/conf/path_windows.go +++ b/conf/path_windows.go @@ -7,22 +7,74 @@ package conf import ( "errors" + "log" "os" "path/filepath" + "strings" "unsafe" "golang.org/x/sys/windows" ) //sys coTaskMemFree(pointer uintptr) = ole32.CoTaskMemFree -//sys shGetKnownFolderPath(id *windows.GUID, flags uint32, token windows.Handle, path **uint16) (err error) [failretval!=0] = shell32.SHGetKnownFolderPath +//sys shGetKnownFolderPath(id *windows.GUID, flags uint32, token windows.Handle, path **uint16) (ret error) = shell32.SHGetKnownFolderPath +//sys getFileSecurity(fileName *uint16, securityInformation uint32, securityDescriptor *byte, descriptorLen uint32, requestedLen *uint32) (err error) = advapi32.GetFileSecurityW +//sys getSecurityDescriptorOwner(securityDescriptor *byte, sid **windows.SID, ownerDefaulted *bool) (err error) = advapi32.GetSecurityDescriptorOwner + var folderIDLocalAppData = windows.GUID{0xf1b32785, 0x6fba, 0x4fcf, [8]byte{0x9d, 0x55, 0x7b, 0x8e, 0x7f, 0x15, 0x70, 0x91}} const kfFlagCreate = 0x00008000 +const ownerSecurityInformation = 0x00000001 var cachedConfigFileDir string var cachedRootDir string +func maybeMigrate(c string) { + vol := filepath.VolumeName(c) + withoutVol := strings.TrimPrefix(c, vol) + oldRoot := filepath.Join(vol, "\\windows.old") + oldC := filepath.Join(oldRoot, withoutVol) + + var err error + var sd []byte + reqLen := uint32(128) + for { + sd = make([]byte, reqLen) + //XXX: Since this takes a file path, it's technically a TOCTOU. + err = getFileSecurity(windows.StringToUTF16Ptr(oldRoot), ownerSecurityInformation, &sd[0], uint32(len(sd)), &reqLen) + if err != windows.ERROR_INSUFFICIENT_BUFFER { + break + } + } + if err == windows.ERROR_PATH_NOT_FOUND { + return + } + if err != nil { + log.Printf("Not migrating configuration from '%s' due to GetFileSecurity error: %v", oldRoot, err) + return + } + var defaulted bool + var sid *windows.SID + err = getSecurityDescriptorOwner(&sd[0], &sid, &defaulted) + if err != nil { + log.Printf("Not migrating configuration from '%s' due to GetSecurityDescriptorOwner error: %v", oldRoot, err) + return + } + if defaulted || !sid.IsWellKnown(windows.WinLocalSystemSid) { + sidStr, _ := sid.String() + log.Printf("Not migrating configuration from '%s' it is not explicitly owned by SYSTEM, but rather '%s'", oldRoot, sidStr) + return + } + err = windows.MoveFileEx(windows.StringToUTF16Ptr(oldC), windows.StringToUTF16Ptr(c), windows.MOVEFILE_COPY_ALLOWED) + if err != nil { + if err != windows.ERROR_FILE_NOT_FOUND && err != windows.ERROR_ALREADY_EXISTS { + log.Printf("Not migrating configuration from '%s' due to error when moving files: %v", oldRoot, err) + } + return + } + log.Printf("Migrated configuration from '%s'", oldRoot) +} + func tunnelConfigurationsDirectory() (string, error) { if cachedConfigFileDir != "" { return cachedConfigFileDir, nil @@ -32,6 +84,7 @@ func tunnelConfigurationsDirectory() (string, error) { return "", err } c := filepath.Join(root, "Configurations") + maybeMigrate(c) err = os.MkdirAll(c, os.ModeDir|0700) if err != nil { return "", err @@ -44,13 +97,8 @@ func RootDirectory() (string, error) { if cachedRootDir != "" { return cachedRootDir, nil } - processToken, err := windows.OpenCurrentProcessToken() - if err != nil { - return "", err - } - defer processToken.Close() var path *uint16 - err = shGetKnownFolderPath(&folderIDLocalAppData, kfFlagCreate, windows.Handle(processToken), &path) + err := shGetKnownFolderPath(&folderIDLocalAppData, kfFlagCreate, 0, &path) if err != nil { return "", err } diff --git a/conf/store.go b/conf/store.go index 2886d027..b5cdd1ef 100644 --- a/conf/store.go +++ b/conf/store.go @@ -129,6 +129,8 @@ func LoadFromName(name string) (*Config, error) { } func LoadFromPath(path string) (*Config, error) { + tunnelConfigurationsDirectory() // Provoke migrations, if needed. + name, err := NameFromPath(path) if err != nil { return nil, err diff --git a/conf/zsyscall_windows.go b/conf/zsyscall_windows.go index d8984bef..7ad7bbab 100644 --- a/conf/zsyscall_windows.go +++ b/conf/zsyscall_windows.go @@ -40,11 +40,14 @@ var ( modwininet = windows.NewLazySystemDLL("wininet.dll") modole32 = windows.NewLazySystemDLL("ole32.dll") modshell32 = windows.NewLazySystemDLL("shell32.dll") + modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") modkernel32 = windows.NewLazySystemDLL("kernel32.dll") procInternetGetConnectedState = modwininet.NewProc("InternetGetConnectedState") procCoTaskMemFree = modole32.NewProc("CoTaskMemFree") procSHGetKnownFolderPath = modshell32.NewProc("SHGetKnownFolderPath") + procGetFileSecurityW = modadvapi32.NewProc("GetFileSecurityW") + procGetSecurityDescriptorOwner = modadvapi32.NewProc("GetSecurityDescriptorOwner") procFindFirstChangeNotificationW = modkernel32.NewProc("FindFirstChangeNotificationW") procFindNextChangeNotification = modkernel32.NewProc("FindNextChangeNotification") ) @@ -60,9 +63,29 @@ func coTaskMemFree(pointer uintptr) { return } -func shGetKnownFolderPath(id *windows.GUID, flags uint32, token windows.Handle, path **uint16) (err error) { - r1, _, e1 := syscall.Syscall6(procSHGetKnownFolderPath.Addr(), 4, uintptr(unsafe.Pointer(id)), uintptr(flags), uintptr(token), uintptr(unsafe.Pointer(path)), 0, 0) - if r1 != 0 { +func shGetKnownFolderPath(id *windows.GUID, flags uint32, token windows.Handle, path **uint16) (ret error) { + r0, _, _ := syscall.Syscall6(procSHGetKnownFolderPath.Addr(), 4, uintptr(unsafe.Pointer(id)), uintptr(flags), uintptr(token), uintptr(unsafe.Pointer(path)), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func getFileSecurity(fileName *uint16, securityInformation uint32, securityDescriptor *byte, descriptorLen uint32, requestedLen *uint32) (err error) { + r1, _, e1 := syscall.Syscall6(procGetFileSecurityW.Addr(), 5, uintptr(unsafe.Pointer(fileName)), uintptr(securityInformation), uintptr(unsafe.Pointer(securityDescriptor)), uintptr(descriptorLen), uintptr(unsafe.Pointer(requestedLen)), 0) + if r1 == 0 { + if e1 != 0 { + err = errnoErr(e1) + } else { + err = syscall.EINVAL + } + } + return +} + +func getSecurityDescriptorOwner(securityDescriptor *byte, sid **windows.SID, ownerDefaulted *bool) (err error) { + r1, _, e1 := syscall.Syscall(procGetSecurityDescriptorOwner.Addr(), 3, uintptr(unsafe.Pointer(securityDescriptor)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(ownerDefaulted))) + if r1 == 0 { if e1 != 0 { err = errnoErr(e1) } else { -- cgit v1.2.3-59-g8ed1b