aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/conf
diff options
context:
space:
mode:
Diffstat (limited to 'conf')
-rw-r--r--conf/filewriter_windows.go72
-rw-r--r--conf/migration_windows.go10
-rw-r--r--conf/path_windows.go2
-rw-r--r--conf/store.go13
4 files changed, 76 insertions, 21 deletions
diff --git a/conf/filewriter_windows.go b/conf/filewriter_windows.go
new file mode 100644
index 00000000..9fb1f566
--- /dev/null
+++ b/conf/filewriter_windows.go
@@ -0,0 +1,72 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
+ */
+
+package conf
+
+import (
+ "sync/atomic"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+var encryptedFileSd unsafe.Pointer
+
+func writeEncryptedFile(destination string, contents []byte) error {
+ var err error
+ sa := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{}))}
+ sa.SecurityDescriptor = (*windows.SECURITY_DESCRIPTOR)(atomic.LoadPointer(&encryptedFileSd))
+ if sa.SecurityDescriptor == nil {
+ sa.SecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYG:SYD:PAI(A;;FA;;;SY)(A;;SD;;;BA)")
+ if err != nil {
+ return err
+ }
+ atomic.StorePointer(&encryptedFileSd, unsafe.Pointer(sa.SecurityDescriptor))
+ }
+ destination16, err := windows.UTF16FromString(destination)
+ if err != nil {
+ return err
+ }
+ tmpDestination := destination + ".tmp"
+ tmpDestination16, err := windows.UTF16PtrFromString(tmpDestination)
+ if err != nil {
+ return err
+ }
+ handle, err := windows.CreateFile(tmpDestination16, windows.GENERIC_WRITE|windows.DELETE, 0, sa, windows.CREATE_ALWAYS, windows.FILE_ATTRIBUTE_NORMAL, 0)
+ if err != nil {
+ return err
+ }
+ defer windows.CloseHandle(handle)
+ deleteIt := func() {
+ yes := byte(1)
+ windows.SetFileInformationByHandle(handle, windows.FileDispositionInfo, &yes, 1)
+ }
+ n, err := windows.Write(handle, contents)
+ if err != nil {
+ deleteIt()
+ return err
+ }
+ if n != len(contents) {
+ deleteIt()
+ return windows.ERROR_IO_INCOMPLETE
+ }
+ fileRenameInfo := &struct {
+ replaceIfExists byte
+ rootDirectory windows.Handle
+ fileNameLength uint32
+ fileName [windows.MAX_PATH]uint16
+ }{replaceIfExists: 1, fileNameLength: uint32(len(destination16) - 1)}
+ if len(destination16) > len(fileRenameInfo.fileName) {
+ deleteIt()
+ return windows.ERROR_BUFFER_OVERFLOW
+ }
+ copy(fileRenameInfo.fileName[:], destination16[:])
+ err = windows.SetFileInformationByHandle(handle, windows.FileRenameInfo, (*byte)(unsafe.Pointer(fileRenameInfo)), uint32(unsafe.Sizeof(*fileRenameInfo)))
+ if err != nil {
+ deleteIt()
+ return err
+ }
+ return nil
+} \ No newline at end of file
diff --git a/conf/migration_windows.go b/conf/migration_windows.go
index 643d05c9..d2f0c765 100644
--- a/conf/migration_windows.go
+++ b/conf/migration_windows.go
@@ -9,7 +9,6 @@ import (
"fmt"
"io/ioutil"
"log"
- "os"
"path/filepath"
"regexp"
"strings"
@@ -57,17 +56,10 @@ func maybeMigrateConfiguration(c string) {
}
newPath := filepath.Join(c, fileName)
- newFile, err := os.OpenFile(newPath, os.O_EXCL|os.O_CREATE|os.O_WRONLY, 0600)
+ err = writeEncryptedFile(newPath, oldConfig)
if err != nil {
continue
}
- _, err = newFile.Write(oldConfig)
- if err != nil {
- newFile.Close()
- os.Remove(newPath)
- continue
- }
- newFile.Close()
oldPath16, err := windows.UTF16PtrFromString(oldPath)
if err == nil {
windows.MoveFileEx(oldPath16, nil, windows.MOVEFILE_DELAY_UNTIL_REBOOT)
diff --git a/conf/path_windows.go b/conf/path_windows.go
index 526aeba0..b328364b 100644
--- a/conf/path_windows.go
+++ b/conf/path_windows.go
@@ -102,7 +102,7 @@ func RootDirectory(create bool) (string, error) {
return "", err
}
- dataDirectorySd, err := windows.SecurityDescriptorFromString("O:SYG:SYD:PAI(A;OICI;FA;;;SY)(A;OICI;FR;;;BA)")
+ dataDirectorySd, err := windows.SecurityDescriptorFromString("O:SYG:SYD:PAI(A;OICI;FA;;;SY)(A;OICI;FA;;;BA)")
if err != nil {
return "", err
}
diff --git a/conf/store.go b/conf/store.go
index 21bd3a22..e79e24b8 100644
--- a/conf/store.go
+++ b/conf/store.go
@@ -103,7 +103,7 @@ func MigrateUnencryptedConfigs() (int, []error) {
e++
continue
}
- err = ioutil.WriteFile(dstFile, bytes, 0600)
+ err = writeEncryptedFile(dstFile, bytes)
if err != nil {
errs[e] = err
e++
@@ -185,16 +185,7 @@ func (config *Config) Save() error {
if err != nil {
return err
}
- err = ioutil.WriteFile(filename+".tmp", bytes, 0600)
- if err != nil {
- return err
- }
- err = os.Rename(filename+".tmp", filename)
- if err != nil {
- os.Remove(filename + ".tmp")
- return err
- }
- return nil
+ return writeEncryptedFile(filename, bytes)
}
func (config *Config) Path() (string, error) {