diff options
Diffstat (limited to 'conn/bind_std.go')
-rw-r--r-- | conn/bind_std.go | 524 |
1 files changed, 436 insertions, 88 deletions
diff --git a/conn/bind_std.go b/conn/bind_std.go index cb85cfd..46df7fd 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -1,72 +1,126 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn import ( + "context" "errors" + "fmt" "net" + "net/netip" + "runtime" + "strconv" "sync" "syscall" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +var ( + _ Bind = (*StdNetBind)(nil) ) -// StdNetBind is meant to be a temporary solution on platforms for which -// the sticky socket / source caching behavior has not yet been implemented. -// It uses the Go's net package to implement networking. -// See LinuxSocketBind for a proper implementation on the Linux platform. +// StdNetBind implements Bind for all platforms. While Windows has its own Bind +// (see bind_windows.go), it may fall back to StdNetBind. +// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable +// methods for sending and receiving multiple datagrams per-syscall. See the +// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. type StdNetBind struct { - mu sync.Mutex // protects following fields - ipv4 *net.UDPConn - ipv6 *net.UDPConn + mu sync.Mutex // protects all fields except as specified + ipv4 *net.UDPConn + ipv6 *net.UDPConn + ipv4PC *ipv4.PacketConn // will be nil on non-Linux + ipv6PC *ipv6.PacketConn // will be nil on non-Linux + ipv4TxOffload bool + ipv4RxOffload bool + ipv6TxOffload bool + ipv6RxOffload bool + + // these two fields are not guarded by mu + udpAddrPool sync.Pool + msgsPool sync.Pool + blackhole4 bool blackhole6 bool } -func NewStdNetBind() Bind { return &StdNetBind{} } +func NewStdNetBind() Bind { + return &StdNetBind{ + udpAddrPool: sync.Pool{ + New: func() any { + return &net.UDPAddr{ + IP: make([]byte, 16), + } + }, + }, + + msgsPool: sync.Pool{ + New: func() any { + // ipv6.Message and ipv4.Message are interchangeable as they are + // both aliases for x/net/internal/socket.Message. + msgs := make([]ipv6.Message, IdealBatchSize) + for i := range msgs { + msgs[i].Buffers = make(net.Buffers, 1) + msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize) + } + return &msgs + }, + }, + } +} -type StdNetEndpoint net.UDPAddr +type StdNetEndpoint struct { + // AddrPort is the endpoint destination. + netip.AddrPort + // src is the current sticky source address and interface index, if + // supported. Typically this is a PKTINFO structure from/for control + // messages, see unix.PKTINFO for an example. + src []byte +} -var _ Bind = (*StdNetBind)(nil) -var _ Endpoint = (*StdNetEndpoint)(nil) +var ( + _ Bind = (*StdNetBind)(nil) + _ Endpoint = &StdNetEndpoint{} +) func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { - addr, err := parseEndpoint(s) - return (*StdNetEndpoint)(addr), err + e, err := netip.ParseAddrPort(s) + if err != nil { + return nil, err + } + return &StdNetEndpoint{ + AddrPort: e, + }, nil } -func (*StdNetEndpoint) ClearSrc() {} - -func (e *StdNetEndpoint) DstIP() net.IP { - return (*net.UDPAddr)(e).IP +func (e *StdNetEndpoint) ClearSrc() { + if e.src != nil { + // Truncate src, no need to reallocate. + e.src = e.src[:0] + } } -func (e *StdNetEndpoint) SrcIP() net.IP { - return nil // not supported +func (e *StdNetEndpoint) DstIP() netip.Addr { + return e.AddrPort.Addr() } +// See control_default,linux, etc for implementations of SrcIP and SrcIfidx. + func (e *StdNetEndpoint) DstToBytes() []byte { - addr := (*net.UDPAddr)(e) - out := addr.IP.To4() - if out == nil { - out = addr.IP - } - out = append(out, byte(addr.Port&0xff)) - out = append(out, byte((addr.Port>>8)&0xff)) - return out + b, _ := e.AddrPort.MarshalBinary() + return b } func (e *StdNetEndpoint) DstToString() string { - return (*net.UDPAddr)(e).String() -} - -func (e *StdNetEndpoint) SrcToString() string { - return "" + return e.AddrPort.String() } func listenNet(network string, port int) (*net.UDPConn, int, error) { - conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) + conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) if err != nil { return nil, 0, err } @@ -80,17 +134,17 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) { if err != nil { return nil, 0, err } - return conn, uaddr.Port, nil + return conn.(*net.UDPConn), uaddr.Port, nil } -func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { - bind.mu.Lock() - defer bind.mu.Unlock() +func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { + s.mu.Lock() + defer s.mu.Unlock() var err error var tries int - if bind.ipv4 != nil || bind.ipv6 != nil { + if s.ipv4 != nil || s.ipv6 != nil { return nil, 0, ErrBindAlreadyOpen } @@ -98,92 +152,207 @@ func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { // If uport is 0, we can retry on failure. again: port := int(uport) - var ipv4, ipv6 *net.UDPConn + var v4conn, v6conn *net.UDPConn + var v4pc *ipv4.PacketConn + var v6pc *ipv6.PacketConn - ipv4, port, err = listenNet("udp4", port) + v4conn, port, err = listenNet("udp4", port) if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { return nil, 0, err } // Listen on the same port as we're using for ipv4. - ipv6, port, err = listenNet("udp6", port) + v6conn, port, err = listenNet("udp6", port) if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { - ipv4.Close() + v4conn.Close() tries++ goto again } if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { - ipv4.Close() + v4conn.Close() return nil, 0, err } var fns []ReceiveFunc - if ipv4 != nil { - fns = append(fns, bind.makeReceiveIPv4(ipv4)) - bind.ipv4 = ipv4 + if v4conn != nil { + s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn) + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + v4pc = ipv4.NewPacketConn(v4conn) + s.ipv4PC = v4pc + } + fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload)) + s.ipv4 = v4conn } - if ipv6 != nil { - fns = append(fns, bind.makeReceiveIPv6(ipv6)) - bind.ipv6 = ipv6 + if v6conn != nil { + s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn) + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + v6pc = ipv6.NewPacketConn(v6conn) + s.ipv6PC = v6pc + } + fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload)) + s.ipv6 = v6conn } if len(fns) == 0 { return nil, 0, syscall.EAFNOSUPPORT } + return fns, uint16(port), nil } -func (bind *StdNetBind) Close() error { - bind.mu.Lock() - defer bind.mu.Unlock() +func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) { + for i := range *msgs { + (*msgs)[i].OOB = (*msgs)[i].OOB[:0] + (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB} + } + s.msgsPool.Put(msgs) +} + +func (s *StdNetBind) getMessages() *[]ipv6.Message { + return s.msgsPool.Get().(*[]ipv6.Message) +} + +var ( + // If compilation fails here these are no longer the same underlying type. + _ ipv6.Message = ipv4.Message{} +) + +type batchReader interface { + ReadBatch([]ipv6.Message, int) (int, error) +} + +type batchWriter interface { + WriteBatch([]ipv6.Message, int) (int, error) +} + +func (s *StdNetBind) receiveIP( + br batchReader, + conn *net.UDPConn, + rxOffload bool, + bufs [][]byte, + sizes []int, + eps []Endpoint, +) (n int, err error) { + msgs := s.getMessages() + for i := range bufs { + (*msgs)[i].Buffers[0] = bufs[i] + (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] + } + defer s.putMessages(msgs) + var numMsgs int + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + if rxOffload { + readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams) + numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) + if err != nil { + return 0, err + } + numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize) + if err != nil { + return 0, err + } + } else { + numMsgs, err = br.ReadBatch(*msgs, 0) + if err != nil { + return 0, err + } + } + } else { + msg := &(*msgs)[0] + msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + if err != nil { + return 0, err + } + numMsgs = 1 + } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + if sizes[i] == 0 { + continue + } + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation + getSrcFromControl(msg.OOB[:msg.NN], ep) + eps[i] = ep + } + return numMsgs, nil +} + +func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + } +} + +func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + } +} + +// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and +// rename the IdealBatchSize constant to BatchSize. +func (s *StdNetBind) BatchSize() int { + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + return IdealBatchSize + } + return 1 +} + +func (s *StdNetBind) Close() error { + s.mu.Lock() + defer s.mu.Unlock() var err1, err2 error - if bind.ipv4 != nil { - err1 = bind.ipv4.Close() - bind.ipv4 = nil + if s.ipv4 != nil { + err1 = s.ipv4.Close() + s.ipv4 = nil + s.ipv4PC = nil } - if bind.ipv6 != nil { - err2 = bind.ipv6.Close() - bind.ipv6 = nil + if s.ipv6 != nil { + err2 = s.ipv6.Close() + s.ipv6 = nil + s.ipv6PC = nil } - bind.blackhole4 = false - bind.blackhole6 = false + s.blackhole4 = false + s.blackhole6 = false + s.ipv4TxOffload = false + s.ipv4RxOffload = false + s.ipv6TxOffload = false + s.ipv6RxOffload = false if err1 != nil { return err1 } return err2 } -func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc { - return func(buff []byte) (int, Endpoint, error) { - n, endpoint, err := conn.ReadFromUDP(buff) - if endpoint != nil { - endpoint.IP = endpoint.IP.To4() - } - return n, (*StdNetEndpoint)(endpoint), err - } +type ErrUDPGSODisabled struct { + onLaddr string + RetryErr error } -func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc { - return func(buff []byte) (int, Endpoint, error) { - n, endpoint, err := conn.ReadFromUDP(buff) - return n, (*StdNetEndpoint)(endpoint), err - } +func (e ErrUDPGSODisabled) Error() string { + return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr) } -func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { - var err error - nend, ok := endpoint.(*StdNetEndpoint) - if !ok { - return ErrWrongEndpointType - } +func (e ErrUDPGSODisabled) Unwrap() error { + return e.RetryErr +} - bind.mu.Lock() - blackhole := bind.blackhole4 - conn := bind.ipv4 - if nend.IP.To4() == nil { - blackhole = bind.blackhole6 - conn = bind.ipv6 +func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { + s.mu.Lock() + blackhole := s.blackhole4 + conn := s.ipv4 + offload := s.ipv4TxOffload + br := batchWriter(s.ipv4PC) + is6 := false + if endpoint.DstIP().Is6() { + blackhole = s.blackhole6 + conn = s.ipv6 + br = s.ipv6PC + is6 = true + offload = s.ipv6TxOffload } - bind.mu.Unlock() + s.mu.Unlock() if blackhole { return nil @@ -191,6 +360,185 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { if conn == nil { return syscall.EAFNOSUPPORT } - _, err = conn.WriteToUDP(buff, (*net.UDPAddr)(nend)) + + msgs := s.getMessages() + defer s.putMessages(msgs) + ua := s.udpAddrPool.Get().(*net.UDPAddr) + defer s.udpAddrPool.Put(ua) + if is6 { + as16 := endpoint.DstIP().As16() + copy(ua.IP, as16[:]) + ua.IP = ua.IP[:16] + } else { + as4 := endpoint.DstIP().As4() + copy(ua.IP, as4[:]) + ua.IP = ua.IP[:4] + } + ua.Port = int(endpoint.(*StdNetEndpoint).Port()) + var ( + retried bool + err error + ) +retry: + if offload { + n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize) + err = s.send(conn, br, (*msgs)[:n]) + if err != nil && offload && errShouldDisableUDPGSO(err) { + offload = false + s.mu.Lock() + if is6 { + s.ipv6TxOffload = false + } else { + s.ipv4TxOffload = false + } + s.mu.Unlock() + retried = true + goto retry + } + } else { + for i := range bufs { + (*msgs)[i].Addr = ua + (*msgs)[i].Buffers[0] = bufs[i] + setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint)) + } + err = s.send(conn, br, (*msgs)[:len(bufs)]) + } + if retried { + return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err} + } + return err +} + +func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error { + var ( + n int + err error + start int + ) + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + for { + n, err = pc.WriteBatch(msgs[start:], 0) + if err != nil || n == len(msgs[start:]) { + break + } + start += n + } + } else { + for _, msg := range msgs { + _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr)) + if err != nil { + break + } + } + } return err } + +const ( + // Exceeding these values results in EMSGSIZE. They account for layer3 and + // layer4 headers. IPv6 does not need to account for itself as the payload + // length field is self excluding. + maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 + maxIPv6PayloadLen = 1<<16 - 1 - 8 + + // This is a hard limit imposed by the kernel. + udpSegmentMaxDatagrams = 64 +) + +type setGSOFunc func(control *[]byte, gsoSize uint16) + +func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int { + var ( + base = -1 // index of msg we are currently coalescing into + gsoSize int // segmentation size of msgs[base] + dgramCnt int // number of dgrams coalesced into msgs[base] + endBatch bool // tracking flag to start a new batch on next iteration of bufs + ) + maxPayloadLen := maxIPv4PayloadLen + if ep.DstIP().Is6() { + maxPayloadLen = maxIPv6PayloadLen + } + for i, buf := range bufs { + if i > 0 { + msgLen := len(buf) + baseLenBefore := len(msgs[base].Buffers[0]) + freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore + if msgLen+baseLenBefore <= maxPayloadLen && + msgLen <= gsoSize && + msgLen <= freeBaseCap && + dgramCnt < udpSegmentMaxDatagrams && + !endBatch { + msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...) + if i == len(bufs)-1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) + } + dgramCnt++ + if msgLen < gsoSize { + // A smaller than gsoSize packet on the tail is legal, but + // it must end the batch. + endBatch = true + } + continue + } + } + if dgramCnt > 1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) + } + // Reset prior to incrementing base since we are preparing to start a + // new potential batch. + endBatch = false + base++ + gsoSize = len(buf) + setSrcControl(&msgs[base].OOB, ep) + msgs[base].Buffers[0] = buf + msgs[base].Addr = addr + dgramCnt = 1 + } + return base + 1 +} + +type getGSOFunc func(control []byte) (int, error) + +func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) { + for i := firstMsgAt; i < len(msgs); i++ { + msg := &msgs[i] + if msg.N == 0 { + return n, err + } + var ( + gsoSize int + start int + end = msg.N + numToSplit = 1 + ) + gsoSize, err = getGSO(msg.OOB[:msg.NN]) + if err != nil { + return n, err + } + if gsoSize > 0 { + numToSplit = (msg.N + gsoSize - 1) / gsoSize + end = gsoSize + } + for j := 0; j < numToSplit; j++ { + if n > i { + return n, errors.New("splitting coalesced packet resulted in overflow") + } + copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) + msgs[n].N = copied + msgs[n].Addr = msg.Addr + start = end + end += gsoSize + if end > msg.N { + end = msg.N + } + n++ + } + if i != n-1 { + // It is legal for bytes to move within msg.Buffers[0] as a result + // of splitting, so we only zero the source msg len when it is not + // the destination of the last split operation above. + msg.N = 0 + } + } + return n, nil +} |