From ef8d6804d77d9ce09f0e2c7f6d85bbe222712b73 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Fri, 5 Nov 2021 01:52:54 +0100 Subject: global: use netip where possible now There are more places where we'll need to add it later, when Go 1.18 comes out with support for it in the "net" package. Also, allowedips still uses slices internally, which might be suboptimal. Signed-off-by: Jason A. Donenfeld --- tun/netstack/tun.go | 143 +++++++++++++++++++++++++++++----------------------- 1 file changed, 81 insertions(+), 62 deletions(-) (limited to 'tun/netstack/tun.go') diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index 24d0835..f1c03f4 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -18,6 +18,7 @@ import ( "strings" "time" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/tun" "golang.org/x/net/dns/dnsmessage" @@ -38,7 +39,7 @@ type netTun struct { events chan tun.Event incomingPacket chan buffer.VectorisedView mtu int - dnsServers []net.IP + dnsServers []netip.Addr hasV4, hasV6 bool } type endpoint netTun @@ -94,7 +95,7 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType { func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { } -func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Net, error) { +func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) { opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, @@ -112,25 +113,23 @@ func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Ne return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) } for _, ip := range localAddresses { - if ip4 := ip.To4(); ip4 != nil { - protoAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.Address(ip4).WithPrefix(), - } - tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) - if tcpipErr != nil { - return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip4, tcpipErr) - } + var protoNumber tcpip.NetworkProtocolNumber + if ip.Is4() { + protoNumber = ipv4.ProtocolNumber + } else if ip.Is6() { + protoNumber = ipv6.ProtocolNumber + } + protoAddr := tcpip.ProtocolAddress{ + Protocol: protoNumber, + AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(), + } + tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) + if tcpipErr != nil { + return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) + } + if ip.Is4() { dev.hasV4 = true - } else { - protoAddr := tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: tcpip.Address(ip).WithPrefix(), - } - tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) - if tcpipErr != nil { - return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) - } + } else if ip.Is6() { dev.hasV6 = true } } @@ -202,62 +201,83 @@ func (tun *netTun) MTU() (int, error) { return tun.mtu, nil } -func convertToFullAddr(ip net.IP, port int) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { - if ip4 := ip.To4(); ip4 != nil { - return tcpip.FullAddress{ - NIC: 1, - Addr: tcpip.Address(ip4), - Port: uint16(port), - }, ipv4.ProtocolNumber +func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { + var protoNumber tcpip.NetworkProtocolNumber + if endpoint.Addr().Is4() { + protoNumber = ipv4.ProtocolNumber } else { - return tcpip.FullAddress{ - NIC: 1, - Addr: tcpip.Address(ip), - Port: uint16(port), - }, ipv6.ProtocolNumber + protoNumber = ipv6.ProtocolNumber } + return tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.Address(endpoint.Addr().AsSlice()), + Port: endpoint.Port(), + }, protoNumber +} + +func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialContextTCP(ctx, net.stack, fa, pn) } func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { if addr == nil { - panic("todo: deal with auto addr semantics for nil addr") + return net.DialContextTCPAddrPort(ctx, netip.AddrPort{}) } - fa, pn := convertToFullAddr(addr.IP, addr.Port) - return gonet.DialContextTCP(ctx, net.stack, fa, pn) + return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port))) +} + +func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialTCP(net.stack, fa, pn) } func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { if addr == nil { - panic("todo: deal with auto addr semantics for nil addr") + return net.DialTCPAddrPort(netip.AddrPort{}) } - fa, pn := convertToFullAddr(addr.IP, addr.Port) - return gonet.DialTCP(net.stack, fa, pn) + return net.DialTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port))) +} + +func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) { + fa, pn := convertToFullAddr(addr) + return gonet.ListenTCP(net.stack, fa, pn) } func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { if addr == nil { - panic("todo: deal with auto addr semantics for nil addr") + return net.ListenTCPAddrPort(netip.AddrPort{}) } - fa, pn := convertToFullAddr(addr.IP, addr.Port) - return gonet.ListenTCP(net.stack, fa, pn) + return net.ListenTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port))) } -func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { +func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { var lfa, rfa *tcpip.FullAddress var pn tcpip.NetworkProtocolNumber - if laddr != nil { + if laddr.IsValid() || laddr.Port() > 0 { var addr tcpip.FullAddress - addr, pn = convertToFullAddr(laddr.IP, laddr.Port) + addr, pn = convertToFullAddr(laddr) lfa = &addr } - if raddr != nil { + if raddr.IsValid() || raddr.Port() > 0 { var addr tcpip.FullAddress - addr, pn = convertToFullAddr(raddr.IP, raddr.Port) + addr, pn = convertToFullAddr(raddr) rfa = &addr } return gonet.DialUDP(net.stack, lfa, rfa, pn) } +func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { + var la, ra netip.AddrPort + if laddr != nil { + la = netip.AddrPortFrom(netip.AddrFromSlice(laddr.IP), uint16(laddr.Port)) + } + if raddr != nil { + ra = netip.AddrPortFrom(netip.AddrFromSlice(raddr.IP), uint16(raddr.Port)) + } + return net.DialUDPAddrPort(la, ra) +} + var ( errNoSuchHost = errors.New("no such host") errLameReferral = errors.New("lame referral") @@ -433,7 +453,7 @@ func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []by return p, h, nil } -func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { +func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { q.Class = dnsmessage.ClassINET id, udpReq, tcpReq, err := newRequest(q) if err != nil { @@ -447,9 +467,9 @@ func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Quest var c net.Conn var err error if useUDP { - c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: server, Port: 53}) + c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53)) } else { - c, err = tnet.DialContextTCP(ctx, &net.TCPAddr{IP: server, Port: 53}) + c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53)) } if err != nil { @@ -600,8 +620,8 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, zlen = zidx } } - if ip := net.ParseIP(host[:zlen]); ip != nil { - return []string{host[:zlen]}, nil + if ip, err := netip.ParseAddr(host[:zlen]); err == nil { + return []string{ip.String()}, nil } if !isDomainName(host) { @@ -612,7 +632,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, server string error } - var addrsV4, addrsV6 []net.IP + var addrsV4, addrsV6 []netip.Addr lanes := 0 if tnet.hasV4 { lanes++ @@ -667,7 +687,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, } break loop } - addrsV4 = append(addrsV4, net.IP(a.A[:])) + addrsV4 = append(addrsV4, netip.AddrFrom4(a.A)) case dnsmessage.TypeAAAA: aaaa, err := result.p.AAAAResource() @@ -679,7 +699,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, } break loop } - addrsV6 = append(addrsV6, net.IP(aaaa.AAAA[:])) + addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA)) default: if err := result.p.SkipAnswer(); err != nil { @@ -695,7 +715,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, } } // We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled - var addrs []net.IP + var addrs []netip.Addr if tnet.hasV6 { addrs = append(addrsV6, addrsV4...) } else { @@ -764,12 +784,11 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net. if err != nil { return nil, &net.OpError{Op: "dial", Err: err} } - var addrs []net.IP + var addrs []netip.AddrPort for _, addr := range allAddr { - if strings.IndexByte(addr, ':') != -1 && acceptV6 { - addrs = append(addrs, net.ParseIP(addr)) - } else if strings.IndexByte(addr, '.') != -1 && acceptV4 { - addrs = append(addrs, net.ParseIP(addr)) + ip, err := netip.ParseAddr(addr) + if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) { + addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port))) } } if len(addrs) == 0 && len(allAddr) != 0 { @@ -808,9 +827,9 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net. var c net.Conn if useUDP { - c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: addr, Port: port}) + c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr) } else { - c, err = tnet.DialContextTCP(dialCtx, &net.TCPAddr{IP: addr, Port: port}) + c, err = tnet.DialContextTCPAddrPort(dialCtx, addr) } if err == nil { return c, nil -- cgit v1.2.3-59-g8ed1b