aboutsummaryrefslogtreecommitdiffstats
path: root/conn/bindtest/bindtest.go
diff options
context:
space:
mode:
Diffstat (limited to 'conn/bindtest/bindtest.go')
-rw-r--r--conn/bindtest/bindtest.go61
1 files changed, 33 insertions, 28 deletions
diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go
index 7d43fb3..74e7add 100644
--- a/conn/bindtest/bindtest.go
+++ b/conn/bindtest/bindtest.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bindtest
@@ -9,8 +9,8 @@ import (
"fmt"
"math/rand"
"net"
+ "net/netip"
"os"
- "strconv"
"golang.zx2c4.com/wireguard/conn"
)
@@ -25,8 +25,10 @@ type ChannelBind struct {
type ChannelEndpoint uint16
-var _ conn.Bind = (*ChannelBind)(nil)
-var _ conn.Endpoint = (*ChannelEndpoint)(nil)
+var (
+ _ conn.Bind = (*ChannelBind)(nil)
+ _ conn.Endpoint = (*ChannelEndpoint)(nil)
+)
func NewChannelBinds() [2]conn.Bind {
arx4 := make(chan []byte, 8192)
@@ -61,9 +63,9 @@ func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d
func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
-func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
+func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
-func (c ChannelEndpoint) SrcIP() net.IP { return nil }
+func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
c.closeSignal = make(chan bool)
@@ -87,45 +89,48 @@ func (c *ChannelBind) Close() error {
return nil
}
+func (c *ChannelBind) BatchSize() int { return 1 }
+
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
- return func(b []byte) (n int, ep conn.Endpoint, err error) {
+ return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
select {
case <-c.closeSignal:
- return 0, nil, net.ErrClosed
+ return 0, net.ErrClosed
case rx := <-ch:
- return copy(b, rx), c.target6, nil
+ copied := copy(bufs[0], rx)
+ sizes[0] = copied
+ eps[0] = c.target6
+ return 1, nil
}
}
}
-func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
- select {
- case <-c.closeSignal:
- return net.ErrClosed
- default:
- bc := make([]byte, len(b))
- copy(bc, b)
- if ep.(ChannelEndpoint) == c.target4 {
- *c.tx4 <- bc
- } else if ep.(ChannelEndpoint) == c.target6 {
- *c.tx6 <- bc
- } else {
- return os.ErrInvalid
+func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error {
+ for _, b := range bufs {
+ select {
+ case <-c.closeSignal:
+ return net.ErrClosed
+ default:
+ bc := make([]byte, len(b))
+ copy(bc, b)
+ if ep.(ChannelEndpoint) == c.target4 {
+ *c.tx4 <- bc
+ } else if ep.(ChannelEndpoint) == c.target6 {
+ *c.tx6 <- bc
+ } else {
+ return os.ErrInvalid
+ }
}
}
return nil
}
func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
- _, port, err := net.SplitHostPort(s)
- if err != nil {
- return nil, err
- }
- i, err := strconv.ParseUint(port, 10, 16)
+ addr, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
- return ChannelEndpoint(i), nil
+ return ChannelEndpoint(addr.Port()), nil
}