From a702597e228006460b1b9f01fb9d8cc12327132f Mon Sep 17 00:00:00 2001 From: "Thomas H. Ptacek" Date: Mon, 31 Jan 2022 16:55:36 -0600 Subject: tun/netstack: implement ICMP ping Provide a PacketConn interface for netstack's ICMP endpoint; netstack currently only provides EchoRequest/EchoResponse ICMP support, so this code exposes only an interface for doing ping. Currently is missing: - Write deadlines - Context support Signed-off-by: Thomas Ptacek [Jason: rework structure, match std go interfaces, add example code] Signed-off-by: Jason A. Donenfeld --- tun/netstack/tun.go | 231 ++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 207 insertions(+), 24 deletions(-) (limited to 'tun/netstack/tun.go') diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index fb7f07d..f0e954b 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -14,6 +14,7 @@ import ( "io" "net" "os" + "regexp" "strconv" "strings" "time" @@ -29,8 +30,10 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" ) type netTun struct { @@ -101,7 +104,7 @@ func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.Network func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) { opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, HandleLocal: true, } dev := &netTun{ @@ -281,6 +284,178 @@ func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { return net.DialUDPAddrPort(la, ra) } +type PingConn struct { + laddr PingAddr + raddr PingAddr + wq waiter.Queue + ep tcpip.Endpoint + deadline time.Time +} + +type PingAddr struct{ addr netip.Addr } + +func (ia PingAddr) String() string { + return ia.addr.String() +} + +func (ia PingAddr) Network() string { + if ia.addr.Is4() { + return "ping4" + } else if ia.addr.Is6() { + return "ping6" + } + return "ping" +} + +func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) { + v6 := laddr.Is6() || raddr.Is6() + bind := laddr.IsValid() + if !bind { + if v6 { + laddr = netip.IPv6Unspecified() + } else { + laddr = netip.IPv4Unspecified() + } + } + + tn := icmp.ProtocolNumber4 + pn := ipv4.ProtocolNumber + if v6 { + tn = icmp.ProtocolNumber6 + pn = ipv6.ProtocolNumber + } + + pc := &PingConn{laddr: PingAddr{laddr}} + + ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq) + if tcpipErr != nil { + return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr) + } + pc.ep = ep + + if bind { + fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0)) + if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil { + return nil, fmt.Errorf("ping bind: %s", tcpipErr) + } + } + + if raddr.IsValid() { + pc.raddr = PingAddr{raddr} + fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0)) + if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil { + return nil, fmt.Errorf("ping connect: %s", tcpipErr) + } + } + + return pc, nil +} + +func (pc *PingConn) LocalAddr() net.Addr { + return pc.laddr +} + +func (pc *PingConn) RemoteAddr() net.Addr { + return pc.raddr +} + +func (pc *PingConn) Close() error { + pc.ep.Close() + return nil +} + +func (pc *PingConn) SetWriteDeadline(t time.Time) error { + return errors.New("not implemented") +} + +func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + ia, ok := addr.(PingAddr) + if !ok || !((ia.addr.Is4() && pc.laddr.addr.Is4()) || (ia.addr.Is6() && pc.laddr.addr.Is6())) { + return 0, fmt.Errorf("ping write: mismatched protocols") + } + + var buf buffer.View + if ia.addr.Is4() { + buf = buffer.NewView(header.ICMPv4MinimumSize + len(p)) + copy(buf[header.ICMPv4MinimumSize:], p) + icmp := header.ICMPv4(buf) + icmp.SetType(header.ICMPv4Echo) + } else if ia.addr.Is6() { + buf = buffer.NewView(header.ICMPv6MinimumSize + len(p)) + copy(buf[header.ICMPv6MinimumSize:], p) + icmp := header.ICMPv6(buf) + icmp.SetType(header.ICMPv6EchoRequest) + } + + rdr := buf.Reader() + rfa, _ := convertToFullAddr(netip.AddrPortFrom(ia.addr, 0)) + // won't block, no deadlines + n64, tcpipErr := pc.ep.Write(&rdr, tcpip.WriteOptions{ + To: &rfa, + }) + if tcpipErr != nil { + return int(n64), fmt.Errorf("ping write: %s", tcpipErr) + } + + return int(n64), nil +} + +func (pc *PingConn) Write(p []byte) (n int, err error) { + return pc.WriteTo(p, pc.raddr) +} + +func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + e, notifyCh := waiter.NewChannelEntry(nil) + pc.wq.EventRegister(&e, waiter.EventIn) + defer pc.wq.EventUnregister(&e) + + deadline := pc.deadline + + if deadline.IsZero() { + <-notifyCh + } else { + select { + case <-time.NewTimer(deadline.Sub(time.Now())).C: + return 0, nil, os.ErrDeadlineExceeded + case <-notifyCh: + } + } + + min := header.ICMPv6MinimumSize + if pc.laddr.addr.Is4() { + min = header.ICMPv4MinimumSize + } + reply := make([]byte, min+len(p)) + w := tcpip.SliceWriter(reply) + + res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{ + NeedRemoteAddr: true, + }) + if tcpipErr != nil { + return 0, nil, fmt.Errorf("ping read: %s", tcpipErr) + } + + addr = PingAddr{netip.AddrFromSlice([]byte(res.RemoteAddr.Addr))} + copy(p, reply[min:res.Count]) + return res.Count - min, addr, nil +} + +func (pc *PingConn) Read(p []byte) (n int, err error) { + n, _, err = pc.ReadFrom(p) + return +} + +func (pc *PingConn) SetDeadline(t time.Time) error { + // pc.SetWriteDeadline is unimplemented + + return pc.SetReadDeadline(t) +} + +func (pc *PingConn) SetReadDeadline(t time.Time) error { + pc.deadline = t + return nil +} + var ( errNoSuchHost = errors.New("no such host") errLameReferral = errors.New("lame referral") @@ -755,33 +930,38 @@ func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, er return now.Add(timeout), nil } +var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`) + func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) { if ctx == nil { panic("nil context") } - var acceptV4, acceptV6, useUDP bool - if len(network) == 3 { + var acceptV4, acceptV6 bool + matches := protoSplitter.FindStringSubmatch(network) + if matches == nil { + return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)} + } else if len(matches[2]) == 0 { acceptV4 = true acceptV6 = true - } else if len(network) == 4 { - acceptV4 = network[3] == '4' - acceptV6 = network[3] == '6' - } - if !acceptV4 && !acceptV6 { - return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)} - } - if network[:3] == "udp" { - useUDP = true - } else if network[:3] != "tcp" { - return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)} - } - host, sport, err := net.SplitHostPort(address) - if err != nil { - return nil, &net.OpError{Op: "dial", Err: err} + } else { + acceptV4 = matches[2][0] == '4' + acceptV6 = !acceptV4 } - port, err := strconv.Atoi(sport) - if err != nil || port < 0 || port > 65535 { - return nil, &net.OpError{Op: "dial", Err: errNumericPort} + var host string + var port int + if matches[1] == "ping" { + host = address + } else { + var sport string + var err error + host, sport, err = net.SplitHostPort(address) + if err != nil { + return nil, &net.OpError{Op: "dial", Err: err} + } + port, err = strconv.Atoi(sport) + if err != nil || port < 0 || port > 65535 { + return nil, &net.OpError{Op: "dial", Err: errNumericPort} + } } allAddr, err := tnet.LookupContextHost(ctx, host) if err != nil { @@ -829,10 +1009,13 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net. } var c net.Conn - if useUDP { - c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr) - } else { + switch matches[1] { + case "tcp": c, err = tnet.DialContextTCPAddrPort(dialCtx, addr) + case "udp": + c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr) + case "ping": + c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr()) } if err == nil { return c, nil -- cgit v1.2.3-59-g8ed1b