aboutsummaryrefslogtreecommitdiffstats
path: root/tun/netstack/tun.go
diff options
context:
space:
mode:
Diffstat (limited to 'tun/netstack/tun.go')
-rw-r--r--tun/netstack/tun.go1055
1 files changed, 1055 insertions, 0 deletions
diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go
new file mode 100644
index 0000000..2b73054
--- /dev/null
+++ b/tun/netstack/tun.go
@@ -0,0 +1,1055 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package netstack
+
+import (
+ "bytes"
+ "context"
+ "crypto/rand"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/netip"
+ "os"
+ "regexp"
+ "strconv"
+ "strings"
+ "syscall"
+ "time"
+
+ "golang.zx2c4.com/wireguard/tun"
+
+ "golang.org/x/net/dns/dnsmessage"
+ "gvisor.dev/gvisor/pkg/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+type netTun struct {
+ ep *channel.Endpoint
+ stack *stack.Stack
+ events chan tun.Event
+ incomingPacket chan *buffer.View
+ mtu int
+ dnsServers []netip.Addr
+ hasV4, hasV6 bool
+}
+
+type Net netTun
+
+func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
+ HandleLocal: true,
+ }
+ dev := &netTun{
+ ep: channel.New(1024, uint32(mtu), ""),
+ stack: stack.New(opts),
+ events: make(chan tun.Event, 10),
+ incomingPacket: make(chan *buffer.View),
+ dnsServers: dnsServers,
+ mtu: mtu,
+ }
+ sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
+ tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
+ if tcpipErr != nil {
+ return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
+ }
+ dev.ep.AddNotify(dev)
+ tcpipErr = dev.stack.CreateNIC(1, dev.ep)
+ if tcpipErr != nil {
+ return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
+ }
+ for _, ip := range localAddresses {
+ var protoNumber tcpip.NetworkProtocolNumber
+ if ip.Is4() {
+ protoNumber = ipv4.ProtocolNumber
+ } else if ip.Is6() {
+ protoNumber = ipv6.ProtocolNumber
+ }
+ protoAddr := tcpip.ProtocolAddress{
+ Protocol: protoNumber,
+ AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
+ }
+ tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
+ if tcpipErr != nil {
+ return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
+ }
+ if ip.Is4() {
+ dev.hasV4 = true
+ } else if ip.Is6() {
+ dev.hasV6 = true
+ }
+ }
+ if dev.hasV4 {
+ dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
+ }
+ if dev.hasV6 {
+ dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
+ }
+
+ dev.events <- tun.EventUp
+ return dev, (*Net)(dev), nil
+}
+
+func (tun *netTun) Name() (string, error) {
+ return "go", nil
+}
+
+func (tun *netTun) File() *os.File {
+ return nil
+}
+
+func (tun *netTun) Events() <-chan tun.Event {
+ return tun.events
+}
+
+func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
+ view, ok := <-tun.incomingPacket
+ if !ok {
+ return 0, os.ErrClosed
+ }
+
+ n, err := view.Read(buf[0][offset:])
+ if err != nil {
+ return 0, err
+ }
+ sizes[0] = n
+ return 1, nil
+}
+
+func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
+ for _, buf := range buf {
+ packet := buf[offset:]
+ if len(packet) == 0 {
+ continue
+ }
+
+ pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
+ switch packet[0] >> 4 {
+ case 4:
+ tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
+ case 6:
+ tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
+ default:
+ return 0, syscall.EAFNOSUPPORT
+ }
+ }
+ return len(buf), nil
+}
+
+func (tun *netTun) WriteNotify() {
+ pkt := tun.ep.Read()
+ if pkt.IsNil() {
+ return
+ }
+
+ view := pkt.ToView()
+ pkt.DecRef()
+
+ tun.incomingPacket <- view
+}
+
+func (tun *netTun) Close() error {
+ tun.stack.RemoveNIC(1)
+
+ if tun.events != nil {
+ close(tun.events)
+ }
+
+ tun.ep.Close()
+
+ if tun.incomingPacket != nil {
+ close(tun.incomingPacket)
+ }
+
+ return nil
+}
+
+func (tun *netTun) MTU() (int, error) {
+ return tun.mtu, nil
+}
+
+func (tun *netTun) BatchSize() int {
+ return 1
+}
+
+func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
+ var protoNumber tcpip.NetworkProtocolNumber
+ if endpoint.Addr().Is4() {
+ protoNumber = ipv4.ProtocolNumber
+ } else {
+ protoNumber = ipv6.ProtocolNumber
+ }
+ return tcpip.FullAddress{
+ NIC: 1,
+ Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
+ Port: endpoint.Port(),
+ }, protoNumber
+}
+
+func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.DialContextTCP(ctx, net.stack, fa, pn)
+}
+
+func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
+ if addr == nil {
+ return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
+ }
+ ip, _ := netip.AddrFromSlice(addr.IP)
+ return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port)))
+}
+
+func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.DialTCP(net.stack, fa, pn)
+}
+
+func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
+ if addr == nil {
+ return net.DialTCPAddrPort(netip.AddrPort{})
+ }
+ ip, _ := netip.AddrFromSlice(addr.IP)
+ return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
+}
+
+func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.ListenTCP(net.stack, fa, pn)
+}
+
+func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
+ if addr == nil {
+ return net.ListenTCPAddrPort(netip.AddrPort{})
+ }
+ ip, _ := netip.AddrFromSlice(addr.IP)
+ return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
+}
+
+func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
+ var lfa, rfa *tcpip.FullAddress
+ var pn tcpip.NetworkProtocolNumber
+ if laddr.IsValid() || laddr.Port() > 0 {
+ var addr tcpip.FullAddress
+ addr, pn = convertToFullAddr(laddr)
+ lfa = &addr
+ }
+ if raddr.IsValid() || raddr.Port() > 0 {
+ var addr tcpip.FullAddress
+ addr, pn = convertToFullAddr(raddr)
+ rfa = &addr
+ }
+ return gonet.DialUDP(net.stack, lfa, rfa, pn)
+}
+
+func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) {
+ return net.DialUDPAddrPort(laddr, netip.AddrPort{})
+}
+
+func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
+ var la, ra netip.AddrPort
+ if laddr != nil {
+ ip, _ := netip.AddrFromSlice(laddr.IP)
+ la = netip.AddrPortFrom(ip, uint16(laddr.Port))
+ }
+ if raddr != nil {
+ ip, _ := netip.AddrFromSlice(raddr.IP)
+ ra = netip.AddrPortFrom(ip, uint16(raddr.Port))
+ }
+ return net.DialUDPAddrPort(la, ra)
+}
+
+func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
+ return net.DialUDP(laddr, nil)
+}
+
+type PingConn struct {
+ laddr PingAddr
+ raddr PingAddr
+ wq waiter.Queue
+ ep tcpip.Endpoint
+ deadline *time.Timer
+}
+
+type PingAddr struct{ addr netip.Addr }
+
+func (ia PingAddr) String() string {
+ return ia.addr.String()
+}
+
+func (ia PingAddr) Network() string {
+ if ia.addr.Is4() {
+ return "ping4"
+ } else if ia.addr.Is6() {
+ return "ping6"
+ }
+ return "ping"
+}
+
+func (ia PingAddr) Addr() netip.Addr {
+ return ia.addr
+}
+
+func PingAddrFromAddr(addr netip.Addr) *PingAddr {
+ return &PingAddr{addr}
+}
+
+func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
+ if !laddr.IsValid() && !raddr.IsValid() {
+ return nil, errors.New("ping dial: invalid address")
+ }
+ v6 := laddr.Is6() || raddr.Is6()
+ bind := laddr.IsValid()
+ if !bind {
+ if v6 {
+ laddr = netip.IPv6Unspecified()
+ } else {
+ laddr = netip.IPv4Unspecified()
+ }
+ }
+
+ tn := icmp.ProtocolNumber4
+ pn := ipv4.ProtocolNumber
+ if v6 {
+ tn = icmp.ProtocolNumber6
+ pn = ipv6.ProtocolNumber
+ }
+
+ pc := &PingConn{
+ laddr: PingAddr{laddr},
+ deadline: time.NewTimer(time.Hour << 10),
+ }
+ pc.deadline.Stop()
+
+ ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
+ if tcpipErr != nil {
+ return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr)
+ }
+ pc.ep = ep
+
+ if bind {
+ fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0))
+ if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil {
+ return nil, fmt.Errorf("ping bind: %s", tcpipErr)
+ }
+ }
+
+ if raddr.IsValid() {
+ pc.raddr = PingAddr{raddr}
+ fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0))
+ if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil {
+ return nil, fmt.Errorf("ping connect: %s", tcpipErr)
+ }
+ }
+
+ return pc, nil
+}
+
+func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) {
+ return net.DialPingAddr(laddr, netip.Addr{})
+}
+
+func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) {
+ var la, ra netip.Addr
+ if laddr != nil {
+ la = laddr.addr
+ }
+ if raddr != nil {
+ ra = raddr.addr
+ }
+ return net.DialPingAddr(la, ra)
+}
+
+func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) {
+ var la netip.Addr
+ if laddr != nil {
+ la = laddr.addr
+ }
+ return net.ListenPingAddr(la)
+}
+
+func (pc *PingConn) LocalAddr() net.Addr {
+ return pc.laddr
+}
+
+func (pc *PingConn) RemoteAddr() net.Addr {
+ return pc.raddr
+}
+
+func (pc *PingConn) Close() error {
+ pc.deadline.Reset(0)
+ pc.ep.Close()
+ return nil
+}
+
+func (pc *PingConn) SetWriteDeadline(t time.Time) error {
+ return errors.New("not implemented")
+}
+
+func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
+ var na netip.Addr
+ switch v := addr.(type) {
+ case *PingAddr:
+ na = v.addr
+ case *net.IPAddr:
+ na, _ = netip.AddrFromSlice(v.IP)
+ default:
+ return 0, fmt.Errorf("ping write: wrong net.Addr type")
+ }
+ if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) {
+ return 0, fmt.Errorf("ping write: mismatched protocols")
+ }
+
+ buf := bytes.NewReader(p)
+ rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0))
+ // won't block, no deadlines
+ n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{
+ To: &rfa,
+ })
+ if tcpipErr != nil {
+ return int(n64), fmt.Errorf("ping write: %s", tcpipErr)
+ }
+
+ return int(n64), nil
+}
+
+func (pc *PingConn) Write(p []byte) (n int, err error) {
+ return pc.WriteTo(p, &pc.raddr)
+}
+
+func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
+ e, notifyCh := waiter.NewChannelEntry(waiter.EventIn)
+ pc.wq.EventRegister(&e)
+ defer pc.wq.EventUnregister(&e)
+
+ select {
+ case <-pc.deadline.C:
+ return 0, nil, os.ErrDeadlineExceeded
+ case <-notifyCh:
+ }
+
+ w := tcpip.SliceWriter(p)
+
+ res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{
+ NeedRemoteAddr: true,
+ })
+ if tcpipErr != nil {
+ return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
+ }
+
+ remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice())
+ return res.Count, &PingAddr{remoteAddr}, nil
+}
+
+func (pc *PingConn) Read(p []byte) (n int, err error) {
+ n, _, err = pc.ReadFrom(p)
+ return
+}
+
+func (pc *PingConn) SetDeadline(t time.Time) error {
+ // pc.SetWriteDeadline is unimplemented
+
+ return pc.SetReadDeadline(t)
+}
+
+func (pc *PingConn) SetReadDeadline(t time.Time) error {
+ pc.deadline.Reset(time.Until(t))
+ return nil
+}
+
+var (
+ errNoSuchHost = errors.New("no such host")
+ errLameReferral = errors.New("lame referral")
+ errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message")
+ errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message")
+ errServerMisbehaving = errors.New("server misbehaving")
+ errInvalidDNSResponse = errors.New("invalid DNS response")
+ errNoAnswerFromDNSServer = errors.New("no answer from DNS server")
+ errServerTemporarilyMisbehaving = errors.New("server misbehaving")
+ errCanceled = errors.New("operation was canceled")
+ errTimeout = errors.New("i/o timeout")
+ errNumericPort = errors.New("port must be numeric")
+ errNoSuitableAddress = errors.New("no suitable address found")
+ errMissingAddress = errors.New("missing address")
+)
+
+func (net *Net) LookupHost(host string) (addrs []string, err error) {
+ return net.LookupContextHost(context.Background(), host)
+}
+
+func isDomainName(s string) bool {
+ l := len(s)
+ if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
+ return false
+ }
+ last := byte('.')
+ nonNumeric := false
+ partlen := 0
+ for i := 0; i < len(s); i++ {
+ c := s[i]
+ switch {
+ default:
+ return false
+ case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
+ nonNumeric = true
+ partlen++
+ case '0' <= c && c <= '9':
+ partlen++
+ case c == '-':
+ if last == '.' {
+ return false
+ }
+ partlen++
+ nonNumeric = true
+ case c == '.':
+ if last == '.' || last == '-' {
+ return false
+ }
+ if partlen > 63 || partlen == 0 {
+ return false
+ }
+ partlen = 0
+ }
+ last = c
+ }
+ if last == '-' || partlen > 63 {
+ return false
+ }
+ return nonNumeric
+}
+
+func randU16() uint16 {
+ var b [2]byte
+ _, err := rand.Read(b[:])
+ if err != nil {
+ panic(err)
+ }
+ return binary.LittleEndian.Uint16(b[:])
+}
+
+func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) {
+ id = randU16()
+ b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true})
+ b.EnableCompression()
+ if err := b.StartQuestions(); err != nil {
+ return 0, nil, nil, err
+ }
+ if err := b.Question(q); err != nil {
+ return 0, nil, nil, err
+ }
+ tcpReq, err = b.Finish()
+ udpReq = tcpReq[2:]
+ l := len(tcpReq) - 2
+ tcpReq[0] = byte(l >> 8)
+ tcpReq[1] = byte(l)
+ return id, udpReq, tcpReq, err
+}
+
+func equalASCIIName(x, y dnsmessage.Name) bool {
+ if x.Length != y.Length {
+ return false
+ }
+ for i := 0; i < int(x.Length); i++ {
+ a := x.Data[i]
+ b := y.Data[i]
+ if 'A' <= a && a <= 'Z' {
+ a += 0x20
+ }
+ if 'A' <= b && b <= 'Z' {
+ b += 0x20
+ }
+ if a != b {
+ return false
+ }
+ }
+ return true
+}
+
+func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool {
+ if !respHdr.Response {
+ return false
+ }
+ if reqID != respHdr.ID {
+ return false
+ }
+ if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) {
+ return false
+ }
+ return true
+}
+
+func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
+ if _, err := c.Write(b); err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ b = make([]byte, 512)
+ for {
+ n, err := c.Read(b)
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ var p dnsmessage.Parser
+ h, err := p.Start(b[:n])
+ if err != nil {
+ continue
+ }
+ q, err := p.Question()
+ if err != nil || !checkResponse(id, query, h, q) {
+ continue
+ }
+ return p, h, nil
+ }
+}
+
+func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
+ if _, err := c.Write(b); err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ b = make([]byte, 1280)
+ if _, err := io.ReadFull(c, b[:2]); err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ l := int(b[0])<<8 | int(b[1])
+ if l > len(b) {
+ b = make([]byte, l)
+ }
+ n, err := io.ReadFull(c, b[:l])
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ var p dnsmessage.Parser
+ h, err := p.Start(b[:n])
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
+ }
+ q, err := p.Question()
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
+ }
+ if !checkResponse(id, query, h, q) {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
+ }
+ return p, h, nil
+}
+
+func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
+ q.Class = dnsmessage.ClassINET
+ id, udpReq, tcpReq, err := newRequest(q)
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage
+ }
+
+ for _, useUDP := range []bool{true, false} {
+ ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
+ defer cancel()
+
+ var c net.Conn
+ var err error
+ if useUDP {
+ c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
+ } else {
+ c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
+ }
+
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ if d, ok := ctx.Deadline(); ok && !d.IsZero() {
+ err := c.SetDeadline(d)
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ }
+ var p dnsmessage.Parser
+ var h dnsmessage.Header
+ if useUDP {
+ p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
+ } else {
+ p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
+ }
+ c.Close()
+ if err != nil {
+ if err == context.Canceled {
+ err = errCanceled
+ } else if err == context.DeadlineExceeded {
+ err = errTimeout
+ }
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
+ }
+ if h.Truncated {
+ continue
+ }
+ return p, h, nil
+ }
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer
+}
+
+func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
+ if h.RCode == dnsmessage.RCodeNameError {
+ return errNoSuchHost
+ }
+ _, err := p.AnswerHeader()
+ if err != nil && err != dnsmessage.ErrSectionDone {
+ return errCannotUnmarshalDNSMessage
+ }
+ if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
+ return errLameReferral
+ }
+ if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError {
+ if h.RCode == dnsmessage.RCodeServerFailure {
+ return errServerTemporarilyMisbehaving
+ }
+ return errServerMisbehaving
+ }
+ return nil
+}
+
+func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error {
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ return errNoSuchHost
+ }
+ if err != nil {
+ return errCannotUnmarshalDNSMessage
+ }
+ if h.Type == qtype {
+ return nil
+ }
+ if err := p.SkipAnswer(); err != nil {
+ return errCannotUnmarshalDNSMessage
+ }
+ }
+}
+
+func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
+ var lastErr error
+
+ n, err := dnsmessage.NewName(name)
+ if err != nil {
+ return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage
+ }
+ q := dnsmessage.Question{
+ Name: n,
+ Type: qtype,
+ Class: dnsmessage.ClassINET,
+ }
+
+ for i := 0; i < 2; i++ {
+ for _, server := range tnet.dnsServers {
+ p, h, err := tnet.exchange(ctx, server, q, time.Second*5)
+ if err != nil {
+ dnsErr := &net.DNSError{
+ Err: err.Error(),
+ Name: name,
+ Server: server.String(),
+ }
+ if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
+ dnsErr.IsTimeout = true
+ }
+ if _, ok := err.(*net.OpError); ok {
+ dnsErr.IsTemporary = true
+ }
+ lastErr = dnsErr
+ continue
+ }
+
+ if err := checkHeader(&p, h); err != nil {
+ dnsErr := &net.DNSError{
+ Err: err.Error(),
+ Name: name,
+ Server: server.String(),
+ }
+ if err == errServerTemporarilyMisbehaving {
+ dnsErr.IsTemporary = true
+ }
+ if err == errNoSuchHost {
+ dnsErr.IsNotFound = true
+ return p, server.String(), dnsErr
+ }
+ lastErr = dnsErr
+ continue
+ }
+
+ err = skipToAnswer(&p, qtype)
+ if err == nil {
+ return p, server.String(), nil
+ }
+ lastErr = &net.DNSError{
+ Err: err.Error(),
+ Name: name,
+ Server: server.String(),
+ }
+ if err == errNoSuchHost {
+ lastErr.(*net.DNSError).IsNotFound = true
+ return p, server.String(), lastErr
+ }
+ }
+ }
+ return dnsmessage.Parser{}, "", lastErr
+}
+
+func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) {
+ if host == "" || (!tnet.hasV6 && !tnet.hasV4) {
+ return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
+ }
+ zlen := len(host)
+ if strings.IndexByte(host, ':') != -1 {
+ if zidx := strings.LastIndexByte(host, '%'); zidx != -1 {
+ zlen = zidx
+ }
+ }
+ if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
+ return []string{ip.String()}, nil
+ }
+
+ if !isDomainName(host) {
+ return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
+ }
+ type result struct {
+ p dnsmessage.Parser
+ server string
+ error
+ }
+ var addrsV4, addrsV6 []netip.Addr
+ lanes := 0
+ if tnet.hasV4 {
+ lanes++
+ }
+ if tnet.hasV6 {
+ lanes++
+ }
+ lane := make(chan result, lanes)
+ var lastErr error
+ if tnet.hasV4 {
+ go func() {
+ p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA)
+ lane <- result{p, server, err}
+ }()
+ }
+ if tnet.hasV6 {
+ go func() {
+ p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA)
+ lane <- result{p, server, err}
+ }()
+ }
+ for l := 0; l < lanes; l++ {
+ result := <-lane
+ if result.error != nil {
+ if lastErr == nil {
+ lastErr = result.error
+ }
+ continue
+ }
+
+ loop:
+ for {
+ h, err := result.p.AnswerHeader()
+ if err != nil && err != dnsmessage.ErrSectionDone {
+ lastErr = &net.DNSError{
+ Err: errCannotMarshalDNSMessage.Error(),
+ Name: host,
+ Server: result.server,
+ }
+ }
+ if err != nil {
+ break
+ }
+ switch h.Type {
+ case dnsmessage.TypeA:
+ a, err := result.p.AResource()
+ if err != nil {
+ lastErr = &net.DNSError{
+ Err: errCannotMarshalDNSMessage.Error(),
+ Name: host,
+ Server: result.server,
+ }
+ break loop
+ }
+ addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
+
+ case dnsmessage.TypeAAAA:
+ aaaa, err := result.p.AAAAResource()
+ if err != nil {
+ lastErr = &net.DNSError{
+ Err: errCannotMarshalDNSMessage.Error(),
+ Name: host,
+ Server: result.server,
+ }
+ break loop
+ }
+ addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
+
+ default:
+ if err := result.p.SkipAnswer(); err != nil {
+ lastErr = &net.DNSError{
+ Err: errCannotMarshalDNSMessage.Error(),
+ Name: host,
+ Server: result.server,
+ }
+ break loop
+ }
+ continue
+ }
+ }
+ }
+ // We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled
+ var addrs []netip.Addr
+ if tnet.hasV6 {
+ addrs = append(addrsV6, addrsV4...)
+ } else {
+ addrs = append(addrsV4, addrsV6...)
+ }
+
+ if len(addrs) == 0 && lastErr != nil {
+ return nil, lastErr
+ }
+ saddrs := make([]string, 0, len(addrs))
+ for _, ip := range addrs {
+ saddrs = append(saddrs, ip.String())
+ }
+ return saddrs, nil
+}
+
+func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
+ if deadline.IsZero() {
+ return deadline, nil
+ }
+ timeRemaining := deadline.Sub(now)
+ if timeRemaining <= 0 {
+ return time.Time{}, errTimeout
+ }
+ timeout := timeRemaining / time.Duration(addrsRemaining)
+ const saneMinimum = 2 * time.Second
+ if timeout < saneMinimum {
+ if timeRemaining < saneMinimum {
+ timeout = timeRemaining
+ } else {
+ timeout = saneMinimum
+ }
+ }
+ return now.Add(timeout), nil
+}
+
+var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`)
+
+func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
+ if ctx == nil {
+ panic("nil context")
+ }
+ var acceptV4, acceptV6 bool
+ matches := protoSplitter.FindStringSubmatch(network)
+ if matches == nil {
+ return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
+ } else if len(matches[2]) == 0 {
+ acceptV4 = true
+ acceptV6 = true
+ } else {
+ acceptV4 = matches[2][0] == '4'
+ acceptV6 = !acceptV4
+ }
+ var host string
+ var port int
+ if matches[1] == "ping" {
+ host = address
+ } else {
+ var sport string
+ var err error
+ host, sport, err = net.SplitHostPort(address)
+ if err != nil {
+ return nil, &net.OpError{Op: "dial", Err: err}
+ }
+ port, err = strconv.Atoi(sport)
+ if err != nil || port < 0 || port > 65535 {
+ return nil, &net.OpError{Op: "dial", Err: errNumericPort}
+ }
+ }
+ allAddr, err := tnet.LookupContextHost(ctx, host)
+ if err != nil {
+ return nil, &net.OpError{Op: "dial", Err: err}
+ }
+ var addrs []netip.AddrPort
+ for _, addr := range allAddr {
+ ip, err := netip.ParseAddr(addr)
+ if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
+ addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
+ }
+ }
+ if len(addrs) == 0 && len(allAddr) != 0 {
+ return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress}
+ }
+
+ var firstErr error
+ for i, addr := range addrs {
+ select {
+ case <-ctx.Done():
+ err := ctx.Err()
+ if err == context.Canceled {
+ err = errCanceled
+ } else if err == context.DeadlineExceeded {
+ err = errTimeout
+ }
+ return nil, &net.OpError{Op: "dial", Err: err}
+ default:
+ }
+
+ dialCtx := ctx
+ if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
+ partialDeadline, err := partialDeadline(time.Now(), deadline, len(addrs)-i)
+ if err != nil {
+ if firstErr == nil {
+ firstErr = &net.OpError{Op: "dial", Err: err}
+ }
+ break
+ }
+ if partialDeadline.Before(deadline) {
+ var cancel context.CancelFunc
+ dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
+ defer cancel()
+ }
+ }
+
+ var c net.Conn
+ switch matches[1] {
+ case "tcp":
+ c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
+ case "udp":
+ c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
+ case "ping":
+ c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr())
+ }
+ if err == nil {
+ return c, nil
+ }
+ if firstErr == nil {
+ firstErr = err
+ }
+ }
+ if firstErr == nil {
+ firstErr = &net.OpError{Op: "dial", Err: errMissingAddress}
+ }
+ return nil, firstErr
+}
+
+func (tnet *Net) Dial(network, address string) (net.Conn, error) {
+ return tnet.DialContext(context.Background(), network, address)
+}