diff options
Diffstat (limited to 'conn/bindtest/bindtest.go')
-rw-r--r-- | conn/bindtest/bindtest.go | 61 |
1 files changed, 33 insertions, 28 deletions
diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go index 7d43fb3..74e7add 100644 --- a/conn/bindtest/bindtest.go +++ b/conn/bindtest/bindtest.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package bindtest @@ -9,8 +9,8 @@ import ( "fmt" "math/rand" "net" + "net/netip" "os" - "strconv" "golang.zx2c4.com/wireguard/conn" ) @@ -25,8 +25,10 @@ type ChannelBind struct { type ChannelEndpoint uint16 -var _ conn.Bind = (*ChannelBind)(nil) -var _ conn.Endpoint = (*ChannelEndpoint)(nil) +var ( + _ conn.Bind = (*ChannelBind)(nil) + _ conn.Endpoint = (*ChannelEndpoint)(nil) +) func NewChannelBinds() [2]conn.Bind { arx4 := make(chan []byte, 8192) @@ -61,9 +63,9 @@ func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} } -func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) } +func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } -func (c ChannelEndpoint) SrcIP() net.IP { return nil } +func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} } func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { c.closeSignal = make(chan bool) @@ -87,45 +89,48 @@ func (c *ChannelBind) Close() error { return nil } +func (c *ChannelBind) BatchSize() int { return 1 } + func (c *ChannelBind) SetMark(mark uint32) error { return nil } func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { - return func(b []byte) (n int, ep conn.Endpoint, err error) { + return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { select { case <-c.closeSignal: - return 0, nil, net.ErrClosed + return 0, net.ErrClosed case rx := <-ch: - return copy(b, rx), c.target6, nil + copied := copy(bufs[0], rx) + sizes[0] = copied + eps[0] = c.target6 + return 1, nil } } } -func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error { - select { - case <-c.closeSignal: - return net.ErrClosed - default: - bc := make([]byte, len(b)) - copy(bc, b) - if ep.(ChannelEndpoint) == c.target4 { - *c.tx4 <- bc - } else if ep.(ChannelEndpoint) == c.target6 { - *c.tx6 <- bc - } else { - return os.ErrInvalid +func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error { + for _, b := range bufs { + select { + case <-c.closeSignal: + return net.ErrClosed + default: + bc := make([]byte, len(b)) + copy(bc, b) + if ep.(ChannelEndpoint) == c.target4 { + *c.tx4 <- bc + } else if ep.(ChannelEndpoint) == c.target6 { + *c.tx6 <- bc + } else { + return os.ErrInvalid + } } } return nil } func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) { - _, port, err := net.SplitHostPort(s) - if err != nil { - return nil, err - } - i, err := strconv.ParseUint(port, 10, 16) + addr, err := netip.ParseAddrPort(s) if err != nil { return nil, err } - return ChannelEndpoint(i), nil + return ChannelEndpoint(addr.Port()), nil } |