From 76bfd58543e57a7e91e1ef096780acaaea151712 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Tue, 7 May 2019 09:27:01 +0200 Subject: conf: retry DNS resoluion when no internet present --- conf/dnsresolver_windows.go | 88 +++++++++++++++++++++++++++++++++++++++++++++ conf/mksyscall.go | 2 +- conf/writer.go | 27 ++++---------- conf/zsyscall_windows.go | 8 +++++ 4 files changed, 103 insertions(+), 22 deletions(-) create mode 100644 conf/dnsresolver_windows.go diff --git a/conf/dnsresolver_windows.go b/conf/dnsresolver_windows.go new file mode 100644 index 00000000..9c1e817c --- /dev/null +++ b/conf/dnsresolver_windows.go @@ -0,0 +1,88 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package conf + +import ( + "fmt" + "golang.org/x/sys/windows" + "log" + "net" + "syscall" + "time" + "unsafe" +) + +//sys internetGetConnectedState(flags *uint32, reserved uint32) (connected bool) = wininet.InternetGetConnectedState + +func resolveHostname(name string) (resolvedIpString string, err error) { + const maxTries = 10 + + for i := 0; i < maxTries; i++ { + resolvedIpString, err = resolveHostnameOnce(name) + if err == nil { + return + } + if err == windows.WSATRY_AGAIN { + log.Printf("Temporary DNS error when resolving %s, sleeping for 4 seconds", name) + time.Sleep(time.Second * 4) + continue + } + var state uint32 + if err == windows.WSAHOST_NOT_FOUND && !internetGetConnectedState(&state, 0) { + log.Printf("Host not found when resolving %s, but no Internet connection available, sleeping for 4 seconds", name) + time.Sleep(time.Second * 4) + continue + } + return + } + return +} + +func resolveHostnameOnce(name string) (resolvedIpString string, err error) { + hints := windows.AddrinfoW{ + Family: windows.AF_UNSPEC, + Socktype: windows.SOCK_DGRAM, + Protocol: windows.IPPROTO_IP, + } + var result *windows.AddrinfoW + name16, err := windows.UTF16PtrFromString(name) + if err != nil { + return + } + err = windows.GetAddrInfoW(name16, nil, &hints, &result) + if err != nil { + return + } + if result == nil { + err = windows.WSAHOST_NOT_FOUND + return + } + defer windows.FreeAddrInfoW(result) + ipv6 := "" + for ; result != nil; result = result.Next { + addr := unsafe.Pointer(result.Addr) + switch result.Family { + case windows.AF_INET: + a := (*syscall.RawSockaddrInet4)(addr).Addr + return net.IP{a[0], a[1], a[2], a[3]}.String(), nil + case windows.AF_INET6: + if len(ipv6) != 0 { + continue + } + a := (*syscall.RawSockaddrInet6)(addr).Addr + ipv6 = net.IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}.String() + scope := uint32((*syscall.RawSockaddrInet6)(addr).Scope_id) + if scope != 0 { + ipv6 += fmt.Sprintf("%%%d", scope) + } + } + } + if len(ipv6) != 0 { + return ipv6, nil + } + err = windows.WSAHOST_NOT_FOUND + return +} diff --git a/conf/mksyscall.go b/conf/mksyscall.go index 2bdb0204..b5c4857f 100644 --- a/conf/mksyscall.go +++ b/conf/mksyscall.go @@ -5,4 +5,4 @@ package conf -//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go path_windows.go storewatcher_windows.go +//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go dnsresolver_windows.go path_windows.go storewatcher_windows.go diff --git a/conf/writer.go b/conf/writer.go index 642d14a7..d9504178 100644 --- a/conf/writer.go +++ b/conf/writer.go @@ -6,9 +6,7 @@ package conf import ( - "errors" "fmt" - "net" "strings" ) @@ -70,7 +68,7 @@ func (conf *Config) ToWgQuick() string { return output.String() } -func (conf *Config) ToUAPI() (string, error) { +func (conf *Config) ToUAPI() (uapi string, dnsErr error) { var output strings.Builder output.WriteString(fmt.Sprintf("private_key=%s\n", conf.Interface.PrivateKey.HexString())) @@ -90,25 +88,12 @@ func (conf *Config) ToUAPI() (string, error) { } if !peer.Endpoint.IsEmpty() { - ips, err := net.LookupIP(peer.Endpoint.Host) - if err != nil { - return "", err + var resolvedIp string + resolvedIp, dnsErr = resolveHostname(peer.Endpoint.Host) + if dnsErr != nil { + return } - var ip net.IP - for _, iterip := range ips { - iterip = iterip.To4() - if iterip != nil { - ip = iterip - break - } - if ip == nil { - ip = iterip - } - } - if ip == nil { - return "", errors.New("Unable to resolve IP address of endpoint") - } - resolvedEndpoint := Endpoint{ip.String(), peer.Endpoint.Port} + resolvedEndpoint := Endpoint{resolvedIp, peer.Endpoint.Port} output.WriteString(fmt.Sprintf("endpoint=%s\n", resolvedEndpoint.String())) } diff --git a/conf/zsyscall_windows.go b/conf/zsyscall_windows.go index 64bec1f4..d8984bef 100644 --- a/conf/zsyscall_windows.go +++ b/conf/zsyscall_windows.go @@ -37,16 +37,24 @@ func errnoErr(e syscall.Errno) error { } var ( + modwininet = windows.NewLazySystemDLL("wininet.dll") modole32 = windows.NewLazySystemDLL("ole32.dll") modshell32 = windows.NewLazySystemDLL("shell32.dll") modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + procInternetGetConnectedState = modwininet.NewProc("InternetGetConnectedState") procCoTaskMemFree = modole32.NewProc("CoTaskMemFree") procSHGetKnownFolderPath = modshell32.NewProc("SHGetKnownFolderPath") procFindFirstChangeNotificationW = modkernel32.NewProc("FindFirstChangeNotificationW") procFindNextChangeNotification = modkernel32.NewProc("FindNextChangeNotification") ) +func internetGetConnectedState(flags *uint32, reserved uint32) (connected bool) { + r0, _, _ := syscall.Syscall(procInternetGetConnectedState.Addr(), 2, uintptr(unsafe.Pointer(flags)), uintptr(reserved), 0) + connected = r0 != 0 + return +} + func coTaskMemFree(pointer uintptr) { syscall.Syscall(procCoTaskMemFree.Addr(), 1, uintptr(pointer), 0, 0) return -- cgit v1.2.3-59-g8ed1b