aboutsummaryrefslogtreecommitdiffstats
path: root/conn
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-02-22 04:30:31 +0100
committerJason A. Donenfeld <Jason@zx2c4.com>2021-02-23 20:00:57 +0100
commit9a29ae267cc4573f88f5d9871e2aa53ea201e873 (patch)
tree178ac0efd5722db32eb2651d6e3cbeb58bd1e828 /conn
parentdevice: cleanup unused test components (diff)
downloadwireguard-go-9a29ae267cc4573f88f5d9871e2aa53ea201e873.tar.xz
wireguard-go-9a29ae267cc4573f88f5d9871e2aa53ea201e873.zip
device: test up/down using virtual conn
This prevents port clashing bugs. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'conn')
-rw-r--r--conn/bindtest/bindtest.go136
1 files changed, 136 insertions, 0 deletions
diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go
new file mode 100644
index 0000000..ad8fa05
--- /dev/null
+++ b/conn/bindtest/bindtest.go
@@ -0,0 +1,136 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package bindtest
+
+import (
+ "fmt"
+ "math/rand"
+ "net"
+ "os"
+ "strconv"
+
+ "golang.zx2c4.com/wireguard/conn"
+)
+
+type ChannelBind struct {
+ rx4, tx4 *chan []byte
+ rx6, tx6 *chan []byte
+ closeSignal chan bool
+ source4, source6 ChannelEndpoint
+ target4, target6 ChannelEndpoint
+}
+
+type ChannelEndpoint uint16
+
+var _ conn.Bind = (*ChannelBind)(nil)
+var _ conn.Endpoint = (*ChannelEndpoint)(nil)
+
+func NewChannelBinds() [2]conn.Bind {
+ arx4 := make(chan []byte, 8192)
+ brx4 := make(chan []byte, 8192)
+ arx6 := make(chan []byte, 8192)
+ brx6 := make(chan []byte, 8192)
+ var binds [2]ChannelBind
+ binds[0].rx4 = &arx4
+ binds[0].tx4 = &brx4
+ binds[1].rx4 = &brx4
+ binds[1].tx4 = &arx4
+ binds[0].rx6 = &arx6
+ binds[0].tx6 = &brx6
+ binds[1].rx6 = &brx6
+ binds[1].tx6 = &arx6
+ binds[0].target4 = ChannelEndpoint(1)
+ binds[1].target4 = ChannelEndpoint(2)
+ binds[0].target6 = ChannelEndpoint(3)
+ binds[1].target6 = ChannelEndpoint(4)
+ binds[0].source4 = binds[1].target4
+ binds[0].source6 = binds[1].target6
+ binds[1].source4 = binds[0].target4
+ binds[1].source6 = binds[0].target6
+ return [2]conn.Bind{&binds[0], &binds[1]}
+}
+
+func (c ChannelEndpoint) ClearSrc() {}
+
+func (c ChannelEndpoint) SrcToString() string { return "" }
+
+func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) }
+
+func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
+
+func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
+
+func (c ChannelEndpoint) SrcIP() net.IP { return nil }
+
+func (c *ChannelBind) Open(port uint16) (actualPort uint16, err error) {
+ c.closeSignal = make(chan bool)
+ if rand.Uint32()&1 == 0 {
+ return uint16(c.source4), nil
+ } else {
+ return uint16(c.source6), nil
+ }
+}
+
+func (c *ChannelBind) Close() error {
+ if c.closeSignal != nil {
+ select {
+ case <-c.closeSignal:
+ default:
+ close(c.closeSignal)
+ }
+ }
+ return nil
+}
+
+func (c *ChannelBind) SetMark(mark uint32) error { return nil }
+
+func (c *ChannelBind) ReceiveIPv6(b []byte) (n int, ep conn.Endpoint, err error) {
+ select {
+ case <-c.closeSignal:
+ return 0, nil, net.ErrClosed
+ case rx := <-*c.rx6:
+ return copy(b, rx), c.target6, nil
+ }
+}
+
+func (c *ChannelBind) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
+ select {
+ case <-c.closeSignal:
+ return 0, nil, net.ErrClosed
+ case rx := <-*c.rx4:
+ return copy(b, rx), c.target4, nil
+ }
+}
+
+func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
+ select {
+ case <-c.closeSignal:
+ return net.ErrClosed
+ default:
+ bc := make([]byte, len(b))
+ copy(bc, b)
+ if ep.(ChannelEndpoint) == c.target4 {
+ *c.tx4 <- bc
+ } else if ep.(ChannelEndpoint) == c.target6 {
+ *c.tx6 <- bc
+ } else {
+ return os.ErrInvalid
+ }
+ }
+ return nil
+}
+
+func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
+ _, port, err := net.SplitHostPort(s)
+ if err != nil {
+ return nil, err
+ }
+ i, err := strconv.ParseUint(port, 10, 16)
+ if err != nil {
+ return nil, err
+ }
+ return ChannelEndpoint(i), nil
+}