aboutsummaryrefslogtreecommitdiffstats
path: root/device/device_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/device_test.go')
-rw-r--r--device/device_test.go594
1 files changed, 416 insertions, 178 deletions
diff --git a/device/device_test.go b/device/device_test.go
index 14cc605..fff172b 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -1,238 +1,476 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
- "bufio"
"bytes"
- "encoding/binary"
+ "encoding/hex"
+ "fmt"
"io"
- "net"
+ "math/rand"
+ "net/netip"
"os"
- "strings"
+ "runtime"
+ "runtime/pprof"
+ "sync"
+ "sync/atomic"
"testing"
"time"
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/conn/bindtest"
"golang.zx2c4.com/wireguard/tun"
+ "golang.zx2c4.com/wireguard/tun/tuntest"
)
-func TestTwoDevicePing(t *testing.T) {
- // TODO(crawshaw): pick unused ports on localhost
- cfg1 := `private_key=481eb0d8113a4a5da532d2c3e9c14b53c8454b34ab109676f6b58c2245e37b58
-listen_port=53511
-replace_peers=true
-public_key=f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725
-protocol_version=1
-replace_allowed_ips=true
-allowed_ip=1.0.0.2/32
-endpoint=127.0.0.1:53512`
- tun1 := NewChannelTUN()
- dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: "))
- dev1.Up()
- defer dev1.Close()
- if err := dev1.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg1))); err != nil {
- t.Fatal(err)
- }
-
- cfg2 := `private_key=98c7989b1661a0d64fd6af3502000f87716b7c4bbcf00d04fc6073aa7b539768
-listen_port=53512
-replace_peers=true
-public_key=49e80929259cebdda4f322d6d2b1a6fad819d603acd26fd5d845e7a123036427
-protocol_version=1
-replace_allowed_ips=true
-allowed_ip=1.0.0.1/32
-endpoint=127.0.0.1:53511`
- tun2 := NewChannelTUN()
- dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: "))
- dev2.Up()
- defer dev2.Close()
- if err := dev2.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg2))); err != nil {
- t.Fatal(err)
+// uapiCfg returns a string that contains cfg formatted use with IpcSet.
+// cfg is a series of alternating key/value strings.
+// uapiCfg exists because editors and humans like to insert
+// whitespace into configs, which can cause failures, some of which are silent.
+// For example, a leading blank newline causes the remainder
+// of the config to be silently ignored.
+func uapiCfg(cfg ...string) string {
+ if len(cfg)%2 != 0 {
+ panic("odd number of args to uapiReader")
}
-
- t.Run("ping 1.0.0.1", func(t *testing.T) {
- msg2to1 := ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
- tun2.Outbound <- msg2to1
- select {
- case msgRecv := <-tun1.Inbound:
- if !bytes.Equal(msg2to1, msgRecv) {
- t.Error("ping did not transit correctly")
- }
- case <-time.After(300 * time.Millisecond):
- t.Error("ping did not transit")
+ buf := new(bytes.Buffer)
+ for i, s := range cfg {
+ buf.WriteString(s)
+ sep := byte('\n')
+ if i%2 == 0 {
+ sep = '='
}
- })
+ buf.WriteByte(sep)
+ }
+ return buf.String()
+}
- t.Run("ping 1.0.0.2", func(t *testing.T) {
- msg1to2 := ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
- tun1.Outbound <- msg1to2
- select {
- case msgRecv := <-tun2.Inbound:
- if !bytes.Equal(msg1to2, msgRecv) {
- t.Error("return ping did not transit correctly")
- }
- case <-time.After(300 * time.Millisecond):
- t.Error("return ping did not transit")
- }
- })
+// genConfigs generates a pair of configs that connect to each other.
+// The configs use distinct, probably-usable ports.
+func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
+ var key1, key2 NoisePrivateKey
+ _, err := rand.Read(key1[:])
+ if err != nil {
+ tb.Errorf("unable to generate private key random bytes: %v", err)
+ }
+ _, err = rand.Read(key2[:])
+ if err != nil {
+ tb.Errorf("unable to generate private key random bytes: %v", err)
+ }
+ pub1, pub2 := key1.publicKey(), key2.publicKey()
+
+ cfgs[0] = uapiCfg(
+ "private_key", hex.EncodeToString(key1[:]),
+ "listen_port", "0",
+ "replace_peers", "true",
+ "public_key", hex.EncodeToString(pub2[:]),
+ "protocol_version", "1",
+ "replace_allowed_ips", "true",
+ "allowed_ip", "1.0.0.2/32",
+ )
+ endpointCfgs[0] = uapiCfg(
+ "public_key", hex.EncodeToString(pub2[:]),
+ "endpoint", "127.0.0.1:%d",
+ )
+ cfgs[1] = uapiCfg(
+ "private_key", hex.EncodeToString(key2[:]),
+ "listen_port", "0",
+ "replace_peers", "true",
+ "public_key", hex.EncodeToString(pub1[:]),
+ "protocol_version", "1",
+ "replace_allowed_ips", "true",
+ "allowed_ip", "1.0.0.1/32",
+ )
+ endpointCfgs[1] = uapiCfg(
+ "public_key", hex.EncodeToString(pub1[:]),
+ "endpoint", "127.0.0.1:%d",
+ )
+ return
+}
+
+// A testPair is a pair of testPeers.
+type testPair [2]testPeer
+
+// A testPeer is a peer used for testing.
+type testPeer struct {
+ tun *tuntest.ChannelTUN
+ dev *Device
+ ip netip.Addr
}
-func ping(dst, src net.IP) []byte {
- localPort := uint16(1337)
- seq := uint16(0)
+type SendDirection bool
- payload := make([]byte, 4)
- binary.BigEndian.PutUint16(payload[0:], localPort)
- binary.BigEndian.PutUint16(payload[2:], seq)
+const (
+ Ping SendDirection = true
+ Pong SendDirection = false
+)
- return genICMPv4(payload, dst, src)
+func (d SendDirection) String() string {
+ if d == Ping {
+ return "ping"
+ }
+ return "pong"
}
-// checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
-func checksum(buf []byte, initial uint16) uint16 {
- v := uint32(initial)
- for i := 0; i < len(buf)-1; i += 2 {
- v += uint32(binary.BigEndian.Uint16(buf[i:]))
+func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}) {
+ tb.Helper()
+ p0, p1 := pair[0], pair[1]
+ if !ping {
+ // pong is the new ping
+ p0, p1 = p1, p0
}
- if len(buf)%2 == 1 {
- v += uint32(buf[len(buf)-1]) << 8
+ msg := tuntest.Ping(p0.ip, p1.ip)
+ p1.tun.Outbound <- msg
+ timer := time.NewTimer(5 * time.Second)
+ defer timer.Stop()
+ var err error
+ select {
+ case msgRecv := <-p0.tun.Inbound:
+ if !bytes.Equal(msg, msgRecv) {
+ err = fmt.Errorf("%s did not transit correctly", ping)
+ }
+ case <-timer.C:
+ err = fmt.Errorf("%s did not transit", ping)
+ case <-done:
}
- for v > 0xffff {
- v = (v >> 16) + (v & 0xffff)
+ if err != nil {
+ // The error may have occurred because the test is done.
+ select {
+ case <-done:
+ return
+ default:
+ }
+ // Real error.
+ tb.Error(err)
}
- return ^uint16(v)
}
-func genICMPv4(payload []byte, dst, src net.IP) []byte {
- const (
- icmpv4ProtocolNumber = 1
- icmpv4Echo = 8
- icmpv4ChecksumOffset = 2
- icmpv4Size = 8
- ipv4Size = 20
- ipv4TotalLenOffset = 2
- ipv4ChecksumOffset = 10
- ttl = 65
- )
+// genTestPair creates a testPair.
+func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
+ cfg, endpointCfg := genConfigs(tb)
+ var binds [2]conn.Bind
+ if realSocket {
+ binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
+ } else {
+ binds = bindtest.NewChannelBinds()
+ }
+ // Bring up a ChannelTun for each config.
+ for i := range pair {
+ p := &pair[i]
+ p.tun = tuntest.NewChannelTUN()
+ p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
+ level := LogLevelVerbose
+ if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
+ level = LogLevelError
+ }
+ p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)))
+ if err := p.dev.IpcSet(cfg[i]); err != nil {
+ tb.Errorf("failed to configure device %d: %v", i, err)
+ p.dev.Close()
+ continue
+ }
+ if err := p.dev.Up(); err != nil {
+ tb.Errorf("failed to bring up device %d: %v", i, err)
+ p.dev.Close()
+ continue
+ }
+ endpointCfg[i^1] = fmt.Sprintf(endpointCfg[i^1], p.dev.net.port)
+ }
+ for i := range pair {
+ p := &pair[i]
+ if err := p.dev.IpcSet(endpointCfg[i]); err != nil {
+ tb.Errorf("failed to configure device endpoint %d: %v", i, err)
+ p.dev.Close()
+ continue
+ }
+ // The device is ready. Close it when the test completes.
+ tb.Cleanup(p.dev.Close)
+ }
+ return
+}
+
+func TestTwoDevicePing(t *testing.T) {
+ goroutineLeakCheck(t)
+ pair := genTestPair(t, true)
+ t.Run("ping 1.0.0.1", func(t *testing.T) {
+ pair.Send(t, Ping, nil)
+ })
+ t.Run("ping 1.0.0.2", func(t *testing.T) {
+ pair.Send(t, Pong, nil)
+ })
+}
+
+func TestUpDown(t *testing.T) {
+ goroutineLeakCheck(t)
+ const itrials = 50
+ const otrials = 10
+
+ for n := 0; n < otrials; n++ {
+ pair := genTestPair(t, false)
+ for i := range pair {
+ for k := range pair[i].dev.peers.keyMap {
+ pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
+ }
+ }
+ var wg sync.WaitGroup
+ wg.Add(len(pair))
+ for i := range pair {
+ go func(d *Device) {
+ defer wg.Done()
+ for i := 0; i < itrials; i++ {
+ if err := d.Up(); err != nil {
+ t.Errorf("failed up bring up device: %v", err)
+ }
+ time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
+ if err := d.Down(); err != nil {
+ t.Errorf("failed to bring down device: %v", err)
+ }
+ time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
+ }
+ }(pair[i].dev)
+ }
+ wg.Wait()
+ for i := range pair {
+ pair[i].dev.Up()
+ pair[i].dev.Close()
+ }
+ }
+}
- hdr := make([]byte, ipv4Size+icmpv4Size)
+// TestConcurrencySafety does other things concurrently with tunnel use.
+// It is intended to be used with the race detector to catch data races.
+func TestConcurrencySafety(t *testing.T) {
+ pair := genTestPair(t, true)
+ done := make(chan struct{})
- ip := hdr[0:ipv4Size]
- icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
+ const warmupIters = 10
+ var warmup sync.WaitGroup
+ warmup.Add(warmupIters)
+ go func() {
+ // Send data continuously back and forth until we're done.
+ // Note that we may continue to attempt to send data
+ // even after done is closed.
+ i := warmupIters
+ for ping := Ping; ; ping = !ping {
+ pair.Send(t, ping, done)
+ select {
+ case <-done:
+ return
+ default:
+ }
+ if i > 0 {
+ warmup.Done()
+ i--
+ }
+ }
+ }()
+ warmup.Wait()
- // https://tools.ietf.org/html/rfc792
- icmpv4[0] = icmpv4Echo // type
- icmpv4[1] = 0 // code
- chksum := ^checksum(icmpv4, checksum(payload, 0))
- binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
+ applyCfg := func(cfg string) {
+ err := pair[0].dev.IpcSet(cfg)
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
- // https://tools.ietf.org/html/rfc760 section 3.1
- length := uint16(len(hdr) + len(payload))
- ip[0] = (4 << 4) | (ipv4Size / 4)
- binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
- ip[8] = ttl
- ip[9] = icmpv4ProtocolNumber
- copy(ip[12:], src.To4())
- copy(ip[16:], dst.To4())
- chksum = ^checksum(ip[:], 0)
- binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
+ // Change persistent_keepalive_interval concurrently with tunnel use.
+ t.Run("persistentKeepaliveInterval", func(t *testing.T) {
+ var pub NoisePublicKey
+ for key := range pair[0].dev.peers.keyMap {
+ pub = key
+ break
+ }
+ cfg := uapiCfg(
+ "public_key", hex.EncodeToString(pub[:]),
+ "persistent_keepalive_interval", "1",
+ )
+ for i := 0; i < 1000; i++ {
+ applyCfg(cfg)
+ }
+ })
- var v []byte
- v = append(v, hdr...)
- v = append(v, payload...)
- return []byte(v)
-}
+ // Change private keys concurrently with tunnel use.
+ t.Run("privateKey", func(t *testing.T) {
+ bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777")
+ good := uapiCfg("private_key", hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:]))
+ // Set iters to a large number like 1000 to flush out data races quickly.
+ // Don't leave it large. That can cause logical races
+ // in which the handshake is interleaved with key changes
+ // such that the private key appears to be unchanging but
+ // other state gets reset, which can cause handshake failures like
+ // "Received packet with invalid mac1".
+ const iters = 1
+ for i := 0; i < iters; i++ {
+ applyCfg(bad)
+ applyCfg(good)
+ }
+ })
-// TODO(crawshaw): find a reusable home for this. package devicetest?
-type ChannelTUN struct {
- Inbound chan []byte // incoming packets, closed on TUN close
- Outbound chan []byte // outbound packets, blocks forever on TUN close
+ // Perform bind updates and keepalive sends concurrently with tunnel use.
+ t.Run("bindUpdate and keepalive", func(t *testing.T) {
+ const iters = 10
+ for i := 0; i < iters; i++ {
+ for _, peer := range pair {
+ peer.dev.BindUpdate()
+ peer.dev.SendKeepalivesToPeersWithCurrentKeypair()
+ }
+ }
+ })
- closed chan struct{}
- events chan tun.Event
- tun chTun
+ close(done)
}
-func NewChannelTUN() *ChannelTUN {
- c := &ChannelTUN{
- Inbound: make(chan []byte),
- Outbound: make(chan []byte),
- closed: make(chan struct{}),
- events: make(chan tun.Event, 1),
+func BenchmarkLatency(b *testing.B) {
+ pair := genTestPair(b, true)
+
+ // Establish a connection.
+ pair.Send(b, Ping, nil)
+ pair.Send(b, Pong, nil)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ pair.Send(b, Ping, nil)
+ pair.Send(b, Pong, nil)
}
- c.tun.c = c
- c.events <- tun.EventUp
- return c
}
-func (c *ChannelTUN) TUN() tun.Device {
- return &c.tun
-}
+func BenchmarkThroughput(b *testing.B) {
+ pair := genTestPair(b, true)
-type chTun struct {
- c *ChannelTUN
-}
+ // Establish a connection.
+ pair.Send(b, Ping, nil)
+ pair.Send(b, Pong, nil)
-func (t *chTun) File() *os.File { return nil }
+ // Measure how long it takes to receive b.N packets,
+ // starting when we receive the first packet.
+ var recv atomic.Uint64
+ var elapsed time.Duration
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ var start time.Time
+ for {
+ <-pair[0].tun.Inbound
+ new := recv.Add(1)
+ if new == 1 {
+ start = time.Now()
+ }
+ // Careful! Don't change this to else if; b.N can be equal to 1.
+ if new == uint64(b.N) {
+ elapsed = time.Since(start)
+ return
+ }
+ }
+ }()
-func (t *chTun) Read(data []byte, offset int) (int, error) {
- select {
- case <-t.c.closed:
- return 0, io.EOF // TODO(crawshaw): what is the correct error value?
- case msg := <-t.c.Outbound:
- return copy(data[offset:], msg), nil
+ // Send packets as fast as we can until we've received enough.
+ ping := tuntest.Ping(pair[0].ip, pair[1].ip)
+ pingc := pair[1].tun.Outbound
+ var sent uint64
+ for recv.Load() != uint64(b.N) {
+ sent++
+ pingc <- ping
}
+ wg.Wait()
+
+ b.ReportMetric(float64(elapsed)/float64(b.N), "ns/op")
+ b.ReportMetric(1-float64(b.N)/float64(sent), "packet-loss")
}
-// Write is called by the wireguard device to deliver a packet for routing.
-func (t *chTun) Write(data []byte, offset int) (int, error) {
- if offset == -1 {
- close(t.c.closed)
- close(t.c.events)
- return 0, io.EOF
+func BenchmarkUAPIGet(b *testing.B) {
+ pair := genTestPair(b, true)
+ pair.Send(b, Ping, nil)
+ pair.Send(b, Pong, nil)
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ pair[0].dev.IpcGetOperation(io.Discard)
}
- msg := make([]byte, len(data)-offset)
- copy(msg, data[offset:])
- select {
- case <-t.c.closed:
- return 0, io.EOF // TODO(crawshaw): what is the correct error value?
- case t.c.Inbound <- msg:
- return len(data) - offset, nil
+}
+
+func goroutineLeakCheck(t *testing.T) {
+ goroutines := func() (int, []byte) {
+ p := pprof.Lookup("goroutine")
+ b := new(bytes.Buffer)
+ p.WriteTo(b, 1)
+ return p.Count(), b.Bytes()
}
+
+ startGoroutines, startStacks := goroutines()
+ t.Cleanup(func() {
+ if t.Failed() {
+ return
+ }
+ // Give goroutines time to exit, if they need it.
+ for i := 0; i < 10000; i++ {
+ if runtime.NumGoroutine() <= startGoroutines {
+ return
+ }
+ time.Sleep(1 * time.Millisecond)
+ }
+ endGoroutines, endStacks := goroutines()
+ t.Logf("starting stacks:\n%s\n", startStacks)
+ t.Logf("ending stacks:\n%s\n", endStacks)
+ t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
+ })
}
-func (t *chTun) Flush() error { return nil }
-func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
-func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
-func (t *chTun) Events() chan tun.Event { return t.c.events }
-func (t *chTun) Close() error {
- t.Write(nil, -1)
- return nil
+type fakeBindSized struct {
+ size int
}
-func assertNil(t *testing.T, err error) {
- if err != nil {
- t.Fatal(err)
- }
+func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
+ return nil, 0, nil
}
+func (b *fakeBindSized) Close() error { return nil }
+func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
+func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
+func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
+func (b *fakeBindSized) BatchSize() int { return b.size }
-func assertEqual(t *testing.T, a, b []byte) {
- if !bytes.Equal(a, b) {
- t.Fatal(a, "!=", b)
- }
+type fakeTUNDeviceSized struct {
+ size int
}
-func randDevice(t *testing.T) *Device {
- sk, err := newPrivateKey()
- if err != nil {
- t.Fatal(err)
+func (t *fakeTUNDeviceSized) File() *os.File { return nil }
+func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
+ return 0, nil
+}
+func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
+func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
+func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
+func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
+func (t *fakeTUNDeviceSized) Close() error { return nil }
+func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
+
+func TestBatchSize(t *testing.T) {
+ d := Device{}
+
+ d.net.bind = &fakeBindSized{1}
+ d.tun.device = &fakeTUNDeviceSized{1}
+ if want, got := 1, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{1}
+ d.tun.device = &fakeTUNDeviceSized{128}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{128}
+ d.tun.device = &fakeTUNDeviceSized{1}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{128}
+ d.tun.device = &fakeTUNDeviceSized{128}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
}
- tun := newDummyTUN("dummy")
- logger := NewLogger(LogLevelError, "")
- device := NewDevice(tun, logger)
- device.SetPrivateKey(sk)
- return device
}