aboutsummaryrefslogtreecommitdiffstats
path: root/tun/netstack/tun.go
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tun/netstack/tun.go422
1 files changed, 314 insertions, 108 deletions
diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go
index f1c03f4..2b73054 100644
--- a/tun/netstack/tun.go
+++ b/tun/netstack/tun.go
@@ -1,11 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package netstack
import (
+ "bytes"
"context"
"crypto/rand"
"encoding/binary"
@@ -13,102 +14,64 @@ import (
"fmt"
"io"
"net"
+ "net/netip"
"os"
+ "regexp"
"strconv"
"strings"
+ "syscall"
"time"
- "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/tun"
"golang.org/x/net/dns/dnsmessage"
+ "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
)
type netTun struct {
+ ep *channel.Endpoint
stack *stack.Stack
- dispatcher stack.NetworkDispatcher
events chan tun.Event
- incomingPacket chan buffer.VectorisedView
+ incomingPacket chan *buffer.View
mtu int
dnsServers []netip.Addr
hasV4, hasV6 bool
}
-type endpoint netTun
-type Net netTun
-
-func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
- e.dispatcher = dispatcher
-}
-
-func (e *endpoint) IsAttached() bool {
- return e.dispatcher != nil
-}
-
-func (e *endpoint) MTU() uint32 {
- mtu, err := (*netTun)(e).MTU()
- if err != nil {
- panic(err)
- }
- return uint32(mtu)
-}
-
-func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return stack.CapabilityNone
-}
-
-func (*endpoint) MaxHeaderLength() uint16 {
- return 0
-}
-
-func (*endpoint) LinkAddress() tcpip.LinkAddress {
- return ""
-}
-
-func (*endpoint) Wait() {}
-
-func (e *endpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- e.incomingPacket <- buffer.NewVectorisedView(pkt.Size(), pkt.Views())
- return nil
-}
-
-func (e *endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- panic("not implemented")
-}
-
-func (e *endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
- panic("not implemented")
-}
-func (*endpoint) ARPHardwareType() header.ARPHardwareType {
- return header.ARPHardwareNone
-}
-
-func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
-}
+type Net netTun
func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
HandleLocal: true,
}
dev := &netTun{
+ ep: channel.New(1024, uint32(mtu), ""),
stack: stack.New(opts),
events: make(chan tun.Event, 10),
- incomingPacket: make(chan buffer.VectorisedView),
+ incomingPacket: make(chan *buffer.View),
dnsServers: dnsServers,
mtu: mtu,
}
- tcpipErr := dev.stack.CreateNIC(1, (*endpoint)(dev))
+ sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
+ tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
+ if tcpipErr != nil {
+ return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
+ }
+ dev.ep.AddNotify(dev)
+ tcpipErr = dev.stack.CreateNIC(1, dev.ep)
if tcpipErr != nil {
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
}
@@ -121,7 +84,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device,
}
protoAddr := tcpip.ProtocolAddress{
Protocol: protoNumber,
- AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
+ AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
}
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
if tcpipErr != nil {
@@ -152,37 +115,54 @@ func (tun *netTun) File() *os.File {
return nil
}
-func (tun *netTun) Events() chan tun.Event {
+func (tun *netTun) Events() <-chan tun.Event {
return tun.events
}
-func (tun *netTun) Read(buf []byte, offset int) (int, error) {
+func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
view, ok := <-tun.incomingPacket
if !ok {
return 0, os.ErrClosed
}
- return view.Read(buf[offset:])
-}
-func (tun *netTun) Write(buf []byte, offset int) (int, error) {
- packet := buf[offset:]
- if len(packet) == 0 {
- return 0, nil
+ n, err := view.Read(buf[0][offset:])
+ if err != nil {
+ return 0, err
}
+ sizes[0] = n
+ return 1, nil
+}
- pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(len(packet), []buffer.View{buffer.NewViewFromBytes(packet)})})
- switch packet[0] >> 4 {
- case 4:
- tun.dispatcher.DeliverNetworkPacket("", "", ipv4.ProtocolNumber, pkb)
- case 6:
- tun.dispatcher.DeliverNetworkPacket("", "", ipv6.ProtocolNumber, pkb)
- }
+func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
+ for _, buf := range buf {
+ packet := buf[offset:]
+ if len(packet) == 0 {
+ continue
+ }
+ pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
+ switch packet[0] >> 4 {
+ case 4:
+ tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
+ case 6:
+ tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
+ default:
+ return 0, syscall.EAFNOSUPPORT
+ }
+ }
return len(buf), nil
}
-func (tun *netTun) Flush() error {
- return nil
+func (tun *netTun) WriteNotify() {
+ pkt := tun.ep.Read()
+ if pkt.IsNil() {
+ return
+ }
+
+ view := pkt.ToView()
+ pkt.DecRef()
+
+ tun.incomingPacket <- view
}
func (tun *netTun) Close() error {
@@ -191,9 +171,13 @@ func (tun *netTun) Close() error {
if tun.events != nil {
close(tun.events)
}
+
+ tun.ep.Close()
+
if tun.incomingPacket != nil {
close(tun.incomingPacket)
}
+
return nil
}
@@ -201,6 +185,10 @@ func (tun *netTun) MTU() (int, error) {
return tun.mtu, nil
}
+func (tun *netTun) BatchSize() int {
+ return 1
+}
+
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
var protoNumber tcpip.NetworkProtocolNumber
if endpoint.Addr().Is4() {
@@ -210,7 +198,7 @@ func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.Networ
}
return tcpip.FullAddress{
NIC: 1,
- Addr: tcpip.Address(endpoint.Addr().AsSlice()),
+ Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
Port: endpoint.Port(),
}, protoNumber
}
@@ -224,7 +212,8 @@ func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.T
if addr == nil {
return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
}
- return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
+ ip, _ := netip.AddrFromSlice(addr.IP)
+ return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port)))
}
func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
@@ -236,7 +225,8 @@ func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil {
return net.DialTCPAddrPort(netip.AddrPort{})
}
- return net.DialTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
+ ip, _ := netip.AddrFromSlice(addr.IP)
+ return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
}
func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
@@ -248,7 +238,8 @@ func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
if addr == nil {
return net.ListenTCPAddrPort(netip.AddrPort{})
}
- return net.ListenTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
+ ip, _ := netip.AddrFromSlice(addr.IP)
+ return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
}
func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
@@ -267,17 +258,221 @@ func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, er
return gonet.DialUDP(net.stack, lfa, rfa, pn)
}
+func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) {
+ return net.DialUDPAddrPort(laddr, netip.AddrPort{})
+}
+
func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
var la, ra netip.AddrPort
if laddr != nil {
- la = netip.AddrPortFrom(netip.AddrFromSlice(laddr.IP), uint16(laddr.Port))
+ ip, _ := netip.AddrFromSlice(laddr.IP)
+ la = netip.AddrPortFrom(ip, uint16(laddr.Port))
}
if raddr != nil {
- ra = netip.AddrPortFrom(netip.AddrFromSlice(raddr.IP), uint16(raddr.Port))
+ ip, _ := netip.AddrFromSlice(raddr.IP)
+ ra = netip.AddrPortFrom(ip, uint16(raddr.Port))
}
return net.DialUDPAddrPort(la, ra)
}
+func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
+ return net.DialUDP(laddr, nil)
+}
+
+type PingConn struct {
+ laddr PingAddr
+ raddr PingAddr
+ wq waiter.Queue
+ ep tcpip.Endpoint
+ deadline *time.Timer
+}
+
+type PingAddr struct{ addr netip.Addr }
+
+func (ia PingAddr) String() string {
+ return ia.addr.String()
+}
+
+func (ia PingAddr) Network() string {
+ if ia.addr.Is4() {
+ return "ping4"
+ } else if ia.addr.Is6() {
+ return "ping6"
+ }
+ return "ping"
+}
+
+func (ia PingAddr) Addr() netip.Addr {
+ return ia.addr
+}
+
+func PingAddrFromAddr(addr netip.Addr) *PingAddr {
+ return &PingAddr{addr}
+}
+
+func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
+ if !laddr.IsValid() && !raddr.IsValid() {
+ return nil, errors.New("ping dial: invalid address")
+ }
+ v6 := laddr.Is6() || raddr.Is6()
+ bind := laddr.IsValid()
+ if !bind {
+ if v6 {
+ laddr = netip.IPv6Unspecified()
+ } else {
+ laddr = netip.IPv4Unspecified()
+ }
+ }
+
+ tn := icmp.ProtocolNumber4
+ pn := ipv4.ProtocolNumber
+ if v6 {
+ tn = icmp.ProtocolNumber6
+ pn = ipv6.ProtocolNumber
+ }
+
+ pc := &PingConn{
+ laddr: PingAddr{laddr},
+ deadline: time.NewTimer(time.Hour << 10),
+ }
+ pc.deadline.Stop()
+
+ ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
+ if tcpipErr != nil {
+ return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr)
+ }
+ pc.ep = ep
+
+ if bind {
+ fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0))
+ if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil {
+ return nil, fmt.Errorf("ping bind: %s", tcpipErr)
+ }
+ }
+
+ if raddr.IsValid() {
+ pc.raddr = PingAddr{raddr}
+ fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0))
+ if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil {
+ return nil, fmt.Errorf("ping connect: %s", tcpipErr)
+ }
+ }
+
+ return pc, nil
+}
+
+func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) {
+ return net.DialPingAddr(laddr, netip.Addr{})
+}
+
+func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) {
+ var la, ra netip.Addr
+ if laddr != nil {
+ la = laddr.addr
+ }
+ if raddr != nil {
+ ra = raddr.addr
+ }
+ return net.DialPingAddr(la, ra)
+}
+
+func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) {
+ var la netip.Addr
+ if laddr != nil {
+ la = laddr.addr
+ }
+ return net.ListenPingAddr(la)
+}
+
+func (pc *PingConn) LocalAddr() net.Addr {
+ return pc.laddr
+}
+
+func (pc *PingConn) RemoteAddr() net.Addr {
+ return pc.raddr
+}
+
+func (pc *PingConn) Close() error {
+ pc.deadline.Reset(0)
+ pc.ep.Close()
+ return nil
+}
+
+func (pc *PingConn) SetWriteDeadline(t time.Time) error {
+ return errors.New("not implemented")
+}
+
+func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
+ var na netip.Addr
+ switch v := addr.(type) {
+ case *PingAddr:
+ na = v.addr
+ case *net.IPAddr:
+ na, _ = netip.AddrFromSlice(v.IP)
+ default:
+ return 0, fmt.Errorf("ping write: wrong net.Addr type")
+ }
+ if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) {
+ return 0, fmt.Errorf("ping write: mismatched protocols")
+ }
+
+ buf := bytes.NewReader(p)
+ rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0))
+ // won't block, no deadlines
+ n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{
+ To: &rfa,
+ })
+ if tcpipErr != nil {
+ return int(n64), fmt.Errorf("ping write: %s", tcpipErr)
+ }
+
+ return int(n64), nil
+}
+
+func (pc *PingConn) Write(p []byte) (n int, err error) {
+ return pc.WriteTo(p, &pc.raddr)
+}
+
+func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
+ e, notifyCh := waiter.NewChannelEntry(waiter.EventIn)
+ pc.wq.EventRegister(&e)
+ defer pc.wq.EventUnregister(&e)
+
+ select {
+ case <-pc.deadline.C:
+ return 0, nil, os.ErrDeadlineExceeded
+ case <-notifyCh:
+ }
+
+ w := tcpip.SliceWriter(p)
+
+ res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{
+ NeedRemoteAddr: true,
+ })
+ if tcpipErr != nil {
+ return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
+ }
+
+ remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice())
+ return res.Count, &PingAddr{remoteAddr}, nil
+}
+
+func (pc *PingConn) Read(p []byte) (n int, err error) {
+ n, _, err = pc.ReadFrom(p)
+ return
+}
+
+func (pc *PingConn) SetDeadline(t time.Time) error {
+ // pc.SetWriteDeadline is unimplemented
+
+ return pc.SetReadDeadline(t)
+}
+
+func (pc *PingConn) SetReadDeadline(t time.Time) error {
+ pc.deadline.Reset(time.Until(t))
+ return nil
+}
+
var (
errNoSuchHost = errors.New("no such host")
errLameReferral = errors.New("lame referral")
@@ -476,7 +671,10 @@ func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Q
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
if d, ok := ctx.Deadline(); ok && !d.IsZero() {
- c.SetDeadline(d)
+ err := c.SetDeadline(d)
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
}
var p dnsmessage.Parser
var h dnsmessage.Header
@@ -714,7 +912,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
}
}
- // We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
+ // We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled
var addrs []netip.Addr
if tnet.hasV6 {
addrs = append(addrsV6, addrsV4...)
@@ -752,33 +950,38 @@ func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, er
return now.Add(timeout), nil
}
+var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`)
+
func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
if ctx == nil {
panic("nil context")
}
- var acceptV4, acceptV6, useUDP bool
- if len(network) == 3 {
+ var acceptV4, acceptV6 bool
+ matches := protoSplitter.FindStringSubmatch(network)
+ if matches == nil {
+ return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
+ } else if len(matches[2]) == 0 {
acceptV4 = true
acceptV6 = true
- } else if len(network) == 4 {
- acceptV4 = network[3] == '4'
- acceptV6 = network[3] == '6'
- }
- if !acceptV4 && !acceptV6 {
- return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
- }
- if network[:3] == "udp" {
- useUDP = true
- } else if network[:3] != "tcp" {
- return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
- }
- host, sport, err := net.SplitHostPort(address)
- if err != nil {
- return nil, &net.OpError{Op: "dial", Err: err}
+ } else {
+ acceptV4 = matches[2][0] == '4'
+ acceptV6 = !acceptV4
}
- port, err := strconv.Atoi(sport)
- if err != nil || port < 0 || port > 65535 {
- return nil, &net.OpError{Op: "dial", Err: errNumericPort}
+ var host string
+ var port int
+ if matches[1] == "ping" {
+ host = address
+ } else {
+ var sport string
+ var err error
+ host, sport, err = net.SplitHostPort(address)
+ if err != nil {
+ return nil, &net.OpError{Op: "dial", Err: err}
+ }
+ port, err = strconv.Atoi(sport)
+ if err != nil || port < 0 || port > 65535 {
+ return nil, &net.OpError{Op: "dial", Err: errNumericPort}
+ }
}
allAddr, err := tnet.LookupContextHost(ctx, host)
if err != nil {
@@ -826,10 +1029,13 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
}
var c net.Conn
- if useUDP {
- c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
- } else {
+ switch matches[1] {
+ case "tcp":
c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
+ case "udp":
+ c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
+ case "ping":
+ c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr())
}
if err == nil {
return c, nil