diff options
Diffstat (limited to 'tun/tuntest/tuntest.go')
-rw-r--r-- | tun/tuntest/tuntest.go | 65 |
1 files changed, 35 insertions, 30 deletions
diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go index e94c6d8..9c4564f 100644 --- a/tun/tuntest/tuntest.go +++ b/tun/tuntest/tuntest.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package tuntest @@ -8,13 +8,13 @@ package tuntest import ( "encoding/binary" "io" - "net" + "net/netip" "os" "golang.zx2c4.com/wireguard/tun" ) -func Ping(dst, src net.IP) []byte { +func Ping(dst, src netip.Addr) []byte { localPort := uint16(1337) seq := uint16(0) @@ -40,7 +40,7 @@ func checksum(buf []byte, initial uint16) uint16 { return ^uint16(v) } -func genICMPv4(payload []byte, dst, src net.IP) []byte { +func genICMPv4(payload []byte, dst, src netip.Addr) []byte { const ( icmpv4ProtocolNumber = 1 icmpv4Echo = 8 @@ -50,12 +50,13 @@ func genICMPv4(payload []byte, dst, src net.IP) []byte { ipv4TotalLenOffset = 2 ipv4ChecksumOffset = 10 ttl = 65 + headerSize = ipv4Size + icmpv4Size ) - hdr := make([]byte, ipv4Size+icmpv4Size) + pkt := make([]byte, headerSize+len(payload)) - ip := hdr[0:ipv4Size] - icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size] + ip := pkt[0:ipv4Size] + icmpv4 := pkt[ipv4Size : ipv4Size+icmpv4Size] // https://tools.ietf.org/html/rfc792 icmpv4[0] = icmpv4Echo // type @@ -64,23 +65,20 @@ func genICMPv4(payload []byte, dst, src net.IP) []byte { binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum) // https://tools.ietf.org/html/rfc760 section 3.1 - length := uint16(len(hdr) + len(payload)) + length := uint16(len(pkt)) ip[0] = (4 << 4) | (ipv4Size / 4) binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length) ip[8] = ttl ip[9] = icmpv4ProtocolNumber - copy(ip[12:], src.To4()) - copy(ip[16:], dst.To4()) + copy(ip[12:], src.AsSlice()) + copy(ip[16:], dst.AsSlice()) chksum = ^checksum(ip[:], 0) binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum) - var v []byte - v = append(v, hdr...) - v = append(v, payload...) - return []byte(v) + copy(pkt[headerSize:], payload) + return pkt } -// TODO(crawshaw): find a reusable home for this. package devicetest? type ChannelTUN struct { Inbound chan []byte // incoming packets, closed on TUN close Outbound chan []byte // outbound packets, blocks forever on TUN close @@ -112,38 +110,45 @@ type chTun struct { func (t *chTun) File() *os.File { return nil } -func (t *chTun) Read(data []byte, offset int) (int, error) { +func (t *chTun) Read(packets [][]byte, sizes []int, offset int) (int, error) { select { case <-t.c.closed: - return 0, io.EOF // TODO(crawshaw): what is the correct error value? + return 0, os.ErrClosed case msg := <-t.c.Outbound: - return copy(data[offset:], msg), nil + n := copy(packets[0][offset:], msg) + sizes[0] = n + return 1, nil } } // Write is called by the wireguard device to deliver a packet for routing. -func (t *chTun) Write(data []byte, offset int) (int, error) { +func (t *chTun) Write(packets [][]byte, offset int) (int, error) { if offset == -1 { close(t.c.closed) close(t.c.events) return 0, io.EOF } - msg := make([]byte, len(data)-offset) - copy(msg, data[offset:]) - select { - case <-t.c.closed: - return 0, io.EOF // TODO(crawshaw): what is the correct error value? - case t.c.Inbound <- msg: - return len(data) - offset, nil + for i, data := range packets { + msg := make([]byte, len(data)-offset) + copy(msg, data[offset:]) + select { + case <-t.c.closed: + return i, os.ErrClosed + case t.c.Inbound <- msg: + } } + return len(packets), nil +} + +func (t *chTun) BatchSize() int { + return 1 } const DefaultMTU = 1420 -func (t *chTun) Flush() error { return nil } -func (t *chTun) MTU() (int, error) { return DefaultMTU, nil } -func (t *chTun) Name() (string, error) { return "loopbackTun1", nil } -func (t *chTun) Events() chan tun.Event { return t.c.events } +func (t *chTun) MTU() (int, error) { return DefaultMTU, nil } +func (t *chTun) Name() (string, error) { return "loopbackTun1", nil } +func (t *chTun) Events() <-chan tun.Event { return t.c.events } func (t *chTun) Close() error { t.Write(nil, -1) return nil |