diff options
Diffstat (limited to 'device/device_test.go')
-rw-r--r-- | device/device_test.go | 594 |
1 files changed, 416 insertions, 178 deletions
diff --git a/device/device_test.go b/device/device_test.go index 14cc605..fff172b 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -1,238 +1,476 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( - "bufio" "bytes" - "encoding/binary" + "encoding/hex" + "fmt" "io" - "net" + "math/rand" + "net/netip" "os" - "strings" + "runtime" + "runtime/pprof" + "sync" + "sync/atomic" "testing" "time" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/conn/bindtest" "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/tun/tuntest" ) -func TestTwoDevicePing(t *testing.T) { - // TODO(crawshaw): pick unused ports on localhost - cfg1 := `private_key=481eb0d8113a4a5da532d2c3e9c14b53c8454b34ab109676f6b58c2245e37b58 -listen_port=53511 -replace_peers=true -public_key=f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725 -protocol_version=1 -replace_allowed_ips=true -allowed_ip=1.0.0.2/32 -endpoint=127.0.0.1:53512` - tun1 := NewChannelTUN() - dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: ")) - dev1.Up() - defer dev1.Close() - if err := dev1.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg1))); err != nil { - t.Fatal(err) - } - - cfg2 := `private_key=98c7989b1661a0d64fd6af3502000f87716b7c4bbcf00d04fc6073aa7b539768 -listen_port=53512 -replace_peers=true -public_key=49e80929259cebdda4f322d6d2b1a6fad819d603acd26fd5d845e7a123036427 -protocol_version=1 -replace_allowed_ips=true -allowed_ip=1.0.0.1/32 -endpoint=127.0.0.1:53511` - tun2 := NewChannelTUN() - dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: ")) - dev2.Up() - defer dev2.Close() - if err := dev2.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg2))); err != nil { - t.Fatal(err) +// uapiCfg returns a string that contains cfg formatted use with IpcSet. +// cfg is a series of alternating key/value strings. +// uapiCfg exists because editors and humans like to insert +// whitespace into configs, which can cause failures, some of which are silent. +// For example, a leading blank newline causes the remainder +// of the config to be silently ignored. +func uapiCfg(cfg ...string) string { + if len(cfg)%2 != 0 { + panic("odd number of args to uapiReader") } - - t.Run("ping 1.0.0.1", func(t *testing.T) { - msg2to1 := ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2")) - tun2.Outbound <- msg2to1 - select { - case msgRecv := <-tun1.Inbound: - if !bytes.Equal(msg2to1, msgRecv) { - t.Error("ping did not transit correctly") - } - case <-time.After(300 * time.Millisecond): - t.Error("ping did not transit") + buf := new(bytes.Buffer) + for i, s := range cfg { + buf.WriteString(s) + sep := byte('\n') + if i%2 == 0 { + sep = '=' } - }) + buf.WriteByte(sep) + } + return buf.String() +} - t.Run("ping 1.0.0.2", func(t *testing.T) { - msg1to2 := ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1")) - tun1.Outbound <- msg1to2 - select { - case msgRecv := <-tun2.Inbound: - if !bytes.Equal(msg1to2, msgRecv) { - t.Error("return ping did not transit correctly") - } - case <-time.After(300 * time.Millisecond): - t.Error("return ping did not transit") - } - }) +// genConfigs generates a pair of configs that connect to each other. +// The configs use distinct, probably-usable ports. +func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { + var key1, key2 NoisePrivateKey + _, err := rand.Read(key1[:]) + if err != nil { + tb.Errorf("unable to generate private key random bytes: %v", err) + } + _, err = rand.Read(key2[:]) + if err != nil { + tb.Errorf("unable to generate private key random bytes: %v", err) + } + pub1, pub2 := key1.publicKey(), key2.publicKey() + + cfgs[0] = uapiCfg( + "private_key", hex.EncodeToString(key1[:]), + "listen_port", "0", + "replace_peers", "true", + "public_key", hex.EncodeToString(pub2[:]), + "protocol_version", "1", + "replace_allowed_ips", "true", + "allowed_ip", "1.0.0.2/32", + ) + endpointCfgs[0] = uapiCfg( + "public_key", hex.EncodeToString(pub2[:]), + "endpoint", "127.0.0.1:%d", + ) + cfgs[1] = uapiCfg( + "private_key", hex.EncodeToString(key2[:]), + "listen_port", "0", + "replace_peers", "true", + "public_key", hex.EncodeToString(pub1[:]), + "protocol_version", "1", + "replace_allowed_ips", "true", + "allowed_ip", "1.0.0.1/32", + ) + endpointCfgs[1] = uapiCfg( + "public_key", hex.EncodeToString(pub1[:]), + "endpoint", "127.0.0.1:%d", + ) + return +} + +// A testPair is a pair of testPeers. +type testPair [2]testPeer + +// A testPeer is a peer used for testing. +type testPeer struct { + tun *tuntest.ChannelTUN + dev *Device + ip netip.Addr } -func ping(dst, src net.IP) []byte { - localPort := uint16(1337) - seq := uint16(0) +type SendDirection bool - payload := make([]byte, 4) - binary.BigEndian.PutUint16(payload[0:], localPort) - binary.BigEndian.PutUint16(payload[2:], seq) +const ( + Ping SendDirection = true + Pong SendDirection = false +) - return genICMPv4(payload, dst, src) +func (d SendDirection) String() string { + if d == Ping { + return "ping" + } + return "pong" } -// checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071. -func checksum(buf []byte, initial uint16) uint16 { - v := uint32(initial) - for i := 0; i < len(buf)-1; i += 2 { - v += uint32(binary.BigEndian.Uint16(buf[i:])) +func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}) { + tb.Helper() + p0, p1 := pair[0], pair[1] + if !ping { + // pong is the new ping + p0, p1 = p1, p0 } - if len(buf)%2 == 1 { - v += uint32(buf[len(buf)-1]) << 8 + msg := tuntest.Ping(p0.ip, p1.ip) + p1.tun.Outbound <- msg + timer := time.NewTimer(5 * time.Second) + defer timer.Stop() + var err error + select { + case msgRecv := <-p0.tun.Inbound: + if !bytes.Equal(msg, msgRecv) { + err = fmt.Errorf("%s did not transit correctly", ping) + } + case <-timer.C: + err = fmt.Errorf("%s did not transit", ping) + case <-done: } - for v > 0xffff { - v = (v >> 16) + (v & 0xffff) + if err != nil { + // The error may have occurred because the test is done. + select { + case <-done: + return + default: + } + // Real error. + tb.Error(err) } - return ^uint16(v) } -func genICMPv4(payload []byte, dst, src net.IP) []byte { - const ( - icmpv4ProtocolNumber = 1 - icmpv4Echo = 8 - icmpv4ChecksumOffset = 2 - icmpv4Size = 8 - ipv4Size = 20 - ipv4TotalLenOffset = 2 - ipv4ChecksumOffset = 10 - ttl = 65 - ) +// genTestPair creates a testPair. +func genTestPair(tb testing.TB, realSocket bool) (pair testPair) { + cfg, endpointCfg := genConfigs(tb) + var binds [2]conn.Bind + if realSocket { + binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind() + } else { + binds = bindtest.NewChannelBinds() + } + // Bring up a ChannelTun for each config. + for i := range pair { + p := &pair[i] + p.tun = tuntest.NewChannelTUN() + p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)}) + level := LogLevelVerbose + if _, ok := tb.(*testing.B); ok && !testing.Verbose() { + level = LogLevelError + } + p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i))) + if err := p.dev.IpcSet(cfg[i]); err != nil { + tb.Errorf("failed to configure device %d: %v", i, err) + p.dev.Close() + continue + } + if err := p.dev.Up(); err != nil { + tb.Errorf("failed to bring up device %d: %v", i, err) + p.dev.Close() + continue + } + endpointCfg[i^1] = fmt.Sprintf(endpointCfg[i^1], p.dev.net.port) + } + for i := range pair { + p := &pair[i] + if err := p.dev.IpcSet(endpointCfg[i]); err != nil { + tb.Errorf("failed to configure device endpoint %d: %v", i, err) + p.dev.Close() + continue + } + // The device is ready. Close it when the test completes. + tb.Cleanup(p.dev.Close) + } + return +} + +func TestTwoDevicePing(t *testing.T) { + goroutineLeakCheck(t) + pair := genTestPair(t, true) + t.Run("ping 1.0.0.1", func(t *testing.T) { + pair.Send(t, Ping, nil) + }) + t.Run("ping 1.0.0.2", func(t *testing.T) { + pair.Send(t, Pong, nil) + }) +} + +func TestUpDown(t *testing.T) { + goroutineLeakCheck(t) + const itrials = 50 + const otrials = 10 + + for n := 0; n < otrials; n++ { + pair := genTestPair(t, false) + for i := range pair { + for k := range pair[i].dev.peers.keyMap { + pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:]))) + } + } + var wg sync.WaitGroup + wg.Add(len(pair)) + for i := range pair { + go func(d *Device) { + defer wg.Done() + for i := 0; i < itrials; i++ { + if err := d.Up(); err != nil { + t.Errorf("failed up bring up device: %v", err) + } + time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) + if err := d.Down(); err != nil { + t.Errorf("failed to bring down device: %v", err) + } + time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) + } + }(pair[i].dev) + } + wg.Wait() + for i := range pair { + pair[i].dev.Up() + pair[i].dev.Close() + } + } +} - hdr := make([]byte, ipv4Size+icmpv4Size) +// TestConcurrencySafety does other things concurrently with tunnel use. +// It is intended to be used with the race detector to catch data races. +func TestConcurrencySafety(t *testing.T) { + pair := genTestPair(t, true) + done := make(chan struct{}) - ip := hdr[0:ipv4Size] - icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size] + const warmupIters = 10 + var warmup sync.WaitGroup + warmup.Add(warmupIters) + go func() { + // Send data continuously back and forth until we're done. + // Note that we may continue to attempt to send data + // even after done is closed. + i := warmupIters + for ping := Ping; ; ping = !ping { + pair.Send(t, ping, done) + select { + case <-done: + return + default: + } + if i > 0 { + warmup.Done() + i-- + } + } + }() + warmup.Wait() - // https://tools.ietf.org/html/rfc792 - icmpv4[0] = icmpv4Echo // type - icmpv4[1] = 0 // code - chksum := ^checksum(icmpv4, checksum(payload, 0)) - binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum) + applyCfg := func(cfg string) { + err := pair[0].dev.IpcSet(cfg) + if err != nil { + t.Fatal(err) + } + } - // https://tools.ietf.org/html/rfc760 section 3.1 - length := uint16(len(hdr) + len(payload)) - ip[0] = (4 << 4) | (ipv4Size / 4) - binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length) - ip[8] = ttl - ip[9] = icmpv4ProtocolNumber - copy(ip[12:], src.To4()) - copy(ip[16:], dst.To4()) - chksum = ^checksum(ip[:], 0) - binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum) + // Change persistent_keepalive_interval concurrently with tunnel use. + t.Run("persistentKeepaliveInterval", func(t *testing.T) { + var pub NoisePublicKey + for key := range pair[0].dev.peers.keyMap { + pub = key + break + } + cfg := uapiCfg( + "public_key", hex.EncodeToString(pub[:]), + "persistent_keepalive_interval", "1", + ) + for i := 0; i < 1000; i++ { + applyCfg(cfg) + } + }) - var v []byte - v = append(v, hdr...) - v = append(v, payload...) - return []byte(v) -} + // Change private keys concurrently with tunnel use. + t.Run("privateKey", func(t *testing.T) { + bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777") + good := uapiCfg("private_key", hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:])) + // Set iters to a large number like 1000 to flush out data races quickly. + // Don't leave it large. That can cause logical races + // in which the handshake is interleaved with key changes + // such that the private key appears to be unchanging but + // other state gets reset, which can cause handshake failures like + // "Received packet with invalid mac1". + const iters = 1 + for i := 0; i < iters; i++ { + applyCfg(bad) + applyCfg(good) + } + }) -// TODO(crawshaw): find a reusable home for this. package devicetest? -type ChannelTUN struct { - Inbound chan []byte // incoming packets, closed on TUN close - Outbound chan []byte // outbound packets, blocks forever on TUN close + // Perform bind updates and keepalive sends concurrently with tunnel use. + t.Run("bindUpdate and keepalive", func(t *testing.T) { + const iters = 10 + for i := 0; i < iters; i++ { + for _, peer := range pair { + peer.dev.BindUpdate() + peer.dev.SendKeepalivesToPeersWithCurrentKeypair() + } + } + }) - closed chan struct{} - events chan tun.Event - tun chTun + close(done) } -func NewChannelTUN() *ChannelTUN { - c := &ChannelTUN{ - Inbound: make(chan []byte), - Outbound: make(chan []byte), - closed: make(chan struct{}), - events: make(chan tun.Event, 1), +func BenchmarkLatency(b *testing.B) { + pair := genTestPair(b, true) + + // Establish a connection. + pair.Send(b, Ping, nil) + pair.Send(b, Pong, nil) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + pair.Send(b, Ping, nil) + pair.Send(b, Pong, nil) } - c.tun.c = c - c.events <- tun.EventUp - return c } -func (c *ChannelTUN) TUN() tun.Device { - return &c.tun -} +func BenchmarkThroughput(b *testing.B) { + pair := genTestPair(b, true) -type chTun struct { - c *ChannelTUN -} + // Establish a connection. + pair.Send(b, Ping, nil) + pair.Send(b, Pong, nil) -func (t *chTun) File() *os.File { return nil } + // Measure how long it takes to receive b.N packets, + // starting when we receive the first packet. + var recv atomic.Uint64 + var elapsed time.Duration + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + var start time.Time + for { + <-pair[0].tun.Inbound + new := recv.Add(1) + if new == 1 { + start = time.Now() + } + // Careful! Don't change this to else if; b.N can be equal to 1. + if new == uint64(b.N) { + elapsed = time.Since(start) + return + } + } + }() -func (t *chTun) Read(data []byte, offset int) (int, error) { - select { - case <-t.c.closed: - return 0, io.EOF // TODO(crawshaw): what is the correct error value? - case msg := <-t.c.Outbound: - return copy(data[offset:], msg), nil + // Send packets as fast as we can until we've received enough. + ping := tuntest.Ping(pair[0].ip, pair[1].ip) + pingc := pair[1].tun.Outbound + var sent uint64 + for recv.Load() != uint64(b.N) { + sent++ + pingc <- ping } + wg.Wait() + + b.ReportMetric(float64(elapsed)/float64(b.N), "ns/op") + b.ReportMetric(1-float64(b.N)/float64(sent), "packet-loss") } -// Write is called by the wireguard device to deliver a packet for routing. -func (t *chTun) Write(data []byte, offset int) (int, error) { - if offset == -1 { - close(t.c.closed) - close(t.c.events) - return 0, io.EOF +func BenchmarkUAPIGet(b *testing.B) { + pair := genTestPair(b, true) + pair.Send(b, Ping, nil) + pair.Send(b, Pong, nil) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + pair[0].dev.IpcGetOperation(io.Discard) } - msg := make([]byte, len(data)-offset) - copy(msg, data[offset:]) - select { - case <-t.c.closed: - return 0, io.EOF // TODO(crawshaw): what is the correct error value? - case t.c.Inbound <- msg: - return len(data) - offset, nil +} + +func goroutineLeakCheck(t *testing.T) { + goroutines := func() (int, []byte) { + p := pprof.Lookup("goroutine") + b := new(bytes.Buffer) + p.WriteTo(b, 1) + return p.Count(), b.Bytes() } + + startGoroutines, startStacks := goroutines() + t.Cleanup(func() { + if t.Failed() { + return + } + // Give goroutines time to exit, if they need it. + for i := 0; i < 10000; i++ { + if runtime.NumGoroutine() <= startGoroutines { + return + } + time.Sleep(1 * time.Millisecond) + } + endGoroutines, endStacks := goroutines() + t.Logf("starting stacks:\n%s\n", startStacks) + t.Logf("ending stacks:\n%s\n", endStacks) + t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines) + }) } -func (t *chTun) Flush() error { return nil } -func (t *chTun) MTU() (int, error) { return DefaultMTU, nil } -func (t *chTun) Name() (string, error) { return "loopbackTun1", nil } -func (t *chTun) Events() chan tun.Event { return t.c.events } -func (t *chTun) Close() error { - t.Write(nil, -1) - return nil +type fakeBindSized struct { + size int } -func assertNil(t *testing.T, err error) { - if err != nil { - t.Fatal(err) - } +func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { + return nil, 0, nil } +func (b *fakeBindSized) Close() error { return nil } +func (b *fakeBindSized) SetMark(mark uint32) error { return nil } +func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil } +func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } +func (b *fakeBindSized) BatchSize() int { return b.size } -func assertEqual(t *testing.T, a, b []byte) { - if !bytes.Equal(a, b) { - t.Fatal(a, "!=", b) - } +type fakeTUNDeviceSized struct { + size int } -func randDevice(t *testing.T) *Device { - sk, err := newPrivateKey() - if err != nil { - t.Fatal(err) +func (t *fakeTUNDeviceSized) File() *os.File { return nil } +func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { + return 0, nil +} +func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil } +func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil } +func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil } +func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil } +func (t *fakeTUNDeviceSized) Close() error { return nil } +func (t *fakeTUNDeviceSized) BatchSize() int { return t.size } + +func TestBatchSize(t *testing.T) { + d := Device{} + + d.net.bind = &fakeBindSized{1} + d.tun.device = &fakeTUNDeviceSized{1} + if want, got := 1, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } + + d.net.bind = &fakeBindSized{1} + d.tun.device = &fakeTUNDeviceSized{128} + if want, got := 128, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } + + d.net.bind = &fakeBindSized{128} + d.tun.device = &fakeTUNDeviceSized{1} + if want, got := 128, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } + + d.net.bind = &fakeBindSized{128} + d.tun.device = &fakeTUNDeviceSized{128} + if want, got := 128, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) } - tun := newDummyTUN("dummy") - logger := NewLogger(LogLevelError, "") - device := NewDevice(tun, logger) - device.SetPrivateKey(sk) - return device } |