diff options
-rw-r--r-- | conf/filewriter_windows.go | 23 | ||||
-rw-r--r-- | conf/migration_windows.go | 2 | ||||
-rw-r--r-- | conf/store.go | 27 | ||||
-rw-r--r-- | manager/service.go | 2 |
4 files changed, 41 insertions, 13 deletions
diff --git a/conf/filewriter_windows.go b/conf/filewriter_windows.go index 6f7c40ad..ca45bf42 100644 --- a/conf/filewriter_windows.go +++ b/conf/filewriter_windows.go @@ -6,6 +6,8 @@ package conf import ( + "crypto/rand" + "encoding/hex" "sync/atomic" "unsafe" @@ -14,7 +16,16 @@ import ( var encryptedFileSd unsafe.Pointer -func writeEncryptedFile(destination string, contents []byte) error { +func randomFileName() string { + var randBytes [32]byte + _, err := rand.Read(randBytes[:]) + if err != nil { + panic(err) + } + return hex.EncodeToString(randBytes[:]) + ".tmp" +} + +func writeEncryptedFile(destination string, overwrite bool, contents []byte) error { var err error sa := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{}))} sa.SecurityDescriptor = (*windows.SECURITY_DESCRIPTOR)(atomic.LoadPointer(&encryptedFileSd)) @@ -29,7 +40,7 @@ func writeEncryptedFile(destination string, contents []byte) error { if err != nil { return err } - tmpDestination := destination + ".tmp" + tmpDestination := randomFileName() tmpDestination16, err := windows.UTF16PtrFromString(tmpDestination) if err != nil { return err @@ -57,7 +68,13 @@ func writeEncryptedFile(destination string, contents []byte) error { rootDirectory windows.Handle fileNameLength uint32 fileName [windows.MAX_PATH]uint16 - }{replaceIfExists: 1, fileNameLength: uint32(len(destination16) - 1)} + }{replaceIfExists: func() byte { + if overwrite { + return 1 + } else { + return 0 + } + }(), fileNameLength: uint32(len(destination16) - 1)} if len(destination16) > len(fileRenameInfo.fileName) { deleteIt() return windows.ERROR_BUFFER_OVERFLOW diff --git a/conf/migration_windows.go b/conf/migration_windows.go index 060491ed..5f1086e8 100644 --- a/conf/migration_windows.go +++ b/conf/migration_windows.go @@ -56,7 +56,7 @@ func maybeMigrateConfiguration(c string) { } newPath := filepath.Join(c, fileName) - err = writeEncryptedFile(newPath, oldConfig) + err = writeEncryptedFile(newPath, false, oldConfig) if err != nil { continue } diff --git a/conf/store.go b/conf/store.go index 0c4aa41f..4782aca0 100644 --- a/conf/store.go +++ b/conf/store.go @@ -11,6 +11,10 @@ import ( "os" "path/filepath" "strings" + "sync" + "time" + + "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/conf/dpapi" ) @@ -47,7 +51,12 @@ func ListConfigNames() ([]string, error) { return configs[:i], nil } -func MigrateUnencryptedConfigs() (int, []error) { +var migrating sync.Mutex +var lastMigrationTimer *time.Timer + +func MigrateUnencryptedConfigs(sharingBase int) (int, []error) { + migrating.Lock() + defer migrating.Unlock() configFileDir, err := tunnelConfigurationsDirectory() if err != nil { return 0, []error{err} @@ -73,6 +82,13 @@ func MigrateUnencryptedConfigs() (int, []error) { // of Windows file locking for ensuring the file is finished being written. f, err := os.OpenFile(path, os.O_RDWR, 0) if err != nil { + if sharingBase > 0 && errors.Is(err, windows.ERROR_SHARING_VIOLATION) { + if lastMigrationTimer != nil { + lastMigrationTimer.Stop() + } + lastMigrationTimer = time.AfterFunc(time.Second/time.Duration(sharingBase*sharingBase), func() { MigrateUnencryptedConfigs(sharingBase - 1) }) + sharingBase = 0 + } errs[e] = err e++ continue @@ -98,12 +114,7 @@ func MigrateUnencryptedConfigs() (int, []error) { continue } dstFile := strings.TrimSuffix(path, configFileUnencryptedSuffix) + configFileSuffix - if _, err = os.Stat(dstFile); err != nil && !os.IsNotExist(err) { - errs[e] = errors.New("Unable to migrate to " + dstFile + " as it already exists") - e++ - continue - } - err = writeEncryptedFile(dstFile, bytes) + err = writeEncryptedFile(dstFile, false, bytes) if err != nil { errs[e] = err e++ @@ -185,7 +196,7 @@ func (config *Config) Save() error { if err != nil { return err } - return writeEncryptedFile(filename, bytes) + return writeEncryptedFile(filename, true, bytes) } func (config *Config) Path() (string, error) { diff --git a/manager/service.go b/manager/service.go index b671f79e..5ae8c9df 100644 --- a/manager/service.go +++ b/manager/service.go @@ -83,7 +83,7 @@ func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest return } - conf.RegisterStoreChangeCallback(func() { conf.MigrateUnencryptedConfigs() }) // Ignore return value for now, but could be useful later. + conf.RegisterStoreChangeCallback(func() { conf.MigrateUnencryptedConfigs(3) }) conf.RegisterStoreChangeCallback(IPCServerNotifyTunnelsChange) procs := make(map[uint32]*os.Process) |