aboutsummaryrefslogtreecommitdiffstats
path: root/conn/bind_std.go
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--conn/bind_std.go522
1 files changed, 427 insertions, 95 deletions
diff --git a/conn/bind_std.go b/conn/bind_std.go
index b6a7ab3..46df7fd 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -1,69 +1,126 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
+ "context"
"errors"
+ "fmt"
"net"
"net/netip"
+ "runtime"
+ "strconv"
"sync"
"syscall"
+
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
)
-// StdNetBind is meant to be a temporary solution on platforms for which
-// the sticky socket / source caching behavior has not yet been implemented.
-// It uses the Go's net package to implement networking.
-// See LinuxSocketBind for a proper implementation on the Linux platform.
+var (
+ _ Bind = (*StdNetBind)(nil)
+)
+
+// StdNetBind implements Bind for all platforms. While Windows has its own Bind
+// (see bind_windows.go), it may fall back to StdNetBind.
+// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
+// methods for sending and receiving multiple datagrams per-syscall. See the
+// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
type StdNetBind struct {
- mu sync.Mutex // protects following fields
- ipv4 *net.UDPConn
- ipv6 *net.UDPConn
+ mu sync.Mutex // protects all fields except as specified
+ ipv4 *net.UDPConn
+ ipv6 *net.UDPConn
+ ipv4PC *ipv4.PacketConn // will be nil on non-Linux
+ ipv6PC *ipv6.PacketConn // will be nil on non-Linux
+ ipv4TxOffload bool
+ ipv4RxOffload bool
+ ipv6TxOffload bool
+ ipv6RxOffload bool
+
+ // these two fields are not guarded by mu
+ udpAddrPool sync.Pool
+ msgsPool sync.Pool
+
blackhole4 bool
blackhole6 bool
}
-func NewStdNetBind() Bind { return &StdNetBind{} }
+func NewStdNetBind() Bind {
+ return &StdNetBind{
+ udpAddrPool: sync.Pool{
+ New: func() any {
+ return &net.UDPAddr{
+ IP: make([]byte, 16),
+ }
+ },
+ },
-type StdNetEndpoint netip.AddrPort
+ msgsPool: sync.Pool{
+ New: func() any {
+ // ipv6.Message and ipv4.Message are interchangeable as they are
+ // both aliases for x/net/internal/socket.Message.
+ msgs := make([]ipv6.Message, IdealBatchSize)
+ for i := range msgs {
+ msgs[i].Buffers = make(net.Buffers, 1)
+ msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
+ }
+ return &msgs
+ },
+ },
+ }
+}
+
+type StdNetEndpoint struct {
+ // AddrPort is the endpoint destination.
+ netip.AddrPort
+ // src is the current sticky source address and interface index, if
+ // supported. Typically this is a PKTINFO structure from/for control
+ // messages, see unix.PKTINFO for an example.
+ src []byte
+}
var (
_ Bind = (*StdNetBind)(nil)
- _ Endpoint = StdNetEndpoint{}
+ _ Endpoint = &StdNetEndpoint{}
)
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
e, err := netip.ParseAddrPort(s)
- return asEndpoint(e), err
+ if err != nil {
+ return nil, err
+ }
+ return &StdNetEndpoint{
+ AddrPort: e,
+ }, nil
}
-func (StdNetEndpoint) ClearSrc() {}
-
-func (e StdNetEndpoint) DstIP() netip.Addr {
- return (netip.AddrPort)(e).Addr()
+func (e *StdNetEndpoint) ClearSrc() {
+ if e.src != nil {
+ // Truncate src, no need to reallocate.
+ e.src = e.src[:0]
+ }
}
-func (e StdNetEndpoint) SrcIP() netip.Addr {
- return netip.Addr{} // not supported
+func (e *StdNetEndpoint) DstIP() netip.Addr {
+ return e.AddrPort.Addr()
}
-func (e StdNetEndpoint) DstToBytes() []byte {
- b, _ := (netip.AddrPort)(e).MarshalBinary()
- return b
-}
+// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
-func (e StdNetEndpoint) DstToString() string {
- return (netip.AddrPort)(e).String()
+func (e *StdNetEndpoint) DstToBytes() []byte {
+ b, _ := e.AddrPort.MarshalBinary()
+ return b
}
-func (e StdNetEndpoint) SrcToString() string {
- return ""
+func (e *StdNetEndpoint) DstToString() string {
+ return e.AddrPort.String()
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
- conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
+ conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
if err != nil {
return nil, 0, err
}
@@ -77,17 +134,17 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
if err != nil {
return nil, 0, err
}
- return conn, uaddr.Port, nil
+ return conn.(*net.UDPConn), uaddr.Port, nil
}
-func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
- bind.mu.Lock()
- defer bind.mu.Unlock()
+func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
var err error
var tries int
- if bind.ipv4 != nil || bind.ipv6 != nil {
+ if s.ipv4 != nil || s.ipv6 != nil {
return nil, 0, ErrBindAlreadyOpen
}
@@ -95,90 +152,207 @@ func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
// If uport is 0, we can retry on failure.
again:
port := int(uport)
- var ipv4, ipv6 *net.UDPConn
+ var v4conn, v6conn *net.UDPConn
+ var v4pc *ipv4.PacketConn
+ var v6pc *ipv6.PacketConn
- ipv4, port, err = listenNet("udp4", port)
+ v4conn, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
// Listen on the same port as we're using for ipv4.
- ipv6, port, err = listenNet("udp6", port)
+ v6conn, port, err = listenNet("udp6", port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
- ipv4.Close()
+ v4conn.Close()
tries++
goto again
}
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
- ipv4.Close()
+ v4conn.Close()
return nil, 0, err
}
var fns []ReceiveFunc
- if ipv4 != nil {
- fns = append(fns, bind.makeReceiveIPv4(ipv4))
- bind.ipv4 = ipv4
+ if v4conn != nil {
+ s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ v4pc = ipv4.NewPacketConn(v4conn)
+ s.ipv4PC = v4pc
+ }
+ fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
+ s.ipv4 = v4conn
}
- if ipv6 != nil {
- fns = append(fns, bind.makeReceiveIPv6(ipv6))
- bind.ipv6 = ipv6
+ if v6conn != nil {
+ s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ v6pc = ipv6.NewPacketConn(v6conn)
+ s.ipv6PC = v6pc
+ }
+ fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
+ s.ipv6 = v6conn
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
+
return fns, uint16(port), nil
}
-func (bind *StdNetBind) Close() error {
- bind.mu.Lock()
- defer bind.mu.Unlock()
+func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
+ for i := range *msgs {
+ (*msgs)[i].OOB = (*msgs)[i].OOB[:0]
+ (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
+ }
+ s.msgsPool.Put(msgs)
+}
+
+func (s *StdNetBind) getMessages() *[]ipv6.Message {
+ return s.msgsPool.Get().(*[]ipv6.Message)
+}
+
+var (
+ // If compilation fails here these are no longer the same underlying type.
+ _ ipv6.Message = ipv4.Message{}
+)
+
+type batchReader interface {
+ ReadBatch([]ipv6.Message, int) (int, error)
+}
+
+type batchWriter interface {
+ WriteBatch([]ipv6.Message, int) (int, error)
+}
+
+func (s *StdNetBind) receiveIP(
+ br batchReader,
+ conn *net.UDPConn,
+ rxOffload bool,
+ bufs [][]byte,
+ sizes []int,
+ eps []Endpoint,
+) (n int, err error) {
+ msgs := s.getMessages()
+ for i := range bufs {
+ (*msgs)[i].Buffers[0] = bufs[i]
+ (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
+ }
+ defer s.putMessages(msgs)
+ var numMsgs int
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ if rxOffload {
+ readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
+ numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
+ if err != nil {
+ return 0, err
+ }
+ numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
+ if err != nil {
+ return 0, err
+ }
+ } else {
+ numMsgs, err = br.ReadBatch(*msgs, 0)
+ if err != nil {
+ return 0, err
+ }
+ }
+ } else {
+ msg := &(*msgs)[0]
+ msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
+ if err != nil {
+ return 0, err
+ }
+ numMsgs = 1
+ }
+ for i := 0; i < numMsgs; i++ {
+ msg := &(*msgs)[i]
+ sizes[i] = msg.N
+ if sizes[i] == 0 {
+ continue
+ }
+ addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+ ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
+ getSrcFromControl(msg.OOB[:msg.NN], ep)
+ eps[i] = ep
+ }
+ return numMsgs, nil
+}
+
+func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
+ return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
+ }
+}
+
+func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
+ return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
+ }
+}
+
+// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
+// rename the IdealBatchSize constant to BatchSize.
+func (s *StdNetBind) BatchSize() int {
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ return IdealBatchSize
+ }
+ return 1
+}
+
+func (s *StdNetBind) Close() error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
var err1, err2 error
- if bind.ipv4 != nil {
- err1 = bind.ipv4.Close()
- bind.ipv4 = nil
+ if s.ipv4 != nil {
+ err1 = s.ipv4.Close()
+ s.ipv4 = nil
+ s.ipv4PC = nil
}
- if bind.ipv6 != nil {
- err2 = bind.ipv6.Close()
- bind.ipv6 = nil
+ if s.ipv6 != nil {
+ err2 = s.ipv6.Close()
+ s.ipv6 = nil
+ s.ipv6PC = nil
}
- bind.blackhole4 = false
- bind.blackhole6 = false
+ s.blackhole4 = false
+ s.blackhole6 = false
+ s.ipv4TxOffload = false
+ s.ipv4RxOffload = false
+ s.ipv6TxOffload = false
+ s.ipv6RxOffload = false
if err1 != nil {
return err1
}
return err2
}
-func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
- return func(buff []byte) (int, Endpoint, error) {
- n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
- return n, asEndpoint(endpoint), err
- }
+type ErrUDPGSODisabled struct {
+ onLaddr string
+ RetryErr error
}
-func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
- return func(buff []byte) (int, Endpoint, error) {
- n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
- return n, asEndpoint(endpoint), err
- }
+func (e ErrUDPGSODisabled) Error() string {
+ return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
}
-func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
- var err error
- nend, ok := endpoint.(StdNetEndpoint)
- if !ok {
- return ErrWrongEndpointType
- }
- addrPort := netip.AddrPort(nend)
+func (e ErrUDPGSODisabled) Unwrap() error {
+ return e.RetryErr
+}
- bind.mu.Lock()
- blackhole := bind.blackhole4
- conn := bind.ipv4
- if addrPort.Addr().Is6() {
- blackhole = bind.blackhole6
- conn = bind.ipv6
+func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
+ s.mu.Lock()
+ blackhole := s.blackhole4
+ conn := s.ipv4
+ offload := s.ipv4TxOffload
+ br := batchWriter(s.ipv4PC)
+ is6 := false
+ if endpoint.DstIP().Is6() {
+ blackhole = s.blackhole6
+ conn = s.ipv6
+ br = s.ipv6PC
+ is6 = true
+ offload = s.ipv6TxOffload
}
- bind.mu.Unlock()
+ s.mu.Unlock()
if blackhole {
return nil
@@ -186,27 +360,185 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
if conn == nil {
return syscall.EAFNOSUPPORT
}
- _, err = conn.WriteToUDPAddrPort(buff, addrPort)
+
+ msgs := s.getMessages()
+ defer s.putMessages(msgs)
+ ua := s.udpAddrPool.Get().(*net.UDPAddr)
+ defer s.udpAddrPool.Put(ua)
+ if is6 {
+ as16 := endpoint.DstIP().As16()
+ copy(ua.IP, as16[:])
+ ua.IP = ua.IP[:16]
+ } else {
+ as4 := endpoint.DstIP().As4()
+ copy(ua.IP, as4[:])
+ ua.IP = ua.IP[:4]
+ }
+ ua.Port = int(endpoint.(*StdNetEndpoint).Port())
+ var (
+ retried bool
+ err error
+ )
+retry:
+ if offload {
+ n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
+ err = s.send(conn, br, (*msgs)[:n])
+ if err != nil && offload && errShouldDisableUDPGSO(err) {
+ offload = false
+ s.mu.Lock()
+ if is6 {
+ s.ipv6TxOffload = false
+ } else {
+ s.ipv4TxOffload = false
+ }
+ s.mu.Unlock()
+ retried = true
+ goto retry
+ }
+ } else {
+ for i := range bufs {
+ (*msgs)[i].Addr = ua
+ (*msgs)[i].Buffers[0] = bufs[i]
+ setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
+ }
+ err = s.send(conn, br, (*msgs)[:len(bufs)])
+ }
+ if retried {
+ return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
+ }
+ return err
+}
+
+func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
+ var (
+ n int
+ err error
+ start int
+ )
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+ for {
+ n, err = pc.WriteBatch(msgs[start:], 0)
+ if err != nil || n == len(msgs[start:]) {
+ break
+ }
+ start += n
+ }
+ } else {
+ for _, msg := range msgs {
+ _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
+ if err != nil {
+ break
+ }
+ }
+ }
return err
}
-// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
-// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
-// but Endpoints are immutable, so we can re-use them.
-var endpointPool = sync.Pool{
- New: func() any {
- return make(map[netip.AddrPort]Endpoint)
- },
+const (
+ // Exceeding these values results in EMSGSIZE. They account for layer3 and
+ // layer4 headers. IPv6 does not need to account for itself as the payload
+ // length field is self excluding.
+ maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
+ maxIPv6PayloadLen = 1<<16 - 1 - 8
+
+ // This is a hard limit imposed by the kernel.
+ udpSegmentMaxDatagrams = 64
+)
+
+type setGSOFunc func(control *[]byte, gsoSize uint16)
+
+func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
+ var (
+ base = -1 // index of msg we are currently coalescing into
+ gsoSize int // segmentation size of msgs[base]
+ dgramCnt int // number of dgrams coalesced into msgs[base]
+ endBatch bool // tracking flag to start a new batch on next iteration of bufs
+ )
+ maxPayloadLen := maxIPv4PayloadLen
+ if ep.DstIP().Is6() {
+ maxPayloadLen = maxIPv6PayloadLen
+ }
+ for i, buf := range bufs {
+ if i > 0 {
+ msgLen := len(buf)
+ baseLenBefore := len(msgs[base].Buffers[0])
+ freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
+ if msgLen+baseLenBefore <= maxPayloadLen &&
+ msgLen <= gsoSize &&
+ msgLen <= freeBaseCap &&
+ dgramCnt < udpSegmentMaxDatagrams &&
+ !endBatch {
+ msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
+ if i == len(bufs)-1 {
+ setGSO(&msgs[base].OOB, uint16(gsoSize))
+ }
+ dgramCnt++
+ if msgLen < gsoSize {
+ // A smaller than gsoSize packet on the tail is legal, but
+ // it must end the batch.
+ endBatch = true
+ }
+ continue
+ }
+ }
+ if dgramCnt > 1 {
+ setGSO(&msgs[base].OOB, uint16(gsoSize))
+ }
+ // Reset prior to incrementing base since we are preparing to start a
+ // new potential batch.
+ endBatch = false
+ base++
+ gsoSize = len(buf)
+ setSrcControl(&msgs[base].OOB, ep)
+ msgs[base].Buffers[0] = buf
+ msgs[base].Addr = addr
+ dgramCnt = 1
+ }
+ return base + 1
}
-// asEndpoint returns an Endpoint containing ap.
-func asEndpoint(ap netip.AddrPort) Endpoint {
- m := endpointPool.Get().(map[netip.AddrPort]Endpoint)
- defer endpointPool.Put(m)
- e, ok := m[ap]
- if !ok {
- e = Endpoint(StdNetEndpoint(ap))
- m[ap] = e
+type getGSOFunc func(control []byte) (int, error)
+
+func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
+ for i := firstMsgAt; i < len(msgs); i++ {
+ msg := &msgs[i]
+ if msg.N == 0 {
+ return n, err
+ }
+ var (
+ gsoSize int
+ start int
+ end = msg.N
+ numToSplit = 1
+ )
+ gsoSize, err = getGSO(msg.OOB[:msg.NN])
+ if err != nil {
+ return n, err
+ }
+ if gsoSize > 0 {
+ numToSplit = (msg.N + gsoSize - 1) / gsoSize
+ end = gsoSize
+ }
+ for j := 0; j < numToSplit; j++ {
+ if n > i {
+ return n, errors.New("splitting coalesced packet resulted in overflow")
+ }
+ copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
+ msgs[n].N = copied
+ msgs[n].Addr = msg.Addr
+ start = end
+ end += gsoSize
+ if end > msg.N {
+ end = msg.N
+ }
+ n++
+ }
+ if i != n-1 {
+ // It is legal for bytes to move within msg.Buffers[0] as a result
+ // of splitting, so we only zero the source msg len when it is not
+ // the destination of the last split operation above.
+ msg.N = 0
+ }
}
- return e
+ return n, nil
}