aboutsummaryrefslogtreecommitdiffstats
path: root/device
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--device/allowedips.go42
-rw-r--r--device/allowedips_rand_test.go6
-rw-r--r--device/allowedips_test.go6
-rw-r--r--device/device_test.go6
-rw-r--r--device/endpoint_test.go39
-rw-r--r--device/receive.go1
-rw-r--r--device/uapi.go11
7 files changed, 57 insertions, 54 deletions
diff --git a/device/allowedips.go b/device/allowedips.go
index c08399b..7a0b275 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -12,6 +12,8 @@ import (
"net"
"sync"
"unsafe"
+
+ "golang.zx2c4.com/go118/netip"
)
type parentIndirection struct {
@@ -26,7 +28,7 @@ type trieEntry struct {
cidr uint8
bitAtByte uint8
bitAtShift uint8
- bits net.IP
+ bits []byte
perPeerElem *list.Element
}
@@ -51,7 +53,7 @@ func swapU64(i uint64) uint64 {
return bits.ReverseBytes64(i)
}
-func commonBits(ip1 net.IP, ip2 net.IP) uint8 {
+func commonBits(ip1, ip2 []byte) uint8 {
size := len(ip1)
if size == net.IPv4len {
a := (*uint32)(unsafe.Pointer(&ip1[0]))
@@ -85,7 +87,7 @@ func (node *trieEntry) removeFromPeerEntries() {
}
}
-func (node *trieEntry) choose(ip net.IP) byte {
+func (node *trieEntry) choose(ip []byte) byte {
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
}
@@ -104,7 +106,7 @@ func (node *trieEntry) zeroizePointers() {
node.parent.parentBit = nil
}
-func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
+func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
parent = node
if parent.cidr == cidr {
@@ -117,7 +119,7 @@ func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry,
return
}
-func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
+func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
if *trie.parentBit == nil {
node := &trieEntry{
peer: peer,
@@ -207,7 +209,7 @@ func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
}
}
-func (node *trieEntry) lookup(ip net.IP) *Peer {
+func (node *trieEntry) lookup(ip []byte) *Peer {
var found *Peer
size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
@@ -229,13 +231,14 @@ type AllowedIPs struct {
mutex sync.RWMutex
}
-func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) {
+func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
table.mutex.RLock()
defer table.mutex.RUnlock()
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
node := elem.Value.(*trieEntry)
- if !cb(node.bits, node.cidr) {
+ a, _ := netip.AddrFromSlice(node.bits)
+ if !cb(netip.PrefixFrom(a, int(node.cidr))) {
return
}
}
@@ -283,28 +286,29 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
}
}
-func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
+func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
- switch len(ip) {
- case net.IPv6len:
- parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer)
- case net.IPv4len:
- parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer)
- default:
+ if prefix.Addr().Is6() {
+ ip := prefix.Addr().As16()
+ parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+ } else if prefix.Addr().Is4() {
+ ip := prefix.Addr().As4()
+ parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+ } else {
panic(errors.New("inserting unknown address type"))
}
}
-func (table *AllowedIPs) Lookup(address []byte) *Peer {
+func (table *AllowedIPs) Lookup(ip []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
- switch len(address) {
+ switch len(ip) {
case net.IPv6len:
- return table.IPv6.lookup(address)
+ return table.IPv6.lookup(ip)
case net.IPv4len:
- return table.IPv4.lookup(address)
+ return table.IPv4.lookup(ip)
default:
panic(errors.New("looking up unknown address type"))
}
diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go
index 16de170..ff56fe6 100644
--- a/device/allowedips_rand_test.go
+++ b/device/allowedips_rand_test.go
@@ -10,6 +10,8 @@ import (
"net"
"sort"
"testing"
+
+ "golang.zx2c4.com/go118/netip"
)
const (
@@ -93,14 +95,14 @@ func TestTrieRandom(t *testing.T) {
rand.Read(addr4[:])
cidr := uint8(rand.Intn(32) + 1)
index := rand.Intn(NumberOfPeers)
- allowedIPs.Insert(addr4[:], cidr, peers[index])
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
var addr6 [16]byte
rand.Read(addr6[:])
cidr = uint8(rand.Intn(128) + 1)
index = rand.Intn(NumberOfPeers)
- allowedIPs.Insert(addr6[:], cidr, peers[index])
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
}
diff --git a/device/allowedips_test.go b/device/allowedips_test.go
index 2059a88..a274997 100644
--- a/device/allowedips_test.go
+++ b/device/allowedips_test.go
@@ -9,6 +9,8 @@ import (
"math/rand"
"net"
"testing"
+
+ "golang.zx2c4.com/go118/netip"
)
type testPairCommonBits struct {
@@ -98,7 +100,7 @@ func TestTrieIPv4(t *testing.T) {
var allowedIPs AllowedIPs
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
- allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer)
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d byte) {
@@ -208,7 +210,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
- allowedIPs.Insert(addr, cidr, peer)
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d uint32) {
diff --git a/device/device_test.go b/device/device_test.go
index 29daeb9..84221be 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -11,7 +11,6 @@ import (
"fmt"
"io"
"math/rand"
- "net"
"runtime"
"runtime/pprof"
"sync"
@@ -19,6 +18,7 @@ import (
"testing"
"time"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/conn/bindtest"
"golang.zx2c4.com/wireguard/tun/tuntest"
@@ -96,7 +96,7 @@ type testPair [2]testPeer
type testPeer struct {
tun *tuntest.ChannelTUN
dev *Device
- ip net.IP
+ ip netip.Addr
}
type SendDirection bool
@@ -159,7 +159,7 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
for i := range pair {
p := &pair[i]
p.tun = tuntest.NewChannelTUN()
- p.ip = net.IPv4(1, 0, 0, byte(i+1))
+ p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
level := LogLevelVerbose
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
level = LogLevelError
diff --git a/device/endpoint_test.go b/device/endpoint_test.go
index 57c361c..f1ae47e 100644
--- a/device/endpoint_test.go
+++ b/device/endpoint_test.go
@@ -7,47 +7,44 @@ package device
import (
"math/rand"
- "net"
+
+ "golang.zx2c4.com/go118/netip"
)
type DummyEndpoint struct {
- src [16]byte
- dst [16]byte
+ src, dst netip.Addr
}
func CreateDummyEndpoint() (*DummyEndpoint, error) {
- var end DummyEndpoint
- if _, err := rand.Read(end.src[:]); err != nil {
+ var src, dst [16]byte
+ if _, err := rand.Read(src[:]); err != nil {
return nil, err
}
- _, err := rand.Read(end.dst[:])
- return &end, err
+ _, err := rand.Read(dst[:])
+ return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err
}
func (e *DummyEndpoint) ClearSrc() {}
func (e *DummyEndpoint) SrcToString() string {
- var addr net.UDPAddr
- addr.IP = e.SrcIP()
- addr.Port = 1000
- return addr.String()
+ return netip.AddrPortFrom(e.SrcIP(), 1000).String()
}
func (e *DummyEndpoint) DstToString() string {
- var addr net.UDPAddr
- addr.IP = e.DstIP()
- addr.Port = 1000
- return addr.String()
+ return netip.AddrPortFrom(e.DstIP(), 1000).String()
}
-func (e *DummyEndpoint) SrcToBytes() []byte {
- return e.src[:]
+func (e *DummyEndpoint) DstToBytes() []byte {
+ out := e.DstIP().AsSlice()
+ out = append(out, byte(1000&0xff))
+ out = append(out, byte((1000>>8)&0xff))
+ return out
}
-func (e *DummyEndpoint) DstIP() net.IP {
- return e.dst[:]
+func (e *DummyEndpoint) DstIP() netip.Addr {
+ return e.dst
}
-func (e *DummyEndpoint) SrcIP() net.IP {
- return e.src[:]
+func (e *DummyEndpoint) SrcIP() netip.Addr {
+ return e.src
}
diff --git a/device/receive.go b/device/receive.go
index 5857481..cc34498 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -17,7 +17,6 @@ import (
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
-
"golang.zx2c4.com/wireguard/conn"
)
diff --git a/device/uapi.go b/device/uapi.go
index 2306183..98e8311 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -18,6 +18,7 @@ import (
"sync/atomic"
"time"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/ipc"
)
@@ -121,8 +122,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
- device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool {
- sendf("allowed_ip=%s/%d", ip.String(), cidr)
+ device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
+ sendf("allowed_ip=%s", prefix.String())
return true
})
}
@@ -374,16 +375,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
case "allowed_ip":
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
-
- _, network, err := net.ParseCIDR(value)
+ prefix, err := netip.ParsePrefix(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
}
if peer.dummy {
return nil
}
- ones, _ := network.Mask.Size()
- device.allowedips.Insert(network.IP, uint8(ones), peer.Peer)
+ device.allowedips.Insert(prefix, peer.Peer)
case "protocol_version":
if value != "1" {