aboutsummaryrefslogtreecommitdiffstats
path: root/tun/netstack/tun.go
diff options
context:
space:
mode:
Diffstat (limited to 'tun/netstack/tun.go')
-rw-r--r--tun/netstack/tun.go143
1 files changed, 81 insertions, 62 deletions
diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go
index 24d0835..f1c03f4 100644
--- a/tun/netstack/tun.go
+++ b/tun/netstack/tun.go
@@ -18,6 +18,7 @@ import (
"strings"
"time"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/tun"
"golang.org/x/net/dns/dnsmessage"
@@ -38,7 +39,7 @@ type netTun struct {
events chan tun.Event
incomingPacket chan buffer.VectorisedView
mtu int
- dnsServers []net.IP
+ dnsServers []netip.Addr
hasV4, hasV6 bool
}
type endpoint netTun
@@ -94,7 +95,7 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
}
-func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Net, error) {
+func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
@@ -112,25 +113,23 @@ func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Ne
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
}
for _, ip := range localAddresses {
- if ip4 := ip.To4(); ip4 != nil {
- protoAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.Address(ip4).WithPrefix(),
- }
- tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
- if tcpipErr != nil {
- return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip4, tcpipErr)
- }
+ var protoNumber tcpip.NetworkProtocolNumber
+ if ip.Is4() {
+ protoNumber = ipv4.ProtocolNumber
+ } else if ip.Is6() {
+ protoNumber = ipv6.ProtocolNumber
+ }
+ protoAddr := tcpip.ProtocolAddress{
+ Protocol: protoNumber,
+ AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
+ }
+ tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
+ if tcpipErr != nil {
+ return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
+ }
+ if ip.Is4() {
dev.hasV4 = true
- } else {
- protoAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: tcpip.Address(ip).WithPrefix(),
- }
- tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
- if tcpipErr != nil {
- return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
- }
+ } else if ip.Is6() {
dev.hasV6 = true
}
}
@@ -202,62 +201,83 @@ func (tun *netTun) MTU() (int, error) {
return tun.mtu, nil
}
-func convertToFullAddr(ip net.IP, port int) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
- if ip4 := ip.To4(); ip4 != nil {
- return tcpip.FullAddress{
- NIC: 1,
- Addr: tcpip.Address(ip4),
- Port: uint16(port),
- }, ipv4.ProtocolNumber
+func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
+ var protoNumber tcpip.NetworkProtocolNumber
+ if endpoint.Addr().Is4() {
+ protoNumber = ipv4.ProtocolNumber
} else {
- return tcpip.FullAddress{
- NIC: 1,
- Addr: tcpip.Address(ip),
- Port: uint16(port),
- }, ipv6.ProtocolNumber
+ protoNumber = ipv6.ProtocolNumber
}
+ return tcpip.FullAddress{
+ NIC: 1,
+ Addr: tcpip.Address(endpoint.Addr().AsSlice()),
+ Port: endpoint.Port(),
+ }, protoNumber
+}
+
+func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.DialContextTCP(ctx, net.stack, fa, pn)
}
func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil {
- panic("todo: deal with auto addr semantics for nil addr")
+ return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
}
- fa, pn := convertToFullAddr(addr.IP, addr.Port)
- return gonet.DialContextTCP(ctx, net.stack, fa, pn)
+ return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
+}
+
+func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.DialTCP(net.stack, fa, pn)
}
func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil {
- panic("todo: deal with auto addr semantics for nil addr")
+ return net.DialTCPAddrPort(netip.AddrPort{})
}
- fa, pn := convertToFullAddr(addr.IP, addr.Port)
- return gonet.DialTCP(net.stack, fa, pn)
+ return net.DialTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
+}
+
+func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.ListenTCP(net.stack, fa, pn)
}
func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
if addr == nil {
- panic("todo: deal with auto addr semantics for nil addr")
+ return net.ListenTCPAddrPort(netip.AddrPort{})
}
- fa, pn := convertToFullAddr(addr.IP, addr.Port)
- return gonet.ListenTCP(net.stack, fa, pn)
+ return net.ListenTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
}
-func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
+func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
var lfa, rfa *tcpip.FullAddress
var pn tcpip.NetworkProtocolNumber
- if laddr != nil {
+ if laddr.IsValid() || laddr.Port() > 0 {
var addr tcpip.FullAddress
- addr, pn = convertToFullAddr(laddr.IP, laddr.Port)
+ addr, pn = convertToFullAddr(laddr)
lfa = &addr
}
- if raddr != nil {
+ if raddr.IsValid() || raddr.Port() > 0 {
var addr tcpip.FullAddress
- addr, pn = convertToFullAddr(raddr.IP, raddr.Port)
+ addr, pn = convertToFullAddr(raddr)
rfa = &addr
}
return gonet.DialUDP(net.stack, lfa, rfa, pn)
}
+func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
+ var la, ra netip.AddrPort
+ if laddr != nil {
+ la = netip.AddrPortFrom(netip.AddrFromSlice(laddr.IP), uint16(laddr.Port))
+ }
+ if raddr != nil {
+ ra = netip.AddrPortFrom(netip.AddrFromSlice(raddr.IP), uint16(raddr.Port))
+ }
+ return net.DialUDPAddrPort(la, ra)
+}
+
var (
errNoSuchHost = errors.New("no such host")
errLameReferral = errors.New("lame referral")
@@ -433,7 +453,7 @@ func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []by
return p, h, nil
}
-func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
+func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
q.Class = dnsmessage.ClassINET
id, udpReq, tcpReq, err := newRequest(q)
if err != nil {
@@ -447,9 +467,9 @@ func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Quest
var c net.Conn
var err error
if useUDP {
- c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: server, Port: 53})
+ c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
} else {
- c, err = tnet.DialContextTCP(ctx, &net.TCPAddr{IP: server, Port: 53})
+ c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
}
if err != nil {
@@ -600,8 +620,8 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
zlen = zidx
}
}
- if ip := net.ParseIP(host[:zlen]); ip != nil {
- return []string{host[:zlen]}, nil
+ if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
+ return []string{ip.String()}, nil
}
if !isDomainName(host) {
@@ -612,7 +632,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
server string
error
}
- var addrsV4, addrsV6 []net.IP
+ var addrsV4, addrsV6 []netip.Addr
lanes := 0
if tnet.hasV4 {
lanes++
@@ -667,7 +687,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
break loop
}
- addrsV4 = append(addrsV4, net.IP(a.A[:]))
+ addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
case dnsmessage.TypeAAAA:
aaaa, err := result.p.AAAAResource()
@@ -679,7 +699,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
break loop
}
- addrsV6 = append(addrsV6, net.IP(aaaa.AAAA[:]))
+ addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
default:
if err := result.p.SkipAnswer(); err != nil {
@@ -695,7 +715,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
}
// We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
- var addrs []net.IP
+ var addrs []netip.Addr
if tnet.hasV6 {
addrs = append(addrsV6, addrsV4...)
} else {
@@ -764,12 +784,11 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
if err != nil {
return nil, &net.OpError{Op: "dial", Err: err}
}
- var addrs []net.IP
+ var addrs []netip.AddrPort
for _, addr := range allAddr {
- if strings.IndexByte(addr, ':') != -1 && acceptV6 {
- addrs = append(addrs, net.ParseIP(addr))
- } else if strings.IndexByte(addr, '.') != -1 && acceptV4 {
- addrs = append(addrs, net.ParseIP(addr))
+ ip, err := netip.ParseAddr(addr)
+ if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
+ addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
}
}
if len(addrs) == 0 && len(allAddr) != 0 {
@@ -808,9 +827,9 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
var c net.Conn
if useUDP {
- c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: addr, Port: port})
+ c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
} else {
- c, err = tnet.DialContextTCP(dialCtx, &net.TCPAddr{IP: addr, Port: port})
+ c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
}
if err == nil {
return c, nil