aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/conf/dpapi/dpapi_windows.go
diff options
context:
space:
mode:
Diffstat (limited to 'conf/dpapi/dpapi_windows.go')
-rw-r--r--conf/dpapi/dpapi_windows.go28
1 files changed, 7 insertions, 21 deletions
diff --git a/conf/dpapi/dpapi_windows.go b/conf/dpapi/dpapi_windows.go
index b3f28a93..49a32915 100644
--- a/conf/dpapi/dpapi_windows.go
+++ b/conf/dpapi/dpapi_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package dpapi
@@ -28,16 +28,9 @@ func Encrypt(data []byte, name string) ([]byte, error) {
if err != nil {
return nil, fmt.Errorf("unable to encrypt DPAPI protected data: %w", err)
}
-
- outSlice := *(*[]byte)(unsafe.Pointer(&(struct {
- addr *byte
- len int
- cap int
- }{out.Data, int(out.Size), int(out.Size)})))
- ret := make([]byte, len(outSlice))
- copy(ret, outSlice)
+ ret := make([]byte, out.Size)
+ copy(ret, unsafe.Slice(out.Data, out.Size))
windows.LocalFree(windows.Handle(unsafe.Pointer(out.Data)))
-
return ret, nil
}
@@ -48,19 +41,12 @@ func Decrypt(data []byte, name string) ([]byte, error) {
if err != nil {
return nil, err
}
-
err = windows.CryptUnprotectData(bytesToBlob(data), &outName, nil, 0, nil, windows.CRYPTPROTECT_UI_FORBIDDEN, &out)
if err != nil {
return nil, fmt.Errorf("unable to decrypt DPAPI protected data: %w", err)
}
-
- outSlice := *(*[]byte)(unsafe.Pointer(&(struct {
- addr *byte
- len int
- cap int
- }{out.Data, int(out.Size), int(out.Size)})))
- ret := make([]byte, len(outSlice))
- copy(ret, outSlice)
+ ret := make([]byte, out.Size)
+ copy(ret, unsafe.Slice(out.Data, out.Size))
windows.LocalFree(windows.Handle(unsafe.Pointer(out.Data)))
// Note: this ridiculous open-coded strcmp is not constant time.
@@ -75,8 +61,8 @@ func Decrypt(data []byte, name string) ([]byte, error) {
if *a == 0 || *b == 0 {
break
}
- a = (*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(a)) + 2))
- b = (*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(b)) + 2))
+ a = (*uint16)(unsafe.Add(unsafe.Pointer(a), 2))
+ b = (*uint16)(unsafe.Add(unsafe.Pointer(b), 2))
}
runtime.KeepAlive(utf16Name)
windows.LocalFree(windows.Handle(unsafe.Pointer(outName)))