diff options
Diffstat (limited to 'device/device_test.go')
-rw-r--r-- | device/device_test.go | 85 |
1 files changed, 77 insertions, 8 deletions
diff --git a/device/device_test.go b/device/device_test.go index 29daeb9..fff172b 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -11,7 +11,8 @@ import ( "fmt" "io" "math/rand" - "net" + "net/netip" + "os" "runtime" "runtime/pprof" "sync" @@ -21,6 +22,7 @@ import ( "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn/bindtest" + "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun/tuntest" ) @@ -48,7 +50,7 @@ func uapiCfg(cfg ...string) string { // genConfigs generates a pair of configs that connect to each other. // The configs use distinct, probably-usable ports. -func genConfigs(tb testing.TB) (cfgs [2]string, endpointCfgs [2]string) { +func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { var key1, key2 NoisePrivateKey _, err := rand.Read(key1[:]) if err != nil { @@ -96,7 +98,7 @@ type testPair [2]testPeer type testPeer struct { tun *tuntest.ChannelTUN dev *Device - ip net.IP + ip netip.Addr } type SendDirection bool @@ -159,7 +161,7 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) { for i := range pair { p := &pair[i] p.tun = tuntest.NewChannelTUN() - p.ip = net.IPv4(1, 0, 0, byte(i+1)) + p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)}) level := LogLevelVerbose if _, ok := tb.(*testing.B); ok && !testing.Verbose() { level = LogLevelError @@ -307,6 +309,17 @@ func TestConcurrencySafety(t *testing.T) { } }) + // Perform bind updates and keepalive sends concurrently with tunnel use. + t.Run("bindUpdate and keepalive", func(t *testing.T) { + const iters = 10 + for i := 0; i < iters; i++ { + for _, peer := range pair { + peer.dev.BindUpdate() + peer.dev.SendKeepalivesToPeersWithCurrentKeypair() + } + } + }) + close(done) } @@ -333,7 +346,7 @@ func BenchmarkThroughput(b *testing.B) { // Measure how long it takes to receive b.N packets, // starting when we receive the first packet. - var recv uint64 + var recv atomic.Uint64 var elapsed time.Duration var wg sync.WaitGroup wg.Add(1) @@ -342,7 +355,7 @@ func BenchmarkThroughput(b *testing.B) { var start time.Time for { <-pair[0].tun.Inbound - new := atomic.AddUint64(&recv, 1) + new := recv.Add(1) if new == 1 { start = time.Now() } @@ -358,7 +371,7 @@ func BenchmarkThroughput(b *testing.B) { ping := tuntest.Ping(pair[0].ip, pair[1].ip) pingc := pair[1].tun.Outbound var sent uint64 - for atomic.LoadUint64(&recv) != uint64(b.N) { + for recv.Load() != uint64(b.N) { sent++ pingc <- ping } @@ -405,3 +418,59 @@ func goroutineLeakCheck(t *testing.T) { t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines) }) } + +type fakeBindSized struct { + size int +} + +func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { + return nil, 0, nil +} +func (b *fakeBindSized) Close() error { return nil } +func (b *fakeBindSized) SetMark(mark uint32) error { return nil } +func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil } +func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } +func (b *fakeBindSized) BatchSize() int { return b.size } + +type fakeTUNDeviceSized struct { + size int +} + +func (t *fakeTUNDeviceSized) File() *os.File { return nil } +func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { + return 0, nil +} +func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil } +func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil } +func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil } +func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil } +func (t *fakeTUNDeviceSized) Close() error { return nil } +func (t *fakeTUNDeviceSized) BatchSize() int { return t.size } + +func TestBatchSize(t *testing.T) { + d := Device{} + + d.net.bind = &fakeBindSized{1} + d.tun.device = &fakeTUNDeviceSized{1} + if want, got := 1, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } + + d.net.bind = &fakeBindSized{1} + d.tun.device = &fakeTUNDeviceSized{128} + if want, got := 128, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } + + d.net.bind = &fakeBindSized{128} + d.tun.device = &fakeTUNDeviceSized{1} + if want, got := 128, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } + + d.net.bind = &fakeBindSized{128} + d.tun.device = &fakeTUNDeviceSized{128} + if want, got := 128, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } +} |