aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJordan Whited <jordan@tailscale.com>2023-03-06 15:58:32 -0800
committerJason A. Donenfeld <Jason@zx2c4.com>2023-03-10 14:52:24 +0100
commit2fcdaf979915be4702bf8aba4a90ac3c3ae0796b (patch)
tree5850a80bb1c6e49fdcd36f6d29e0b15f3a27eb0e
parentconn: inch BatchSize toward being non-dynamic (diff)
downloadwireguard-go-2fcdaf979915be4702bf8aba4a90ac3c3ae0796b.tar.xz
wireguard-go-2fcdaf979915be4702bf8aba4a90ac3c3ae0796b.zip
conn: fix StdNetBind fallback on Windows
If RIO is unavailable, NewWinRingBind() falls back to StdNetBind. StdNetBind uses x/net/ipv{4,6}.PacketConn for sending and receiving datagrams, specifically via the {Read,Write}Batch methods. These methods are unimplemented on Windows and will return runtime errors as a result. Additionally, only Linux benefits from these x/net types for reading and writing, so we update StdNetBind to fall back to the standard library net package for all platforms other than Linux. Reviewed-by: James Tucker <james@tailscale.com> Signed-off-by: Jordan Whited <jordan@tailscale.com> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r--conn/bind_std.go192
-rw-r--r--conn/bind_std_test.go22
2 files changed, 150 insertions, 64 deletions
diff --git a/conn/bind_std.go b/conn/bind_std.go
index b9da4c3..a842b12 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -10,6 +10,7 @@ import (
"errors"
"net"
"net/netip"
+ "runtime"
"strconv"
"sync"
"syscall"
@@ -22,16 +23,21 @@ var (
_ Bind = (*StdNetBind)(nil)
)
-// StdNetBind implements Bind for all platforms except Windows.
+// StdNetBind implements Bind for all platforms. While Windows has its own Bind
+// (see bind_windows.go), it may fall back to StdNetBind.
+// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
+// methods for sending and receiving multiple datagrams per-syscall. See the
+// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
type StdNetBind struct {
- mu sync.Mutex // protects following fields
- ipv4 *net.UDPConn
- ipv6 *net.UDPConn
- blackhole4 bool
- blackhole6 bool
- ipv4PC *ipv4.PacketConn
- ipv6PC *ipv6.PacketConn
- udpAddrPool sync.Pool
+ mu sync.Mutex // protects following fields
+ ipv4 *net.UDPConn
+ ipv6 *net.UDPConn
+ blackhole4 bool
+ blackhole6 bool
+ ipv4PC *ipv4.PacketConn // will be nil on non-Linux
+ ipv6PC *ipv6.PacketConn // will be nil on non-Linux
+
+ udpAddrPool sync.Pool // following fields are not guarded by mu
ipv4MsgsPool sync.Pool
ipv6MsgsPool sync.Pool
}
@@ -154,6 +160,8 @@ func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
again:
port := int(uport)
var v4conn, v6conn *net.UDPConn
+ var v4pc *ipv4.PacketConn
+ var v6pc *ipv6.PacketConn
v4conn, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
@@ -173,63 +181,92 @@ again:
}
var fns []ReceiveFunc
if v4conn != nil {
- fns = append(fns, s.receiveIPv4)
+ if runtime.GOOS == "linux" {
+ v4pc = ipv4.NewPacketConn(v4conn)
+ s.ipv4PC = v4pc
+ }
+ fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
s.ipv4 = v4conn
}
if v6conn != nil {
- fns = append(fns, s.receiveIPv6)
+ if runtime.GOOS == "linux" {
+ v6pc = ipv6.NewPacketConn(v6conn)
+ s.ipv6PC = v6pc
+ }
+ fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
s.ipv6 = v6conn
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
- s.ipv4PC = ipv4.NewPacketConn(s.ipv4)
- s.ipv6PC = ipv6.NewPacketConn(s.ipv6)
-
return fns, uint16(port), nil
}
-func (s *StdNetBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
- msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
- defer s.ipv4MsgsPool.Put(msgs)
- for i := range buffs {
- (*msgs)[i].Buffers[0] = buffs[i]
- }
- numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0)
- if err != nil {
- return 0, err
- }
- for i := 0; i < numMsgs; i++ {
- msg := &(*msgs)[i]
- sizes[i] = msg.N
- addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
- ep := asEndpoint(addrPort)
- getSrcFromControl(msg.OOB, ep)
- eps[i] = ep
+func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
+ return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
+ defer s.ipv4MsgsPool.Put(msgs)
+ for i := range buffs {
+ (*msgs)[i].Buffers[0] = buffs[i]
+ }
+ var numMsgs int
+ if runtime.GOOS == "linux" {
+ numMsgs, err = pc.ReadBatch(*msgs, 0)
+ if err != nil {
+ return 0, err
+ }
+ } else {
+ msg := &(*msgs)[0]
+ msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
+ if err != nil {
+ return 0, err
+ }
+ numMsgs = 1
+ }
+ for i := 0; i < numMsgs; i++ {
+ msg := &(*msgs)[i]
+ sizes[i] = msg.N
+ addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+ ep := asEndpoint(addrPort)
+ getSrcFromControl(msg.OOB, ep)
+ eps[i] = ep
+ }
+ return numMsgs, nil
}
- return numMsgs, nil
}
-func (s *StdNetBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
- msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
- defer s.ipv6MsgsPool.Put(msgs)
- for i := range buffs {
- (*msgs)[i].Buffers[0] = buffs[i]
- }
- numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0)
- if err != nil {
- return 0, err
- }
- for i := 0; i < numMsgs; i++ {
- msg := &(*msgs)[i]
- sizes[i] = msg.N
- addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
- ep := asEndpoint(addrPort)
- getSrcFromControl(msg.OOB, ep)
- eps[i] = ep
+func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
+ return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ msgs := s.ipv4MsgsPool.Get().(*[]ipv6.Message)
+ defer s.ipv4MsgsPool.Put(msgs)
+ for i := range buffs {
+ (*msgs)[i].Buffers[0] = buffs[i]
+ }
+ var numMsgs int
+ if runtime.GOOS == "linux" {
+ numMsgs, err = pc.ReadBatch(*msgs, 0)
+ if err != nil {
+ return 0, err
+ }
+ } else {
+ msg := &(*msgs)[0]
+ msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
+ if err != nil {
+ return 0, err
+ }
+ numMsgs = 1
+ }
+ for i := 0; i < numMsgs; i++ {
+ msg := &(*msgs)[i]
+ sizes[i] = msg.N
+ addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+ ep := asEndpoint(addrPort)
+ getSrcFromControl(msg.OOB, ep)
+ eps[i] = ep
+ }
+ return numMsgs, nil
}
- return numMsgs, nil
}
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
@@ -246,10 +283,12 @@ func (s *StdNetBind) Close() error {
if s.ipv4 != nil {
err1 = s.ipv4.Close()
s.ipv4 = nil
+ s.ipv4PC = nil
}
if s.ipv6 != nil {
err2 = s.ipv6.Close()
s.ipv6 = nil
+ s.ipv6PC = nil
}
s.blackhole4 = false
s.blackhole6 = false
@@ -263,11 +302,18 @@ func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
s.mu.Lock()
blackhole := s.blackhole4
conn := s.ipv4
+ var (
+ pc4 *ipv4.PacketConn
+ pc6 *ipv6.PacketConn
+ )
is6 := false
if endpoint.DstIP().Is6() {
blackhole = s.blackhole6
conn = s.ipv6
+ pc6 = s.ipv6PC
is6 = true
+ } else {
+ pc4 = s.ipv4PC
}
s.mu.Unlock()
@@ -278,13 +324,13 @@ func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
return syscall.EAFNOSUPPORT
}
if is6 {
- return s.send6(s.ipv6PC, endpoint, buffs)
+ return s.send6(conn, pc6, endpoint, buffs)
} else {
- return s.send4(s.ipv4PC, endpoint, buffs)
+ return s.send4(conn, pc4, endpoint, buffs)
}
}
-func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error {
+func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as4 := ep.DstIP().As4()
copy(ua.IP, as4[:])
@@ -301,19 +347,28 @@ func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) e
err error
start int
)
- for {
- n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
- if err != nil || n == len((*msgs)[start:len(buffs)]) {
- break
+ if runtime.GOOS == "linux" {
+ for {
+ n, err = pc.WriteBatch((*msgs)[start:len(buffs)], 0)
+ if err != nil || n == len((*msgs)[start:len(buffs)]) {
+ break
+ }
+ start += n
+ }
+ } else {
+ for i, buff := range buffs {
+ _, _, err = conn.WriteMsgUDP(buff, (*msgs)[i].OOB, ua)
+ if err != nil {
+ break
+ }
}
- start += n
}
s.udpAddrPool.Put(ua)
s.ipv4MsgsPool.Put(msgs)
return err
}
-func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error {
+func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as16 := ep.DstIP().As16()
copy(ua.IP, as16[:])
@@ -330,12 +385,21 @@ func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) e
err error
start int
)
- for {
- n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
- if err != nil || n == len((*msgs)[start:len(buffs)]) {
- break
+ if runtime.GOOS == "linux" {
+ for {
+ n, err = pc.WriteBatch((*msgs)[start:len(buffs)], 0)
+ if err != nil || n == len((*msgs)[start:len(buffs)]) {
+ break
+ }
+ start += n
+ }
+ } else {
+ for i, buff := range buffs {
+ _, _, err = conn.WriteMsgUDP(buff, (*msgs)[i].OOB, ua)
+ if err != nil {
+ break
+ }
}
- start += n
}
s.udpAddrPool.Put(ua)
s.ipv6MsgsPool.Put(msgs)
diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go
new file mode 100644
index 0000000..76afa30
--- /dev/null
+++ b/conn/bind_std_test.go
@@ -0,0 +1,22 @@
+package conn
+
+import "testing"
+
+func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
+ bind := NewStdNetBind().(*StdNetBind)
+ fns, _, err := bind.Open(0)
+ if err != nil {
+ t.Fatal(err)
+ }
+ bind.Close()
+ buffs := make([][]byte, 1)
+ buffs[0] = make([]byte, 1)
+ sizes := make([]int, 1)
+ eps := make([]Endpoint, 1)
+ for _, fn := range fns {
+ // The ReceiveFuncs must not access conn-related fields on StdNetBind
+ // unguarded. Close() nils the conn-related fields resulting in a panic
+ // if they violate the mutex.
+ fn(buffs, sizes, eps)
+ }
+}