aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-11-05 01:52:54 +0100
committerJason A. Donenfeld <Jason@zx2c4.com>2021-11-23 22:03:15 +0100
commitef8d6804d77d9ce09f0e2c7f6d85bbe222712b73 (patch)
tree5b4a3b53dfb092f10cf11fbe0b5724f58df3a1bf
parentdevice: only propagate roaming value before peer is referenced elsewhere (diff)
downloadwireguard-go-ef8d6804d77d9ce09f0e2c7f6d85bbe222712b73.tar.xz
wireguard-go-ef8d6804d77d9ce09f0e2c7f6d85bbe222712b73.zip
global: use netip where possible now
There are more places where we'll need to add it later, when Go 1.18 comes out with support for it in the "net" package. Also, allowedips still uses slices internally, which might be suboptimal. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r--conn/bind_linux.go50
-rw-r--r--conn/bind_std.go19
-rw-r--r--conn/bind_windows.go19
-rw-r--r--conn/bindtest/bindtest.go14
-rw-r--r--conn/conn.go37
-rw-r--r--device/allowedips.go42
-rw-r--r--device/allowedips_rand_test.go6
-rw-r--r--device/allowedips_test.go6
-rw-r--r--device/device_test.go6
-rw-r--r--device/endpoint_test.go39
-rw-r--r--device/receive.go1
-rw-r--r--device/uapi.go11
-rw-r--r--go.mod7
-rw-r--r--go.sum13
-rw-r--r--ratelimiter/ratelimiter.go58
-rw-r--r--ratelimiter/ratelimiter_test.go33
-rw-r--r--tun/netstack/examples/http_client.go7
-rw-r--r--tun/netstack/examples/http_server.go6
-rw-r--r--tun/netstack/go.mod1
-rw-r--r--tun/netstack/go.sum4
-rw-r--r--tun/netstack/tun.go143
-rw-r--r--tun/tuntest/tuntest.go10
22 files changed, 247 insertions, 285 deletions
diff --git a/conn/bind_linux.go b/conn/bind_linux.go
index 7b970e6..da0670a 100644
--- a/conn/bind_linux.go
+++ b/conn/bind_linux.go
@@ -14,6 +14,7 @@ import (
"unsafe"
"golang.org/x/sys/unix"
+ "golang.zx2c4.com/go118/netip"
)
type ipv4Source struct {
@@ -70,32 +71,30 @@ var _ Bind = (*LinuxSocketBind)(nil)
func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
var end LinuxSocketEndpoint
- addr, err := parseEndpoint(s)
+ e, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
- ipv4 := addr.IP.To4()
- if ipv4 != nil {
+ if e.Addr().Is4() {
dst := end.dst4()
end.isV6 = false
- dst.Port = addr.Port
- copy(dst.Addr[:], ipv4)
+ dst.Port = int(e.Port())
+ dst.Addr = e.Addr().As4()
end.ClearSrc()
return &end, nil
}
- ipv6 := addr.IP.To16()
- if ipv6 != nil {
- zone, err := zoneToUint32(addr.Zone)
+ if e.Addr().Is6() {
+ zone, err := zoneToUint32(e.Addr().Zone())
if err != nil {
return nil, err
}
dst := end.dst6()
end.isV6 = true
- dst.Port = addr.Port
+ dst.Port = int(e.Port())
dst.ZoneId = zone
- copy(dst.Addr[:], ipv6[:])
+ dst.Addr = e.Addr().As16()
end.ClearSrc()
return &end, nil
}
@@ -266,29 +265,19 @@ func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
}
}
-func (end *LinuxSocketEndpoint) SrcIP() net.IP {
+func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
if !end.isV6 {
- return net.IPv4(
- end.src4().Src[0],
- end.src4().Src[1],
- end.src4().Src[2],
- end.src4().Src[3],
- )
+ return netip.AddrFrom4(end.src4().Src)
} else {
- return end.src6().src[:]
+ return netip.AddrFrom16(end.src6().src)
}
}
-func (end *LinuxSocketEndpoint) DstIP() net.IP {
+func (end *LinuxSocketEndpoint) DstIP() netip.Addr {
if !end.isV6 {
- return net.IPv4(
- end.dst4().Addr[0],
- end.dst4().Addr[1],
- end.dst4().Addr[2],
- end.dst4().Addr[3],
- )
+ return netip.AddrFrom4(end.dst4().Addr)
} else {
- return end.dst6().Addr[:]
+ return netip.AddrFrom16(end.dst6().Addr)
}
}
@@ -305,14 +294,13 @@ func (end *LinuxSocketEndpoint) SrcToString() string {
}
func (end *LinuxSocketEndpoint) DstToString() string {
- var udpAddr net.UDPAddr
- udpAddr.IP = end.DstIP()
+ var port int
if !end.isV6 {
- udpAddr.Port = end.dst4().Port
+ port = end.dst4().Port
} else {
- udpAddr.Port = end.dst6().Port
+ port = end.dst6().Port
}
- return udpAddr.String()
+ return netip.AddrPortFrom(end.DstIP(), uint16(port)).String()
}
func (end *LinuxSocketEndpoint) ClearDst() {
diff --git a/conn/bind_std.go b/conn/bind_std.go
index cb85cfd..a3cbb15 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -10,6 +10,8 @@ import (
"net"
"sync"
"syscall"
+
+ "golang.zx2c4.com/go118/netip"
)
// StdNetBind is meant to be a temporary solution on platforms for which
@@ -32,18 +34,23 @@ var _ Bind = (*StdNetBind)(nil)
var _ Endpoint = (*StdNetEndpoint)(nil)
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
- addr, err := parseEndpoint(s)
- return (*StdNetEndpoint)(addr), err
+ e, err := netip.ParseAddrPort(s)
+ return (*StdNetEndpoint)(&net.UDPAddr{
+ IP: e.Addr().AsSlice(),
+ Port: int(e.Port()),
+ Zone: e.Addr().Zone(),
+ }), err
}
func (*StdNetEndpoint) ClearSrc() {}
-func (e *StdNetEndpoint) DstIP() net.IP {
- return (*net.UDPAddr)(e).IP
+func (e *StdNetEndpoint) DstIP() netip.Addr {
+ a, _ := netip.AddrFromSlice((*net.UDPAddr)(e).IP)
+ return a
}
-func (e *StdNetEndpoint) SrcIP() net.IP {
- return nil // not supported
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+ return netip.Addr{} // not supported
}
func (e *StdNetEndpoint) DstToBytes() []byte {
diff --git a/conn/bind_windows.go b/conn/bind_windows.go
index 42e06ad..26a3af8 100644
--- a/conn/bind_windows.go
+++ b/conn/bind_windows.go
@@ -15,6 +15,7 @@ import (
"unsafe"
"golang.org/x/sys/windows"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn/winrio"
)
@@ -128,18 +129,18 @@ func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
func (*WinRingEndpoint) ClearSrc() {}
-func (e *WinRingEndpoint) DstIP() net.IP {
+func (e *WinRingEndpoint) DstIP() netip.Addr {
switch e.family {
case windows.AF_INET:
- return append([]byte{}, e.data[2:6]...)
+ return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
case windows.AF_INET6:
- return append([]byte{}, e.data[6:22]...)
+ return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
}
- return nil
+ return netip.Addr{}
}
-func (e *WinRingEndpoint) SrcIP() net.IP {
- return nil // not supported
+func (e *WinRingEndpoint) SrcIP() netip.Addr {
+ return netip.Addr{} // not supported
}
func (e *WinRingEndpoint) DstToBytes() []byte {
@@ -161,15 +162,13 @@ func (e *WinRingEndpoint) DstToBytes() []byte {
func (e *WinRingEndpoint) DstToString() string {
switch e.family {
case windows.AF_INET:
- addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
- return addr.String()
+ netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
case windows.AF_INET6:
var zone string
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
zone = strconv.FormatUint(uint64(scope), 10)
}
- addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
- return addr.String()
+ return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
}
return ""
}
diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go
index 7d43fb3..6a45896 100644
--- a/conn/bindtest/bindtest.go
+++ b/conn/bindtest/bindtest.go
@@ -10,8 +10,8 @@ import (
"math/rand"
"net"
"os"
- "strconv"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn"
)
@@ -61,9 +61,9 @@ func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d
func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
-func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
+func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
-func (c ChannelEndpoint) SrcIP() net.IP { return nil }
+func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
c.closeSignal = make(chan bool)
@@ -119,13 +119,9 @@ func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
}
func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
- _, port, err := net.SplitHostPort(s)
+ addr, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
- i, err := strconv.ParseUint(port, 10, 16)
- if err != nil {
- return nil, err
- }
- return ChannelEndpoint(i), nil
+ return ChannelEndpoint(addr.Port()), nil
}
diff --git a/conn/conn.go b/conn/conn.go
index 9cce9ad..35fb6b1 100644
--- a/conn/conn.go
+++ b/conn/conn.go
@@ -9,10 +9,11 @@ package conn
import (
"errors"
"fmt"
- "net"
"reflect"
"runtime"
"strings"
+
+ "golang.zx2c4.com/go118/netip"
)
// A ReceiveFunc receives a single inbound packet from the network.
@@ -68,8 +69,8 @@ type Endpoint interface {
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
- DstIP() net.IP
- SrcIP() net.IP
+ DstIP() netip.Addr
+ SrcIP() netip.Addr
}
var (
@@ -119,33 +120,3 @@ func (fn ReceiveFunc) PrettyName() string {
}
return name
}
-
-func parseEndpoint(s string) (*net.UDPAddr, error) {
- // ensure that the host is an IP address
-
- host, _, err := net.SplitHostPort(s)
- if err != nil {
- return nil, err
- }
- if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
- // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
- // trying to make sure with a small sanity test that this is a real IP address and
- // not something that's likely to incur DNS lookups.
- host = host[:i]
- }
- if ip := net.ParseIP(host); ip == nil {
- return nil, errors.New("Failed to parse IP address: " + host)
- }
-
- // parse address and port
-
- addr, err := net.ResolveUDPAddr("udp", s)
- if err != nil {
- return nil, err
- }
- ip4 := addr.IP.To4()
- if ip4 != nil {
- addr.IP = ip4
- }
- return addr, err
-}
diff --git a/device/allowedips.go b/device/allowedips.go
index c08399b..7a0b275 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -12,6 +12,8 @@ import (
"net"
"sync"
"unsafe"
+
+ "golang.zx2c4.com/go118/netip"
)
type parentIndirection struct {
@@ -26,7 +28,7 @@ type trieEntry struct {
cidr uint8
bitAtByte uint8
bitAtShift uint8
- bits net.IP
+ bits []byte
perPeerElem *list.Element
}
@@ -51,7 +53,7 @@ func swapU64(i uint64) uint64 {
return bits.ReverseBytes64(i)
}
-func commonBits(ip1 net.IP, ip2 net.IP) uint8 {
+func commonBits(ip1, ip2 []byte) uint8 {
size := len(ip1)
if size == net.IPv4len {
a := (*uint32)(unsafe.Pointer(&ip1[0]))
@@ -85,7 +87,7 @@ func (node *trieEntry) removeFromPeerEntries() {
}
}
-func (node *trieEntry) choose(ip net.IP) byte {
+func (node *trieEntry) choose(ip []byte) byte {
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
}
@@ -104,7 +106,7 @@ func (node *trieEntry) zeroizePointers() {
node.parent.parentBit = nil
}
-func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
+func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
parent = node
if parent.cidr == cidr {
@@ -117,7 +119,7 @@ func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry,
return
}
-func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
+func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
if *trie.parentBit == nil {
node := &trieEntry{
peer: peer,
@@ -207,7 +209,7 @@ func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
}
}
-func (node *trieEntry) lookup(ip net.IP) *Peer {
+func (node *trieEntry) lookup(ip []byte) *Peer {
var found *Peer
size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
@@ -229,13 +231,14 @@ type AllowedIPs struct {
mutex sync.RWMutex
}
-func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) {
+func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
table.mutex.RLock()
defer table.mutex.RUnlock()
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
node := elem.Value.(*trieEntry)
- if !cb(node.bits, node.cidr) {
+ a, _ := netip.AddrFromSlice(node.bits)
+ if !cb(netip.PrefixFrom(a, int(node.cidr))) {
return
}
}
@@ -283,28 +286,29 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
}
}
-func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
+func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
- switch len(ip) {
- case net.IPv6len:
- parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer)
- case net.IPv4len:
- parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer)
- default:
+ if prefix.Addr().Is6() {
+ ip := prefix.Addr().As16()
+ parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+ } else if prefix.Addr().Is4() {
+ ip := prefix.Addr().As4()
+ parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+ } else {
panic(errors.New("inserting unknown address type"))
}
}
-func (table *AllowedIPs) Lookup(address []byte) *Peer {
+func (table *AllowedIPs) Lookup(ip []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
- switch len(address) {
+ switch len(ip) {
case net.IPv6len:
- return table.IPv6.lookup(address)
+ return table.IPv6.lookup(ip)
case net.IPv4len:
- return table.IPv4.lookup(address)
+ return table.IPv4.lookup(ip)
default:
panic(errors.New("looking up unknown address type"))
}
diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go
index 16de170..ff56fe6 100644
--- a/device/allowedips_rand_test.go
+++ b/device/allowedips_rand_test.go
@@ -10,6 +10,8 @@ import (
"net"
"sort"
"testing"
+
+ "golang.zx2c4.com/go118/netip"
)
const (
@@ -93,14 +95,14 @@ func TestTrieRandom(t *testing.T) {
rand.Read(addr4[:])
cidr := uint8(rand.Intn(32) + 1)
index := rand.Intn(NumberOfPeers)
- allowedIPs.Insert(addr4[:], cidr, peers[index])
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
var addr6 [16]byte
rand.Read(addr6[:])
cidr = uint8(rand.Intn(128) + 1)
index = rand.Intn(NumberOfPeers)
- allowedIPs.Insert(addr6[:], cidr, peers[index])
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
}
diff --git a/device/allowedips_test.go b/device/allowedips_test.go
index 2059a88..a274997 100644
--- a/device/allowedips_test.go
+++ b/device/allowedips_test.go
@@ -9,6 +9,8 @@ import (
"math/rand"
"net"
"testing"
+
+ "golang.zx2c4.com/go118/netip"
)
type testPairCommonBits struct {
@@ -98,7 +100,7 @@ func TestTrieIPv4(t *testing.T) {
var allowedIPs AllowedIPs
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
- allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer)
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d byte) {
@@ -208,7 +210,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
- allowedIPs.Insert(addr, cidr, peer)
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d uint32) {
diff --git a/device/device_test.go b/device/device_test.go
index 29daeb9..84221be 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -11,7 +11,6 @@ import (
"fmt"
"io"
"math/rand"
- "net"
"runtime"
"runtime/pprof"
"sync"
@@ -19,6 +18,7 @@ import (
"testing"
"time"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/conn/bindtest"
"golang.zx2c4.com/wireguard/tun/tuntest"
@@ -96,7 +96,7 @@ type testPair [2]testPeer
type testPeer struct {
tun *tuntest.ChannelTUN
dev *Device
- ip net.IP
+ ip netip.Addr
}
type SendDirection bool
@@ -159,7 +159,7 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
for i := range pair {
p := &pair[i]
p.tun = tuntest.NewChannelTUN()
- p.ip = net.IPv4(1, 0, 0, byte(i+1))
+ p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
level := LogLevelVerbose
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
level = LogLevelError
diff --git a/device/endpoint_test.go b/device/endpoint_test.go
index 57c361c..f1ae47e 100644
--- a/device/endpoint_test.go
+++ b/device/endpoint_test.go
@@ -7,47 +7,44 @@ package device
import (
"math/rand"
- "net"
+
+ "golang.zx2c4.com/go118/netip"
)
type DummyEndpoint struct {
- src [16]byte
- dst [16]byte
+ src, dst netip.Addr
}
func CreateDummyEndpoint() (*DummyEndpoint, error) {
- var end DummyEndpoint
- if _, err := rand.Read(end.src[:]); err != nil {
+ var src, dst [16]byte
+ if _, err := rand.Read(src[:]); err != nil {
return nil, err
}
- _, err := rand.Read(end.dst[:])
- return &end, err
+ _, err := rand.Read(dst[:])
+ return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err
}
func (e *DummyEndpoint) ClearSrc() {}
func (e *DummyEndpoint) SrcToString() string {
- var addr net.UDPAddr
- addr.IP = e.SrcIP()
- addr.Port = 1000
- return addr.String()
+ return netip.AddrPortFrom(e.SrcIP(), 1000).String()
}
func (e *DummyEndpoint) DstToString() string {
- var addr net.UDPAddr
- addr.IP = e.DstIP()
- addr.Port = 1000
- return addr.String()
+ return netip.AddrPortFrom(e.DstIP(), 1000).String()
}
-func (e *DummyEndpoint) SrcToBytes() []byte {
- return e.src[:]
+func (e *DummyEndpoint) DstToBytes() []byte {
+ out := e.DstIP().AsSlice()
+ out = append(out, byte(1000&0xff))
+ out = append(out, byte((1000>>8)&0xff))
+ return out
}
-func (e *DummyEndpoint) DstIP() net.IP {
- return e.dst[:]
+func (e *DummyEndpoint) DstIP() netip.Addr {
+ return e.dst
}
-func (e *DummyEndpoint) SrcIP() net.IP {
- return e.src[:]
+func (e *DummyEndpoint) SrcIP() netip.Addr {
+ return e.src
}
diff --git a/device/receive.go b/device/receive.go
index 5857481..cc34498 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -17,7 +17,6 @@ import (
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
-
"golang.zx2c4.com/wireguard/conn"
)
diff --git a/device/uapi.go b/device/uapi.go
index 2306183..98e8311 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -18,6 +18,7 @@ import (
"sync/atomic"
"time"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/ipc"
)
@@ -121,8 +122,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
- device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool {
- sendf("allowed_ip=%s/%d", ip.String(), cidr)
+ device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
+ sendf("allowed_ip=%s", prefix.String())
return true
})
}
@@ -374,16 +375,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
case "allowed_ip":
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
-
- _, network, err := net.ParseCIDR(value)
+ prefix, err := netip.ParsePrefix(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
}
if peer.dummy {
return nil
}
- ones, _ := network.Mask.Size()
- device.allowedips.Insert(network.IP, uint8(ones), peer.Peer)
+ device.allowedips.Insert(prefix, peer.Peer)
case "protocol_version":
if value != "1" {
diff --git a/go.mod b/go.mod
index 856bb6c..b510960 100644
--- a/go.mod
+++ b/go.mod
@@ -3,8 +3,9 @@ module golang.zx2c4.com/wireguard
go 1.17
require (
- golang.org/x/crypto v0.0.0-20210921155107-089bfa567519
- golang.org/x/net v0.0.0-20211101193420-4a448f8816b3
- golang.org/x/sys v0.0.0-20211103235746-7861aae1554b
+ golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa
+ golang.org/x/net v0.0.0-20211111083644-e5c967477495
+ golang.org/x/sys v0.0.0-20211110154304-99a53858aa08
+ golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224
)
diff --git a/go.sum b/go.sum
index 37fe067..78f7367 100644
--- a/go.sum
+++ b/go.sum
@@ -1,16 +1,19 @@
-golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg=
-golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
+golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa h1:idItI2DDfCokpg0N51B2VtiLdJ4vAuXC9fnCb2gACo4=
+golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.0.0-20211101193420-4a448f8816b3 h1:VrJZAjbekhoRn7n5FBujY31gboH+iB3pdLxn3gE9FjU=
-golang.org/x/net v0.0.0-20211101193420-4a448f8816b3/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
+golang.org/x/net v0.0.0-20211111083644-e5c967477495 h1:cjxxlQm6d4kYbhpZ2ghvmI8xnq0AG+jXmzrhzfkyu5A=
+golang.org/x/net v0.0.0-20211111083644-e5c967477495/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20211103235746-7861aae1554b h1:1VkfZQv42XQlA/jchYumAnv1UPo6RgF9rJFkTgZIxO4=
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20211110154304-99a53858aa08 h1:WecRHqgE09JBkh/584XIE6PMz5KKE/vER4izNUi30AQ=
+golang.org/x/sys v0.0.0-20211110154304-99a53858aa08/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d h1:9+v0G0naRhLPOJEeJOL6NuXTtAHHwmkyZlgQJ0XcQ8I=
+golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go
index 2f7aa2a..8e78d5e 100644
--- a/ratelimiter/ratelimiter.go
+++ b/ratelimiter/ratelimiter.go
@@ -6,9 +6,10 @@
package ratelimiter
import (
- "net"
"sync"
"time"
+
+ "golang.zx2c4.com/go118/netip"
)
const (
@@ -30,8 +31,7 @@ type Ratelimiter struct {
timeNow func() time.Time
stopReset chan struct{} // send to reset, close to stop
- tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
- tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
+ table map[netip.Addr]*RatelimiterEntry
}
func (rate *Ratelimiter) Close() {
@@ -57,8 +57,7 @@ func (rate *Ratelimiter) Init() {
}
rate.stopReset = make(chan struct{})
- rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
- rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
+ rate.table = make(map[netip.Addr]*RatelimiterEntry)
stopReset := rate.stopReset // store in case Init is called again.
@@ -87,71 +86,39 @@ func (rate *Ratelimiter) cleanup() (empty bool) {
rate.mu.Lock()
defer rate.mu.Unlock()
- for key, entry := range rate.tableIPv4 {
+ for key, entry := range rate.table {
entry.mu.Lock()
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
- delete(rate.tableIPv4, key)
+ delete(rate.table, key)
}
entry.mu.Unlock()
}
- for key, entry := range rate.tableIPv6 {
- entry.mu.Lock()
- if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
- delete(rate.tableIPv6, key)
- }
- entry.mu.Unlock()
- }
-
- return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
+ return len(rate.table) == 0
}
-func (rate *Ratelimiter) Allow(ip net.IP) bool {
+func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
var entry *RatelimiterEntry
- var keyIPv4 [net.IPv4len]byte
- var keyIPv6 [net.IPv6len]byte
-
// lookup entry
-
- IPv4 := ip.To4()
- IPv6 := ip.To16()
-
rate.mu.RLock()
-
- if IPv4 != nil {
- copy(keyIPv4[:], IPv4)
- entry = rate.tableIPv4[keyIPv4]
- } else {
- copy(keyIPv6[:], IPv6)
- entry = rate.tableIPv6[keyIPv6]
- }
-
+ entry = rate.table[ip]
rate.mu.RUnlock()
// make new entry if not found
-
if entry == nil {
entry = new(RatelimiterEntry)
entry.tokens = maxTokens - packetCost
entry.lastTime = rate.timeNow()
rate.mu.Lock()
- if IPv4 != nil {
- rate.tableIPv4[keyIPv4] = entry
- if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
- rate.stopReset <- struct{}{}
- }
- } else {
- rate.tableIPv6[keyIPv6] = entry
- if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
- rate.stopReset <- struct{}{}
- }
+ rate.table[ip] = entry
+ if len(rate.table) == 1 {
+ rate.stopReset <- struct{}{}
}
rate.mu.Unlock()
return true
}
// add tokens to entry
-
entry.mu.Lock()
now := rate.timeNow()
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
@@ -161,7 +128,6 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
}
// subtract cost of packet
-
if entry.tokens > packetCost {
entry.tokens -= packetCost
entry.mu.Unlock()
diff --git a/ratelimiter/ratelimiter_test.go b/ratelimiter/ratelimiter_test.go
index f231fe5..3e06ff7 100644
--- a/ratelimiter/ratelimiter_test.go
+++ b/ratelimiter/ratelimiter_test.go
@@ -6,9 +6,10 @@
package ratelimiter
import (
- "net"
"testing"
"time"
+
+ "golang.zx2c4.com/go118/netip"
)
type result struct {
@@ -71,21 +72,21 @@ func TestRatelimiter(t *testing.T) {
text: "packet following 2 packet burst",
})
- ips := []net.IP{
- net.ParseIP("127.0.0.1"),
- net.ParseIP("192.168.1.1"),
- net.ParseIP("172.167.2.3"),
- net.ParseIP("97.231.252.215"),
- net.ParseIP("248.97.91.167"),
- net.ParseIP("188.208.233.47"),
- net.ParseIP("104.2.183.179"),
- net.ParseIP("72.129.46.120"),
- net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
- net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
- net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
- net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
- net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
- net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
+ ips := []netip.Addr{
+ netip.MustParseAddr("127.0.0.1"),
+ netip.MustParseAddr("192.168.1.1"),
+ netip.MustParseAddr("172.167.2.3"),
+ netip.MustParseAddr("97.231.252.215"),
+ netip.MustParseAddr("248.97.91.167"),
+ netip.MustParseAddr("188.208.233.47"),
+ netip.MustParseAddr("104.2.183.179"),
+ netip.MustParseAddr("72.129.46.120"),
+ netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
+ netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
+ netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
+ netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
+ netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
+ netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
}
now := time.Now()
diff --git a/tun/netstack/examples/http_client.go b/tun/netstack/examples/http_client.go
index 6ac2859..b39b453 100644
--- a/tun/netstack/examples/http_client.go
+++ b/tun/netstack/examples/http_client.go
@@ -1,4 +1,5 @@
//go:build ignore
+// +build ignore
/* SPDX-License-Identifier: MIT
*
@@ -10,9 +11,9 @@ package main
import (
"io"
"log"
- "net"
"net/http"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
@@ -20,8 +21,8 @@ import (
func main() {
tun, tnet, err := netstack.CreateNetTUN(
- []net.IP{net.ParseIP("192.168.4.29")},
- []net.IP{net.ParseIP("8.8.8.8")},
+ []netip.Addr{netip.MustParseAddr("192.168.4.29")},
+ []netip.Addr{netip.MustParseAddr("8.8.8.8")},
1420)
if err != nil {
log.Panic(err)
diff --git a/tun/netstack/examples/http_server.go b/tun/netstack/examples/http_server.go
index 577c6ea..40f7804 100644
--- a/tun/netstack/examples/http_server.go
+++ b/tun/netstack/examples/http_server.go
@@ -1,4 +1,5 @@
//go:build ignore
+// +build ignore
/* SPDX-License-Identifier: MIT
*
@@ -13,6 +14,7 @@ import (
"net"
"net/http"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
@@ -20,8 +22,8 @@ import (
func main() {
tun, tnet, err := netstack.CreateNetTUN(
- []net.IP{net.ParseIP("192.168.4.29")},
- []net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")},
+ []netip.Addr{netip.MustParseAddr("192.168.4.29")},
+ []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")},
1420,
)
if err != nil {
diff --git a/tun/netstack/go.mod b/tun/netstack/go.mod
index 8db9f4b..46b57ba 100644
--- a/tun/netstack/go.mod
+++ b/tun/netstack/go.mod
@@ -6,6 +6,7 @@ require (
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 // indirect
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect
+ golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22
gvisor.dev/gvisor v0.0.0-20211020211948-f76a604701b6
)
diff --git a/tun/netstack/go.sum b/tun/netstack/go.sum
index 78c025c..01bfbc7 100644
--- a/tun/netstack/go.sum
+++ b/tun/netstack/go.sum
@@ -805,6 +805,10 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5 h1:mV4w4F7AtWXoDNkko9odoTdWpNwyDh8jx+S1fOZKDLg=
+golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
+golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53 h1:nFvpdzrHF9IPo9xPgayHWObCATpQYKky8VSSdt9lf9E=
+golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 h1:ytS28bw9HtZVDRMDxviC6ryCJuccw+zXhh04u2IRWJw=
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22/go.mod h1:a057zjmoc00UN7gVkaJt2sXVK523kMJcogDTEvPIasg=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go
index 24d0835..f1c03f4 100644
--- a/tun/netstack/tun.go
+++ b/tun/netstack/tun.go
@@ -18,6 +18,7 @@ import (
"strings"
"time"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/tun"
"golang.org/x/net/dns/dnsmessage"
@@ -38,7 +39,7 @@ type netTun struct {
events chan tun.Event
incomingPacket chan buffer.VectorisedView
mtu int
- dnsServers []net.IP
+ dnsServers []netip.Addr
hasV4, hasV6 bool
}
type endpoint netTun
@@ -94,7 +95,7 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
}
-func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Net, error) {
+func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
@@ -112,25 +113,23 @@ func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Ne
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
}
for _, ip := range localAddresses {
- if ip4 := ip.To4(); ip4 != nil {
- protoAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.Address(ip4).WithPrefix(),
- }
- tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
- if tcpipErr != nil {
- return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip4, tcpipErr)
- }
+ var protoNumber tcpip.NetworkProtocolNumber
+ if ip.Is4() {
+ protoNumber = ipv4.ProtocolNumber
+ } else if ip.Is6() {
+ protoNumber = ipv6.ProtocolNumber
+ }
+ protoAddr := tcpip.ProtocolAddress{
+ Protocol: protoNumber,
+ AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
+ }
+ tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
+ if tcpipErr != nil {
+ return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
+ }
+ if ip.Is4() {
dev.hasV4 = true
- } else {
- protoAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: tcpip.Address(ip).WithPrefix(),
- }
- tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
- if tcpipErr != nil {
- return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
- }
+ } else if ip.Is6() {
dev.hasV6 = true
}
}
@@ -202,62 +201,83 @@ func (tun *netTun) MTU() (int, error) {
return tun.mtu, nil
}
-func convertToFullAddr(ip net.IP, port int) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
- if ip4 := ip.To4(); ip4 != nil {
- return tcpip.FullAddress{
- NIC: 1,
- Addr: tcpip.Address(ip4),
- Port: uint16(port),
- }, ipv4.ProtocolNumber
+func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
+ var protoNumber tcpip.NetworkProtocolNumber
+ if endpoint.Addr().Is4() {
+ protoNumber = ipv4.ProtocolNumber
} else {
- return tcpip.FullAddress{
- NIC: 1,
- Addr: tcpip.Address(ip),
- Port: uint16(port),
- }, ipv6.ProtocolNumber
+ protoNumber = ipv6.ProtocolNumber
}
+ return tcpip.FullAddress{
+ NIC: 1,
+ Addr: tcpip.Address(endpoint.Addr().AsSlice()),
+ Port: endpoint.Port(),
+ }, protoNumber
+}
+
+func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.DialContextTCP(ctx, net.stack, fa, pn)
}
func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil {
- panic("todo: deal with auto addr semantics for nil addr")
+ return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
}
- fa, pn := convertToFullAddr(addr.IP, addr.Port)
- return gonet.DialContextTCP(ctx, net.stack, fa, pn)
+ return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
+}
+
+func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.DialTCP(net.stack, fa, pn)
}
func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil {
- panic("todo: deal with auto addr semantics for nil addr")
+ return net.DialTCPAddrPort(netip.AddrPort{})
}
- fa, pn := convertToFullAddr(addr.IP, addr.Port)
- return gonet.DialTCP(net.stack, fa, pn)
+ return net.DialTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
+}
+
+func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.ListenTCP(net.stack, fa, pn)
}
func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
if addr == nil {
- panic("todo: deal with auto addr semantics for nil addr")
+ return net.ListenTCPAddrPort(netip.AddrPort{})
}
- fa, pn := convertToFullAddr(addr.IP, addr.Port)
- return gonet.ListenTCP(net.stack, fa, pn)
+ return net.ListenTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
}
-func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
+func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
var lfa, rfa *tcpip.FullAddress
var pn tcpip.NetworkProtocolNumber
- if laddr != nil {
+ if laddr.IsValid() || laddr.Port() > 0 {
var addr tcpip.FullAddress
- addr, pn = convertToFullAddr(laddr.IP, laddr.Port)
+ addr, pn = convertToFullAddr(laddr)
lfa = &addr
}
- if raddr != nil {
+ if raddr.IsValid() || raddr.Port() > 0 {
var addr tcpip.FullAddress
- addr, pn = convertToFullAddr(raddr.IP, raddr.Port)
+ addr, pn = convertToFullAddr(raddr)
rfa = &addr
}
return gonet.DialUDP(net.stack, lfa, rfa, pn)
}
+func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
+ var la, ra netip.AddrPort
+ if laddr != nil {
+ la = netip.AddrPortFrom(netip.AddrFromSlice(laddr.IP), uint16(laddr.Port))
+ }
+ if raddr != nil {
+ ra = netip.AddrPortFrom(netip.AddrFromSlice(raddr.IP), uint16(raddr.Port))
+ }
+ return net.DialUDPAddrPort(la, ra)
+}
+
var (
errNoSuchHost = errors.New("no such host")
errLameReferral = errors.New("lame referral")
@@ -433,7 +453,7 @@ func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []by
return p, h, nil
}
-func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
+func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
q.Class = dnsmessage.ClassINET
id, udpReq, tcpReq, err := newRequest(q)
if err != nil {
@@ -447,9 +467,9 @@ func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Quest
var c net.Conn
var err error
if useUDP {
- c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: server, Port: 53})
+ c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
} else {
- c, err = tnet.DialContextTCP(ctx, &net.TCPAddr{IP: server, Port: 53})
+ c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
}
if err != nil {
@@ -600,8 +620,8 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
zlen = zidx
}
}
- if ip := net.ParseIP(host[:zlen]); ip != nil {
- return []string{host[:zlen]}, nil
+ if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
+ return []string{ip.String()}, nil
}
if !isDomainName(host) {
@@ -612,7 +632,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
server string
error
}
- var addrsV4, addrsV6 []net.IP
+ var addrsV4, addrsV6 []netip.Addr
lanes := 0
if tnet.hasV4 {
lanes++
@@ -667,7 +687,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
break loop
}
- addrsV4 = append(addrsV4, net.IP(a.A[:]))
+ addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
case dnsmessage.TypeAAAA:
aaaa, err := result.p.AAAAResource()
@@ -679,7 +699,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
break loop
}
- addrsV6 = append(addrsV6, net.IP(aaaa.AAAA[:]))
+ addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
default:
if err := result.p.SkipAnswer(); err != nil {
@@ -695,7 +715,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
}
// We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
- var addrs []net.IP
+ var addrs []netip.Addr
if tnet.hasV6 {
addrs = append(addrsV6, addrsV4...)
} else {
@@ -764,12 +784,11 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
if err != nil {
return nil, &net.OpError{Op: "dial", Err: err}
}
- var addrs []net.IP
+ var addrs []netip.AddrPort
for _, addr := range allAddr {
- if strings.IndexByte(addr, ':') != -1 && acceptV6 {
- addrs = append(addrs, net.ParseIP(addr))
- } else if strings.IndexByte(addr, '.') != -1 && acceptV4 {
- addrs = append(addrs, net.ParseIP(addr))
+ ip, err := netip.ParseAddr(addr)
+ if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
+ addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
}
}
if len(addrs) == 0 && len(allAddr) != 0 {
@@ -808,9 +827,9 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
var c net.Conn
if useUDP {
- c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: addr, Port: port})
+ c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
} else {
- c, err = tnet.DialContextTCP(dialCtx, &net.TCPAddr{IP: addr, Port: port})
+ c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
}
if err == nil {
return c, nil
diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go
index d89db71..bdf0467 100644
--- a/tun/tuntest/tuntest.go
+++ b/tun/tuntest/tuntest.go
@@ -8,13 +8,13 @@ package tuntest
import (
"encoding/binary"
"io"
- "net"
"os"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/tun"
)
-func Ping(dst, src net.IP) []byte {
+func Ping(dst, src netip.Addr) []byte {
localPort := uint16(1337)
seq := uint16(0)
@@ -40,7 +40,7 @@ func checksum(buf []byte, initial uint16) uint16 {
return ^uint16(v)
}
-func genICMPv4(payload []byte, dst, src net.IP) []byte {
+func genICMPv4(payload []byte, dst, src netip.Addr) []byte {
const (
icmpv4ProtocolNumber = 1
icmpv4Echo = 8
@@ -70,8 +70,8 @@ func genICMPv4(payload []byte, dst, src net.IP) []byte {
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
ip[8] = ttl
ip[9] = icmpv4ProtocolNumber
- copy(ip[12:], src.To4())
- copy(ip[16:], dst.To4())
+ copy(ip[12:], src.AsSlice())
+ copy(ip[16:], dst.AsSlice())
chksum = ^checksum(ip[:], 0)
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)