diff options
Diffstat (limited to 'tun/netstack/tun.go')
-rw-r--r-- | tun/netstack/tun.go | 157 |
1 files changed, 70 insertions, 87 deletions
diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index c26d8ed..a7aec9e 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -1,11 +1,12 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package netstack import ( + "bytes" "context" "crypto/rand" "encoding/binary" @@ -18,15 +19,17 @@ import ( "regexp" "strconv" "strings" + "syscall" "time" "golang.zx2c4.com/wireguard/tun" "golang.org/x/net/dns/dnsmessage" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -37,69 +40,17 @@ import ( ) type netTun struct { + ep *channel.Endpoint stack *stack.Stack - dispatcher stack.NetworkDispatcher events chan tun.Event - incomingPacket chan buffer.VectorisedView + notifyHandle *channel.NotificationHandle + incomingPacket chan *buffer.View mtu int dnsServers []netip.Addr hasV4, hasV6 bool } -type ( - endpoint netTun - Net netTun -) - -func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { - e.dispatcher = dispatcher -} - -func (e *endpoint) IsAttached() bool { - return e.dispatcher != nil -} - -func (e *endpoint) MTU() uint32 { - mtu, err := (*netTun)(e).MTU() - if err != nil { - panic(err) - } - return uint32(mtu) -} - -func (*endpoint) Capabilities() stack.LinkEndpointCapabilities { - return stack.CapabilityNone -} - -func (*endpoint) MaxHeaderLength() uint16 { - return 0 -} - -func (*endpoint) LinkAddress() tcpip.LinkAddress { - return "" -} - -func (*endpoint) Wait() {} - -func (e *endpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - e.incomingPacket <- buffer.NewVectorisedView(pkt.Size(), pkt.Views()) - return nil -} - -func (e *endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - panic("not implemented") -} - -func (e *endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { - panic("not implemented") -} - -func (*endpoint) ARPHardwareType() header.ARPHardwareType { - return header.ARPHardwareNone -} - -func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { -} +type Net netTun func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) { opts := stack.Options{ @@ -108,13 +59,20 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, HandleLocal: true, } dev := &netTun{ + ep: channel.New(1024, uint32(mtu), ""), stack: stack.New(opts), events: make(chan tun.Event, 10), - incomingPacket: make(chan buffer.VectorisedView), + incomingPacket: make(chan *buffer.View), dnsServers: dnsServers, mtu: mtu, } - tcpipErr := dev.stack.CreateNIC(1, (*endpoint)(dev)) + sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default + tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) + if tcpipErr != nil { + return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) + } + dev.notifyHandle = dev.ep.AddNotify(dev) + tcpipErr = dev.stack.CreateNIC(1, dev.ep) if tcpipErr != nil { return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) } @@ -127,7 +85,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, } protoAddr := tcpip.ProtocolAddress{ Protocol: protoNumber, - AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(), + AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), } tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) if tcpipErr != nil { @@ -158,48 +116,70 @@ func (tun *netTun) File() *os.File { return nil } -func (tun *netTun) Events() chan tun.Event { +func (tun *netTun) Events() <-chan tun.Event { return tun.events } -func (tun *netTun) Read(buf []byte, offset int) (int, error) { +func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { view, ok := <-tun.incomingPacket if !ok { return 0, os.ErrClosed } - return view.Read(buf[offset:]) -} -func (tun *netTun) Write(buf []byte, offset int) (int, error) { - packet := buf[offset:] - if len(packet) == 0 { - return 0, nil + n, err := view.Read(buf[0][offset:]) + if err != nil { + return 0, err } + sizes[0] = n + return 1, nil +} - pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(len(packet), []buffer.View{buffer.NewViewFromBytes(packet)})}) - switch packet[0] >> 4 { - case 4: - tun.dispatcher.DeliverNetworkPacket("", "", ipv4.ProtocolNumber, pkb) - case 6: - tun.dispatcher.DeliverNetworkPacket("", "", ipv6.ProtocolNumber, pkb) - } +func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { + for _, buf := range buf { + packet := buf[offset:] + if len(packet) == 0 { + continue + } + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) + switch packet[0] >> 4 { + case 4: + tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) + case 6: + tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) + default: + return 0, syscall.EAFNOSUPPORT + } + } return len(buf), nil } -func (tun *netTun) Flush() error { - return nil +func (tun *netTun) WriteNotify() { + pkt := tun.ep.Read() + if pkt == nil { + return + } + + view := pkt.ToView() + pkt.DecRef() + + tun.incomingPacket <- view } func (tun *netTun) Close() error { tun.stack.RemoveNIC(1) + tun.stack.Close() + tun.ep.RemoveNotify(tun.notifyHandle) + tun.ep.Close() if tun.events != nil { close(tun.events) } + if tun.incomingPacket != nil { close(tun.incomingPacket) } + return nil } @@ -207,6 +187,10 @@ func (tun *netTun) MTU() (int, error) { return tun.mtu, nil } +func (tun *netTun) BatchSize() int { + return 1 +} + func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { var protoNumber tcpip.NetworkProtocolNumber if endpoint.Addr().Is4() { @@ -216,7 +200,7 @@ func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.Networ } return tcpip.FullAddress{ NIC: 1, - Addr: tcpip.Address(endpoint.Addr().AsSlice()), + Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), Port: endpoint.Port(), }, protoNumber } @@ -434,11 +418,10 @@ func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { return 0, fmt.Errorf("ping write: mismatched protocols") } - buf := buffer.NewViewFromBytes(p) - rdr := buf.Reader() + buf := bytes.NewReader(p) rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0)) // won't block, no deadlines - n64, tcpipErr := pc.ep.Write(&rdr, tcpip.WriteOptions{ + n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{ To: &rfa, }) if tcpipErr != nil { @@ -453,8 +436,8 @@ func (pc *PingConn) Write(p []byte) (n int, err error) { } func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - e, notifyCh := waiter.NewChannelEntry(nil) - pc.wq.EventRegister(&e, waiter.EventIn) + e, notifyCh := waiter.NewChannelEntry(waiter.EventIn) + pc.wq.EventRegister(&e) defer pc.wq.EventUnregister(&e) select { @@ -472,7 +455,7 @@ func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { return 0, nil, fmt.Errorf("ping read: %s", tcpipErr) } - remoteAddr, _ := netip.AddrFromSlice([]byte(res.RemoteAddr.Addr)) + remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice()) return res.Count, &PingAddr{remoteAddr}, nil } @@ -488,7 +471,7 @@ func (pc *PingConn) SetDeadline(t time.Time) error { } func (pc *PingConn) SetReadDeadline(t time.Time) error { - pc.deadline.Reset(t.Sub(time.Now())) + pc.deadline.Reset(time.Until(t)) return nil } @@ -931,7 +914,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, } } } - // We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled + // We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled var addrs []netip.Addr if tnet.hasV6 { addrs = append(addrsV6, addrsV4...) |