summaryrefslogtreecommitdiffstats
path: root/tun
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-01-11 16:28:12 +0100
committerJason A. Donenfeld <Jason@zx2c4.com>2021-01-13 16:33:40 +0100
commit675955de5d0a1bad66cd7e99671b031fbce8f589 (patch)
tree17c44fc557f462b55e745b3257b9db404cff0471 /tun
parentdevice: receive: do not exit immediately on transient UDP receive errors (diff)
downloadwireguard-go-675955de5d0a1bad66cd7e99671b031fbce8f589.tar.xz
wireguard-go-675955de5d0a1bad66cd7e99671b031fbce8f589.zip
tun: add tcpip stack tunnel abstraction
This allows people to initiate connections over WireGuard without any underlying operating system support. I'm not crazy about the trash it adds to go.sum, but the code this actually adds to the binaries seems contained to the gvisor repo. For the TCP/IP implementation, it uses gvisor. And it borrows some internals from the Go standard library's resolver in order to bring Dial and DialContext to tun_net, along with the LookupHost helper function. This allows for things like HTTP2-over-TLS to work quite well: package main import ( "io" "log" "net" "net/http" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" ) func main() { tun, tnet, err := tun.CreateNetTUN([]net.IP{net.ParseIP("192.168.4.29")}, []net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")}, 1420) if err != nil { log.Panic(err) } dev := device.NewDevice(tun, &device.Logger{log.Default(), log.Default(), log.Default()}) dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b endpoint=163.172.161.0:12912 allowed_ip=0.0.0.0/0 `) dev.Up() client := http.Client{ Transport: &http.Transport{ DialContext: tnet.DialContext, }, } resp, err := client.Get("https://www.zx2c4.com/ip") if err != nil { log.Panic(err) } body, err := io.ReadAll(resp.Body) if err != nil { log.Panic(err) } log.Println(string(body)) } Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'tun')
-rw-r--r--tun/tun_net.go816
1 files changed, 816 insertions, 0 deletions
diff --git a/tun/tun_net.go b/tun/tun_net.go
new file mode 100644
index 0000000..8543341
--- /dev/null
+++ b/tun/tun_net.go
@@ -0,0 +1,816 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+import (
+ "context"
+ "crypto/rand"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "os"
+ "strconv"
+ "strings"
+ "time"
+
+ "golang.org/x/net/dns/dnsmessage"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+)
+
+type netTun struct {
+ stack *stack.Stack
+ dispatcher stack.NetworkDispatcher
+ events chan Event
+ incomingPacket chan buffer.VectorisedView
+ mtu int
+ dnsServers []net.IP
+ hasV4, hasV6 bool
+}
+type endpoint netTun
+type Net netTun
+
+func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+func (e *endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+func (e *endpoint) MTU() uint32 {
+ mtu, err := (*netTun)(e).MTU()
+ if err != nil {
+ panic(err)
+ }
+ return uint32(mtu)
+}
+
+func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return stack.CapabilityNone
+}
+
+func (*endpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+func (*endpoint) LinkAddress() tcpip.LinkAddress {
+ return ""
+}
+
+func (*endpoint) Wait() {}
+
+func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.incomingPacket <- buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+ return nil
+}
+
+func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+func (*endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareNone
+}
+
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
+
+func CreateNetTUN(localAddresses []net.IP, dnsServers []net.IP, mtu int) (Device, *Net, error) {
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
+ HandleLocal: true,
+ }
+ dev := &netTun{
+ stack: stack.New(opts),
+ events: make(chan Event, 10),
+ incomingPacket: make(chan buffer.VectorisedView),
+ dnsServers: dnsServers,
+ mtu: mtu,
+ }
+ tcpipErr := dev.stack.CreateNIC(1, (*endpoint)(dev))
+ if tcpipErr != nil {
+ return nil, nil, fmt.Errorf("CreateNIC: %w", tcpipErr)
+ }
+ for _, ip := range localAddresses {
+ if ip4 := ip.To4(); ip4 != nil {
+ tcpipErr = dev.stack.AddAddress(1, ipv4.ProtocolNumber, tcpip.Address(ip4))
+ if tcpipErr != nil {
+ return nil, nil, fmt.Errorf("AddAddress(%v): %w", ip4, tcpipErr)
+ }
+ dev.hasV4 = true
+ } else {
+ tcpipErr = dev.stack.AddAddress(1, ipv6.ProtocolNumber, tcpip.Address(ip))
+ if tcpipErr != nil {
+ return nil, nil, fmt.Errorf("AddAddress(%v): %w", ip4, tcpipErr)
+ }
+ dev.hasV6 = true
+ }
+ }
+ if dev.hasV4 {
+ dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
+ }
+ if dev.hasV6 {
+ dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
+ }
+
+ dev.events <- EventUp
+ return dev, (*Net)(dev), nil
+}
+
+func (tun *netTun) Name() (string, error) {
+ return "go", nil
+}
+
+func (tun *netTun) File() *os.File {
+ return nil
+}
+
+func (tun *netTun) Events() chan Event {
+ return tun.events
+}
+
+func (tun *netTun) Read(buf []byte, offset int) (int, error) {
+ view, ok := <-tun.incomingPacket
+ if !ok {
+ return 0, os.ErrClosed
+ }
+ return view.Read(buf[offset:])
+}
+
+func (tun *netTun) Write(buf []byte, offset int) (int, error) {
+ packet := buf[offset:]
+ if len(packet) == 0 {
+ return 0, nil
+ }
+
+ pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(len(packet), []buffer.View{buffer.NewViewFromBytes(packet)})})
+ switch packet[0] >> 4 {
+ case 4:
+ tun.dispatcher.DeliverNetworkPacket("", "", ipv4.ProtocolNumber, pkb)
+ case 6:
+ tun.dispatcher.DeliverNetworkPacket("", "", ipv6.ProtocolNumber, pkb)
+ }
+
+ return len(buf), nil
+}
+
+func (tun *netTun) Flush() error {
+ return nil
+}
+
+func (tun *netTun) Close() error {
+ tun.stack.RemoveNIC(1)
+
+ if tun.events != nil {
+ close(tun.events)
+ }
+ if tun.incomingPacket != nil {
+ close(tun.incomingPacket)
+ }
+ return nil
+}
+
+func (tun *netTun) MTU() (int, error) {
+ return tun.mtu, nil
+}
+
+func convertToFullAddr(ip net.IP, port int) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
+ if ip4 := ip.To4(); ip4 != nil {
+ return tcpip.FullAddress{
+ NIC: 1,
+ Addr: tcpip.Address(ip4),
+ Port: uint16(port),
+ }, ipv4.ProtocolNumber
+ } else {
+ return tcpip.FullAddress{
+ NIC: 1,
+ Addr: tcpip.Address(ip),
+ Port: uint16(port),
+ }, ipv6.ProtocolNumber
+ }
+}
+
+func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
+ if addr == nil {
+ panic("todo: deal with auto addr semantics for nil addr")
+ }
+ fa, pn := convertToFullAddr(addr.IP, addr.Port)
+ return gonet.DialContextTCP(ctx, net.stack, fa, pn)
+}
+
+func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
+ if addr == nil {
+ panic("todo: deal with auto addr semantics for nil addr")
+ }
+ fa, pn := convertToFullAddr(addr.IP, addr.Port)
+ return gonet.DialTCP(net.stack, fa, pn)
+}
+
+func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
+ if addr == nil {
+ panic("todo: deal with auto addr semantics for nil addr")
+ }
+ fa, pn := convertToFullAddr(addr.IP, addr.Port)
+ return gonet.ListenTCP(net.stack, fa, pn)
+}
+
+func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
+ var lfa, rfa *tcpip.FullAddress
+ var pn tcpip.NetworkProtocolNumber
+ if laddr != nil {
+ var addr tcpip.FullAddress
+ addr, pn = convertToFullAddr(laddr.IP, laddr.Port)
+ lfa = &addr
+ }
+ if raddr != nil {
+ var addr tcpip.FullAddress
+ addr, pn = convertToFullAddr(raddr.IP, raddr.Port)
+ rfa = &addr
+ }
+ return gonet.DialUDP(net.stack, lfa, rfa, pn)
+}
+
+var (
+ errNoSuchHost = errors.New("no such host")
+ errLameReferral = errors.New("lame referral")
+ errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message")
+ errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message")
+ errServerMisbehaving = errors.New("server misbehaving")
+ errInvalidDNSResponse = errors.New("invalid DNS response")
+ errNoAnswerFromDNSServer = errors.New("no answer from DNS server")
+ errServerTemporarilyMisbehaving = errors.New("server misbehaving")
+ errCanceled = errors.New("operation was canceled")
+ errTimeout = errors.New("i/o timeout")
+ errNumericPort = errors.New("port must be numeric")
+ errNoSuitableAddress = errors.New("no suitable address found")
+ errMissingAddress = errors.New("missing address")
+)
+
+func (net *Net) LookupHost(host string) (addrs []string, err error) {
+ return net.LookupContextHost(context.Background(), host)
+}
+
+func isDomainName(s string) bool {
+ l := len(s)
+ if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
+ return false
+ }
+ last := byte('.')
+ nonNumeric := false
+ partlen := 0
+ for i := 0; i < len(s); i++ {
+ c := s[i]
+ switch {
+ default:
+ return false
+ case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
+ nonNumeric = true
+ partlen++
+ case '0' <= c && c <= '9':
+ partlen++
+ case c == '-':
+ if last == '.' {
+ return false
+ }
+ partlen++
+ nonNumeric = true
+ case c == '.':
+ if last == '.' || last == '-' {
+ return false
+ }
+ if partlen > 63 || partlen == 0 {
+ return false
+ }
+ partlen = 0
+ }
+ last = c
+ }
+ if last == '-' || partlen > 63 {
+ return false
+ }
+ return nonNumeric
+}
+
+func randU16() uint16 {
+ var b [2]byte
+ _, err := rand.Read(b[:])
+ if err != nil {
+ panic(err)
+ }
+ return binary.LittleEndian.Uint16(b[:])
+}
+
+func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) {
+ id = randU16()
+ b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true})
+ b.EnableCompression()
+ if err := b.StartQuestions(); err != nil {
+ return 0, nil, nil, err
+ }
+ if err := b.Question(q); err != nil {
+ return 0, nil, nil, err
+ }
+ tcpReq, err = b.Finish()
+ udpReq = tcpReq[2:]
+ l := len(tcpReq) - 2
+ tcpReq[0] = byte(l >> 8)
+ tcpReq[1] = byte(l)
+ return id, udpReq, tcpReq, err
+}
+
+func equalASCIIName(x, y dnsmessage.Name) bool {
+ if x.Length != y.Length {
+ return false
+ }
+ for i := 0; i < int(x.Length); i++ {
+ a := x.Data[i]
+ b := y.Data[i]
+ if 'A' <= a && a <= 'Z' {
+ a += 0x20
+ }
+ if 'A' <= b && b <= 'Z' {
+ b += 0x20
+ }
+ if a != b {
+ return false
+ }
+ }
+ return true
+}
+
+func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool {
+ if !respHdr.Response {
+ return false
+ }
+ if reqID != respHdr.ID {
+ return false
+ }
+ if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) {
+ return false
+ }
+ return true
+}
+
+func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
+ if _, err := c.Write(b); err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ b = make([]byte, 512)
+ for {
+ n, err := c.Read(b)
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ var p dnsmessage.Parser
+ h, err := p.Start(b[:n])
+ if err != nil {
+ continue
+ }
+ q, err := p.Question()
+ if err != nil || !checkResponse(id, query, h, q) {
+ continue
+ }
+ return p, h, nil
+ }
+}
+
+func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
+ if _, err := c.Write(b); err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ b = make([]byte, 1280)
+ if _, err := io.ReadFull(c, b[:2]); err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ l := int(b[0])<<8 | int(b[1])
+ if l > len(b) {
+ b = make([]byte, l)
+ }
+ n, err := io.ReadFull(c, b[:l])
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ var p dnsmessage.Parser
+ h, err := p.Start(b[:n])
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
+ }
+ q, err := p.Question()
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
+ }
+ if !checkResponse(id, query, h, q) {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
+ }
+ return p, h, nil
+}
+
+func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
+ q.Class = dnsmessage.ClassINET
+ id, udpReq, tcpReq, err := newRequest(q)
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage
+ }
+
+ for _, useUDP := range []bool{true, false} {
+ ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
+ defer cancel()
+
+ var c net.Conn
+ var err error
+ if useUDP {
+ c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: server, Port: 53})
+ } else {
+ c, err = tnet.DialContextTCP(ctx, &net.TCPAddr{IP: server, Port: 53})
+ }
+
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ if d, ok := ctx.Deadline(); ok && !d.IsZero() {
+ c.SetDeadline(d)
+ }
+ var p dnsmessage.Parser
+ var h dnsmessage.Header
+ if useUDP {
+ p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
+ } else {
+ p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
+ }
+ c.Close()
+ if err != nil {
+ if err == context.Canceled {
+ err = errCanceled
+ } else if err == context.DeadlineExceeded {
+ err = errTimeout
+ }
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
+ }
+ if h.Truncated {
+ continue
+ }
+ return p, h, nil
+ }
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer
+}
+
+func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
+ if h.RCode == dnsmessage.RCodeNameError {
+ return errNoSuchHost
+ }
+ _, err := p.AnswerHeader()
+ if err != nil && err != dnsmessage.ErrSectionDone {
+ return errCannotUnmarshalDNSMessage
+ }
+ if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
+ return errLameReferral
+ }
+ if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError {
+ if h.RCode == dnsmessage.RCodeServerFailure {
+ return errServerTemporarilyMisbehaving
+ }
+ return errServerMisbehaving
+ }
+ return nil
+}
+
+func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error {
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ return errNoSuchHost
+ }
+ if err != nil {
+ return errCannotUnmarshalDNSMessage
+ }
+ if h.Type == qtype {
+ return nil
+ }
+ if err := p.SkipAnswer(); err != nil {
+ return errCannotUnmarshalDNSMessage
+ }
+ }
+}
+
+func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
+ var lastErr error
+
+ n, err := dnsmessage.NewName(name)
+ if err != nil {
+ return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage
+ }
+ q := dnsmessage.Question{
+ Name: n,
+ Type: qtype,
+ Class: dnsmessage.ClassINET,
+ }
+
+ for i := 0; i < 2; i++ {
+ for _, server := range tnet.dnsServers {
+ p, h, err := tnet.exchange(ctx, server, q, time.Second*5)
+ if err != nil {
+ dnsErr := &net.DNSError{
+ Err: err.Error(),
+ Name: name,
+ Server: server.String(),
+ }
+ if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
+ dnsErr.IsTimeout = true
+ }
+ if _, ok := err.(*net.OpError); ok {
+ dnsErr.IsTemporary = true
+ }
+ lastErr = dnsErr
+ continue
+ }
+
+ if err := checkHeader(&p, h); err != nil {
+ dnsErr := &net.DNSError{
+ Err: err.Error(),
+ Name: name,
+ Server: server.String(),
+ }
+ if err == errServerTemporarilyMisbehaving {
+ dnsErr.IsTemporary = true
+ }
+ if err == errNoSuchHost {
+ dnsErr.IsNotFound = true
+ return p, server.String(), dnsErr
+ }
+ lastErr = dnsErr
+ continue
+ }
+
+ err = skipToAnswer(&p, qtype)
+ if err == nil {
+ return p, server.String(), nil
+ }
+ lastErr = &net.DNSError{
+ Err: err.Error(),
+ Name: name,
+ Server: server.String(),
+ }
+ if err == errNoSuchHost {
+ lastErr.(*net.DNSError).IsNotFound = true
+ return p, server.String(), lastErr
+ }
+ }
+ }
+ return dnsmessage.Parser{}, "", lastErr
+}
+
+func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) {
+ if host == "" || (!tnet.hasV6 && !tnet.hasV4) {
+ return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
+ }
+ zlen := len(host)
+ if strings.IndexByte(host, ':') != -1 {
+ if zidx := strings.LastIndexByte(host, '%'); zidx != -1 {
+ zlen = zidx
+ }
+ }
+ if ip := net.ParseIP(host[:zlen]); ip != nil {
+ return []string{host[:zlen]}, nil
+ }
+
+ if !isDomainName(host) {
+ return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
+ }
+ type result struct {
+ p dnsmessage.Parser
+ server string
+ error
+ }
+ var addrsV4, addrsV6 []net.IP
+ lanes := 0
+ if tnet.hasV4 {
+ lanes++
+ }
+ if tnet.hasV6 {
+ lanes++
+ }
+ lane := make(chan result, lanes)
+ var lastErr error
+ if tnet.hasV4 {
+ go func() {
+ p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA)
+ lane <- result{p, server, err}
+ }()
+ }
+ if tnet.hasV6 {
+ go func() {
+ p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA)
+ lane <- result{p, server, err}
+ }()
+ }
+ for l := 0; l < lanes; l++ {
+ result := <-lane
+ if result.error != nil {
+ if lastErr == nil {
+ lastErr = result.error
+ }
+ continue
+ }
+
+ loop:
+ for {
+ h, err := result.p.AnswerHeader()
+ if err != nil && err != dnsmessage.ErrSectionDone {
+ lastErr = &net.DNSError{
+ Err: errCannotMarshalDNSMessage.Error(),
+ Name: host,
+ Server: result.server,
+ }
+ }
+ if err != nil {
+ break
+ }
+ switch h.Type {
+ case dnsmessage.TypeA:
+ a, err := result.p.AResource()
+ if err != nil {
+ lastErr = &net.DNSError{
+ Err: errCannotMarshalDNSMessage.Error(),
+ Name: host,
+ Server: result.server,
+ }
+ break loop
+ }
+ addrsV4 = append(addrsV4, net.IP(a.A[:]))
+
+ case dnsmessage.TypeAAAA:
+ aaaa, err := result.p.AAAAResource()
+ if err != nil {
+ lastErr = &net.DNSError{
+ Err: errCannotMarshalDNSMessage.Error(),
+ Name: host,
+ Server: result.server,
+ }
+ break loop
+ }
+ addrsV6 = append(addrsV6, net.IP(aaaa.AAAA[:]))
+
+ default:
+ if err := result.p.SkipAnswer(); err != nil {
+ lastErr = &net.DNSError{
+ Err: errCannotMarshalDNSMessage.Error(),
+ Name: host,
+ Server: result.server,
+ }
+ break loop
+ }
+ continue
+ }
+ }
+ }
+ // We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
+ var addrs []net.IP
+ if tnet.hasV6 {
+ addrs = append(addrsV6, addrsV4...)
+ } else {
+ addrs = append(addrsV4, addrsV6...)
+ }
+
+ if len(addrs) == 0 && lastErr != nil {
+ return nil, lastErr
+ }
+ saddrs := make([]string, 0, len(addrs))
+ for _, ip := range addrs {
+ saddrs = append(saddrs, ip.String())
+ }
+ return saddrs, nil
+}
+
+func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
+ if deadline.IsZero() {
+ return deadline, nil
+ }
+ timeRemaining := deadline.Sub(now)
+ if timeRemaining <= 0 {
+ return time.Time{}, errTimeout
+ }
+ timeout := timeRemaining / time.Duration(addrsRemaining)
+ const saneMinimum = 2 * time.Second
+ if timeout < saneMinimum {
+ if timeRemaining < saneMinimum {
+ timeout = timeRemaining
+ } else {
+ timeout = saneMinimum
+ }
+ }
+ return now.Add(timeout), nil
+}
+
+func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
+ if ctx == nil {
+ panic("nil context")
+ }
+ var acceptV4, acceptV6, useUDP bool
+ if len(network) == 3 {
+ acceptV4 = true
+ acceptV6 = true
+ } else if len(network) == 4 {
+ acceptV4 = network[3] == '4'
+ acceptV6 = network[3] == '6'
+ }
+ if !acceptV4 && !acceptV6 {
+ return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
+ }
+ if network[:3] == "udp" {
+ useUDP = true
+ } else if network[:3] != "tcp" {
+ return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
+ }
+ host, sport, err := net.SplitHostPort(address)
+ if err != nil {
+ return nil, &net.OpError{Op: "dial", Err: err}
+ }
+ port, err := strconv.Atoi(sport)
+ if err != nil || port < 0 || port > 65535 {
+ return nil, &net.OpError{Op: "dial", Err: errNumericPort}
+ }
+ allAddr, err := tnet.LookupContextHost(ctx, host)
+ if err != nil {
+ return nil, &net.OpError{Op: "dial", Err: err}
+ }
+ var addrs []net.IP
+ for _, addr := range allAddr {
+ if strings.IndexByte(addr, ':') != -1 && acceptV6 {
+ addrs = append(addrs, net.ParseIP(addr))
+ } else if strings.IndexByte(addr, '.') != -1 && acceptV4 {
+ addrs = append(addrs, net.ParseIP(addr))
+ }
+ }
+ if len(addrs) == 0 && len(allAddr) != 0 {
+ return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress}
+ }
+
+ var firstErr error
+ for i, addr := range addrs {
+ select {
+ case <-ctx.Done():
+ err := ctx.Err()
+ if err == context.Canceled {
+ err = errCanceled
+ } else if err == context.DeadlineExceeded {
+ err = errTimeout
+ }
+ return nil, &net.OpError{Op: "dial", Err: err}
+ default:
+ }
+
+ dialCtx := ctx
+ if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
+ partialDeadline, err := partialDeadline(time.Now(), deadline, len(addrs)-i)
+ if err != nil {
+ if firstErr == nil {
+ firstErr = &net.OpError{Op: "dial", Err: err}
+ }
+ break
+ }
+ if partialDeadline.Before(deadline) {
+ var cancel context.CancelFunc
+ dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
+ defer cancel()
+ }
+ }
+
+ var c net.Conn
+ if useUDP {
+ c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: addr, Port: port})
+ } else {
+ c, err = tnet.DialContextTCP(dialCtx, &net.TCPAddr{IP: addr, Port: port})
+ }
+ if err == nil {
+ return c, nil
+ }
+ if firstErr == nil {
+ firstErr = err
+ }
+ }
+ if firstErr == nil {
+ firstErr = &net.OpError{Op: "dial", Err: errMissingAddress}
+ }
+ return nil, firstErr
+}
+
+func (tnet *Net) Dial(network, address string) (net.Conn, error) {
+ return tnet.DialContext(context.Background(), network, address)
+}