From e11de8481c4c349456125e5ad6782d82d423a5b9 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Fri, 29 Jan 2021 22:29:22 +0100 Subject: winipcfg: move to undocumented DNS function Signed-off-by: Jason A. Donenfeld --- tunnel/addressconfig.go | 18 +--- tunnel/interfacewatcher.go | 3 +- tunnel/winipcfg/luid.go | 166 +++++++++-------------------------- tunnel/winipcfg/netsh.go | 65 +++++++++++++- tunnel/winipcfg/types.go | 37 ++++++++ tunnel/winipcfg/winipcfg.go | 27 ++++++ tunnel/winipcfg/winipcfg_test.go | 53 ++--------- tunnel/winipcfg/zwinipcfg_windows.go | 37 ++++++++ 8 files changed, 212 insertions(+), 194 deletions(-) diff --git a/tunnel/addressconfig.go b/tunnel/addressconfig.go index 74166ab2..a61905ea 100644 --- a/tunnel/addressconfig.go +++ b/tunnel/addressconfig.go @@ -160,23 +160,7 @@ func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, tun *t return err } - dnsSearch := "" - if len(conf.Interface.DNSSearch) > 0 { - dnsSearch = conf.Interface.DNSSearch[0] - } - err = luid.SetDNSDomain(dnsSearch) - if err != nil { - return nil - } - if len(conf.Interface.DNSSearch) > 1 { - log.Printf("Warning: %d DNS search domains were specified, but only one is supported, so the first one (%s) was used.", len(conf.Interface.DNSSearch), dnsSearch) - } - err = luid.SetDNSForFamily(family, conf.Interface.DNS) - if err != nil { - return err - } - - return nil + return luid.SetDNS(family, conf.Interface.DNS, conf.Interface.DNSSearch) } func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error { diff --git a/tunnel/interfacewatcher.go b/tunnel/interfacewatcher.go index 8298169f..80406874 100644 --- a/tunnel/interfacewatcher.go +++ b/tunnel/interfacewatcher.go @@ -192,9 +192,10 @@ func (iw *interfaceWatcher) Destroy() { luid := winipcfg.LUID(tun.LUID()) luid.FlushRoutes(windows.AF_INET) luid.FlushIPAddresses(windows.AF_INET) + luid.FlushDNS(windows.AF_INET) luid.FlushRoutes(windows.AF_INET6) luid.FlushIPAddresses(windows.AF_INET6) - luid.FlushDNS() + luid.FlushDNS(windows.AF_INET6) } iw.setupMutex.Unlock() } diff --git a/tunnel/winipcfg/luid.go b/tunnel/winipcfg/luid.go index 88b946d3..efdf1ba4 100644 --- a/tunnel/winipcfg/luid.go +++ b/tunnel/winipcfg/luid.go @@ -7,11 +7,10 @@ package winipcfg import ( "errors" - "fmt" "net" + "strings" "golang.org/x/sys/windows" - "golang.org/x/sys/windows/registry" ) // LUID represents a network interface. @@ -317,141 +316,60 @@ func (luid LUID) DNS() ([]net.IP, error) { return r, nil } -const ( - netshCmdTemplateFlush4 = "interface ipv4 set dnsservers name=%d source=static address=none validate=no register=both" - netshCmdTemplateFlush6 = "interface ipv6 set dnsservers name=%d source=static address=none validate=no register=both" - netshCmdTemplateAdd4 = "interface ipv4 add dnsservers name=%d address=%s validate=no" - netshCmdTemplateAdd6 = "interface ipv6 add dnsservers name=%d address=%s validate=no" -) - -// FlushDNS method clears all DNS servers associated with the adapter. -func (luid LUID) FlushDNS() error { - cmds := make([]string, 0, 2) - ipif4, err := luid.IPInterface(windows.AF_INET) - if err == nil { - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateFlush4, ipif4.InterfaceIndex)) - } - ipif6, err := luid.IPInterface(windows.AF_INET6) - if err == nil { - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateFlush6, ipif6.InterfaceIndex)) +// SetDNS method clears previous and associates new DNS servers and search domains with the adapter for a specific family. +func (luid LUID) SetDNS(family AddressFamily, servers []net.IP, domains []string) error { + if family != windows.AF_INET && family != windows.AF_INET6 { + return windows.ERROR_PROTOCOL_UNREACHABLE } - if len(cmds) == 0 { - return nil - } - return runNetsh(cmds) -} - -// AddDNS method associates additional DNS servers with the adapter. -func (luid LUID) AddDNS(dnses []net.IP) error { - var ipif4, ipif6 *MibIPInterfaceRow - var err error - cmds := make([]string, 0, len(dnses)) - for i := 0; i < len(dnses); i++ { - if v4 := dnses[i].To4(); v4 != nil { - if ipif4 == nil { - ipif4, err = luid.IPInterface(windows.AF_INET) - if err != nil { - return err - } - } - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif4.InterfaceIndex, v4.String())) - } else if v6 := dnses[i].To16(); v6 != nil { - if ipif6 == nil { - ipif6, err = luid.IPInterface(windows.AF_INET6) - if err != nil { - return err - } - } - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif6.InterfaceIndex, v6.String())) + var filteredServers []string + for _, server := range servers { + if v4 := server.To4(); v4 != nil && family == windows.AF_INET { + filteredServers = append(filteredServers, v4.String()) + } else if v6 := server.To16(); v4 == nil && v6 != nil && family == windows.AF_INET6 { + filteredServers = append(filteredServers, v6.String()) } } - - if len(cmds) == 0 { - return nil - } - return runNetsh(cmds) -} - -// SetDNS method clears previous and associates new DNS servers with the adapter. -func (luid LUID) SetDNS(dnses []net.IP) error { - cmds := make([]string, 0, 2+len(dnses)) - ipif4, err := luid.IPInterface(windows.AF_INET) - if err == nil { - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateFlush4, ipif4.InterfaceIndex)) - } - ipif6, err := luid.IPInterface(windows.AF_INET6) - if err == nil { - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateFlush6, ipif6.InterfaceIndex)) - } - for i := 0; i < len(dnses); i++ { - if v4 := dnses[i].To4(); v4 != nil { - if ipif4 == nil { - return windows.ERROR_NOT_SUPPORTED - } - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif4.InterfaceIndex, v4.String())) - } else if v6 := dnses[i].To16(); v6 != nil { - if ipif6 == nil { - return windows.ERROR_NOT_SUPPORTED - } - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif6.InterfaceIndex, v6.String())) - } - } - - if len(cmds) == 0 { - return nil - } - return runNetsh(cmds) -} - -// SetDNSForFamily method clears previous and associates new DNS servers with the adapter for a specific family. -func (luid LUID) SetDNSForFamily(family AddressFamily, dnses []net.IP) error { - var templateFlush string - if family == windows.AF_INET { - templateFlush = netshCmdTemplateFlush4 - } else if family == windows.AF_INET6 { - templateFlush = netshCmdTemplateFlush6 - } - - cmds := make([]string, 0, 1+len(dnses)) - ipif, err := luid.IPInterface(family) + servers16, err := windows.UTF16PtrFromString(strings.Join(filteredServers, ",")) if err != nil { return err } - cmds = append(cmds, fmt.Sprintf(templateFlush, ipif.InterfaceIndex)) - for i := 0; i < len(dnses); i++ { - if v4 := dnses[i].To4(); v4 != nil && family == windows.AF_INET { - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif.InterfaceIndex, v4.String())) - } else if v6 := dnses[i].To16(); v4 == nil && v6 != nil && family == windows.AF_INET6 { - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif.InterfaceIndex, v6.String())) - } + domains16, err := windows.UTF16PtrFromString(strings.Join(domains, ",")) + if err != nil { + return err } - return runNetsh(cmds) -} - -// SetDNSDomain method sets the interface-specific DNS domain. -func (luid LUID) SetDNSDomain(domain string) error { guid, err := luid.GUID() if err != nil { - return fmt.Errorf("Error converting luid to guid: %w", err) + return err } - key, err := registry.OpenKey(registry.LOCAL_MACHINE, fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", guid), registry.QUERY_VALUE) - if err != nil { - return fmt.Errorf("Error opening adapter-specific TCP/IP network registry key: %w", err) + var maybeV6 uint64 + if family == windows.AF_INET6 { + maybeV6 = disFlagsIPv6 + } + // For >= Windows 10 1809 + err = setInterfaceDnsSettings(*guid, &dnsInterfaceSettings{ + Version: disVersion1, + Flags: disFlagsNameServer | disFlagsSearchList | maybeV6, + NameServer: servers16, + SearchList: domains16, + }) + if err == nil || !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) { + return err } - paths, _, err := key.GetStringsValue("IpConfig") - key.Close() + + // For < Windows 10 1809 + err = luid.fallbackSetDNSForFamily(family, servers) if err != nil { - return fmt.Errorf("Error reading IpConfig registry key: %w", err) - } - if len(paths) == 0 { - return errors.New("No TCP/IP interfaces found on adapter") + return err } - key, err = registry.OpenKey(registry.LOCAL_MACHINE, fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), registry.SET_VALUE) - if err != nil { - return fmt.Errorf("Unable to open TCP/IP network registry key: %w", err) + if len(domains) > 0 { + return luid.fallbackSetDNSDomain(domains[0]) + } else { + return luid.fallbackSetDNSDomain("") } - err = key.SetStringValue("Domain", domain) - key.Close() - return err +} + +// FlushDNS method clears all DNS servers associated with the adapter. +func (luid LUID) FlushDNS(family AddressFamily) error { + return luid.SetDNS(family, nil, nil) } diff --git a/tunnel/winipcfg/netsh.go b/tunnel/winipcfg/netsh.go index 803a9a35..1f3d12d0 100644 --- a/tunnel/winipcfg/netsh.go +++ b/tunnel/winipcfg/netsh.go @@ -7,25 +7,25 @@ package winipcfg import ( "bytes" + "errors" "fmt" "io" + "net" "os/exec" "path/filepath" "strings" "syscall" "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" ) -// I wish we didn't have to do this. netiohlp.dll (what's used by netsh.exe) has some nice tricks with writing directly -// to the registry and the nsi kernel object, but it's not clear copying those makes for a stable interface. WMI doesn't -// work with v6. CMI isn't in Windows 7. func runNetsh(cmds []string) error { system32, err := windows.GetSystemDirectory() if err != nil { return err } - cmd := exec.Command(filepath.Join(system32, "netsh.exe")) // I wish we could append (, "-f", "CONIN$") but Go sets up the process context wrong. + cmd := exec.Command(filepath.Join(system32, "netsh.exe")) cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} stdin, err := cmd.StdinPipe() @@ -49,3 +49,60 @@ func runNetsh(cmds []string) error { } return nil } + +const ( + netshCmdTemplateFlush4 = "interface ipv4 set dnsservers name=%d source=static address=none validate=no register=both" + netshCmdTemplateFlush6 = "interface ipv6 set dnsservers name=%d source=static address=none validate=no register=both" + netshCmdTemplateAdd4 = "interface ipv4 add dnsservers name=%d address=%s validate=no" + netshCmdTemplateAdd6 = "interface ipv6 add dnsservers name=%d address=%s validate=no" +) + +func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []net.IP) error { + var templateFlush string + if family == windows.AF_INET { + templateFlush = netshCmdTemplateFlush4 + } else if family == windows.AF_INET6 { + templateFlush = netshCmdTemplateFlush6 + } + + cmds := make([]string, 0, 1+len(dnses)) + ipif, err := luid.IPInterface(family) + if err != nil { + return err + } + cmds = append(cmds, fmt.Sprintf(templateFlush, ipif.InterfaceIndex)) + for i := 0; i < len(dnses); i++ { + if v4 := dnses[i].To4(); v4 != nil && family == windows.AF_INET { + cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif.InterfaceIndex, v4.String())) + } else if v6 := dnses[i].To16(); v4 == nil && v6 != nil && family == windows.AF_INET6 { + cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif.InterfaceIndex, v6.String())) + } + } + return runNetsh(cmds) +} + +func (luid LUID) fallbackSetDNSDomain(domain string) error { + guid, err := luid.GUID() + if err != nil { + return fmt.Errorf("Error converting luid to guid: %w", err) + } + key, err := registry.OpenKey(registry.LOCAL_MACHINE, fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", guid), registry.QUERY_VALUE) + if err != nil { + return fmt.Errorf("Error opening adapter-specific TCP/IP network registry key: %w", err) + } + paths, _, err := key.GetStringsValue("IpConfig") + key.Close() + if err != nil { + return fmt.Errorf("Error reading IpConfig registry key: %w", err) + } + if len(paths) == 0 { + return errors.New("No TCP/IP interfaces found on adapter") + } + key, err = registry.OpenKey(registry.LOCAL_MACHINE, fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), registry.SET_VALUE) + if err != nil { + return fmt.Errorf("Unable to open TCP/IP network registry key: %w", err) + } + err = key.SetStringValue("Domain", domain) + key.Close() + return err +} diff --git a/tunnel/winipcfg/types.go b/tunnel/winipcfg/types.go index d60a149d..02f7f788 100644 --- a/tunnel/winipcfg/types.go +++ b/tunnel/winipcfg/types.go @@ -940,6 +940,43 @@ func (tab *mibIPforwardTable2) free() { freeMibTable(unsafe.Pointer(tab)) } +// +// Undocumented DNS API +// + +// dnsInterfaceSettings is mean to be used with setInterfaceDnsSettings +type dnsInterfaceSettings struct { + Version uint32 + _ [4]byte + Flags uint64 + Domain *uint16 + NameServer *uint16 + SearchList *uint16 + RegistrationEnabled uint32 + RegisterAdapterName uint32 + EnableLLMNR uint32 + QueryAdapterName uint32 + ProfileNameServer *uint16 +} + +const ( + disVersion1 = 1 + disVersion2 = 2 + + disFlagsIPv6 = 0x1 + disFlagsNameServer = 0x2 + disFlagsSearchList = 0x4 + disFlagsRegistrationEnabled = 0x8 + disFlagsRegisterAdapterName = 0x10 + disFlagsDomain = 0x20 + disFlagsHostname = 0x40 // ?? + disFlagsEnableLLMNR = 0x80 + disFlagsQueryAdapterName = 0x100 + disFlagsProfileNameServer = 0x200 + disFlagsVersion2 = 0x400 // ?? - v2 only + disFlagsMoreFlags = 0x800 // ?? - v2 only +) + // unsafeSlice updates the slice slicePtr to be a slice // referencing the provided data with its length & capacity set to // lenCap. diff --git a/tunnel/winipcfg/winipcfg.go b/tunnel/winipcfg/winipcfg.go index 8d119f0e..54040a1c 100644 --- a/tunnel/winipcfg/winipcfg.go +++ b/tunnel/winipcfg/winipcfg.go @@ -6,6 +6,7 @@ package winipcfg import ( + "runtime" "unsafe" "golang.org/x/sys/windows" @@ -166,3 +167,29 @@ func GetIPForwardTable2(family AddressFamily) ([]MibIPforwardRow2, error) { // https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-cancelmibchangenotify2 //sys cancelMibChangeNotify2(notificationHandle windows.Handle) (ret error) = iphlpapi.CancelMibChangeNotify2 + +// +// Undocumented DNS API +// + +//sys setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *dnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings? +//sys setInterfaceDnsSettingsByQwords(guid1 uintptr, guid2 uintptr, settings *dnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings? +//sys setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *dnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings? + +// The GUID is passed by value, not by reference, which means different +// things on different calling conventions. On amd64, this means it's +// passed by reference anyway, while on arm, arm64, and 386, it's split +// into words. +func setInterfaceDnsSettings(guid windows.GUID, settings *dnsInterfaceSettings) error { + words := (*[4]uintptr)(unsafe.Pointer(&guid)) + switch runtime.GOARCH { + case "amd64": + return setInterfaceDnsSettingsByPtr(&guid, settings) + case "arm64": + return setInterfaceDnsSettingsByQwords(words[0], words[1], settings) + case "arm", "386": + return setInterfaceDnsSettingsByDwords(words[0], words[1], words[2], words[3], settings) + default: + panic("unknown calling convention") + } +} diff --git a/tunnel/winipcfg/winipcfg_test.go b/tunnel/winipcfg/winipcfg_test.go index e88999ca..7689a0c1 100644 --- a/tunnel/winipcfg/winipcfg_test.go +++ b/tunnel/winipcfg/winipcfg_test.go @@ -25,6 +25,7 @@ import ( "bytes" "net" "strings" + "syscall" "testing" "time" @@ -592,7 +593,7 @@ func TestFlushDNS(t *testing.T) { t.Errorf("LUID.DNS() returned an error: %w", err) } - err = ifc.LUID.FlushDNS() + err = ifc.LUID.FlushDNS(syscall.AF_INET) if err != nil { t.Errorf("LUID.FlushDNS() returned an error: %w", err) } @@ -613,51 +614,7 @@ func TestFlushDNS(t *testing.T) { t.Errorf("DNSServerAddresses contains %d items, although FlushDNS is executed successfully.", n) } - err = ifc.LUID.SetDNS(prevDNSes) - if err != nil { - t.Errorf("LUID.SetDNS() returned an error: %v.", err) - } -} - -func TestAddDNS(t *testing.T) { - ifc, err := getTestInterface() - if err != nil { - t.Errorf("getTestInterface() returned an error: %w", err) - return - } - if !runningElevated() { - t.Errorf("%s requires elevation", t.Name()) - return - } - - prevDNSes, err := ifc.LUID.DNS() - if err != nil { - t.Errorf("LUID.DNS() returned an error: %w", err) - } - expectedDNSes := append(prevDNSes, dnsesToSet...) - - err = ifc.LUID.AddDNS(dnsesToSet) - if err != nil { - t.Errorf("LUID.AddDNS() returned an error: %w", err) - return - } - - ifc, _ = getTestInterface() - - newDNSes, err := ifc.LUID.DNS() - if err != nil { - t.Errorf("LUID.DNS() returned an error: %w", err) - } else if len(newDNSes) != len(expectedDNSes) { - t.Errorf("expectedDNSes contains %d items, while DNSServerAddresses contains %d.", len(expectedDNSes), len(newDNSes)) - } else { - for i := range expectedDNSes { - if !expectedDNSes[i].Equal(newDNSes[i]) { - t.Errorf("expectedDNSes[%d] = %s while DNSServerAddresses[%d] = %s.", i, expectedDNSes[i].String(), i, newDNSes[i].String()) - } - } - } - - err = ifc.LUID.SetDNS(prevDNSes) + err = ifc.LUID.SetDNS(windows.AF_INET, prevDNSes, nil) if err != nil { t.Errorf("LUID.SetDNS() returned an error: %v.", err) } @@ -679,7 +636,7 @@ func TestSetDNS(t *testing.T) { t.Errorf("LUID.DNS() returned an error: %w", err) } - err = ifc.LUID.SetDNS(dnsesToSet) + err = ifc.LUID.SetDNS(windows.AF_INET, dnsesToSet, nil) if err != nil { t.Errorf("LUID.SetDNS() returned an error: %w", err) return @@ -700,7 +657,7 @@ func TestSetDNS(t *testing.T) { } } - err = ifc.LUID.SetDNS(prevDNSes) + err = ifc.LUID.SetDNS(windows.AF_INET, prevDNSes, nil) if err != nil { t.Errorf("LUID.SetDNS() returned an error: %v.", err) } diff --git a/tunnel/winipcfg/zwinipcfg_windows.go b/tunnel/winipcfg/zwinipcfg_windows.go index cf661548..c4bf3b00 100644 --- a/tunnel/winipcfg/zwinipcfg_windows.go +++ b/tunnel/winipcfg/zwinipcfg_windows.go @@ -66,6 +66,7 @@ var ( procNotifyIpInterfaceChange = modiphlpapi.NewProc("NotifyIpInterfaceChange") procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2") procNotifyUnicastIpAddressChange = modiphlpapi.NewProc("NotifyUnicastIpAddressChange") + procSetInterfaceDnsSettings = modiphlpapi.NewProc("SetInterfaceDnsSettings") procSetIpForwardEntry2 = modiphlpapi.NewProc("SetIpForwardEntry2") procSetIpInterfaceEntry = modiphlpapi.NewProc("SetIpInterfaceEntry") procSetUnicastIpAddressEntry = modiphlpapi.NewProc("SetUnicastIpAddressEntry") @@ -279,6 +280,42 @@ func notifyUnicastIPAddressChange(family AddressFamily, callback uintptr, caller return } +func setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *dnsInterfaceSettings) (ret error) { + ret = procSetInterfaceDnsSettings.Find() + if ret != nil { + return + } + r0, _, _ := syscall.Syscall6(procSetInterfaceDnsSettings.Addr(), 5, uintptr(guid1), uintptr(guid2), uintptr(guid3), uintptr(guid4), uintptr(unsafe.Pointer(settings)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *dnsInterfaceSettings) (ret error) { + ret = procSetInterfaceDnsSettings.Find() + if ret != nil { + return + } + r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(settings)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func setInterfaceDnsSettingsByQwords(guid1 uintptr, guid2 uintptr, settings *dnsInterfaceSettings) (ret error) { + ret = procSetInterfaceDnsSettings.Find() + if ret != nil { + return + } + r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 3, uintptr(guid1), uintptr(guid2), uintptr(unsafe.Pointer(settings))) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + func setIPForwardEntry2(route *MibIPforwardRow2) (ret error) { r0, _, _ := syscall.Syscall(procSetIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0) if r0 != 0 { -- cgit v1.2.3-59-g8ed1b