aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/conf
diff options
context:
space:
mode:
Diffstat (limited to 'conf')
-rw-r--r--conf/admin_windows.go36
-rw-r--r--conf/config.go83
-rw-r--r--conf/dnsresolver_windows.go71
-rw-r--r--conf/dpapi/dpapi_windows.go73
-rw-r--r--conf/dpapi/dpapi_windows_test.go8
-rw-r--r--conf/dpapi/mksyscall.go8
-rw-r--r--conf/dpapi/zdpapi_windows.go68
-rw-r--r--conf/filewriter_windows.go90
-rw-r--r--conf/migration_windows.go108
-rw-r--r--conf/mksyscall.go8
-rw-r--r--conf/name.go33
-rw-r--r--conf/parser.go281
-rw-r--r--conf/parser_test.go17
-rw-r--r--conf/path_windows.go106
-rw-r--r--conf/store.go112
-rw-r--r--conf/store_test.go10
-rw-r--r--conf/storewatcher.go2
-rw-r--r--conf/storewatcher_windows.go26
-rw-r--r--conf/writer.go104
-rw-r--r--conf/zsyscall_windows.go83
20 files changed, 638 insertions, 689 deletions
diff --git a/conf/admin_windows.go b/conf/admin_windows.go
new file mode 100644
index 00000000..91d4bdc9
--- /dev/null
+++ b/conf/admin_windows.go
@@ -0,0 +1,36 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
+ */
+
+package conf
+
+import "golang.org/x/sys/windows/registry"
+
+const adminRegKey = `Software\WireGuard`
+
+var adminKey registry.Key
+
+func openAdminKey() (registry.Key, error) {
+ if adminKey != 0 {
+ return adminKey, nil
+ }
+ var err error
+ adminKey, err = registry.OpenKey(registry.LOCAL_MACHINE, adminRegKey, registry.QUERY_VALUE|registry.WOW64_64KEY)
+ if err != nil {
+ return 0, err
+ }
+ return adminKey, nil
+}
+
+func AdminBool(name string) bool {
+ key, err := openAdminKey()
+ if err != nil {
+ return false
+ }
+ val, _, err := key.GetIntegerValue(name)
+ if err != nil {
+ return false
+ }
+ return val != 0
+}
diff --git a/conf/config.go b/conf/config.go
index 9f5dbcc1..74ffacf6 100644
--- a/conf/config.go
+++ b/conf/config.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package conf
@@ -9,9 +9,8 @@ import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
- "encoding/hex"
"fmt"
- "net"
+ "net/netip"
"strings"
"time"
@@ -22,19 +21,16 @@ import (
const KeyLength = 32
-type IPCidr struct {
- IP net.IP
- Cidr uint8
-}
-
type Endpoint struct {
Host string
Port uint16
}
-type Key [KeyLength]byte
-type HandshakeTime time.Duration
-type Bytes uint64
+type (
+ Key [KeyLength]byte
+ HandshakeTime time.Duration
+ Bytes uint64
+)
type Config struct {
Name string
@@ -44,16 +40,22 @@ type Config struct {
type Interface struct {
PrivateKey Key
- Addresses []IPCidr
+ Addresses []netip.Prefix
ListenPort uint16
MTU uint16
- DNS []net.IP
+ DNS []netip.Addr
+ DNSSearch []string
+ PreUp string
+ PostUp string
+ PreDown string
+ PostDown string
+ TableOff bool
}
type Peer struct {
PublicKey Key
PresharedKey Key
- AllowedIPs []IPCidr
+ AllowedIPs []netip.Prefix
Endpoint Endpoint
PersistentKeepalive uint16
@@ -62,26 +64,37 @@ type Peer struct {
LastHandshakeTime HandshakeTime
}
-func (r *IPCidr) String() string {
- return fmt.Sprintf("%s/%d", r.IP.String(), r.Cidr)
-}
-
-func (r *IPCidr) Bits() uint8 {
- if r.IP.To4() != nil {
- return 32
+func (conf *Config) IntersectsWith(other *Config) bool {
+ allRoutes := make(map[netip.Prefix]bool, len(conf.Interface.Addresses)*2+len(conf.Peers)*3)
+ for _, a := range conf.Interface.Addresses {
+ allRoutes[netip.PrefixFrom(a.Addr(), a.Addr().BitLen())] = true
+ allRoutes[a.Masked()] = true
}
- return 128
-}
-
-func (r *IPCidr) IPNet() net.IPNet {
- return net.IPNet{
- IP: r.IP,
- Mask: net.CIDRMask(int(r.Cidr), int(r.Bits())),
+ for i := range conf.Peers {
+ for _, a := range conf.Peers[i].AllowedIPs {
+ allRoutes[a.Masked()] = true
+ }
}
+ for _, a := range other.Interface.Addresses {
+ if allRoutes[netip.PrefixFrom(a.Addr(), a.Addr().BitLen())] {
+ return true
+ }
+ if allRoutes[a.Masked()] {
+ return true
+ }
+ }
+ for i := range other.Peers {
+ for _, a := range other.Peers[i].AllowedIPs {
+ if allRoutes[a.Masked()] {
+ return true
+ }
+ }
+ }
+ return false
}
func (e *Endpoint) String() string {
- if strings.IndexByte(e.Host, ':') > 0 {
+ if strings.IndexByte(e.Host, ':') != -1 {
return fmt.Sprintf("[%s]:%d", e.Host, e.Port)
}
return fmt.Sprintf("%s:%d", e.Host, e.Port)
@@ -95,10 +108,6 @@ func (k *Key) String() string {
return base64.StdEncoding.EncodeToString(k[:])
}
-func (k *Key) HexString() string {
- return hex.EncodeToString(k[:])
-}
-
func (k *Key) IsZero() bool {
var zeros Key
return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1
@@ -229,3 +238,11 @@ func (conf *Config) DeduplicateNetworkEntries() {
peer.AllowedIPs = peer.AllowedIPs[:i]
}
}
+
+func (conf *Config) Redact() {
+ conf.Interface.PrivateKey = Key{}
+ for i := range conf.Peers {
+ conf.Peers[i].PublicKey = Key{}
+ conf.Peers[i].PresharedKey = Key{}
+ }
+}
diff --git a/conf/dnsresolver_windows.go b/conf/dnsresolver_windows.go
index d6c2f1c7..a299c475 100644
--- a/conf/dnsresolver_windows.go
+++ b/conf/dnsresolver_windows.go
@@ -1,43 +1,41 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package conf
import (
- "fmt"
"log"
- "net"
- "syscall"
+ "net/netip"
"time"
"unsafe"
+ "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
+
"golang.org/x/sys/windows"
+ "golang.zx2c4.com/wireguard/windows/services"
)
-//sys internetGetConnectedState(flags *uint32, reserved uint32) (connected bool) = wininet.InternetGetConnectedState
-
func resolveHostname(name string) (resolvedIPString string, err error) {
maxTries := 10
- systemJustBooted := windows.DurationSinceBoot() <= time.Minute*4
- if systemJustBooted {
- maxTries *= 4
+ if services.StartedAtBoot() {
+ maxTries *= 3
}
for i := 0; i < maxTries; i++ {
+ if i > 0 {
+ time.Sleep(time.Second * 4)
+ }
resolvedIPString, err = resolveHostnameOnce(name)
if err == nil {
return
}
if err == windows.WSATRY_AGAIN {
- log.Printf("Temporary DNS error when resolving %s, sleeping for 4 seconds", name)
- time.Sleep(time.Second * 4)
+ log.Printf("Temporary DNS error when resolving %s, so sleeping for 4 seconds", name)
continue
}
- var state uint32
- if err == windows.WSAHOST_NOT_FOUND && systemJustBooted && !internetGetConnectedState(&state, 0) {
- log.Printf("Host not found when resolving %s, but no Internet connection available, sleeping for 4 seconds", name)
- time.Sleep(time.Second * 4)
+ if err == windows.WSAHOST_NOT_FOUND && services.StartedAtBoot() {
+ log.Printf("Host not found when resolving %s at boot time, so sleeping for 4 seconds", name)
continue
}
return
@@ -65,28 +63,35 @@ func resolveHostnameOnce(name string) (resolvedIPString string, err error) {
return
}
defer windows.FreeAddrInfoW(result)
- ipv6 := ""
+ var v6 netip.Addr
for ; result != nil; result = result.Next {
- addr := unsafe.Pointer(result.Addr)
- switch result.Family {
- case windows.AF_INET:
- a := (*syscall.RawSockaddrInet4)(addr).Addr
- return net.IP{a[0], a[1], a[2], a[3]}.String(), nil
- case windows.AF_INET6:
- if len(ipv6) != 0 {
- continue
- }
- a := (*syscall.RawSockaddrInet6)(addr).Addr
- ipv6 = net.IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}.String()
- scope := uint32((*syscall.RawSockaddrInet6)(addr).Scope_id)
- if scope != 0 {
- ipv6 += fmt.Sprintf("%%%d", scope)
- }
+ if result.Family != windows.AF_INET && result.Family != windows.AF_INET6 {
+ continue
+ }
+ addr := (*winipcfg.RawSockaddrInet)(unsafe.Pointer(result.Addr)).Addr()
+ if addr.Is4() {
+ return addr.String(), nil
+ } else if !v6.IsValid() && addr.Is6() {
+ v6 = addr
}
}
- if len(ipv6) != 0 {
- return ipv6, nil
+ if v6.IsValid() {
+ return v6.String(), nil
}
err = windows.WSAHOST_NOT_FOUND
return
}
+
+func (config *Config) ResolveEndpoints() error {
+ for i := range config.Peers {
+ if config.Peers[i].Endpoint.IsEmpty() {
+ continue
+ }
+ var err error
+ config.Peers[i].Endpoint.Host, err = resolveHostname(config.Peers[i].Endpoint.Host)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/conf/dpapi/dpapi_windows.go b/conf/dpapi/dpapi_windows.go
index 851ec1ee..49a32915 100644
--- a/conf/dpapi/dpapi_windows.go
+++ b/conf/dpapi/dpapi_windows.go
@@ -1,86 +1,53 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package dpapi
import (
"errors"
+ "fmt"
"runtime"
"unsafe"
"golang.org/x/sys/windows"
)
-const (
- dpCRYPTPROTECT_UI_FORBIDDEN uint32 = 0x1
- dpCRYPTPROTECT_LOCAL_MACHINE uint32 = 0x4
- dpCRYPTPROTECT_CRED_SYNC uint32 = 0x8
- dpCRYPTPROTECT_AUDIT uint32 = 0x10
- dpCRYPTPROTECT_NO_RECOVERY uint32 = 0x20
- dpCRYPTPROTECT_VERIFY_PROTECTION uint32 = 0x40
- dpCRYPTPROTECT_CRED_REGENERATE uint32 = 0x80
-)
-
-type dpBlob struct {
- len uint32
- data uintptr
-}
-
-func bytesToBlob(bytes []byte) *dpBlob {
- blob := &dpBlob{}
- blob.len = uint32(len(bytes))
+func bytesToBlob(bytes []byte) *windows.DataBlob {
+ blob := &windows.DataBlob{Size: uint32(len(bytes))}
if len(bytes) > 0 {
- blob.data = uintptr(unsafe.Pointer(&bytes[0]))
+ blob.Data = &bytes[0]
}
return blob
}
-//sys cryptProtectData(dataIn *dpBlob, name *uint16, optionalEntropy *dpBlob, reserved uintptr, promptStruct uintptr, flags uint32, dataOut *dpBlob) (err error) = crypt32.CryptProtectData
-
func Encrypt(data []byte, name string) ([]byte, error) {
- out := dpBlob{}
- err := cryptProtectData(bytesToBlob(data), windows.StringToUTF16Ptr(name), nil, 0, 0, dpCRYPTPROTECT_UI_FORBIDDEN, &out)
+ out := windows.DataBlob{}
+ err := windows.CryptProtectData(bytesToBlob(data), windows.StringToUTF16Ptr(name), nil, 0, nil, windows.CRYPTPROTECT_UI_FORBIDDEN, &out)
if err != nil {
- return nil, errors.New("Unable to encrypt DPAPI protected data: " + err.Error())
+ return nil, fmt.Errorf("unable to encrypt DPAPI protected data: %w", err)
}
-
- outSlice := *(*[]byte)(unsafe.Pointer(&(struct {
- addr uintptr
- len int
- cap int
- }{out.data, int(out.len), int(out.len)})))
- ret := make([]byte, len(outSlice))
- copy(ret, outSlice)
- windows.LocalFree(windows.Handle(out.data))
-
+ ret := make([]byte, out.Size)
+ copy(ret, unsafe.Slice(out.Data, out.Size))
+ windows.LocalFree(windows.Handle(unsafe.Pointer(out.Data)))
return ret, nil
}
-//sys cryptUnprotectData(dataIn *dpBlob, name **uint16, optionalEntropy *dpBlob, reserved uintptr, promptStruct uintptr, flags uint32, dataOut *dpBlob) (err error) = crypt32.CryptUnprotectData
-
func Decrypt(data []byte, name string) ([]byte, error) {
- out := dpBlob{}
+ out := windows.DataBlob{}
var outName *uint16
utf16Name, err := windows.UTF16PtrFromString(name)
if err != nil {
return nil, err
}
-
- err = cryptUnprotectData(bytesToBlob(data), &outName, nil, 0, 0, dpCRYPTPROTECT_UI_FORBIDDEN, &out)
+ err = windows.CryptUnprotectData(bytesToBlob(data), &outName, nil, 0, nil, windows.CRYPTPROTECT_UI_FORBIDDEN, &out)
if err != nil {
- return nil, errors.New("Unable to decrypt DPAPI protected data: " + err.Error())
+ return nil, fmt.Errorf("unable to decrypt DPAPI protected data: %w", err)
}
-
- outSlice := *(*[]byte)(unsafe.Pointer(&(struct {
- addr uintptr
- len int
- cap int
- }{out.data, int(out.len), int(out.len)})))
- ret := make([]byte, len(outSlice))
- copy(ret, outSlice)
- windows.LocalFree(windows.Handle(out.data))
+ ret := make([]byte, out.Size)
+ copy(ret, unsafe.Slice(out.Data, out.Size))
+ windows.LocalFree(windows.Handle(unsafe.Pointer(out.Data)))
// Note: this ridiculous open-coded strcmp is not constant time.
different := false
@@ -94,14 +61,14 @@ func Decrypt(data []byte, name string) ([]byte, error) {
if *a == 0 || *b == 0 {
break
}
- a = (*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(a)) + 2))
- b = (*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(b)) + 2))
+ a = (*uint16)(unsafe.Add(unsafe.Pointer(a), 2))
+ b = (*uint16)(unsafe.Add(unsafe.Pointer(b), 2))
}
runtime.KeepAlive(utf16Name)
windows.LocalFree(windows.Handle(unsafe.Pointer(outName)))
if different {
- return nil, errors.New("The input name does not match the stored name")
+ return nil, errors.New("input name does not match the stored name")
}
return ret, nil
diff --git a/conf/dpapi/dpapi_windows_test.go b/conf/dpapi/dpapi_windows_test.go
index 8356f2d4..fd7307e6 100644
--- a/conf/dpapi/dpapi_windows_test.go
+++ b/conf/dpapi/dpapi_windows_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package dpapi
@@ -53,11 +53,7 @@ func TestRoundTrip(t *testing.T) {
if err != nil {
t.Errorf("Unable to get utf16 chars for name: %s", err)
}
- nameUtf16Bytes := *(*[]byte)(unsafe.Pointer(&struct {
- addr *byte
- len int
- cap int
- }{(*byte)(unsafe.Pointer(&nameUtf16[0])), len(nameUtf16) * 2, cap(nameUtf16) * 2}))
+ nameUtf16Bytes := unsafe.Slice((*byte)(unsafe.Pointer(&nameUtf16[0])), len(nameUtf16)*2)
i := bytes.Index(eCorrupt, nameUtf16Bytes)
if i == -1 {
t.Error("Unable to find ad in blob")
diff --git a/conf/dpapi/mksyscall.go b/conf/dpapi/mksyscall.go
deleted file mode 100644
index 3d467f76..00000000
--- a/conf/dpapi/mksyscall.go
+++ /dev/null
@@ -1,8 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package dpapi
-
-//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zdpapi_windows.go dpapi_windows.go
diff --git a/conf/dpapi/zdpapi_windows.go b/conf/dpapi/zdpapi_windows.go
deleted file mode 100644
index e48d36b2..00000000
--- a/conf/dpapi/zdpapi_windows.go
+++ /dev/null
@@ -1,68 +0,0 @@
-// Code generated by 'go generate'; DO NOT EDIT.
-
-package dpapi
-
-import (
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-var _ unsafe.Pointer
-
-// Do the interface allocations only once for common
-// Errno values.
-const (
- errnoERROR_IO_PENDING = 997
-)
-
-var (
- errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
-)
-
-// errnoErr returns common boxed Errno values, to prevent
-// allocations at runtime.
-func errnoErr(e syscall.Errno) error {
- switch e {
- case 0:
- return nil
- case errnoERROR_IO_PENDING:
- return errERROR_IO_PENDING
- }
- // TODO: add more here, after collecting data on the common
- // error values see on Windows. (perhaps when running
- // all.bat?)
- return e
-}
-
-var (
- modcrypt32 = windows.NewLazySystemDLL("crypt32.dll")
-
- procCryptProtectData = modcrypt32.NewProc("CryptProtectData")
- procCryptUnprotectData = modcrypt32.NewProc("CryptUnprotectData")
-)
-
-func cryptProtectData(dataIn *dpBlob, name *uint16, optionalEntropy *dpBlob, reserved uintptr, promptStruct uintptr, flags uint32, dataOut *dpBlob) (err error) {
- r1, _, e1 := syscall.Syscall9(procCryptProtectData.Addr(), 7, uintptr(unsafe.Pointer(dataIn)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(optionalEntropy)), uintptr(reserved), uintptr(promptStruct), uintptr(flags), uintptr(unsafe.Pointer(dataOut)), 0, 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func cryptUnprotectData(dataIn *dpBlob, name **uint16, optionalEntropy *dpBlob, reserved uintptr, promptStruct uintptr, flags uint32, dataOut *dpBlob) (err error) {
- r1, _, e1 := syscall.Syscall9(procCryptUnprotectData.Addr(), 7, uintptr(unsafe.Pointer(dataIn)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(optionalEntropy)), uintptr(reserved), uintptr(promptStruct), uintptr(flags), uintptr(unsafe.Pointer(dataOut)), 0, 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
diff --git a/conf/filewriter_windows.go b/conf/filewriter_windows.go
new file mode 100644
index 00000000..c6bb2b45
--- /dev/null
+++ b/conf/filewriter_windows.go
@@ -0,0 +1,90 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
+ */
+
+package conf
+
+import (
+ "crypto/rand"
+ "encoding/hex"
+ "path/filepath"
+ "sync/atomic"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+var encryptedFileSd unsafe.Pointer
+
+func randomFileName() string {
+ var randBytes [32]byte
+ _, err := rand.Read(randBytes[:])
+ if err != nil {
+ panic(err)
+ }
+ return hex.EncodeToString(randBytes[:]) + ".tmp"
+}
+
+func writeLockedDownFile(destination string, overwrite bool, contents []byte) error {
+ var err error
+ sa := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{}))}
+ sa.SecurityDescriptor = (*windows.SECURITY_DESCRIPTOR)(atomic.LoadPointer(&encryptedFileSd))
+ if sa.SecurityDescriptor == nil {
+ sa.SecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYG:SYD:PAI(A;;FA;;;SY)(A;;SD;;;BA)")
+ if err != nil {
+ return err
+ }
+ atomic.StorePointer(&encryptedFileSd, unsafe.Pointer(sa.SecurityDescriptor))
+ }
+ destination16, err := windows.UTF16FromString(destination)
+ if err != nil {
+ return err
+ }
+ tmpDestination := filepath.Join(filepath.Dir(destination), randomFileName())
+ tmpDestination16, err := windows.UTF16PtrFromString(tmpDestination)
+ if err != nil {
+ return err
+ }
+ handle, err := windows.CreateFile(tmpDestination16, windows.GENERIC_WRITE|windows.DELETE, windows.FILE_SHARE_READ, sa, windows.CREATE_ALWAYS, windows.FILE_ATTRIBUTE_NORMAL, 0)
+ if err != nil {
+ return err
+ }
+ defer windows.CloseHandle(handle)
+ deleteIt := func() {
+ yes := byte(1)
+ windows.SetFileInformationByHandle(handle, windows.FileDispositionInfo, &yes, 1)
+ }
+ n, err := windows.Write(handle, contents)
+ if err != nil {
+ deleteIt()
+ return err
+ }
+ if n != len(contents) {
+ deleteIt()
+ return windows.ERROR_IO_INCOMPLETE
+ }
+ fileRenameInfo := &struct {
+ replaceIfExists byte
+ rootDirectory windows.Handle
+ fileNameLength uint32
+ fileName [windows.MAX_PATH]uint16
+ }{replaceIfExists: func() byte {
+ if overwrite {
+ return 1
+ } else {
+ return 0
+ }
+ }(), fileNameLength: uint32(len(destination16) - 1)}
+ if len(destination16) > len(fileRenameInfo.fileName) {
+ deleteIt()
+ return windows.ERROR_BUFFER_OVERFLOW
+ }
+ copy(fileRenameInfo.fileName[:], destination16[:])
+ err = windows.SetFileInformationByHandle(handle, windows.FileRenameInfo, (*byte)(unsafe.Pointer(fileRenameInfo)), uint32(unsafe.Sizeof(*fileRenameInfo)))
+ if err != nil {
+ deleteIt()
+ return err
+ }
+ return nil
+}
diff --git a/conf/migration_windows.go b/conf/migration_windows.go
index 72b298b6..ed288f3c 100644
--- a/conf/migration_windows.go
+++ b/conf/migration_windows.go
@@ -1,51 +1,109 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package conf
import (
+ "errors"
+ "io"
"log"
+ "os"
"path/filepath"
"strings"
+ "sync"
+ "time"
"golang.org/x/sys/windows"
)
-func maybeMigrate(c string) {
- if disableAutoMigration {
- return
- }
+var (
+ migrating sync.Mutex
+ lastMigrationTimer *time.Timer
+)
- vol := filepath.VolumeName(c)
- withoutVol := strings.TrimPrefix(c, vol)
- oldRoot := filepath.Join(vol, "\\windows.old")
- oldC := filepath.Join(oldRoot, withoutVol)
+type MigrationCallback func(name, oldPath, newPath string)
- sd, err := windows.GetNamedSecurityInfo(oldRoot, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION)
- if err == windows.ERROR_PATH_NOT_FOUND || err == windows.ERROR_FILE_NOT_FOUND {
- return
+func MigrateUnencryptedConfigs(migrated MigrationCallback) { migrateUnencryptedConfigs(3, migrated) }
+
+func migrateUnencryptedConfigs(sharingBase int, migrated MigrationCallback) {
+ if migrated == nil {
+ migrated = func(_, _, _ string) {}
}
+ migrating.Lock()
+ defer migrating.Unlock()
+ configFileDir, err := tunnelConfigurationsDirectory()
if err != nil {
- log.Printf("Not migrating configuration from ‘%s’ due to GetNamedSecurityInfo error: %v", oldRoot, err)
return
}
- owner, defaulted, err := sd.Owner()
+ files, err := os.ReadDir(configFileDir)
if err != nil {
- log.Printf("Not migrating configuration from ‘%s’ due to GetSecurityDescriptorOwner error: %v", oldRoot, err)
return
}
- if defaulted || (!owner.IsWellKnown(windows.WinLocalSystemSid) && !owner.IsWellKnown(windows.WinBuiltinAdministratorsSid)) {
- log.Printf("Not migrating configuration from ‘%s’, as it is not explicitly owned by SYSTEM or Built-in Administrators, but rather ‘%v’", oldRoot, owner)
- return
- }
- err = windows.MoveFileEx(windows.StringToUTF16Ptr(oldC), windows.StringToUTF16Ptr(c), windows.MOVEFILE_COPY_ALLOWED)
- if err != nil {
- if err != windows.ERROR_FILE_NOT_FOUND && err != windows.ERROR_ALREADY_EXISTS {
- log.Printf("Not migrating configuration from ‘%s’ due to error when moving files: %v", oldRoot, err)
+ ignoreSharingViolations := false
+ for _, file := range files {
+ path := filepath.Join(configFileDir, file.Name())
+ name := filepath.Base(file.Name())
+ if len(name) <= len(configFileUnencryptedSuffix) || !strings.HasSuffix(name, configFileUnencryptedSuffix) {
+ continue
}
- return
+ if !file.Type().IsRegular() {
+ continue
+ }
+ info, err := file.Info()
+ if err != nil {
+ continue
+ }
+ if info.Mode().Perm()&0o444 == 0 {
+ continue
+ }
+
+ var bytes []byte
+ var config *Config
+ var newPath string
+ // We don't use os.ReadFile, because we actually want RDWR, so that we can take advantage
+ // of Windows file locking for ensuring the file is finished being written.
+ f, err := os.OpenFile(path, os.O_RDWR, 0)
+ if err != nil {
+ if errors.Is(err, windows.ERROR_SHARING_VIOLATION) {
+ if ignoreSharingViolations {
+ continue
+ } else if sharingBase > 0 {
+ if lastMigrationTimer != nil {
+ lastMigrationTimer.Stop()
+ }
+ lastMigrationTimer = time.AfterFunc(time.Second/time.Duration(sharingBase*sharingBase), func() { migrateUnencryptedConfigs(sharingBase-1, migrated) })
+ ignoreSharingViolations = true
+ continue
+ }
+ }
+ goto error
+ }
+ bytes, err = io.ReadAll(f)
+ f.Close()
+ if err != nil {
+ goto error
+ }
+ config, err = FromWgQuickWithUnknownEncoding(string(bytes), strings.TrimSuffix(name, configFileUnencryptedSuffix))
+ if err != nil {
+ goto error
+ }
+ err = config.Save(false)
+ if err != nil {
+ goto error
+ }
+ err = os.Remove(path)
+ if err != nil {
+ goto error
+ }
+ newPath, err = config.Path()
+ if err != nil {
+ goto error
+ }
+ migrated(config.Name, path, newPath)
+ continue
+ error:
+ log.Printf("Unable to ingest and encrypt %#q: %v", path, err)
}
- log.Printf("Migrated configuration from ‘%s’", oldRoot)
}
diff --git a/conf/mksyscall.go b/conf/mksyscall.go
deleted file mode 100644
index 2706c304..00000000
--- a/conf/mksyscall.go
+++ /dev/null
@@ -1,8 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package conf
-
-//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go dnsresolver_windows.go migration_windows.go storewatcher_windows.go
diff --git a/conf/name.go b/conf/name.go
index 87c463af..0d084070 100644
--- a/conf/name.go
+++ b/conf/name.go
@@ -1,11 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package conf
import (
+ "errors"
"regexp"
"strconv"
"strings"
@@ -17,15 +18,13 @@ var reservedNames = []string{
"LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9",
}
-const serviceNameForbidden = "$"
-const netshellDllForbidden = "\\/:*?\"<>|\t"
-const specialChars = "/\\<>:\"|?*\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x00"
-
-var allowedNameFormat *regexp.Regexp
+const (
+ serviceNameForbidden = "$"
+ netshellDllForbidden = "\\/:*?\"<>|\t"
+ specialChars = "/\\<>:\"|?*\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x00"
+)
-func init() {
- allowedNameFormat = regexp.MustCompile("^[a-zA-Z0-9_=+.-]{1,32}$")
-}
+var allowedNameFormat = regexp.MustCompile("^[a-zA-Z0-9_=+.-]{1,32}$")
func isReserved(name string) bool {
if len(name) == 0 {
@@ -35,6 +34,14 @@ func isReserved(name string) bool {
if strings.EqualFold(name, reserved) {
return true
}
+ for i := len(name) - 1; i >= 0; i-- {
+ if name[i] == '.' {
+ if strings.EqualFold(name[:i], reserved) {
+ return true
+ }
+ break
+ }
+ }
}
return false
}
@@ -55,6 +62,7 @@ type naturalSortToken struct {
maybeString string
maybeNumber int
}
+
type naturalSortString struct {
originalString string
tokens []naturalSortToken
@@ -110,3 +118,10 @@ func TunnelNameIsLess(a, b string) bool {
}
return false
}
+
+func ServiceNameOfTunnel(tunnelName string) (string, error) {
+ if !TunnelNameIsValid(tunnelName) {
+ return "", errors.New("Tunnel name is not valid")
+ }
+ return "WireGuardTunnel$" + tunnelName, nil
+}
diff --git a/conf/parser.go b/conf/parser.go
index 5f44edb2..b1da9816 100644
--- a/conf/parser.go
+++ b/conf/parser.go
@@ -1,20 +1,20 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package conf
import (
"encoding/base64"
- "encoding/hex"
- "net"
+ "net/netip"
"strconv"
"strings"
- "time"
+ "golang.org/x/sys/windows"
"golang.org/x/text/encoding/unicode"
+ "golang.zx2c4.com/wireguard/windows/driver"
"golang.zx2c4.com/wireguard/windows/l18n"
)
@@ -27,43 +27,16 @@ func (e *ParseError) Error() string {
return l18n.Sprintf("%s: %q", e.why, e.offender)
}
-func parseIPCidr(s string) (ipcidr *IPCidr, err error) {
- var addrStr, cidrStr string
- var cidr int
-
- i := strings.IndexByte(s, '/')
- if i < 0 {
- addrStr = s
- } else {
- addrStr, cidrStr = s[:i], s[i+1:]
- }
-
- err = &ParseError{l18n.Sprintf("Invalid IP address"), s}
- addr := net.ParseIP(addrStr)
- if addr == nil {
- return
+func parseIPCidr(s string) (netip.Prefix, error) {
+ ipcidr, err := netip.ParsePrefix(s)
+ if err == nil {
+ return ipcidr, nil
}
- maybeV4 := addr.To4()
- if maybeV4 != nil {
- addr = maybeV4
- }
- if len(cidrStr) > 0 {
- err = &ParseError{l18n.Sprintf("Invalid network prefix length"), s}
- cidr, err = strconv.Atoi(cidrStr)
- if err != nil || cidr < 0 || cidr > 128 {
- return
- }
- if cidr > 32 && maybeV4 != nil {
- return
- }
- } else {
- if maybeV4 != nil {
- cidr = 32
- } else {
- cidr = 128
- }
+ addr, err := netip.ParseAddr(s)
+ if err != nil {
+ return netip.Prefix{}, &ParseError{l18n.Sprintf("Invalid IP address: "), s}
}
- return &IPCidr{addr, uint8(cidr)}, nil
+ return netip.PrefixFrom(addr, addr.BitLen()), nil
}
func parseEndpoint(s string) (*Endpoint, error) {
@@ -87,8 +60,8 @@ func parseEndpoint(s string) (*Endpoint, error) {
if i := strings.LastIndexByte(host, '%'); i > 1 {
end = i
}
- maybeV6 := net.ParseIP(host[1:end])
- if maybeV6 == nil || len(maybeV6) != net.IPv6len {
+ maybeV6, err2 := netip.ParseAddr(host[1:end])
+ if err2 != nil || !maybeV6.Is6() {
return nil, err
}
} else {
@@ -96,7 +69,7 @@ func parseEndpoint(s string) (*Endpoint, error) {
}
host = host[1 : len(host)-1]
}
- return &Endpoint{host, uint16(port)}, nil
+ return &Endpoint{host, port}, nil
}
func parseMTU(s string) (uint16, error) {
@@ -135,21 +108,18 @@ func parsePersistentKeepalive(s string) (uint16, error) {
return uint16(m), nil
}
-func parseKeyBase64(s string) (*Key, error) {
- k, err := base64.StdEncoding.DecodeString(s)
- if err != nil {
- return nil, &ParseError{l18n.Sprintf("Invalid key: %v", err), s}
- }
- if len(k) != KeyLength {
- return nil, &ParseError{l18n.Sprintf("Keys must decode to exactly 32 bytes"), s}
+func parseTableOff(s string) (bool, error) {
+ if s == "off" {
+ return true, nil
+ } else if s == "auto" || s == "main" {
+ return false, nil
}
- var key Key
- copy(key[:], k)
- return &key, nil
+ _, err := strconv.ParseUint(s, 10, 32)
+ return false, err
}
-func parseKeyHex(s string) (*Key, error) {
- k, err := hex.DecodeString(s)
+func parseKeyBase64(s string) (*Key, error) {
+ k, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return nil, &ParseError{l18n.Sprintf("Invalid key: %v", err), s}
}
@@ -161,14 +131,6 @@ func parseKeyHex(s string) (*Key, error) {
return &key, nil
}
-func parseBytesOrStamp(s string) (uint64, error) {
- b, err := strconv.ParseUint(s, 10, 64)
- if err != nil {
- return 0, &ParseError{l18n.Sprintf("Number must be a number between 0 and 2^64-1: %v", err), s}
- }
- return b, nil
-}
-
func splitList(s string) ([]string, error) {
var out []string
for _, split := range strings.Split(s, ",") {
@@ -195,7 +157,7 @@ func (c *Config) maybeAddPeer(p *Peer) {
}
}
-func FromWgQuick(s string, name string) (*Config, error) {
+func FromWgQuick(s, name string) (*Config, error) {
if !TunnelNameIsValid(name) {
return nil, &ParseError{l18n.Sprintf("Tunnel name is not valid"), name}
}
@@ -205,10 +167,7 @@ func FromWgQuick(s string, name string) (*Config, error) {
sawPrivateKey := false
var peer *Peer
for _, line := range lines {
- pound := strings.IndexByte(line, '#')
- if pound >= 0 {
- line = line[:pound]
- }
+ line, _, _ = strings.Cut(line, "#")
line = strings.TrimSpace(line)
lineLower := strings.ToLower(line)
if len(line) == 0 {
@@ -230,7 +189,7 @@ func FromWgQuick(s string, name string) (*Config, error) {
}
equals := strings.IndexByte(line, '=')
if equals < 0 {
- return nil, &ParseError{l18n.Sprintf("Invalid config key is missing an equals separator"), line}
+ return nil, &ParseError{l18n.Sprintf("Config key is missing an equals separator"), line}
}
key, val := strings.TrimSpace(lineLower[:equals]), strings.TrimSpace(line[equals+1:])
if len(val) == 0 {
@@ -267,7 +226,7 @@ func FromWgQuick(s string, name string) (*Config, error) {
if err != nil {
return nil, err
}
- conf.Interface.Addresses = append(conf.Interface.Addresses, *a)
+ conf.Interface.Addresses = append(conf.Interface.Addresses, a)
}
case "dns":
addresses, err := splitList(val)
@@ -275,12 +234,27 @@ func FromWgQuick(s string, name string) (*Config, error) {
return nil, err
}
for _, address := range addresses {
- a := net.ParseIP(address)
- if a == nil {
- return nil, &ParseError{l18n.Sprintf("Invalid IP address"), address}
+ a, err := netip.ParseAddr(address)
+ if err != nil {
+ conf.Interface.DNSSearch = append(conf.Interface.DNSSearch, address)
+ } else {
+ conf.Interface.DNS = append(conf.Interface.DNS, a)
}
- conf.Interface.DNS = append(conf.Interface.DNS, a)
}
+ case "preup":
+ conf.Interface.PreUp = val
+ case "postup":
+ conf.Interface.PostUp = val
+ case "predown":
+ conf.Interface.PreDown = val
+ case "postdown":
+ conf.Interface.PostDown = val
+ case "table":
+ tableOff, err := parseTableOff(val)
+ if err != nil {
+ return nil, err
+ }
+ conf.Interface.TableOff = tableOff
default:
return nil, &ParseError{l18n.Sprintf("Invalid key for [Interface] section"), key}
}
@@ -308,7 +282,7 @@ func FromWgQuick(s string, name string) (*Config, error) {
if err != nil {
return nil, err
}
- peer.AllowedIPs = append(peer.AllowedIPs, *a)
+ peer.AllowedIPs = append(peer.AllowedIPs, a)
}
case "persistentkeepalive":
p, err := parsePersistentKeepalive(val)
@@ -341,7 +315,7 @@ func FromWgQuick(s string, name string) (*Config, error) {
return &conf, nil
}
-func FromWgQuickWithUnknownEncoding(s string, name string) (*Config, error) {
+func FromWgQuickWithUnknownEncoding(s, name string) (*Config, error) {
c, firstErr := FromWgQuick(s, name)
if firstErr == nil {
return c, nil
@@ -358,128 +332,69 @@ func FromWgQuickWithUnknownEncoding(s string, name string) (*Config, error) {
return nil, firstErr
}
-func FromUAPI(s string, existingConfig *Config) (*Config, error) {
- lines := strings.Split(s, "\n")
- parserState := inInterfaceSection
+func FromDriverConfiguration(interfaze *driver.Interface, existingConfig *Config) *Config {
conf := Config{
Name: existingConfig.Name,
Interface: Interface{
Addresses: existingConfig.Interface.Addresses,
DNS: existingConfig.Interface.DNS,
+ DNSSearch: existingConfig.Interface.DNSSearch,
MTU: existingConfig.Interface.MTU,
+ PreUp: existingConfig.Interface.PreUp,
+ PostUp: existingConfig.Interface.PostUp,
+ PreDown: existingConfig.Interface.PreDown,
+ PostDown: existingConfig.Interface.PostDown,
+ TableOff: existingConfig.Interface.TableOff,
},
}
- var peer *Peer
- for _, line := range lines {
- if len(line) == 0 {
- continue
+ if interfaze.Flags&driver.InterfaceHasPrivateKey != 0 {
+ conf.Interface.PrivateKey = interfaze.PrivateKey
+ }
+ if interfaze.Flags&driver.InterfaceHasListenPort != 0 {
+ conf.Interface.ListenPort = interfaze.ListenPort
+ }
+ var p *driver.Peer
+ for i := uint32(0); i < interfaze.PeerCount; i++ {
+ if p == nil {
+ p = interfaze.FirstPeer()
+ } else {
+ p = p.NextPeer()
}
- equals := strings.IndexByte(line, '=')
- if equals < 0 {
- return nil, &ParseError{l18n.Sprintf("Invalid config key is missing an equals separator"), line}
+ peer := Peer{}
+ if p.Flags&driver.PeerHasPublicKey != 0 {
+ peer.PublicKey = p.PublicKey
}
- key, val := line[:equals], line[equals+1:]
- if len(val) == 0 {
- return nil, &ParseError{l18n.Sprintf("Key must have a value"), line}
+ if p.Flags&driver.PeerHasPresharedKey != 0 {
+ peer.PresharedKey = p.PresharedKey
}
- switch key {
- case "public_key":
- conf.maybeAddPeer(peer)
- peer = &Peer{}
- parserState = inPeerSection
- case "errno":
- if val == "0" {
- continue
- } else {
- return nil, &ParseError{l18n.Sprintf("Error in getting configuration"), val}
- }
+ if p.Flags&driver.PeerHasEndpoint != 0 {
+ peer.Endpoint.Port = p.Endpoint.Port()
+ peer.Endpoint.Host = p.Endpoint.Addr().String()
}
- if parserState == inInterfaceSection {
- switch key {
- case "private_key":
- k, err := parseKeyHex(val)
- if err != nil {
- return nil, err
- }
- conf.Interface.PrivateKey = *k
- case "listen_port":
- p, err := parsePort(val)
- if err != nil {
- return nil, err
- }
- conf.Interface.ListenPort = p
- case "fwmark":
- // Ignored for now.
-
- default:
- return nil, &ParseError{l18n.Sprintf("Invalid key for interface section"), key}
+ if p.Flags&driver.PeerHasPersistentKeepalive != 0 {
+ peer.PersistentKeepalive = p.PersistentKeepalive
+ }
+ peer.TxBytes = Bytes(p.TxBytes)
+ peer.RxBytes = Bytes(p.RxBytes)
+ if p.LastHandshake != 0 {
+ peer.LastHandshakeTime = HandshakeTime((p.LastHandshake - 116444736000000000) * 100)
+ }
+ var a *driver.AllowedIP
+ for j := uint32(0); j < p.AllowedIPsCount; j++ {
+ if a == nil {
+ a = p.FirstAllowedIP()
+ } else {
+ a = a.NextAllowedIP()
}
- } else if parserState == inPeerSection {
- switch key {
- case "public_key":
- k, err := parseKeyHex(val)
- if err != nil {
- return nil, err
- }
- peer.PublicKey = *k
- case "preshared_key":
- k, err := parseKeyHex(val)
- if err != nil {
- return nil, err
- }
- peer.PresharedKey = *k
- case "protocol_version":
- if val != "1" {
- return nil, &ParseError{l18n.Sprintf("Protocol version must be 1"), val}
- }
- case "allowed_ip":
- a, err := parseIPCidr(val)
- if err != nil {
- return nil, err
- }
- peer.AllowedIPs = append(peer.AllowedIPs, *a)
- case "persistent_keepalive_interval":
- p, err := parsePersistentKeepalive(val)
- if err != nil {
- return nil, err
- }
- peer.PersistentKeepalive = p
- case "endpoint":
- e, err := parseEndpoint(val)
- if err != nil {
- return nil, err
- }
- peer.Endpoint = *e
- case "tx_bytes":
- b, err := parseBytesOrStamp(val)
- if err != nil {
- return nil, err
- }
- peer.TxBytes = Bytes(b)
- case "rx_bytes":
- b, err := parseBytesOrStamp(val)
- if err != nil {
- return nil, err
- }
- peer.RxBytes = Bytes(b)
- case "last_handshake_time_sec":
- t, err := parseBytesOrStamp(val)
- if err != nil {
- return nil, err
- }
- peer.LastHandshakeTime += HandshakeTime(time.Duration(t) * time.Second)
- case "last_handshake_time_nsec":
- t, err := parseBytesOrStamp(val)
- if err != nil {
- return nil, err
- }
- peer.LastHandshakeTime += HandshakeTime(time.Duration(t) * time.Nanosecond)
- default:
- return nil, &ParseError{l18n.Sprintf("Invalid key for peer section"), key}
+ var ip netip.Addr
+ if a.AddressFamily == windows.AF_INET {
+ ip = netip.AddrFrom4(*(*[4]byte)(a.Address[:4]))
+ } else if a.AddressFamily == windows.AF_INET6 {
+ ip = netip.AddrFrom16(*(*[16]byte)(a.Address[:16]))
}
+ peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(ip, int(a.Cidr)))
}
+ conf.Peers = append(conf.Peers, peer)
}
- conf.maybeAddPeer(peer)
-
- return &conf, nil
+ return &conf
}
diff --git a/conf/parser_test.go b/conf/parser_test.go
index a6afbf53..25d906fd 100644
--- a/conf/parser_test.go
+++ b/conf/parser_test.go
@@ -1,12 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package conf
import (
- "net"
+ "net/netip"
"reflect"
"runtime"
"testing"
@@ -45,7 +45,7 @@ func noError(t *testing.T, err error) bool {
return false
}
-func equal(t *testing.T, expected, actual interface{}) bool {
+func equal(t *testing.T, expected, actual any) bool {
if reflect.DeepEqual(expected, actual) {
return true
}
@@ -53,7 +53,8 @@ func equal(t *testing.T, expected, actual interface{}) bool {
t.Errorf("Failed equals at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected)
return false
}
-func lenTest(t *testing.T, actualO interface{}, expected int) bool {
+
+func lenTest(t *testing.T, actualO any, expected int) bool {
actual := reflect.ValueOf(actualO).Len()
if reflect.DeepEqual(expected, actual) {
return true
@@ -62,7 +63,8 @@ func lenTest(t *testing.T, actualO interface{}, expected int) bool {
t.Errorf("Wrong length at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected)
return false
}
-func contains(t *testing.T, list, element interface{}) bool {
+
+func contains(t *testing.T, list, element any) bool {
listValue := reflect.ValueOf(list)
for i := 0; i < listValue.Len(); i++ {
if reflect.DeepEqual(listValue.Index(i).Interface(), element) {
@@ -77,10 +79,9 @@ func contains(t *testing.T, list, element interface{}) bool {
func TestFromWgQuick(t *testing.T) {
conf, err := FromWgQuick(testInput, "test")
if noError(t, err) {
-
lenTest(t, conf.Interface.Addresses, 2)
- contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 10, 0, 1), uint8(16)})
- contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 192, 122, 1), uint8(24)})
+ contains(t, conf.Interface.Addresses, netip.PrefixFrom(netip.AddrFrom4([4]byte{0, 10, 0, 1}), 16))
+ contains(t, conf.Interface.Addresses, netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 192, 122, 1}), 24))
equal(t, "yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=", conf.Interface.PrivateKey.String())
equal(t, uint16(51820), conf.Interface.ListenPort)
diff --git a/conf/path_windows.go b/conf/path_windows.go
index a53968c5..0ff0a057 100644
--- a/conf/path_windows.go
+++ b/conf/path_windows.go
@@ -1,33 +1,36 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package conf
import (
+ "errors"
"os"
"path/filepath"
+ "strings"
+ "unsafe"
"golang.org/x/sys/windows"
)
-var cachedConfigFileDir string
-var cachedRootDir string
-var disableAutoMigration bool
+var (
+ cachedConfigFileDir string
+ cachedRootDir string
+)
func tunnelConfigurationsDirectory() (string, error) {
if cachedConfigFileDir != "" {
return cachedConfigFileDir, nil
}
- root, err := RootDirectory()
+ root, err := RootDirectory(true)
if err != nil {
return "", err
}
c := filepath.Join(root, "Configurations")
- maybeMigrate(c)
- err = os.MkdirAll(c, os.ModeDir|0700)
- if err != nil {
+ err = os.Mkdir(c, os.ModeDir|0o700)
+ if err != nil && !os.IsExist(err) {
return "", err
}
cachedConfigFileDir = c
@@ -39,22 +42,97 @@ func tunnelConfigurationsDirectory() (string, error) {
// consumers of our libraries who might want to do strange things.
func PresetRootDirectory(root string) {
cachedRootDir = root
- disableAutoMigration = true
}
-func RootDirectory() (string, error) {
+func RootDirectory(create bool) (string, error) {
if cachedRootDir != "" {
return cachedRootDir, nil
}
- root, err := windows.KnownFolderPath(windows.FOLDERID_LocalAppData, windows.KF_FLAG_CREATE)
+ root, err := windows.KnownFolderPath(windows.FOLDERID_ProgramFiles, windows.KF_FLAG_DEFAULT)
+ if err != nil {
+ return "", err
+ }
+ root = filepath.Join(root, "WireGuard")
+ if !create {
+ return filepath.Join(root, "Data"), nil
+ }
+ root16, err := windows.UTF16PtrFromString(root)
+ if err != nil {
+ return "", err
+ }
+
+ // The root directory inherits its ACL from Program Files; we don't want to mess with that
+ err = windows.CreateDirectory(root16, nil)
+ if err != nil && err != windows.ERROR_ALREADY_EXISTS {
+ return "", err
+ }
+
+ dataDirectorySd, err := windows.SecurityDescriptorFromString("O:SYG:SYD:PAI(A;OICI;FA;;;SY)(A;OICI;FA;;;BA)")
if err != nil {
return "", err
}
- c := filepath.Join(root, "WireGuard")
- err = os.MkdirAll(c, os.ModeDir|0700)
+ dataDirectorySa := &windows.SecurityAttributes{
+ Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})),
+ SecurityDescriptor: dataDirectorySd,
+ }
+
+ data := filepath.Join(root, "Data")
+ data16, err := windows.UTF16PtrFromString(data)
if err != nil {
return "", err
}
- cachedRootDir = c
+ var dataHandle windows.Handle
+ for {
+ err = windows.CreateDirectory(data16, dataDirectorySa)
+ if err != nil && err != windows.ERROR_ALREADY_EXISTS {
+ return "", err
+ }
+ dataHandle, err = windows.CreateFile(data16, windows.READ_CONTROL|windows.WRITE_OWNER|windows.WRITE_DAC, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_BACKUP_SEMANTICS|windows.FILE_FLAG_OPEN_REPARSE_POINT|windows.FILE_ATTRIBUTE_DIRECTORY, 0)
+ if err != nil && err != windows.ERROR_FILE_NOT_FOUND {
+ return "", err
+ }
+ if err == nil {
+ break
+ }
+ }
+ defer windows.CloseHandle(dataHandle)
+ var fileInfo windows.ByHandleFileInformation
+ err = windows.GetFileInformationByHandle(dataHandle, &fileInfo)
+ if err != nil {
+ return "", err
+ }
+ if fileInfo.FileAttributes&windows.FILE_ATTRIBUTE_DIRECTORY == 0 {
+ return "", errors.New("Data directory is actually a file")
+ }
+ if fileInfo.FileAttributes&windows.FILE_ATTRIBUTE_REPARSE_POINT != 0 {
+ return "", errors.New("Data directory is reparse point")
+ }
+ buf := make([]uint16, windows.MAX_PATH+4)
+ for {
+ bufLen, err := windows.GetFinalPathNameByHandle(dataHandle, &buf[0], uint32(len(buf)), 0)
+ if err != nil {
+ return "", err
+ }
+ if bufLen < uint32(len(buf)) {
+ break
+ }
+ buf = make([]uint16, bufLen)
+ }
+ if !strings.EqualFold(`\\?\`+data, windows.UTF16ToString(buf[:])) {
+ return "", errors.New("Data directory jumped to unexpected location")
+ }
+ err = windows.SetKernelObjectSecurity(dataHandle, windows.DACL_SECURITY_INFORMATION|windows.GROUP_SECURITY_INFORMATION|windows.OWNER_SECURITY_INFORMATION|windows.PROTECTED_DACL_SECURITY_INFORMATION, dataDirectorySd)
+ if err != nil {
+ return "", err
+ }
+ cachedRootDir = data
return cachedRootDir, nil
}
+
+func LogFile(createRoot bool) (string, error) {
+ root, err := RootDirectory(createRoot)
+ if err != nil {
+ return "", err
+ }
+ return filepath.Join(root, "log.bin"), nil
+}
diff --git a/conf/store.go b/conf/store.go
index 21bd3a22..02807b77 100644
--- a/conf/store.go
+++ b/conf/store.go
@@ -1,13 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package conf
import (
"errors"
- "io/ioutil"
"os"
"path/filepath"
"strings"
@@ -15,109 +14,41 @@ import (
"golang.zx2c4.com/wireguard/windows/conf/dpapi"
)
-const configFileSuffix = ".conf.dpapi"
-const configFileUnencryptedSuffix = ".conf"
+const (
+ configFileSuffix = ".conf.dpapi"
+ configFileUnencryptedSuffix = ".conf"
+)
func ListConfigNames() ([]string, error) {
configFileDir, err := tunnelConfigurationsDirectory()
if err != nil {
return nil, err
}
- files, err := ioutil.ReadDir(configFileDir)
+ files, err := os.ReadDir(configFileDir)
if err != nil {
return nil, err
}
configs := make([]string, len(files))
i := 0
for _, file := range files {
- name := filepath.Base(file.Name())
- if len(name) <= len(configFileSuffix) || !strings.HasSuffix(name, configFileSuffix) {
- continue
- }
- if !file.Mode().IsRegular() || file.Mode().Perm()&0444 == 0 {
- continue
- }
- name = strings.TrimSuffix(name, configFileSuffix)
- if !TunnelNameIsValid(name) {
- continue
- }
- configs[i] = name
- i++
- }
- return configs[:i], nil
-}
-
-func MigrateUnencryptedConfigs() (int, []error) {
- configFileDir, err := tunnelConfigurationsDirectory()
- if err != nil {
- return 0, []error{err}
- }
- files, err := ioutil.ReadDir(configFileDir)
- if err != nil {
- return 0, []error{err}
- }
- errs := make([]error, len(files))
- i := 0
- e := 0
- for _, file := range files {
- path := filepath.Join(configFileDir, file.Name())
- name := filepath.Base(file.Name())
- if len(name) <= len(configFileUnencryptedSuffix) || !strings.HasSuffix(name, configFileUnencryptedSuffix) {
- continue
- }
- if !file.Mode().IsRegular() || file.Mode().Perm()&0444 == 0 {
- continue
- }
-
- // We don't use ioutil's ReadFile, because we actually want RDWR, so that we can take advantage
- // of Windows file locking for ensuring the file is finished being written.
- f, err := os.OpenFile(path, os.O_RDWR, 0)
- if err != nil {
- errs[e] = err
- e++
- continue
- }
- bytes, err := ioutil.ReadAll(f)
- f.Close()
+ name, err := NameFromPath(file.Name())
if err != nil {
- errs[e] = err
- e++
continue
}
- _, err = FromWgQuickWithUnknownEncoding(string(bytes), "input")
- if err != nil {
- errs[e] = err
- e++
+ if !file.Type().IsRegular() {
continue
}
-
- bytes, err = dpapi.Encrypt(bytes, strings.TrimSuffix(name, configFileUnencryptedSuffix))
+ info, err := file.Info()
if err != nil {
- errs[e] = err
- e++
- continue
- }
- dstFile := strings.TrimSuffix(path, configFileUnencryptedSuffix) + configFileSuffix
- if _, err = os.Stat(dstFile); err != nil && !os.IsNotExist(err) {
- errs[e] = errors.New("Unable to migrate to " + dstFile + " as it already exists")
- e++
continue
}
- err = ioutil.WriteFile(dstFile, bytes, 0600)
- if err != nil {
- errs[e] = err
- e++
- continue
- }
- err = os.Remove(path)
- if err != nil && os.Remove(dstFile) == nil {
- errs[e] = err
- e++
+ if info.Mode().Perm()&0o444 == 0 {
continue
}
+ configs[i] = name
i++
}
- return i, errs[:e]
+ return configs[:i], nil
}
func LoadFromName(name string) (*Config, error) {
@@ -129,15 +60,11 @@ func LoadFromName(name string) (*Config, error) {
}
func LoadFromPath(path string) (*Config, error) {
- if !disableAutoMigration {
- tunnelConfigurationsDirectory() // Provoke migrations, if needed.
- }
-
name, err := NameFromPath(path)
if err != nil {
return nil, err
}
- bytes, err := ioutil.ReadFile(path)
+ bytes, err := os.ReadFile(path)
if err != nil {
return nil, err
}
@@ -171,7 +98,7 @@ func NameFromPath(path string) (string, error) {
return name, nil
}
-func (config *Config) Save() error {
+func (config *Config) Save(overwrite bool) error {
if !TunnelNameIsValid(config.Name) {
return errors.New("Tunnel name is not valid")
}
@@ -185,16 +112,7 @@ func (config *Config) Save() error {
if err != nil {
return err
}
- err = ioutil.WriteFile(filename+".tmp", bytes, 0600)
- if err != nil {
- return err
- }
- err = os.Rename(filename+".tmp", filename)
- if err != nil {
- os.Remove(filename + ".tmp")
- return err
- }
- return nil
+ return writeLockedDownFile(filename, overwrite, bytes)
}
func (config *Config) Path() (string, error) {
diff --git a/conf/store_test.go b/conf/store_test.go
index fdef7ea7..3427a2b1 100644
--- a/conf/store_test.go
+++ b/conf/store_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package conf
@@ -17,7 +17,7 @@ func TestStorage(t *testing.T) {
return
}
- err = c.Save()
+ err = c.Save(true)
if err != nil {
t.Errorf("Unable to save config: %s", err.Error())
}
@@ -54,7 +54,11 @@ func TestStorage(t *testing.T) {
}
c.Interface.PrivateKey = *k
- err = c.Save()
+ err = c.Save(false)
+ if err == nil {
+ t.Error("Config disappeared or was unexpectedly overwritten")
+ }
+ err = c.Save(true)
if err != nil {
t.Errorf("Unable to save config a second time: %s", err.Error())
}
diff --git a/conf/storewatcher.go b/conf/storewatcher.go
index ffd20ee0..70a44add 100644
--- a/conf/storewatcher.go
+++ b/conf/storewatcher.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package conf
diff --git a/conf/storewatcher_windows.go b/conf/storewatcher_windows.go
index 19956263..0c4b74e7 100644
--- a/conf/storewatcher_windows.go
+++ b/conf/storewatcher_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package conf
@@ -11,20 +11,6 @@ import (
"golang.org/x/sys/windows"
)
-const (
- fncFILE_NAME uint32 = 0x00000001
- fncDIR_NAME uint32 = 0x00000002
- fncATTRIBUTES uint32 = 0x00000004
- fncSIZE uint32 = 0x00000008
- fncLAST_WRITE uint32 = 0x00000010
- fncLAST_ACCESS uint32 = 0x00000020
- fncCREATION uint32 = 0x00000040
- fncSECURITY uint32 = 0x00000100
-)
-
-//sys findFirstChangeNotification(path *uint16, watchSubtree bool, filter uint32) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = kernel32.FindFirstChangeNotificationW
-//sys findNextChangeNotification(handle windows.Handle) (err error) = kernel32.FindNextChangeNotification
-
var haveStartedWatchingConfigDir bool
func startWatchingConfigDir() {
@@ -36,7 +22,7 @@ func startWatchingConfigDir() {
h := windows.InvalidHandle
defer func() {
if h != windows.InvalidHandle {
- windows.CloseHandle(h)
+ windows.FindCloseChangeNotification(h)
}
haveStartedWatchingConfigDir = false
}()
@@ -45,7 +31,7 @@ func startWatchingConfigDir() {
if err != nil {
return
}
- h, err = findFirstChangeNotification(windows.StringToUTF16Ptr(configFileDir), true, fncFILE_NAME|fncDIR_NAME|fncATTRIBUTES|fncSIZE|fncLAST_WRITE|fncLAST_ACCESS|fncCREATION|fncSECURITY)
+ h, err = windows.FindFirstChangeNotification(configFileDir, true, windows.FILE_NOTIFY_CHANGE_FILE_NAME|windows.FILE_NOTIFY_CHANGE_DIR_NAME|windows.FILE_NOTIFY_CHANGE_ATTRIBUTES|windows.FILE_NOTIFY_CHANGE_SIZE|windows.FILE_NOTIFY_CHANGE_LAST_WRITE|windows.FILE_NOTIFY_CHANGE_LAST_ACCESS|windows.FILE_NOTIFY_CHANGE_CREATION|windows.FILE_NOTIFY_CHANGE_SECURITY)
if err != nil {
log.Printf("Unable to monitor config directory: %v", err)
return
@@ -54,7 +40,7 @@ func startWatchingConfigDir() {
s, err := windows.WaitForSingleObject(h, windows.INFINITE)
if err != nil || s == windows.WAIT_FAILED {
log.Printf("Unable to wait on config directory watcher: %v", err)
- windows.CloseHandle(h)
+ windows.FindCloseChangeNotification(h)
h = windows.InvalidHandle
goto startover
}
@@ -63,10 +49,10 @@ func startWatchingConfigDir() {
cb.cb()
}
- err = findNextChangeNotification(h)
+ err = windows.FindNextChangeNotification(h)
if err != nil {
log.Printf("Unable to monitor config directory again: %v", err)
- windows.CloseHandle(h)
+ windows.FindCloseChangeNotification(h)
h = windows.InvalidHandle
goto startover
}
diff --git a/conf/writer.go b/conf/writer.go
index 748c1d61..162962b5 100644
--- a/conf/writer.go
+++ b/conf/writer.go
@@ -1,13 +1,19 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package conf
import (
"fmt"
+ "net/netip"
"strings"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+ "golang.zx2c4.com/wireguard/windows/driver"
+ "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
func (conf *Config) ToWgQuick() string {
@@ -28,11 +34,12 @@ func (conf *Config) ToWgQuick() string {
output.WriteString(fmt.Sprintf("Address = %s\n", strings.Join(addrStrings[:], ", ")))
}
- if len(conf.Interface.DNS) > 0 {
- addrStrings := make([]string, len(conf.Interface.DNS))
- for i, address := range conf.Interface.DNS {
- addrStrings[i] = address.String()
+ if len(conf.Interface.DNS)+len(conf.Interface.DNSSearch) > 0 {
+ addrStrings := make([]string, 0, len(conf.Interface.DNS)+len(conf.Interface.DNSSearch))
+ for _, address := range conf.Interface.DNS {
+ addrStrings = append(addrStrings, address.String())
}
+ addrStrings = append(addrStrings, conf.Interface.DNSSearch...)
output.WriteString(fmt.Sprintf("DNS = %s\n", strings.Join(addrStrings[:], ", ")))
}
@@ -40,6 +47,22 @@ func (conf *Config) ToWgQuick() string {
output.WriteString(fmt.Sprintf("MTU = %d\n", conf.Interface.MTU))
}
+ if len(conf.Interface.PreUp) > 0 {
+ output.WriteString(fmt.Sprintf("PreUp = %s\n", conf.Interface.PreUp))
+ }
+ if len(conf.Interface.PostUp) > 0 {
+ output.WriteString(fmt.Sprintf("PostUp = %s\n", conf.Interface.PostUp))
+ }
+ if len(conf.Interface.PreDown) > 0 {
+ output.WriteString(fmt.Sprintf("PreDown = %s\n", conf.Interface.PreDown))
+ }
+ if len(conf.Interface.PostDown) > 0 {
+ output.WriteString(fmt.Sprintf("PostDown = %s\n", conf.Interface.PostDown))
+ }
+ if conf.Interface.TableOff {
+ output.WriteString("Table = off\n")
+ }
+
for _, peer := range conf.Peers {
output.WriteString("\n[Peer]\n")
@@ -68,43 +91,50 @@ func (conf *Config) ToWgQuick() string {
return output.String()
}
-func (conf *Config) ToUAPI() (uapi string, dnsErr error) {
- var output strings.Builder
- output.WriteString(fmt.Sprintf("private_key=%s\n", conf.Interface.PrivateKey.HexString()))
-
- if conf.Interface.ListenPort > 0 {
- output.WriteString(fmt.Sprintf("listen_port=%d\n", conf.Interface.ListenPort))
- }
-
- if len(conf.Peers) > 0 {
- output.WriteString("replace_peers=true\n")
+func (config *Config) ToDriverConfiguration() (*driver.Interface, uint32) {
+ preallocation := unsafe.Sizeof(driver.Interface{}) + uintptr(len(config.Peers))*unsafe.Sizeof(driver.Peer{})
+ for i := range config.Peers {
+ preallocation += uintptr(len(config.Peers[i].AllowedIPs)) * unsafe.Sizeof(driver.AllowedIP{})
}
-
- for _, peer := range conf.Peers {
- output.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey.HexString()))
-
- if !peer.PresharedKey.IsZero() {
- output.WriteString(fmt.Sprintf("preshared_key=%s\n", peer.PresharedKey.HexString()))
+ var c driver.ConfigBuilder
+ c.Preallocate(uint32(preallocation))
+ c.AppendInterface(&driver.Interface{
+ Flags: driver.InterfaceHasPrivateKey | driver.InterfaceHasListenPort,
+ ListenPort: config.Interface.ListenPort,
+ PrivateKey: config.Interface.PrivateKey,
+ PeerCount: uint32(len(config.Peers)),
+ })
+ for i := range config.Peers {
+ flags := driver.PeerHasPublicKey | driver.PeerHasPersistentKeepalive
+ if !config.Peers[i].PresharedKey.IsZero() {
+ flags |= driver.PeerHasPresharedKey
}
-
- if !peer.Endpoint.IsEmpty() {
- var resolvedIP string
- resolvedIP, dnsErr = resolveHostname(peer.Endpoint.Host)
- if dnsErr != nil {
- return
+ var endpoint winipcfg.RawSockaddrInet
+ if !config.Peers[i].Endpoint.IsEmpty() {
+ addr, err := netip.ParseAddr(config.Peers[i].Endpoint.Host)
+ if err == nil {
+ flags |= driver.PeerHasEndpoint
+ endpoint.SetAddrPort(netip.AddrPortFrom(addr, config.Peers[i].Endpoint.Port))
}
- resolvedEndpoint := Endpoint{resolvedIP, peer.Endpoint.Port}
- output.WriteString(fmt.Sprintf("endpoint=%s\n", resolvedEndpoint.String()))
}
-
- output.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.PersistentKeepalive))
-
- if len(peer.AllowedIPs) > 0 {
- output.WriteString("replace_allowed_ips=true\n")
- for _, address := range peer.AllowedIPs {
- output.WriteString(fmt.Sprintf("allowed_ip=%s\n", address.String()))
+ c.AppendPeer(&driver.Peer{
+ Flags: flags,
+ PublicKey: config.Peers[i].PublicKey,
+ PresharedKey: config.Peers[i].PresharedKey,
+ PersistentKeepalive: config.Peers[i].PersistentKeepalive,
+ Endpoint: endpoint,
+ AllowedIPsCount: uint32(len(config.Peers[i].AllowedIPs)),
+ })
+ for j := range config.Peers[i].AllowedIPs {
+ a := &driver.AllowedIP{Cidr: uint8(config.Peers[i].AllowedIPs[j].Bits())}
+ copy(a.Address[:], config.Peers[i].AllowedIPs[j].Addr().AsSlice())
+ if config.Peers[i].AllowedIPs[j].Addr().Is4() {
+ a.AddressFamily = windows.AF_INET
+ } else if config.Peers[i].AllowedIPs[j].Addr().Is6() {
+ a.AddressFamily = windows.AF_INET6
}
+ c.AppendAllowedIP(a)
}
}
- return output.String(), nil
+ return c.Interface()
}
diff --git a/conf/zsyscall_windows.go b/conf/zsyscall_windows.go
deleted file mode 100644
index 9dcf68fe..00000000
--- a/conf/zsyscall_windows.go
+++ /dev/null
@@ -1,83 +0,0 @@
-// Code generated by 'go generate'; DO NOT EDIT.
-
-package conf
-
-import (
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-var _ unsafe.Pointer
-
-// Do the interface allocations only once for common
-// Errno values.
-const (
- errnoERROR_IO_PENDING = 997
-)
-
-var (
- errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
-)
-
-// errnoErr returns common boxed Errno values, to prevent
-// allocations at runtime.
-func errnoErr(e syscall.Errno) error {
- switch e {
- case 0:
- return nil
- case errnoERROR_IO_PENDING:
- return errERROR_IO_PENDING
- }
- // TODO: add more here, after collecting data on the common
- // error values see on Windows. (perhaps when running
- // all.bat?)
- return e
-}
-
-var (
- modwininet = windows.NewLazySystemDLL("wininet.dll")
- modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
-
- procInternetGetConnectedState = modwininet.NewProc("InternetGetConnectedState")
- procFindFirstChangeNotificationW = modkernel32.NewProc("FindFirstChangeNotificationW")
- procFindNextChangeNotification = modkernel32.NewProc("FindNextChangeNotification")
-)
-
-func internetGetConnectedState(flags *uint32, reserved uint32) (connected bool) {
- r0, _, _ := syscall.Syscall(procInternetGetConnectedState.Addr(), 2, uintptr(unsafe.Pointer(flags)), uintptr(reserved), 0)
- connected = r0 != 0
- return
-}
-
-func findFirstChangeNotification(path *uint16, watchSubtree bool, filter uint32) (handle windows.Handle, err error) {
- var _p0 uint32
- if watchSubtree {
- _p0 = 1
- } else {
- _p0 = 0
- }
- r0, _, e1 := syscall.Syscall(procFindFirstChangeNotificationW.Addr(), 3, uintptr(unsafe.Pointer(path)), uintptr(_p0), uintptr(filter))
- handle = windows.Handle(r0)
- if handle == windows.InvalidHandle {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func findNextChangeNotification(handle windows.Handle) (err error) {
- r1, _, e1 := syscall.Syscall(procFindNextChangeNotification.Addr(), 1, uintptr(handle), 0, 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}