aboutsummaryrefslogtreecommitdiffstats
path: root/tun/netstack/tun.go
diff options
context:
space:
mode:
authorThomas H. Ptacek <thomas@sockpuppet.org>2022-01-31 16:55:36 -0600
committerJason A. Donenfeld <Jason@zx2c4.com>2022-02-01 20:16:42 +0100
commita702597e228006460b1b9f01fb9d8cc12327132f (patch)
tree1fb3df08ceb5b5f1388b69fc38cea896b7557de1 /tun/netstack/tun.go
parentversion: bump snapshot (diff)
downloadwireguard-go-a702597e228006460b1b9f01fb9d8cc12327132f.tar.xz
wireguard-go-a702597e228006460b1b9f01fb9d8cc12327132f.zip
tun/netstack: implement ICMP ping
Provide a PacketConn interface for netstack's ICMP endpoint; netstack currently only provides EchoRequest/EchoResponse ICMP support, so this code exposes only an interface for doing ping. Currently is missing: - Write deadlines - Context support Signed-off-by: Thomas Ptacek <thomas@sockpuppet.org> [Jason: rework structure, match std go interfaces, add example code] Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'tun/netstack/tun.go')
-rw-r--r--tun/netstack/tun.go231
1 files changed, 207 insertions, 24 deletions
diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go
index fb7f07d..f0e954b 100644
--- a/tun/netstack/tun.go
+++ b/tun/netstack/tun.go
@@ -14,6 +14,7 @@ import (
"io"
"net"
"os"
+ "regexp"
"strconv"
"strings"
"time"
@@ -29,8 +30,10 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
)
type netTun struct {
@@ -101,7 +104,7 @@ func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.Network
func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
HandleLocal: true,
}
dev := &netTun{
@@ -281,6 +284,178 @@ func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
return net.DialUDPAddrPort(la, ra)
}
+type PingConn struct {
+ laddr PingAddr
+ raddr PingAddr
+ wq waiter.Queue
+ ep tcpip.Endpoint
+ deadline time.Time
+}
+
+type PingAddr struct{ addr netip.Addr }
+
+func (ia PingAddr) String() string {
+ return ia.addr.String()
+}
+
+func (ia PingAddr) Network() string {
+ if ia.addr.Is4() {
+ return "ping4"
+ } else if ia.addr.Is6() {
+ return "ping6"
+ }
+ return "ping"
+}
+
+func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
+ v6 := laddr.Is6() || raddr.Is6()
+ bind := laddr.IsValid()
+ if !bind {
+ if v6 {
+ laddr = netip.IPv6Unspecified()
+ } else {
+ laddr = netip.IPv4Unspecified()
+ }
+ }
+
+ tn := icmp.ProtocolNumber4
+ pn := ipv4.ProtocolNumber
+ if v6 {
+ tn = icmp.ProtocolNumber6
+ pn = ipv6.ProtocolNumber
+ }
+
+ pc := &PingConn{laddr: PingAddr{laddr}}
+
+ ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
+ if tcpipErr != nil {
+ return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr)
+ }
+ pc.ep = ep
+
+ if bind {
+ fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0))
+ if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil {
+ return nil, fmt.Errorf("ping bind: %s", tcpipErr)
+ }
+ }
+
+ if raddr.IsValid() {
+ pc.raddr = PingAddr{raddr}
+ fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0))
+ if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil {
+ return nil, fmt.Errorf("ping connect: %s", tcpipErr)
+ }
+ }
+
+ return pc, nil
+}
+
+func (pc *PingConn) LocalAddr() net.Addr {
+ return pc.laddr
+}
+
+func (pc *PingConn) RemoteAddr() net.Addr {
+ return pc.raddr
+}
+
+func (pc *PingConn) Close() error {
+ pc.ep.Close()
+ return nil
+}
+
+func (pc *PingConn) SetWriteDeadline(t time.Time) error {
+ return errors.New("not implemented")
+}
+
+func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
+ ia, ok := addr.(PingAddr)
+ if !ok || !((ia.addr.Is4() && pc.laddr.addr.Is4()) || (ia.addr.Is6() && pc.laddr.addr.Is6())) {
+ return 0, fmt.Errorf("ping write: mismatched protocols")
+ }
+
+ var buf buffer.View
+ if ia.addr.Is4() {
+ buf = buffer.NewView(header.ICMPv4MinimumSize + len(p))
+ copy(buf[header.ICMPv4MinimumSize:], p)
+ icmp := header.ICMPv4(buf)
+ icmp.SetType(header.ICMPv4Echo)
+ } else if ia.addr.Is6() {
+ buf = buffer.NewView(header.ICMPv6MinimumSize + len(p))
+ copy(buf[header.ICMPv6MinimumSize:], p)
+ icmp := header.ICMPv6(buf)
+ icmp.SetType(header.ICMPv6EchoRequest)
+ }
+
+ rdr := buf.Reader()
+ rfa, _ := convertToFullAddr(netip.AddrPortFrom(ia.addr, 0))
+ // won't block, no deadlines
+ n64, tcpipErr := pc.ep.Write(&rdr, tcpip.WriteOptions{
+ To: &rfa,
+ })
+ if tcpipErr != nil {
+ return int(n64), fmt.Errorf("ping write: %s", tcpipErr)
+ }
+
+ return int(n64), nil
+}
+
+func (pc *PingConn) Write(p []byte) (n int, err error) {
+ return pc.WriteTo(p, pc.raddr)
+}
+
+func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
+ e, notifyCh := waiter.NewChannelEntry(nil)
+ pc.wq.EventRegister(&e, waiter.EventIn)
+ defer pc.wq.EventUnregister(&e)
+
+ deadline := pc.deadline
+
+ if deadline.IsZero() {
+ <-notifyCh
+ } else {
+ select {
+ case <-time.NewTimer(deadline.Sub(time.Now())).C:
+ return 0, nil, os.ErrDeadlineExceeded
+ case <-notifyCh:
+ }
+ }
+
+ min := header.ICMPv6MinimumSize
+ if pc.laddr.addr.Is4() {
+ min = header.ICMPv4MinimumSize
+ }
+ reply := make([]byte, min+len(p))
+ w := tcpip.SliceWriter(reply)
+
+ res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{
+ NeedRemoteAddr: true,
+ })
+ if tcpipErr != nil {
+ return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
+ }
+
+ addr = PingAddr{netip.AddrFromSlice([]byte(res.RemoteAddr.Addr))}
+ copy(p, reply[min:res.Count])
+ return res.Count - min, addr, nil
+}
+
+func (pc *PingConn) Read(p []byte) (n int, err error) {
+ n, _, err = pc.ReadFrom(p)
+ return
+}
+
+func (pc *PingConn) SetDeadline(t time.Time) error {
+ // pc.SetWriteDeadline is unimplemented
+
+ return pc.SetReadDeadline(t)
+}
+
+func (pc *PingConn) SetReadDeadline(t time.Time) error {
+ pc.deadline = t
+ return nil
+}
+
var (
errNoSuchHost = errors.New("no such host")
errLameReferral = errors.New("lame referral")
@@ -755,33 +930,38 @@ func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, er
return now.Add(timeout), nil
}
+var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`)
+
func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
if ctx == nil {
panic("nil context")
}
- var acceptV4, acceptV6, useUDP bool
- if len(network) == 3 {
+ var acceptV4, acceptV6 bool
+ matches := protoSplitter.FindStringSubmatch(network)
+ if matches == nil {
+ return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
+ } else if len(matches[2]) == 0 {
acceptV4 = true
acceptV6 = true
- } else if len(network) == 4 {
- acceptV4 = network[3] == '4'
- acceptV6 = network[3] == '6'
- }
- if !acceptV4 && !acceptV6 {
- return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
- }
- if network[:3] == "udp" {
- useUDP = true
- } else if network[:3] != "tcp" {
- return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
- }
- host, sport, err := net.SplitHostPort(address)
- if err != nil {
- return nil, &net.OpError{Op: "dial", Err: err}
+ } else {
+ acceptV4 = matches[2][0] == '4'
+ acceptV6 = !acceptV4
}
- port, err := strconv.Atoi(sport)
- if err != nil || port < 0 || port > 65535 {
- return nil, &net.OpError{Op: "dial", Err: errNumericPort}
+ var host string
+ var port int
+ if matches[1] == "ping" {
+ host = address
+ } else {
+ var sport string
+ var err error
+ host, sport, err = net.SplitHostPort(address)
+ if err != nil {
+ return nil, &net.OpError{Op: "dial", Err: err}
+ }
+ port, err = strconv.Atoi(sport)
+ if err != nil || port < 0 || port > 65535 {
+ return nil, &net.OpError{Op: "dial", Err: errNumericPort}
+ }
}
allAddr, err := tnet.LookupContextHost(ctx, host)
if err != nil {
@@ -829,10 +1009,13 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
}
var c net.Conn
- if useUDP {
- c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
- } else {
+ switch matches[1] {
+ case "tcp":
c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
+ case "udp":
+ c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
+ case "ping":
+ c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr())
}
if err == nil {
return c, nil