aboutsummaryrefslogtreecommitdiffstats
path: root/tun/netstack/tun.go
diff options
context:
space:
mode:
Diffstat (limited to 'tun/netstack/tun.go')
-rw-r--r--tun/netstack/tun.go157
1 files changed, 70 insertions, 87 deletions
diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go
index c26d8ed..a7aec9e 100644
--- a/tun/netstack/tun.go
+++ b/tun/netstack/tun.go
@@ -1,11 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package netstack
import (
+ "bytes"
"context"
"crypto/rand"
"encoding/binary"
@@ -18,15 +19,17 @@ import (
"regexp"
"strconv"
"strings"
+ "syscall"
"time"
"golang.zx2c4.com/wireguard/tun"
"golang.org/x/net/dns/dnsmessage"
+ "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -37,69 +40,17 @@ import (
)
type netTun struct {
+ ep *channel.Endpoint
stack *stack.Stack
- dispatcher stack.NetworkDispatcher
events chan tun.Event
- incomingPacket chan buffer.VectorisedView
+ notifyHandle *channel.NotificationHandle
+ incomingPacket chan *buffer.View
mtu int
dnsServers []netip.Addr
hasV4, hasV6 bool
}
-type (
- endpoint netTun
- Net netTun
-)
-
-func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
- e.dispatcher = dispatcher
-}
-
-func (e *endpoint) IsAttached() bool {
- return e.dispatcher != nil
-}
-
-func (e *endpoint) MTU() uint32 {
- mtu, err := (*netTun)(e).MTU()
- if err != nil {
- panic(err)
- }
- return uint32(mtu)
-}
-
-func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return stack.CapabilityNone
-}
-
-func (*endpoint) MaxHeaderLength() uint16 {
- return 0
-}
-
-func (*endpoint) LinkAddress() tcpip.LinkAddress {
- return ""
-}
-
-func (*endpoint) Wait() {}
-
-func (e *endpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- e.incomingPacket <- buffer.NewVectorisedView(pkt.Size(), pkt.Views())
- return nil
-}
-
-func (e *endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- panic("not implemented")
-}
-
-func (e *endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
- panic("not implemented")
-}
-
-func (*endpoint) ARPHardwareType() header.ARPHardwareType {
- return header.ARPHardwareNone
-}
-
-func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
-}
+type Net netTun
func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
opts := stack.Options{
@@ -108,13 +59,20 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device,
HandleLocal: true,
}
dev := &netTun{
+ ep: channel.New(1024, uint32(mtu), ""),
stack: stack.New(opts),
events: make(chan tun.Event, 10),
- incomingPacket: make(chan buffer.VectorisedView),
+ incomingPacket: make(chan *buffer.View),
dnsServers: dnsServers,
mtu: mtu,
}
- tcpipErr := dev.stack.CreateNIC(1, (*endpoint)(dev))
+ sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
+ tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
+ if tcpipErr != nil {
+ return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
+ }
+ dev.notifyHandle = dev.ep.AddNotify(dev)
+ tcpipErr = dev.stack.CreateNIC(1, dev.ep)
if tcpipErr != nil {
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
}
@@ -127,7 +85,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device,
}
protoAddr := tcpip.ProtocolAddress{
Protocol: protoNumber,
- AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
+ AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
}
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
if tcpipErr != nil {
@@ -158,48 +116,70 @@ func (tun *netTun) File() *os.File {
return nil
}
-func (tun *netTun) Events() chan tun.Event {
+func (tun *netTun) Events() <-chan tun.Event {
return tun.events
}
-func (tun *netTun) Read(buf []byte, offset int) (int, error) {
+func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
view, ok := <-tun.incomingPacket
if !ok {
return 0, os.ErrClosed
}
- return view.Read(buf[offset:])
-}
-func (tun *netTun) Write(buf []byte, offset int) (int, error) {
- packet := buf[offset:]
- if len(packet) == 0 {
- return 0, nil
+ n, err := view.Read(buf[0][offset:])
+ if err != nil {
+ return 0, err
}
+ sizes[0] = n
+ return 1, nil
+}
- pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(len(packet), []buffer.View{buffer.NewViewFromBytes(packet)})})
- switch packet[0] >> 4 {
- case 4:
- tun.dispatcher.DeliverNetworkPacket("", "", ipv4.ProtocolNumber, pkb)
- case 6:
- tun.dispatcher.DeliverNetworkPacket("", "", ipv6.ProtocolNumber, pkb)
- }
+func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
+ for _, buf := range buf {
+ packet := buf[offset:]
+ if len(packet) == 0 {
+ continue
+ }
+ pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
+ switch packet[0] >> 4 {
+ case 4:
+ tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
+ case 6:
+ tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
+ default:
+ return 0, syscall.EAFNOSUPPORT
+ }
+ }
return len(buf), nil
}
-func (tun *netTun) Flush() error {
- return nil
+func (tun *netTun) WriteNotify() {
+ pkt := tun.ep.Read()
+ if pkt == nil {
+ return
+ }
+
+ view := pkt.ToView()
+ pkt.DecRef()
+
+ tun.incomingPacket <- view
}
func (tun *netTun) Close() error {
tun.stack.RemoveNIC(1)
+ tun.stack.Close()
+ tun.ep.RemoveNotify(tun.notifyHandle)
+ tun.ep.Close()
if tun.events != nil {
close(tun.events)
}
+
if tun.incomingPacket != nil {
close(tun.incomingPacket)
}
+
return nil
}
@@ -207,6 +187,10 @@ func (tun *netTun) MTU() (int, error) {
return tun.mtu, nil
}
+func (tun *netTun) BatchSize() int {
+ return 1
+}
+
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
var protoNumber tcpip.NetworkProtocolNumber
if endpoint.Addr().Is4() {
@@ -216,7 +200,7 @@ func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.Networ
}
return tcpip.FullAddress{
NIC: 1,
- Addr: tcpip.Address(endpoint.Addr().AsSlice()),
+ Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
Port: endpoint.Port(),
}, protoNumber
}
@@ -434,11 +418,10 @@ func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return 0, fmt.Errorf("ping write: mismatched protocols")
}
- buf := buffer.NewViewFromBytes(p)
- rdr := buf.Reader()
+ buf := bytes.NewReader(p)
rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0))
// won't block, no deadlines
- n64, tcpipErr := pc.ep.Write(&rdr, tcpip.WriteOptions{
+ n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{
To: &rfa,
})
if tcpipErr != nil {
@@ -453,8 +436,8 @@ func (pc *PingConn) Write(p []byte) (n int, err error) {
}
func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
- e, notifyCh := waiter.NewChannelEntry(nil)
- pc.wq.EventRegister(&e, waiter.EventIn)
+ e, notifyCh := waiter.NewChannelEntry(waiter.EventIn)
+ pc.wq.EventRegister(&e)
defer pc.wq.EventUnregister(&e)
select {
@@ -472,7 +455,7 @@ func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
}
- remoteAddr, _ := netip.AddrFromSlice([]byte(res.RemoteAddr.Addr))
+ remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice())
return res.Count, &PingAddr{remoteAddr}, nil
}
@@ -488,7 +471,7 @@ func (pc *PingConn) SetDeadline(t time.Time) error {
}
func (pc *PingConn) SetReadDeadline(t time.Time) error {
- pc.deadline.Reset(t.Sub(time.Now()))
+ pc.deadline.Reset(time.Until(t))
return nil
}
@@ -931,7 +914,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
}
}
- // We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
+ // We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled
var addrs []netip.Addr
if tnet.hasV6 {
addrs = append(addrsV6, addrsV4...)