diff options
Diffstat (limited to 'conf')
-rw-r--r-- | conf/admin_windows.go | 36 | ||||
-rw-r--r-- | conf/config.go | 83 | ||||
-rw-r--r-- | conf/dnsresolver_windows.go | 71 | ||||
-rw-r--r-- | conf/dpapi/dpapi_windows.go | 73 | ||||
-rw-r--r-- | conf/dpapi/dpapi_windows_test.go | 8 | ||||
-rw-r--r-- | conf/dpapi/mksyscall.go | 8 | ||||
-rw-r--r-- | conf/dpapi/zdpapi_windows.go | 68 | ||||
-rw-r--r-- | conf/filewriter_windows.go | 90 | ||||
-rw-r--r-- | conf/migration_windows.go | 108 | ||||
-rw-r--r-- | conf/mksyscall.go | 8 | ||||
-rw-r--r-- | conf/name.go | 33 | ||||
-rw-r--r-- | conf/parser.go | 281 | ||||
-rw-r--r-- | conf/parser_test.go | 17 | ||||
-rw-r--r-- | conf/path_windows.go | 106 | ||||
-rw-r--r-- | conf/store.go | 112 | ||||
-rw-r--r-- | conf/store_test.go | 10 | ||||
-rw-r--r-- | conf/storewatcher.go | 2 | ||||
-rw-r--r-- | conf/storewatcher_windows.go | 26 | ||||
-rw-r--r-- | conf/writer.go | 104 | ||||
-rw-r--r-- | conf/zsyscall_windows.go | 83 |
20 files changed, 638 insertions, 689 deletions
diff --git a/conf/admin_windows.go b/conf/admin_windows.go new file mode 100644 index 00000000..91d4bdc9 --- /dev/null +++ b/conf/admin_windows.go @@ -0,0 +1,36 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import "golang.org/x/sys/windows/registry" + +const adminRegKey = `Software\WireGuard` + +var adminKey registry.Key + +func openAdminKey() (registry.Key, error) { + if adminKey != 0 { + return adminKey, nil + } + var err error + adminKey, err = registry.OpenKey(registry.LOCAL_MACHINE, adminRegKey, registry.QUERY_VALUE|registry.WOW64_64KEY) + if err != nil { + return 0, err + } + return adminKey, nil +} + +func AdminBool(name string) bool { + key, err := openAdminKey() + if err != nil { + return false + } + val, _, err := key.GetIntegerValue(name) + if err != nil { + return false + } + return val != 0 +} diff --git a/conf/config.go b/conf/config.go index 9f5dbcc1..74ffacf6 100644 --- a/conf/config.go +++ b/conf/config.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package conf @@ -9,9 +9,8 @@ import ( "crypto/rand" "crypto/subtle" "encoding/base64" - "encoding/hex" "fmt" - "net" + "net/netip" "strings" "time" @@ -22,19 +21,16 @@ import ( const KeyLength = 32 -type IPCidr struct { - IP net.IP - Cidr uint8 -} - type Endpoint struct { Host string Port uint16 } -type Key [KeyLength]byte -type HandshakeTime time.Duration -type Bytes uint64 +type ( + Key [KeyLength]byte + HandshakeTime time.Duration + Bytes uint64 +) type Config struct { Name string @@ -44,16 +40,22 @@ type Config struct { type Interface struct { PrivateKey Key - Addresses []IPCidr + Addresses []netip.Prefix ListenPort uint16 MTU uint16 - DNS []net.IP + DNS []netip.Addr + DNSSearch []string + PreUp string + PostUp string + PreDown string + PostDown string + TableOff bool } type Peer struct { PublicKey Key PresharedKey Key - AllowedIPs []IPCidr + AllowedIPs []netip.Prefix Endpoint Endpoint PersistentKeepalive uint16 @@ -62,26 +64,37 @@ type Peer struct { LastHandshakeTime HandshakeTime } -func (r *IPCidr) String() string { - return fmt.Sprintf("%s/%d", r.IP.String(), r.Cidr) -} - -func (r *IPCidr) Bits() uint8 { - if r.IP.To4() != nil { - return 32 +func (conf *Config) IntersectsWith(other *Config) bool { + allRoutes := make(map[netip.Prefix]bool, len(conf.Interface.Addresses)*2+len(conf.Peers)*3) + for _, a := range conf.Interface.Addresses { + allRoutes[netip.PrefixFrom(a.Addr(), a.Addr().BitLen())] = true + allRoutes[a.Masked()] = true } - return 128 -} - -func (r *IPCidr) IPNet() net.IPNet { - return net.IPNet{ - IP: r.IP, - Mask: net.CIDRMask(int(r.Cidr), int(r.Bits())), + for i := range conf.Peers { + for _, a := range conf.Peers[i].AllowedIPs { + allRoutes[a.Masked()] = true + } } + for _, a := range other.Interface.Addresses { + if allRoutes[netip.PrefixFrom(a.Addr(), a.Addr().BitLen())] { + return true + } + if allRoutes[a.Masked()] { + return true + } + } + for i := range other.Peers { + for _, a := range other.Peers[i].AllowedIPs { + if allRoutes[a.Masked()] { + return true + } + } + } + return false } func (e *Endpoint) String() string { - if strings.IndexByte(e.Host, ':') > 0 { + if strings.IndexByte(e.Host, ':') != -1 { return fmt.Sprintf("[%s]:%d", e.Host, e.Port) } return fmt.Sprintf("%s:%d", e.Host, e.Port) @@ -95,10 +108,6 @@ func (k *Key) String() string { return base64.StdEncoding.EncodeToString(k[:]) } -func (k *Key) HexString() string { - return hex.EncodeToString(k[:]) -} - func (k *Key) IsZero() bool { var zeros Key return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1 @@ -229,3 +238,11 @@ func (conf *Config) DeduplicateNetworkEntries() { peer.AllowedIPs = peer.AllowedIPs[:i] } } + +func (conf *Config) Redact() { + conf.Interface.PrivateKey = Key{} + for i := range conf.Peers { + conf.Peers[i].PublicKey = Key{} + conf.Peers[i].PresharedKey = Key{} + } +} diff --git a/conf/dnsresolver_windows.go b/conf/dnsresolver_windows.go index d6c2f1c7..a299c475 100644 --- a/conf/dnsresolver_windows.go +++ b/conf/dnsresolver_windows.go @@ -1,43 +1,41 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package conf import ( - "fmt" "log" - "net" - "syscall" + "net/netip" "time" "unsafe" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/windows/services" ) -//sys internetGetConnectedState(flags *uint32, reserved uint32) (connected bool) = wininet.InternetGetConnectedState - func resolveHostname(name string) (resolvedIPString string, err error) { maxTries := 10 - systemJustBooted := windows.DurationSinceBoot() <= time.Minute*4 - if systemJustBooted { - maxTries *= 4 + if services.StartedAtBoot() { + maxTries *= 3 } for i := 0; i < maxTries; i++ { + if i > 0 { + time.Sleep(time.Second * 4) + } resolvedIPString, err = resolveHostnameOnce(name) if err == nil { return } if err == windows.WSATRY_AGAIN { - log.Printf("Temporary DNS error when resolving %s, sleeping for 4 seconds", name) - time.Sleep(time.Second * 4) + log.Printf("Temporary DNS error when resolving %s, so sleeping for 4 seconds", name) continue } - var state uint32 - if err == windows.WSAHOST_NOT_FOUND && systemJustBooted && !internetGetConnectedState(&state, 0) { - log.Printf("Host not found when resolving %s, but no Internet connection available, sleeping for 4 seconds", name) - time.Sleep(time.Second * 4) + if err == windows.WSAHOST_NOT_FOUND && services.StartedAtBoot() { + log.Printf("Host not found when resolving %s at boot time, so sleeping for 4 seconds", name) continue } return @@ -65,28 +63,35 @@ func resolveHostnameOnce(name string) (resolvedIPString string, err error) { return } defer windows.FreeAddrInfoW(result) - ipv6 := "" + var v6 netip.Addr for ; result != nil; result = result.Next { - addr := unsafe.Pointer(result.Addr) - switch result.Family { - case windows.AF_INET: - a := (*syscall.RawSockaddrInet4)(addr).Addr - return net.IP{a[0], a[1], a[2], a[3]}.String(), nil - case windows.AF_INET6: - if len(ipv6) != 0 { - continue - } - a := (*syscall.RawSockaddrInet6)(addr).Addr - ipv6 = net.IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}.String() - scope := uint32((*syscall.RawSockaddrInet6)(addr).Scope_id) - if scope != 0 { - ipv6 += fmt.Sprintf("%%%d", scope) - } + if result.Family != windows.AF_INET && result.Family != windows.AF_INET6 { + continue + } + addr := (*winipcfg.RawSockaddrInet)(unsafe.Pointer(result.Addr)).Addr() + if addr.Is4() { + return addr.String(), nil + } else if !v6.IsValid() && addr.Is6() { + v6 = addr } } - if len(ipv6) != 0 { - return ipv6, nil + if v6.IsValid() { + return v6.String(), nil } err = windows.WSAHOST_NOT_FOUND return } + +func (config *Config) ResolveEndpoints() error { + for i := range config.Peers { + if config.Peers[i].Endpoint.IsEmpty() { + continue + } + var err error + config.Peers[i].Endpoint.Host, err = resolveHostname(config.Peers[i].Endpoint.Host) + if err != nil { + return err + } + } + return nil +} diff --git a/conf/dpapi/dpapi_windows.go b/conf/dpapi/dpapi_windows.go index 851ec1ee..49a32915 100644 --- a/conf/dpapi/dpapi_windows.go +++ b/conf/dpapi/dpapi_windows.go @@ -1,86 +1,53 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package dpapi import ( "errors" + "fmt" "runtime" "unsafe" "golang.org/x/sys/windows" ) -const ( - dpCRYPTPROTECT_UI_FORBIDDEN uint32 = 0x1 - dpCRYPTPROTECT_LOCAL_MACHINE uint32 = 0x4 - dpCRYPTPROTECT_CRED_SYNC uint32 = 0x8 - dpCRYPTPROTECT_AUDIT uint32 = 0x10 - dpCRYPTPROTECT_NO_RECOVERY uint32 = 0x20 - dpCRYPTPROTECT_VERIFY_PROTECTION uint32 = 0x40 - dpCRYPTPROTECT_CRED_REGENERATE uint32 = 0x80 -) - -type dpBlob struct { - len uint32 - data uintptr -} - -func bytesToBlob(bytes []byte) *dpBlob { - blob := &dpBlob{} - blob.len = uint32(len(bytes)) +func bytesToBlob(bytes []byte) *windows.DataBlob { + blob := &windows.DataBlob{Size: uint32(len(bytes))} if len(bytes) > 0 { - blob.data = uintptr(unsafe.Pointer(&bytes[0])) + blob.Data = &bytes[0] } return blob } -//sys cryptProtectData(dataIn *dpBlob, name *uint16, optionalEntropy *dpBlob, reserved uintptr, promptStruct uintptr, flags uint32, dataOut *dpBlob) (err error) = crypt32.CryptProtectData - func Encrypt(data []byte, name string) ([]byte, error) { - out := dpBlob{} - err := cryptProtectData(bytesToBlob(data), windows.StringToUTF16Ptr(name), nil, 0, 0, dpCRYPTPROTECT_UI_FORBIDDEN, &out) + out := windows.DataBlob{} + err := windows.CryptProtectData(bytesToBlob(data), windows.StringToUTF16Ptr(name), nil, 0, nil, windows.CRYPTPROTECT_UI_FORBIDDEN, &out) if err != nil { - return nil, errors.New("Unable to encrypt DPAPI protected data: " + err.Error()) + return nil, fmt.Errorf("unable to encrypt DPAPI protected data: %w", err) } - - outSlice := *(*[]byte)(unsafe.Pointer(&(struct { - addr uintptr - len int - cap int - }{out.data, int(out.len), int(out.len)}))) - ret := make([]byte, len(outSlice)) - copy(ret, outSlice) - windows.LocalFree(windows.Handle(out.data)) - + ret := make([]byte, out.Size) + copy(ret, unsafe.Slice(out.Data, out.Size)) + windows.LocalFree(windows.Handle(unsafe.Pointer(out.Data))) return ret, nil } -//sys cryptUnprotectData(dataIn *dpBlob, name **uint16, optionalEntropy *dpBlob, reserved uintptr, promptStruct uintptr, flags uint32, dataOut *dpBlob) (err error) = crypt32.CryptUnprotectData - func Decrypt(data []byte, name string) ([]byte, error) { - out := dpBlob{} + out := windows.DataBlob{} var outName *uint16 utf16Name, err := windows.UTF16PtrFromString(name) if err != nil { return nil, err } - - err = cryptUnprotectData(bytesToBlob(data), &outName, nil, 0, 0, dpCRYPTPROTECT_UI_FORBIDDEN, &out) + err = windows.CryptUnprotectData(bytesToBlob(data), &outName, nil, 0, nil, windows.CRYPTPROTECT_UI_FORBIDDEN, &out) if err != nil { - return nil, errors.New("Unable to decrypt DPAPI protected data: " + err.Error()) + return nil, fmt.Errorf("unable to decrypt DPAPI protected data: %w", err) } - - outSlice := *(*[]byte)(unsafe.Pointer(&(struct { - addr uintptr - len int - cap int - }{out.data, int(out.len), int(out.len)}))) - ret := make([]byte, len(outSlice)) - copy(ret, outSlice) - windows.LocalFree(windows.Handle(out.data)) + ret := make([]byte, out.Size) + copy(ret, unsafe.Slice(out.Data, out.Size)) + windows.LocalFree(windows.Handle(unsafe.Pointer(out.Data))) // Note: this ridiculous open-coded strcmp is not constant time. different := false @@ -94,14 +61,14 @@ func Decrypt(data []byte, name string) ([]byte, error) { if *a == 0 || *b == 0 { break } - a = (*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(a)) + 2)) - b = (*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(b)) + 2)) + a = (*uint16)(unsafe.Add(unsafe.Pointer(a), 2)) + b = (*uint16)(unsafe.Add(unsafe.Pointer(b), 2)) } runtime.KeepAlive(utf16Name) windows.LocalFree(windows.Handle(unsafe.Pointer(outName))) if different { - return nil, errors.New("The input name does not match the stored name") + return nil, errors.New("input name does not match the stored name") } return ret, nil diff --git a/conf/dpapi/dpapi_windows_test.go b/conf/dpapi/dpapi_windows_test.go index 8356f2d4..fd7307e6 100644 --- a/conf/dpapi/dpapi_windows_test.go +++ b/conf/dpapi/dpapi_windows_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package dpapi @@ -53,11 +53,7 @@ func TestRoundTrip(t *testing.T) { if err != nil { t.Errorf("Unable to get utf16 chars for name: %s", err) } - nameUtf16Bytes := *(*[]byte)(unsafe.Pointer(&struct { - addr *byte - len int - cap int - }{(*byte)(unsafe.Pointer(&nameUtf16[0])), len(nameUtf16) * 2, cap(nameUtf16) * 2})) + nameUtf16Bytes := unsafe.Slice((*byte)(unsafe.Pointer(&nameUtf16[0])), len(nameUtf16)*2) i := bytes.Index(eCorrupt, nameUtf16Bytes) if i == -1 { t.Error("Unable to find ad in blob") diff --git a/conf/dpapi/mksyscall.go b/conf/dpapi/mksyscall.go deleted file mode 100644 index 3d467f76..00000000 --- a/conf/dpapi/mksyscall.go +++ /dev/null @@ -1,8 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package dpapi - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zdpapi_windows.go dpapi_windows.go diff --git a/conf/dpapi/zdpapi_windows.go b/conf/dpapi/zdpapi_windows.go deleted file mode 100644 index e48d36b2..00000000 --- a/conf/dpapi/zdpapi_windows.go +++ /dev/null @@ -1,68 +0,0 @@ -// Code generated by 'go generate'; DO NOT EDIT. - -package dpapi - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return nil - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( - modcrypt32 = windows.NewLazySystemDLL("crypt32.dll") - - procCryptProtectData = modcrypt32.NewProc("CryptProtectData") - procCryptUnprotectData = modcrypt32.NewProc("CryptUnprotectData") -) - -func cryptProtectData(dataIn *dpBlob, name *uint16, optionalEntropy *dpBlob, reserved uintptr, promptStruct uintptr, flags uint32, dataOut *dpBlob) (err error) { - r1, _, e1 := syscall.Syscall9(procCryptProtectData.Addr(), 7, uintptr(unsafe.Pointer(dataIn)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(optionalEntropy)), uintptr(reserved), uintptr(promptStruct), uintptr(flags), uintptr(unsafe.Pointer(dataOut)), 0, 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func cryptUnprotectData(dataIn *dpBlob, name **uint16, optionalEntropy *dpBlob, reserved uintptr, promptStruct uintptr, flags uint32, dataOut *dpBlob) (err error) { - r1, _, e1 := syscall.Syscall9(procCryptUnprotectData.Addr(), 7, uintptr(unsafe.Pointer(dataIn)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(optionalEntropy)), uintptr(reserved), uintptr(promptStruct), uintptr(flags), uintptr(unsafe.Pointer(dataOut)), 0, 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} diff --git a/conf/filewriter_windows.go b/conf/filewriter_windows.go new file mode 100644 index 00000000..c6bb2b45 --- /dev/null +++ b/conf/filewriter_windows.go @@ -0,0 +1,90 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "crypto/rand" + "encoding/hex" + "path/filepath" + "sync/atomic" + "unsafe" + + "golang.org/x/sys/windows" +) + +var encryptedFileSd unsafe.Pointer + +func randomFileName() string { + var randBytes [32]byte + _, err := rand.Read(randBytes[:]) + if err != nil { + panic(err) + } + return hex.EncodeToString(randBytes[:]) + ".tmp" +} + +func writeLockedDownFile(destination string, overwrite bool, contents []byte) error { + var err error + sa := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{}))} + sa.SecurityDescriptor = (*windows.SECURITY_DESCRIPTOR)(atomic.LoadPointer(&encryptedFileSd)) + if sa.SecurityDescriptor == nil { + sa.SecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYG:SYD:PAI(A;;FA;;;SY)(A;;SD;;;BA)") + if err != nil { + return err + } + atomic.StorePointer(&encryptedFileSd, unsafe.Pointer(sa.SecurityDescriptor)) + } + destination16, err := windows.UTF16FromString(destination) + if err != nil { + return err + } + tmpDestination := filepath.Join(filepath.Dir(destination), randomFileName()) + tmpDestination16, err := windows.UTF16PtrFromString(tmpDestination) + if err != nil { + return err + } + handle, err := windows.CreateFile(tmpDestination16, windows.GENERIC_WRITE|windows.DELETE, windows.FILE_SHARE_READ, sa, windows.CREATE_ALWAYS, windows.FILE_ATTRIBUTE_NORMAL, 0) + if err != nil { + return err + } + defer windows.CloseHandle(handle) + deleteIt := func() { + yes := byte(1) + windows.SetFileInformationByHandle(handle, windows.FileDispositionInfo, &yes, 1) + } + n, err := windows.Write(handle, contents) + if err != nil { + deleteIt() + return err + } + if n != len(contents) { + deleteIt() + return windows.ERROR_IO_INCOMPLETE + } + fileRenameInfo := &struct { + replaceIfExists byte + rootDirectory windows.Handle + fileNameLength uint32 + fileName [windows.MAX_PATH]uint16 + }{replaceIfExists: func() byte { + if overwrite { + return 1 + } else { + return 0 + } + }(), fileNameLength: uint32(len(destination16) - 1)} + if len(destination16) > len(fileRenameInfo.fileName) { + deleteIt() + return windows.ERROR_BUFFER_OVERFLOW + } + copy(fileRenameInfo.fileName[:], destination16[:]) + err = windows.SetFileInformationByHandle(handle, windows.FileRenameInfo, (*byte)(unsafe.Pointer(fileRenameInfo)), uint32(unsafe.Sizeof(*fileRenameInfo))) + if err != nil { + deleteIt() + return err + } + return nil +} diff --git a/conf/migration_windows.go b/conf/migration_windows.go index 72b298b6..ed288f3c 100644 --- a/conf/migration_windows.go +++ b/conf/migration_windows.go @@ -1,51 +1,109 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package conf import ( + "errors" + "io" "log" + "os" "path/filepath" "strings" + "sync" + "time" "golang.org/x/sys/windows" ) -func maybeMigrate(c string) { - if disableAutoMigration { - return - } +var ( + migrating sync.Mutex + lastMigrationTimer *time.Timer +) - vol := filepath.VolumeName(c) - withoutVol := strings.TrimPrefix(c, vol) - oldRoot := filepath.Join(vol, "\\windows.old") - oldC := filepath.Join(oldRoot, withoutVol) +type MigrationCallback func(name, oldPath, newPath string) - sd, err := windows.GetNamedSecurityInfo(oldRoot, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION) - if err == windows.ERROR_PATH_NOT_FOUND || err == windows.ERROR_FILE_NOT_FOUND { - return +func MigrateUnencryptedConfigs(migrated MigrationCallback) { migrateUnencryptedConfigs(3, migrated) } + +func migrateUnencryptedConfigs(sharingBase int, migrated MigrationCallback) { + if migrated == nil { + migrated = func(_, _, _ string) {} } + migrating.Lock() + defer migrating.Unlock() + configFileDir, err := tunnelConfigurationsDirectory() if err != nil { - log.Printf("Not migrating configuration from ‘%s’ due to GetNamedSecurityInfo error: %v", oldRoot, err) return } - owner, defaulted, err := sd.Owner() + files, err := os.ReadDir(configFileDir) if err != nil { - log.Printf("Not migrating configuration from ‘%s’ due to GetSecurityDescriptorOwner error: %v", oldRoot, err) return } - if defaulted || (!owner.IsWellKnown(windows.WinLocalSystemSid) && !owner.IsWellKnown(windows.WinBuiltinAdministratorsSid)) { - log.Printf("Not migrating configuration from ‘%s’, as it is not explicitly owned by SYSTEM or Built-in Administrators, but rather ‘%v’", oldRoot, owner) - return - } - err = windows.MoveFileEx(windows.StringToUTF16Ptr(oldC), windows.StringToUTF16Ptr(c), windows.MOVEFILE_COPY_ALLOWED) - if err != nil { - if err != windows.ERROR_FILE_NOT_FOUND && err != windows.ERROR_ALREADY_EXISTS { - log.Printf("Not migrating configuration from ‘%s’ due to error when moving files: %v", oldRoot, err) + ignoreSharingViolations := false + for _, file := range files { + path := filepath.Join(configFileDir, file.Name()) + name := filepath.Base(file.Name()) + if len(name) <= len(configFileUnencryptedSuffix) || !strings.HasSuffix(name, configFileUnencryptedSuffix) { + continue } - return + if !file.Type().IsRegular() { + continue + } + info, err := file.Info() + if err != nil { + continue + } + if info.Mode().Perm()&0o444 == 0 { + continue + } + + var bytes []byte + var config *Config + var newPath string + // We don't use os.ReadFile, because we actually want RDWR, so that we can take advantage + // of Windows file locking for ensuring the file is finished being written. + f, err := os.OpenFile(path, os.O_RDWR, 0) + if err != nil { + if errors.Is(err, windows.ERROR_SHARING_VIOLATION) { + if ignoreSharingViolations { + continue + } else if sharingBase > 0 { + if lastMigrationTimer != nil { + lastMigrationTimer.Stop() + } + lastMigrationTimer = time.AfterFunc(time.Second/time.Duration(sharingBase*sharingBase), func() { migrateUnencryptedConfigs(sharingBase-1, migrated) }) + ignoreSharingViolations = true + continue + } + } + goto error + } + bytes, err = io.ReadAll(f) + f.Close() + if err != nil { + goto error + } + config, err = FromWgQuickWithUnknownEncoding(string(bytes), strings.TrimSuffix(name, configFileUnencryptedSuffix)) + if err != nil { + goto error + } + err = config.Save(false) + if err != nil { + goto error + } + err = os.Remove(path) + if err != nil { + goto error + } + newPath, err = config.Path() + if err != nil { + goto error + } + migrated(config.Name, path, newPath) + continue + error: + log.Printf("Unable to ingest and encrypt %#q: %v", path, err) } - log.Printf("Migrated configuration from ‘%s’", oldRoot) } diff --git a/conf/mksyscall.go b/conf/mksyscall.go deleted file mode 100644 index 2706c304..00000000 --- a/conf/mksyscall.go +++ /dev/null @@ -1,8 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package conf - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go dnsresolver_windows.go migration_windows.go storewatcher_windows.go diff --git a/conf/name.go b/conf/name.go index 87c463af..0d084070 100644 --- a/conf/name.go +++ b/conf/name.go @@ -1,11 +1,12 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package conf import ( + "errors" "regexp" "strconv" "strings" @@ -17,15 +18,13 @@ var reservedNames = []string{ "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9", } -const serviceNameForbidden = "$" -const netshellDllForbidden = "\\/:*?\"<>|\t" -const specialChars = "/\\<>:\"|?*\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x00" - -var allowedNameFormat *regexp.Regexp +const ( + serviceNameForbidden = "$" + netshellDllForbidden = "\\/:*?\"<>|\t" + specialChars = "/\\<>:\"|?*\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x00" +) -func init() { - allowedNameFormat = regexp.MustCompile("^[a-zA-Z0-9_=+.-]{1,32}$") -} +var allowedNameFormat = regexp.MustCompile("^[a-zA-Z0-9_=+.-]{1,32}$") func isReserved(name string) bool { if len(name) == 0 { @@ -35,6 +34,14 @@ func isReserved(name string) bool { if strings.EqualFold(name, reserved) { return true } + for i := len(name) - 1; i >= 0; i-- { + if name[i] == '.' { + if strings.EqualFold(name[:i], reserved) { + return true + } + break + } + } } return false } @@ -55,6 +62,7 @@ type naturalSortToken struct { maybeString string maybeNumber int } + type naturalSortString struct { originalString string tokens []naturalSortToken @@ -110,3 +118,10 @@ func TunnelNameIsLess(a, b string) bool { } return false } + +func ServiceNameOfTunnel(tunnelName string) (string, error) { + if !TunnelNameIsValid(tunnelName) { + return "", errors.New("Tunnel name is not valid") + } + return "WireGuardTunnel$" + tunnelName, nil +} diff --git a/conf/parser.go b/conf/parser.go index 5f44edb2..b1da9816 100644 --- a/conf/parser.go +++ b/conf/parser.go @@ -1,20 +1,20 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package conf import ( "encoding/base64" - "encoding/hex" - "net" + "net/netip" "strconv" "strings" - "time" + "golang.org/x/sys/windows" "golang.org/x/text/encoding/unicode" + "golang.zx2c4.com/wireguard/windows/driver" "golang.zx2c4.com/wireguard/windows/l18n" ) @@ -27,43 +27,16 @@ func (e *ParseError) Error() string { return l18n.Sprintf("%s: %q", e.why, e.offender) } -func parseIPCidr(s string) (ipcidr *IPCidr, err error) { - var addrStr, cidrStr string - var cidr int - - i := strings.IndexByte(s, '/') - if i < 0 { - addrStr = s - } else { - addrStr, cidrStr = s[:i], s[i+1:] - } - - err = &ParseError{l18n.Sprintf("Invalid IP address"), s} - addr := net.ParseIP(addrStr) - if addr == nil { - return +func parseIPCidr(s string) (netip.Prefix, error) { + ipcidr, err := netip.ParsePrefix(s) + if err == nil { + return ipcidr, nil } - maybeV4 := addr.To4() - if maybeV4 != nil { - addr = maybeV4 - } - if len(cidrStr) > 0 { - err = &ParseError{l18n.Sprintf("Invalid network prefix length"), s} - cidr, err = strconv.Atoi(cidrStr) - if err != nil || cidr < 0 || cidr > 128 { - return - } - if cidr > 32 && maybeV4 != nil { - return - } - } else { - if maybeV4 != nil { - cidr = 32 - } else { - cidr = 128 - } + addr, err := netip.ParseAddr(s) + if err != nil { + return netip.Prefix{}, &ParseError{l18n.Sprintf("Invalid IP address: "), s} } - return &IPCidr{addr, uint8(cidr)}, nil + return netip.PrefixFrom(addr, addr.BitLen()), nil } func parseEndpoint(s string) (*Endpoint, error) { @@ -87,8 +60,8 @@ func parseEndpoint(s string) (*Endpoint, error) { if i := strings.LastIndexByte(host, '%'); i > 1 { end = i } - maybeV6 := net.ParseIP(host[1:end]) - if maybeV6 == nil || len(maybeV6) != net.IPv6len { + maybeV6, err2 := netip.ParseAddr(host[1:end]) + if err2 != nil || !maybeV6.Is6() { return nil, err } } else { @@ -96,7 +69,7 @@ func parseEndpoint(s string) (*Endpoint, error) { } host = host[1 : len(host)-1] } - return &Endpoint{host, uint16(port)}, nil + return &Endpoint{host, port}, nil } func parseMTU(s string) (uint16, error) { @@ -135,21 +108,18 @@ func parsePersistentKeepalive(s string) (uint16, error) { return uint16(m), nil } -func parseKeyBase64(s string) (*Key, error) { - k, err := base64.StdEncoding.DecodeString(s) - if err != nil { - return nil, &ParseError{l18n.Sprintf("Invalid key: %v", err), s} - } - if len(k) != KeyLength { - return nil, &ParseError{l18n.Sprintf("Keys must decode to exactly 32 bytes"), s} +func parseTableOff(s string) (bool, error) { + if s == "off" { + return true, nil + } else if s == "auto" || s == "main" { + return false, nil } - var key Key - copy(key[:], k) - return &key, nil + _, err := strconv.ParseUint(s, 10, 32) + return false, err } -func parseKeyHex(s string) (*Key, error) { - k, err := hex.DecodeString(s) +func parseKeyBase64(s string) (*Key, error) { + k, err := base64.StdEncoding.DecodeString(s) if err != nil { return nil, &ParseError{l18n.Sprintf("Invalid key: %v", err), s} } @@ -161,14 +131,6 @@ func parseKeyHex(s string) (*Key, error) { return &key, nil } -func parseBytesOrStamp(s string) (uint64, error) { - b, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return 0, &ParseError{l18n.Sprintf("Number must be a number between 0 and 2^64-1: %v", err), s} - } - return b, nil -} - func splitList(s string) ([]string, error) { var out []string for _, split := range strings.Split(s, ",") { @@ -195,7 +157,7 @@ func (c *Config) maybeAddPeer(p *Peer) { } } -func FromWgQuick(s string, name string) (*Config, error) { +func FromWgQuick(s, name string) (*Config, error) { if !TunnelNameIsValid(name) { return nil, &ParseError{l18n.Sprintf("Tunnel name is not valid"), name} } @@ -205,10 +167,7 @@ func FromWgQuick(s string, name string) (*Config, error) { sawPrivateKey := false var peer *Peer for _, line := range lines { - pound := strings.IndexByte(line, '#') - if pound >= 0 { - line = line[:pound] - } + line, _, _ = strings.Cut(line, "#") line = strings.TrimSpace(line) lineLower := strings.ToLower(line) if len(line) == 0 { @@ -230,7 +189,7 @@ func FromWgQuick(s string, name string) (*Config, error) { } equals := strings.IndexByte(line, '=') if equals < 0 { - return nil, &ParseError{l18n.Sprintf("Invalid config key is missing an equals separator"), line} + return nil, &ParseError{l18n.Sprintf("Config key is missing an equals separator"), line} } key, val := strings.TrimSpace(lineLower[:equals]), strings.TrimSpace(line[equals+1:]) if len(val) == 0 { @@ -267,7 +226,7 @@ func FromWgQuick(s string, name string) (*Config, error) { if err != nil { return nil, err } - conf.Interface.Addresses = append(conf.Interface.Addresses, *a) + conf.Interface.Addresses = append(conf.Interface.Addresses, a) } case "dns": addresses, err := splitList(val) @@ -275,12 +234,27 @@ func FromWgQuick(s string, name string) (*Config, error) { return nil, err } for _, address := range addresses { - a := net.ParseIP(address) - if a == nil { - return nil, &ParseError{l18n.Sprintf("Invalid IP address"), address} + a, err := netip.ParseAddr(address) + if err != nil { + conf.Interface.DNSSearch = append(conf.Interface.DNSSearch, address) + } else { + conf.Interface.DNS = append(conf.Interface.DNS, a) } - conf.Interface.DNS = append(conf.Interface.DNS, a) } + case "preup": + conf.Interface.PreUp = val + case "postup": + conf.Interface.PostUp = val + case "predown": + conf.Interface.PreDown = val + case "postdown": + conf.Interface.PostDown = val + case "table": + tableOff, err := parseTableOff(val) + if err != nil { + return nil, err + } + conf.Interface.TableOff = tableOff default: return nil, &ParseError{l18n.Sprintf("Invalid key for [Interface] section"), key} } @@ -308,7 +282,7 @@ func FromWgQuick(s string, name string) (*Config, error) { if err != nil { return nil, err } - peer.AllowedIPs = append(peer.AllowedIPs, *a) + peer.AllowedIPs = append(peer.AllowedIPs, a) } case "persistentkeepalive": p, err := parsePersistentKeepalive(val) @@ -341,7 +315,7 @@ func FromWgQuick(s string, name string) (*Config, error) { return &conf, nil } -func FromWgQuickWithUnknownEncoding(s string, name string) (*Config, error) { +func FromWgQuickWithUnknownEncoding(s, name string) (*Config, error) { c, firstErr := FromWgQuick(s, name) if firstErr == nil { return c, nil @@ -358,128 +332,69 @@ func FromWgQuickWithUnknownEncoding(s string, name string) (*Config, error) { return nil, firstErr } -func FromUAPI(s string, existingConfig *Config) (*Config, error) { - lines := strings.Split(s, "\n") - parserState := inInterfaceSection +func FromDriverConfiguration(interfaze *driver.Interface, existingConfig *Config) *Config { conf := Config{ Name: existingConfig.Name, Interface: Interface{ Addresses: existingConfig.Interface.Addresses, DNS: existingConfig.Interface.DNS, + DNSSearch: existingConfig.Interface.DNSSearch, MTU: existingConfig.Interface.MTU, + PreUp: existingConfig.Interface.PreUp, + PostUp: existingConfig.Interface.PostUp, + PreDown: existingConfig.Interface.PreDown, + PostDown: existingConfig.Interface.PostDown, + TableOff: existingConfig.Interface.TableOff, }, } - var peer *Peer - for _, line := range lines { - if len(line) == 0 { - continue + if interfaze.Flags&driver.InterfaceHasPrivateKey != 0 { + conf.Interface.PrivateKey = interfaze.PrivateKey + } + if interfaze.Flags&driver.InterfaceHasListenPort != 0 { + conf.Interface.ListenPort = interfaze.ListenPort + } + var p *driver.Peer + for i := uint32(0); i < interfaze.PeerCount; i++ { + if p == nil { + p = interfaze.FirstPeer() + } else { + p = p.NextPeer() } - equals := strings.IndexByte(line, '=') - if equals < 0 { - return nil, &ParseError{l18n.Sprintf("Invalid config key is missing an equals separator"), line} + peer := Peer{} + if p.Flags&driver.PeerHasPublicKey != 0 { + peer.PublicKey = p.PublicKey } - key, val := line[:equals], line[equals+1:] - if len(val) == 0 { - return nil, &ParseError{l18n.Sprintf("Key must have a value"), line} + if p.Flags&driver.PeerHasPresharedKey != 0 { + peer.PresharedKey = p.PresharedKey } - switch key { - case "public_key": - conf.maybeAddPeer(peer) - peer = &Peer{} - parserState = inPeerSection - case "errno": - if val == "0" { - continue - } else { - return nil, &ParseError{l18n.Sprintf("Error in getting configuration"), val} - } + if p.Flags&driver.PeerHasEndpoint != 0 { + peer.Endpoint.Port = p.Endpoint.Port() + peer.Endpoint.Host = p.Endpoint.Addr().String() } - if parserState == inInterfaceSection { - switch key { - case "private_key": - k, err := parseKeyHex(val) - if err != nil { - return nil, err - } - conf.Interface.PrivateKey = *k - case "listen_port": - p, err := parsePort(val) - if err != nil { - return nil, err - } - conf.Interface.ListenPort = p - case "fwmark": - // Ignored for now. - - default: - return nil, &ParseError{l18n.Sprintf("Invalid key for interface section"), key} + if p.Flags&driver.PeerHasPersistentKeepalive != 0 { + peer.PersistentKeepalive = p.PersistentKeepalive + } + peer.TxBytes = Bytes(p.TxBytes) + peer.RxBytes = Bytes(p.RxBytes) + if p.LastHandshake != 0 { + peer.LastHandshakeTime = HandshakeTime((p.LastHandshake - 116444736000000000) * 100) + } + var a *driver.AllowedIP + for j := uint32(0); j < p.AllowedIPsCount; j++ { + if a == nil { + a = p.FirstAllowedIP() + } else { + a = a.NextAllowedIP() } - } else if parserState == inPeerSection { - switch key { - case "public_key": - k, err := parseKeyHex(val) - if err != nil { - return nil, err - } - peer.PublicKey = *k - case "preshared_key": - k, err := parseKeyHex(val) - if err != nil { - return nil, err - } - peer.PresharedKey = *k - case "protocol_version": - if val != "1" { - return nil, &ParseError{l18n.Sprintf("Protocol version must be 1"), val} - } - case "allowed_ip": - a, err := parseIPCidr(val) - if err != nil { - return nil, err - } - peer.AllowedIPs = append(peer.AllowedIPs, *a) - case "persistent_keepalive_interval": - p, err := parsePersistentKeepalive(val) - if err != nil { - return nil, err - } - peer.PersistentKeepalive = p - case "endpoint": - e, err := parseEndpoint(val) - if err != nil { - return nil, err - } - peer.Endpoint = *e - case "tx_bytes": - b, err := parseBytesOrStamp(val) - if err != nil { - return nil, err - } - peer.TxBytes = Bytes(b) - case "rx_bytes": - b, err := parseBytesOrStamp(val) - if err != nil { - return nil, err - } - peer.RxBytes = Bytes(b) - case "last_handshake_time_sec": - t, err := parseBytesOrStamp(val) - if err != nil { - return nil, err - } - peer.LastHandshakeTime += HandshakeTime(time.Duration(t) * time.Second) - case "last_handshake_time_nsec": - t, err := parseBytesOrStamp(val) - if err != nil { - return nil, err - } - peer.LastHandshakeTime += HandshakeTime(time.Duration(t) * time.Nanosecond) - default: - return nil, &ParseError{l18n.Sprintf("Invalid key for peer section"), key} + var ip netip.Addr + if a.AddressFamily == windows.AF_INET { + ip = netip.AddrFrom4(*(*[4]byte)(a.Address[:4])) + } else if a.AddressFamily == windows.AF_INET6 { + ip = netip.AddrFrom16(*(*[16]byte)(a.Address[:16])) } + peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(ip, int(a.Cidr))) } + conf.Peers = append(conf.Peers, peer) } - conf.maybeAddPeer(peer) - - return &conf, nil + return &conf } diff --git a/conf/parser_test.go b/conf/parser_test.go index a6afbf53..25d906fd 100644 --- a/conf/parser_test.go +++ b/conf/parser_test.go @@ -1,12 +1,12 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package conf import ( - "net" + "net/netip" "reflect" "runtime" "testing" @@ -45,7 +45,7 @@ func noError(t *testing.T, err error) bool { return false } -func equal(t *testing.T, expected, actual interface{}) bool { +func equal(t *testing.T, expected, actual any) bool { if reflect.DeepEqual(expected, actual) { return true } @@ -53,7 +53,8 @@ func equal(t *testing.T, expected, actual interface{}) bool { t.Errorf("Failed equals at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected) return false } -func lenTest(t *testing.T, actualO interface{}, expected int) bool { + +func lenTest(t *testing.T, actualO any, expected int) bool { actual := reflect.ValueOf(actualO).Len() if reflect.DeepEqual(expected, actual) { return true @@ -62,7 +63,8 @@ func lenTest(t *testing.T, actualO interface{}, expected int) bool { t.Errorf("Wrong length at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected) return false } -func contains(t *testing.T, list, element interface{}) bool { + +func contains(t *testing.T, list, element any) bool { listValue := reflect.ValueOf(list) for i := 0; i < listValue.Len(); i++ { if reflect.DeepEqual(listValue.Index(i).Interface(), element) { @@ -77,10 +79,9 @@ func contains(t *testing.T, list, element interface{}) bool { func TestFromWgQuick(t *testing.T) { conf, err := FromWgQuick(testInput, "test") if noError(t, err) { - lenTest(t, conf.Interface.Addresses, 2) - contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 10, 0, 1), uint8(16)}) - contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 192, 122, 1), uint8(24)}) + contains(t, conf.Interface.Addresses, netip.PrefixFrom(netip.AddrFrom4([4]byte{0, 10, 0, 1}), 16)) + contains(t, conf.Interface.Addresses, netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 192, 122, 1}), 24)) equal(t, "yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=", conf.Interface.PrivateKey.String()) equal(t, uint16(51820), conf.Interface.ListenPort) diff --git a/conf/path_windows.go b/conf/path_windows.go index a53968c5..0ff0a057 100644 --- a/conf/path_windows.go +++ b/conf/path_windows.go @@ -1,33 +1,36 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package conf import ( + "errors" "os" "path/filepath" + "strings" + "unsafe" "golang.org/x/sys/windows" ) -var cachedConfigFileDir string -var cachedRootDir string -var disableAutoMigration bool +var ( + cachedConfigFileDir string + cachedRootDir string +) func tunnelConfigurationsDirectory() (string, error) { if cachedConfigFileDir != "" { return cachedConfigFileDir, nil } - root, err := RootDirectory() + root, err := RootDirectory(true) if err != nil { return "", err } c := filepath.Join(root, "Configurations") - maybeMigrate(c) - err = os.MkdirAll(c, os.ModeDir|0700) - if err != nil { + err = os.Mkdir(c, os.ModeDir|0o700) + if err != nil && !os.IsExist(err) { return "", err } cachedConfigFileDir = c @@ -39,22 +42,97 @@ func tunnelConfigurationsDirectory() (string, error) { // consumers of our libraries who might want to do strange things. func PresetRootDirectory(root string) { cachedRootDir = root - disableAutoMigration = true } -func RootDirectory() (string, error) { +func RootDirectory(create bool) (string, error) { if cachedRootDir != "" { return cachedRootDir, nil } - root, err := windows.KnownFolderPath(windows.FOLDERID_LocalAppData, windows.KF_FLAG_CREATE) + root, err := windows.KnownFolderPath(windows.FOLDERID_ProgramFiles, windows.KF_FLAG_DEFAULT) + if err != nil { + return "", err + } + root = filepath.Join(root, "WireGuard") + if !create { + return filepath.Join(root, "Data"), nil + } + root16, err := windows.UTF16PtrFromString(root) + if err != nil { + return "", err + } + + // The root directory inherits its ACL from Program Files; we don't want to mess with that + err = windows.CreateDirectory(root16, nil) + if err != nil && err != windows.ERROR_ALREADY_EXISTS { + return "", err + } + + dataDirectorySd, err := windows.SecurityDescriptorFromString("O:SYG:SYD:PAI(A;OICI;FA;;;SY)(A;OICI;FA;;;BA)") if err != nil { return "", err } - c := filepath.Join(root, "WireGuard") - err = os.MkdirAll(c, os.ModeDir|0700) + dataDirectorySa := &windows.SecurityAttributes{ + Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})), + SecurityDescriptor: dataDirectorySd, + } + + data := filepath.Join(root, "Data") + data16, err := windows.UTF16PtrFromString(data) if err != nil { return "", err } - cachedRootDir = c + var dataHandle windows.Handle + for { + err = windows.CreateDirectory(data16, dataDirectorySa) + if err != nil && err != windows.ERROR_ALREADY_EXISTS { + return "", err + } + dataHandle, err = windows.CreateFile(data16, windows.READ_CONTROL|windows.WRITE_OWNER|windows.WRITE_DAC, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_BACKUP_SEMANTICS|windows.FILE_FLAG_OPEN_REPARSE_POINT|windows.FILE_ATTRIBUTE_DIRECTORY, 0) + if err != nil && err != windows.ERROR_FILE_NOT_FOUND { + return "", err + } + if err == nil { + break + } + } + defer windows.CloseHandle(dataHandle) + var fileInfo windows.ByHandleFileInformation + err = windows.GetFileInformationByHandle(dataHandle, &fileInfo) + if err != nil { + return "", err + } + if fileInfo.FileAttributes&windows.FILE_ATTRIBUTE_DIRECTORY == 0 { + return "", errors.New("Data directory is actually a file") + } + if fileInfo.FileAttributes&windows.FILE_ATTRIBUTE_REPARSE_POINT != 0 { + return "", errors.New("Data directory is reparse point") + } + buf := make([]uint16, windows.MAX_PATH+4) + for { + bufLen, err := windows.GetFinalPathNameByHandle(dataHandle, &buf[0], uint32(len(buf)), 0) + if err != nil { + return "", err + } + if bufLen < uint32(len(buf)) { + break + } + buf = make([]uint16, bufLen) + } + if !strings.EqualFold(`\\?\`+data, windows.UTF16ToString(buf[:])) { + return "", errors.New("Data directory jumped to unexpected location") + } + err = windows.SetKernelObjectSecurity(dataHandle, windows.DACL_SECURITY_INFORMATION|windows.GROUP_SECURITY_INFORMATION|windows.OWNER_SECURITY_INFORMATION|windows.PROTECTED_DACL_SECURITY_INFORMATION, dataDirectorySd) + if err != nil { + return "", err + } + cachedRootDir = data return cachedRootDir, nil } + +func LogFile(createRoot bool) (string, error) { + root, err := RootDirectory(createRoot) + if err != nil { + return "", err + } + return filepath.Join(root, "log.bin"), nil +} diff --git a/conf/store.go b/conf/store.go index 21bd3a22..02807b77 100644 --- a/conf/store.go +++ b/conf/store.go @@ -1,13 +1,12 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package conf import ( "errors" - "io/ioutil" "os" "path/filepath" "strings" @@ -15,109 +14,41 @@ import ( "golang.zx2c4.com/wireguard/windows/conf/dpapi" ) -const configFileSuffix = ".conf.dpapi" -const configFileUnencryptedSuffix = ".conf" +const ( + configFileSuffix = ".conf.dpapi" + configFileUnencryptedSuffix = ".conf" +) func ListConfigNames() ([]string, error) { configFileDir, err := tunnelConfigurationsDirectory() if err != nil { return nil, err } - files, err := ioutil.ReadDir(configFileDir) + files, err := os.ReadDir(configFileDir) if err != nil { return nil, err } configs := make([]string, len(files)) i := 0 for _, file := range files { - name := filepath.Base(file.Name()) - if len(name) <= len(configFileSuffix) || !strings.HasSuffix(name, configFileSuffix) { - continue - } - if !file.Mode().IsRegular() || file.Mode().Perm()&0444 == 0 { - continue - } - name = strings.TrimSuffix(name, configFileSuffix) - if !TunnelNameIsValid(name) { - continue - } - configs[i] = name - i++ - } - return configs[:i], nil -} - -func MigrateUnencryptedConfigs() (int, []error) { - configFileDir, err := tunnelConfigurationsDirectory() - if err != nil { - return 0, []error{err} - } - files, err := ioutil.ReadDir(configFileDir) - if err != nil { - return 0, []error{err} - } - errs := make([]error, len(files)) - i := 0 - e := 0 - for _, file := range files { - path := filepath.Join(configFileDir, file.Name()) - name := filepath.Base(file.Name()) - if len(name) <= len(configFileUnencryptedSuffix) || !strings.HasSuffix(name, configFileUnencryptedSuffix) { - continue - } - if !file.Mode().IsRegular() || file.Mode().Perm()&0444 == 0 { - continue - } - - // We don't use ioutil's ReadFile, because we actually want RDWR, so that we can take advantage - // of Windows file locking for ensuring the file is finished being written. - f, err := os.OpenFile(path, os.O_RDWR, 0) - if err != nil { - errs[e] = err - e++ - continue - } - bytes, err := ioutil.ReadAll(f) - f.Close() + name, err := NameFromPath(file.Name()) if err != nil { - errs[e] = err - e++ continue } - _, err = FromWgQuickWithUnknownEncoding(string(bytes), "input") - if err != nil { - errs[e] = err - e++ + if !file.Type().IsRegular() { continue } - - bytes, err = dpapi.Encrypt(bytes, strings.TrimSuffix(name, configFileUnencryptedSuffix)) + info, err := file.Info() if err != nil { - errs[e] = err - e++ - continue - } - dstFile := strings.TrimSuffix(path, configFileUnencryptedSuffix) + configFileSuffix - if _, err = os.Stat(dstFile); err != nil && !os.IsNotExist(err) { - errs[e] = errors.New("Unable to migrate to " + dstFile + " as it already exists") - e++ continue } - err = ioutil.WriteFile(dstFile, bytes, 0600) - if err != nil { - errs[e] = err - e++ - continue - } - err = os.Remove(path) - if err != nil && os.Remove(dstFile) == nil { - errs[e] = err - e++ + if info.Mode().Perm()&0o444 == 0 { continue } + configs[i] = name i++ } - return i, errs[:e] + return configs[:i], nil } func LoadFromName(name string) (*Config, error) { @@ -129,15 +60,11 @@ func LoadFromName(name string) (*Config, error) { } func LoadFromPath(path string) (*Config, error) { - if !disableAutoMigration { - tunnelConfigurationsDirectory() // Provoke migrations, if needed. - } - name, err := NameFromPath(path) if err != nil { return nil, err } - bytes, err := ioutil.ReadFile(path) + bytes, err := os.ReadFile(path) if err != nil { return nil, err } @@ -171,7 +98,7 @@ func NameFromPath(path string) (string, error) { return name, nil } -func (config *Config) Save() error { +func (config *Config) Save(overwrite bool) error { if !TunnelNameIsValid(config.Name) { return errors.New("Tunnel name is not valid") } @@ -185,16 +112,7 @@ func (config *Config) Save() error { if err != nil { return err } - err = ioutil.WriteFile(filename+".tmp", bytes, 0600) - if err != nil { - return err - } - err = os.Rename(filename+".tmp", filename) - if err != nil { - os.Remove(filename + ".tmp") - return err - } - return nil + return writeLockedDownFile(filename, overwrite, bytes) } func (config *Config) Path() (string, error) { diff --git a/conf/store_test.go b/conf/store_test.go index fdef7ea7..3427a2b1 100644 --- a/conf/store_test.go +++ b/conf/store_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package conf @@ -17,7 +17,7 @@ func TestStorage(t *testing.T) { return } - err = c.Save() + err = c.Save(true) if err != nil { t.Errorf("Unable to save config: %s", err.Error()) } @@ -54,7 +54,11 @@ func TestStorage(t *testing.T) { } c.Interface.PrivateKey = *k - err = c.Save() + err = c.Save(false) + if err == nil { + t.Error("Config disappeared or was unexpectedly overwritten") + } + err = c.Save(true) if err != nil { t.Errorf("Unable to save config a second time: %s", err.Error()) } diff --git a/conf/storewatcher.go b/conf/storewatcher.go index ffd20ee0..70a44add 100644 --- a/conf/storewatcher.go +++ b/conf/storewatcher.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package conf diff --git a/conf/storewatcher_windows.go b/conf/storewatcher_windows.go index 19956263..0c4b74e7 100644 --- a/conf/storewatcher_windows.go +++ b/conf/storewatcher_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package conf @@ -11,20 +11,6 @@ import ( "golang.org/x/sys/windows" ) -const ( - fncFILE_NAME uint32 = 0x00000001 - fncDIR_NAME uint32 = 0x00000002 - fncATTRIBUTES uint32 = 0x00000004 - fncSIZE uint32 = 0x00000008 - fncLAST_WRITE uint32 = 0x00000010 - fncLAST_ACCESS uint32 = 0x00000020 - fncCREATION uint32 = 0x00000040 - fncSECURITY uint32 = 0x00000100 -) - -//sys findFirstChangeNotification(path *uint16, watchSubtree bool, filter uint32) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = kernel32.FindFirstChangeNotificationW -//sys findNextChangeNotification(handle windows.Handle) (err error) = kernel32.FindNextChangeNotification - var haveStartedWatchingConfigDir bool func startWatchingConfigDir() { @@ -36,7 +22,7 @@ func startWatchingConfigDir() { h := windows.InvalidHandle defer func() { if h != windows.InvalidHandle { - windows.CloseHandle(h) + windows.FindCloseChangeNotification(h) } haveStartedWatchingConfigDir = false }() @@ -45,7 +31,7 @@ func startWatchingConfigDir() { if err != nil { return } - h, err = findFirstChangeNotification(windows.StringToUTF16Ptr(configFileDir), true, fncFILE_NAME|fncDIR_NAME|fncATTRIBUTES|fncSIZE|fncLAST_WRITE|fncLAST_ACCESS|fncCREATION|fncSECURITY) + h, err = windows.FindFirstChangeNotification(configFileDir, true, windows.FILE_NOTIFY_CHANGE_FILE_NAME|windows.FILE_NOTIFY_CHANGE_DIR_NAME|windows.FILE_NOTIFY_CHANGE_ATTRIBUTES|windows.FILE_NOTIFY_CHANGE_SIZE|windows.FILE_NOTIFY_CHANGE_LAST_WRITE|windows.FILE_NOTIFY_CHANGE_LAST_ACCESS|windows.FILE_NOTIFY_CHANGE_CREATION|windows.FILE_NOTIFY_CHANGE_SECURITY) if err != nil { log.Printf("Unable to monitor config directory: %v", err) return @@ -54,7 +40,7 @@ func startWatchingConfigDir() { s, err := windows.WaitForSingleObject(h, windows.INFINITE) if err != nil || s == windows.WAIT_FAILED { log.Printf("Unable to wait on config directory watcher: %v", err) - windows.CloseHandle(h) + windows.FindCloseChangeNotification(h) h = windows.InvalidHandle goto startover } @@ -63,10 +49,10 @@ func startWatchingConfigDir() { cb.cb() } - err = findNextChangeNotification(h) + err = windows.FindNextChangeNotification(h) if err != nil { log.Printf("Unable to monitor config directory again: %v", err) - windows.CloseHandle(h) + windows.FindCloseChangeNotification(h) h = windows.InvalidHandle goto startover } diff --git a/conf/writer.go b/conf/writer.go index 748c1d61..162962b5 100644 --- a/conf/writer.go +++ b/conf/writer.go @@ -1,13 +1,19 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package conf import ( "fmt" + "net/netip" "strings" + "unsafe" + + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/windows/driver" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) func (conf *Config) ToWgQuick() string { @@ -28,11 +34,12 @@ func (conf *Config) ToWgQuick() string { output.WriteString(fmt.Sprintf("Address = %s\n", strings.Join(addrStrings[:], ", "))) } - if len(conf.Interface.DNS) > 0 { - addrStrings := make([]string, len(conf.Interface.DNS)) - for i, address := range conf.Interface.DNS { - addrStrings[i] = address.String() + if len(conf.Interface.DNS)+len(conf.Interface.DNSSearch) > 0 { + addrStrings := make([]string, 0, len(conf.Interface.DNS)+len(conf.Interface.DNSSearch)) + for _, address := range conf.Interface.DNS { + addrStrings = append(addrStrings, address.String()) } + addrStrings = append(addrStrings, conf.Interface.DNSSearch...) output.WriteString(fmt.Sprintf("DNS = %s\n", strings.Join(addrStrings[:], ", "))) } @@ -40,6 +47,22 @@ func (conf *Config) ToWgQuick() string { output.WriteString(fmt.Sprintf("MTU = %d\n", conf.Interface.MTU)) } + if len(conf.Interface.PreUp) > 0 { + output.WriteString(fmt.Sprintf("PreUp = %s\n", conf.Interface.PreUp)) + } + if len(conf.Interface.PostUp) > 0 { + output.WriteString(fmt.Sprintf("PostUp = %s\n", conf.Interface.PostUp)) + } + if len(conf.Interface.PreDown) > 0 { + output.WriteString(fmt.Sprintf("PreDown = %s\n", conf.Interface.PreDown)) + } + if len(conf.Interface.PostDown) > 0 { + output.WriteString(fmt.Sprintf("PostDown = %s\n", conf.Interface.PostDown)) + } + if conf.Interface.TableOff { + output.WriteString("Table = off\n") + } + for _, peer := range conf.Peers { output.WriteString("\n[Peer]\n") @@ -68,43 +91,50 @@ func (conf *Config) ToWgQuick() string { return output.String() } -func (conf *Config) ToUAPI() (uapi string, dnsErr error) { - var output strings.Builder - output.WriteString(fmt.Sprintf("private_key=%s\n", conf.Interface.PrivateKey.HexString())) - - if conf.Interface.ListenPort > 0 { - output.WriteString(fmt.Sprintf("listen_port=%d\n", conf.Interface.ListenPort)) - } - - if len(conf.Peers) > 0 { - output.WriteString("replace_peers=true\n") +func (config *Config) ToDriverConfiguration() (*driver.Interface, uint32) { + preallocation := unsafe.Sizeof(driver.Interface{}) + uintptr(len(config.Peers))*unsafe.Sizeof(driver.Peer{}) + for i := range config.Peers { + preallocation += uintptr(len(config.Peers[i].AllowedIPs)) * unsafe.Sizeof(driver.AllowedIP{}) } - - for _, peer := range conf.Peers { - output.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey.HexString())) - - if !peer.PresharedKey.IsZero() { - output.WriteString(fmt.Sprintf("preshared_key=%s\n", peer.PresharedKey.HexString())) + var c driver.ConfigBuilder + c.Preallocate(uint32(preallocation)) + c.AppendInterface(&driver.Interface{ + Flags: driver.InterfaceHasPrivateKey | driver.InterfaceHasListenPort, + ListenPort: config.Interface.ListenPort, + PrivateKey: config.Interface.PrivateKey, + PeerCount: uint32(len(config.Peers)), + }) + for i := range config.Peers { + flags := driver.PeerHasPublicKey | driver.PeerHasPersistentKeepalive + if !config.Peers[i].PresharedKey.IsZero() { + flags |= driver.PeerHasPresharedKey } - - if !peer.Endpoint.IsEmpty() { - var resolvedIP string - resolvedIP, dnsErr = resolveHostname(peer.Endpoint.Host) - if dnsErr != nil { - return + var endpoint winipcfg.RawSockaddrInet + if !config.Peers[i].Endpoint.IsEmpty() { + addr, err := netip.ParseAddr(config.Peers[i].Endpoint.Host) + if err == nil { + flags |= driver.PeerHasEndpoint + endpoint.SetAddrPort(netip.AddrPortFrom(addr, config.Peers[i].Endpoint.Port)) } - resolvedEndpoint := Endpoint{resolvedIP, peer.Endpoint.Port} - output.WriteString(fmt.Sprintf("endpoint=%s\n", resolvedEndpoint.String())) } - - output.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.PersistentKeepalive)) - - if len(peer.AllowedIPs) > 0 { - output.WriteString("replace_allowed_ips=true\n") - for _, address := range peer.AllowedIPs { - output.WriteString(fmt.Sprintf("allowed_ip=%s\n", address.String())) + c.AppendPeer(&driver.Peer{ + Flags: flags, + PublicKey: config.Peers[i].PublicKey, + PresharedKey: config.Peers[i].PresharedKey, + PersistentKeepalive: config.Peers[i].PersistentKeepalive, + Endpoint: endpoint, + AllowedIPsCount: uint32(len(config.Peers[i].AllowedIPs)), + }) + for j := range config.Peers[i].AllowedIPs { + a := &driver.AllowedIP{Cidr: uint8(config.Peers[i].AllowedIPs[j].Bits())} + copy(a.Address[:], config.Peers[i].AllowedIPs[j].Addr().AsSlice()) + if config.Peers[i].AllowedIPs[j].Addr().Is4() { + a.AddressFamily = windows.AF_INET + } else if config.Peers[i].AllowedIPs[j].Addr().Is6() { + a.AddressFamily = windows.AF_INET6 } + c.AppendAllowedIP(a) } } - return output.String(), nil + return c.Interface() } diff --git a/conf/zsyscall_windows.go b/conf/zsyscall_windows.go deleted file mode 100644 index 9dcf68fe..00000000 --- a/conf/zsyscall_windows.go +++ /dev/null @@ -1,83 +0,0 @@ -// Code generated by 'go generate'; DO NOT EDIT. - -package conf - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return nil - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( - modwininet = windows.NewLazySystemDLL("wininet.dll") - modkernel32 = windows.NewLazySystemDLL("kernel32.dll") - - procInternetGetConnectedState = modwininet.NewProc("InternetGetConnectedState") - procFindFirstChangeNotificationW = modkernel32.NewProc("FindFirstChangeNotificationW") - procFindNextChangeNotification = modkernel32.NewProc("FindNextChangeNotification") -) - -func internetGetConnectedState(flags *uint32, reserved uint32) (connected bool) { - r0, _, _ := syscall.Syscall(procInternetGetConnectedState.Addr(), 2, uintptr(unsafe.Pointer(flags)), uintptr(reserved), 0) - connected = r0 != 0 - return -} - -func findFirstChangeNotification(path *uint16, watchSubtree bool, filter uint32) (handle windows.Handle, err error) { - var _p0 uint32 - if watchSubtree { - _p0 = 1 - } else { - _p0 = 0 - } - r0, _, e1 := syscall.Syscall(procFindFirstChangeNotificationW.Addr(), 3, uintptr(unsafe.Pointer(path)), uintptr(_p0), uintptr(filter)) - handle = windows.Handle(r0) - if handle == windows.InvalidHandle { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func findNextChangeNotification(handle windows.Handle) (err error) { - r1, _, e1 := syscall.Syscall(procFindNextChangeNotification.Addr(), 1, uintptr(handle), 0, 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} |