aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/conf
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2019-06-12 18:37:49 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2019-06-13 09:21:50 +0200
commit919e042c8410a927d48d26a25c9dd7df737ffa43 (patch)
tree7b4b9b1f8989d827bd691867b34a07d9a7f67055 /conf
parentmod: bump wireguard-go (diff)
downloadwireguard-windows-919e042c8410a927d48d26a25c9dd7df737ffa43.tar.xz
wireguard-windows-919e042c8410a927d48d26a25c9dd7df737ffa43.zip
conf: manually migrate from windows.old
Diffstat (limited to 'conf')
-rw-r--r--conf/path_windows.go62
-rw-r--r--conf/store.go2
-rw-r--r--conf/zsyscall_windows.go29
3 files changed, 83 insertions, 10 deletions
diff --git a/conf/path_windows.go b/conf/path_windows.go
index 34e189b2..ad2d759c 100644
--- a/conf/path_windows.go
+++ b/conf/path_windows.go
@@ -7,22 +7,74 @@ package conf
import (
"errors"
+ "log"
"os"
"path/filepath"
+ "strings"
"unsafe"
"golang.org/x/sys/windows"
)
//sys coTaskMemFree(pointer uintptr) = ole32.CoTaskMemFree
-//sys shGetKnownFolderPath(id *windows.GUID, flags uint32, token windows.Handle, path **uint16) (err error) [failretval!=0] = shell32.SHGetKnownFolderPath
+//sys shGetKnownFolderPath(id *windows.GUID, flags uint32, token windows.Handle, path **uint16) (ret error) = shell32.SHGetKnownFolderPath
+//sys getFileSecurity(fileName *uint16, securityInformation uint32, securityDescriptor *byte, descriptorLen uint32, requestedLen *uint32) (err error) = advapi32.GetFileSecurityW
+//sys getSecurityDescriptorOwner(securityDescriptor *byte, sid **windows.SID, ownerDefaulted *bool) (err error) = advapi32.GetSecurityDescriptorOwner
+
var folderIDLocalAppData = windows.GUID{0xf1b32785, 0x6fba, 0x4fcf, [8]byte{0x9d, 0x55, 0x7b, 0x8e, 0x7f, 0x15, 0x70, 0x91}}
const kfFlagCreate = 0x00008000
+const ownerSecurityInformation = 0x00000001
var cachedConfigFileDir string
var cachedRootDir string
+func maybeMigrate(c string) {
+ vol := filepath.VolumeName(c)
+ withoutVol := strings.TrimPrefix(c, vol)
+ oldRoot := filepath.Join(vol, "\\windows.old")
+ oldC := filepath.Join(oldRoot, withoutVol)
+
+ var err error
+ var sd []byte
+ reqLen := uint32(128)
+ for {
+ sd = make([]byte, reqLen)
+ //XXX: Since this takes a file path, it's technically a TOCTOU.
+ err = getFileSecurity(windows.StringToUTF16Ptr(oldRoot), ownerSecurityInformation, &sd[0], uint32(len(sd)), &reqLen)
+ if err != windows.ERROR_INSUFFICIENT_BUFFER {
+ break
+ }
+ }
+ if err == windows.ERROR_PATH_NOT_FOUND {
+ return
+ }
+ if err != nil {
+ log.Printf("Not migrating configuration from '%s' due to GetFileSecurity error: %v", oldRoot, err)
+ return
+ }
+ var defaulted bool
+ var sid *windows.SID
+ err = getSecurityDescriptorOwner(&sd[0], &sid, &defaulted)
+ if err != nil {
+ log.Printf("Not migrating configuration from '%s' due to GetSecurityDescriptorOwner error: %v", oldRoot, err)
+ return
+ }
+ if defaulted || !sid.IsWellKnown(windows.WinLocalSystemSid) {
+ sidStr, _ := sid.String()
+ log.Printf("Not migrating configuration from '%s' it is not explicitly owned by SYSTEM, but rather '%s'", oldRoot, sidStr)
+ return
+ }
+ err = windows.MoveFileEx(windows.StringToUTF16Ptr(oldC), windows.StringToUTF16Ptr(c), windows.MOVEFILE_COPY_ALLOWED)
+ if err != nil {
+ if err != windows.ERROR_FILE_NOT_FOUND && err != windows.ERROR_ALREADY_EXISTS {
+ log.Printf("Not migrating configuration from '%s' due to error when moving files: %v", oldRoot, err)
+ }
+ return
+ }
+ log.Printf("Migrated configuration from '%s'", oldRoot)
+}
+
func tunnelConfigurationsDirectory() (string, error) {
if cachedConfigFileDir != "" {
return cachedConfigFileDir, nil
@@ -32,6 +84,7 @@ func tunnelConfigurationsDirectory() (string, error) {
return "", err
}
c := filepath.Join(root, "Configurations")
+ maybeMigrate(c)
err = os.MkdirAll(c, os.ModeDir|0700)
if err != nil {
return "", err
@@ -44,13 +97,8 @@ func RootDirectory() (string, error) {
if cachedRootDir != "" {
return cachedRootDir, nil
}
- processToken, err := windows.OpenCurrentProcessToken()
- if err != nil {
- return "", err
- }
- defer processToken.Close()
var path *uint16
- err = shGetKnownFolderPath(&folderIDLocalAppData, kfFlagCreate, windows.Handle(processToken), &path)
+ err := shGetKnownFolderPath(&folderIDLocalAppData, kfFlagCreate, 0, &path)
if err != nil {
return "", err
}
diff --git a/conf/store.go b/conf/store.go
index 2886d027..b5cdd1ef 100644
--- a/conf/store.go
+++ b/conf/store.go
@@ -129,6 +129,8 @@ func LoadFromName(name string) (*Config, error) {
}
func LoadFromPath(path string) (*Config, error) {
+ tunnelConfigurationsDirectory() // Provoke migrations, if needed.
+
name, err := NameFromPath(path)
if err != nil {
return nil, err
diff --git a/conf/zsyscall_windows.go b/conf/zsyscall_windows.go
index d8984bef..7ad7bbab 100644
--- a/conf/zsyscall_windows.go
+++ b/conf/zsyscall_windows.go
@@ -40,11 +40,14 @@ var (
modwininet = windows.NewLazySystemDLL("wininet.dll")
modole32 = windows.NewLazySystemDLL("ole32.dll")
modshell32 = windows.NewLazySystemDLL("shell32.dll")
+ modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
procInternetGetConnectedState = modwininet.NewProc("InternetGetConnectedState")
procCoTaskMemFree = modole32.NewProc("CoTaskMemFree")
procSHGetKnownFolderPath = modshell32.NewProc("SHGetKnownFolderPath")
+ procGetFileSecurityW = modadvapi32.NewProc("GetFileSecurityW")
+ procGetSecurityDescriptorOwner = modadvapi32.NewProc("GetSecurityDescriptorOwner")
procFindFirstChangeNotificationW = modkernel32.NewProc("FindFirstChangeNotificationW")
procFindNextChangeNotification = modkernel32.NewProc("FindNextChangeNotification")
)
@@ -60,9 +63,29 @@ func coTaskMemFree(pointer uintptr) {
return
}
-func shGetKnownFolderPath(id *windows.GUID, flags uint32, token windows.Handle, path **uint16) (err error) {
- r1, _, e1 := syscall.Syscall6(procSHGetKnownFolderPath.Addr(), 4, uintptr(unsafe.Pointer(id)), uintptr(flags), uintptr(token), uintptr(unsafe.Pointer(path)), 0, 0)
- if r1 != 0 {
+func shGetKnownFolderPath(id *windows.GUID, flags uint32, token windows.Handle, path **uint16) (ret error) {
+ r0, _, _ := syscall.Syscall6(procSHGetKnownFolderPath.Addr(), 4, uintptr(unsafe.Pointer(id)), uintptr(flags), uintptr(token), uintptr(unsafe.Pointer(path)), 0, 0)
+ if r0 != 0 {
+ ret = syscall.Errno(r0)
+ }
+ return
+}
+
+func getFileSecurity(fileName *uint16, securityInformation uint32, securityDescriptor *byte, descriptorLen uint32, requestedLen *uint32) (err error) {
+ r1, _, e1 := syscall.Syscall6(procGetFileSecurityW.Addr(), 5, uintptr(unsafe.Pointer(fileName)), uintptr(securityInformation), uintptr(unsafe.Pointer(securityDescriptor)), uintptr(descriptorLen), uintptr(unsafe.Pointer(requestedLen)), 0)
+ if r1 == 0 {
+ if e1 != 0 {
+ err = errnoErr(e1)
+ } else {
+ err = syscall.EINVAL
+ }
+ }
+ return
+}
+
+func getSecurityDescriptorOwner(securityDescriptor *byte, sid **windows.SID, ownerDefaulted *bool) (err error) {
+ r1, _, e1 := syscall.Syscall(procGetSecurityDescriptorOwner.Addr(), 3, uintptr(unsafe.Pointer(securityDescriptor)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(ownerDefaulted)))
+ if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {