aboutsummaryrefslogtreecommitdiffstats
path: root/net/ipv4/tcp.c
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--net/ipv4/tcp.c199
1 files changed, 144 insertions, 55 deletions
diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index 32545ecf2ab1..a3422e42784e 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -280,6 +280,12 @@
#include <asm/ioctls.h>
#include <net/busy_poll.h>
+/* Track pending CMSGs. */
+enum {
+ TCP_CMSG_INQ = 1,
+ TCP_CMSG_TS = 2
+};
+
struct percpu_counter tcp_orphan_count;
EXPORT_SYMBOL_GPL(tcp_orphan_count);
@@ -475,19 +481,11 @@ static void tcp_tx_timestamp(struct sock *sk, u16 tsflags)
}
}
-static inline bool tcp_stream_is_readable(const struct tcp_sock *tp,
- int target, struct sock *sk)
+static bool tcp_stream_is_readable(struct sock *sk, int target)
{
- int avail = READ_ONCE(tp->rcv_nxt) - READ_ONCE(tp->copied_seq);
-
- if (avail > 0) {
- if (avail >= target)
- return true;
- if (tcp_rmem_pressure(sk))
- return true;
- if (tcp_receive_window(tp) <= inet_csk(sk)->icsk_ack.rcv_mss)
- return true;
- }
+ if (tcp_epollin_ready(sk, target))
+ return true;
+
if (sk->sk_prot->stream_memory_read)
return sk->sk_prot->stream_memory_read(sk);
return false;
@@ -562,7 +560,7 @@ __poll_t tcp_poll(struct file *file, struct socket *sock, poll_table *wait)
tp->urg_data)
target++;
- if (tcp_stream_is_readable(tp, target, sk))
+ if (tcp_stream_is_readable(sk, target))
mask |= EPOLLIN | EPOLLRDNORM;
if (!(sk->sk_shutdown & SEND_SHUTDOWN)) {
@@ -1010,7 +1008,7 @@ new_segment:
}
if (!(flags & MSG_NO_SHARED_FRAGS))
- skb_shinfo(skb)->tx_flags |= SKBTX_SHARED_FRAG;
+ skb_shinfo(skb)->flags |= SKBFL_SHARED_FRAG;
skb->len += copy;
skb->data_len += copy;
@@ -1217,7 +1215,7 @@ int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size)
if (flags & MSG_ZEROCOPY && size && sock_flag(sk, SOCK_ZEROCOPY)) {
skb = tcp_write_queue_tail(sk);
- uarg = sock_zerocopy_realloc(sk, size, skb_zcopy(skb));
+ uarg = msg_zerocopy_realloc(sk, size, skb_zcopy(skb));
if (!uarg) {
err = -ENOBUFS;
goto out_err;
@@ -1429,7 +1427,7 @@ out:
tcp_push(sk, flags, mss_now, tp->nonagle, size_goal);
}
out_nopush:
- sock_zerocopy_put(uarg);
+ net_zcopy_put(uarg);
return copied + copied_syn;
do_error:
@@ -1440,7 +1438,7 @@ do_fault:
if (copied + copied_syn)
goto out;
out_err:
- sock_zerocopy_put_abort(uarg, true);
+ net_zcopy_put_abort(uarg, true);
err = sk_stream_error(sk, flags, err);
/* make sure we wake any epoll edge trigger waiter */
if (unlikely(tcp_rtx_and_write_queues_empty(sk) && err == -EAGAIN)) {
@@ -1739,6 +1737,20 @@ int tcp_set_rcvlowat(struct sock *sk, int val)
}
EXPORT_SYMBOL(tcp_set_rcvlowat);
+static void tcp_update_recv_tstamps(struct sk_buff *skb,
+ struct scm_timestamping_internal *tss)
+{
+ if (skb->tstamp)
+ tss->ts[0] = ktime_to_timespec64(skb->tstamp);
+ else
+ tss->ts[0] = (struct timespec64) {0};
+
+ if (skb_hwtstamps(skb)->hwtstamp)
+ tss->ts[2] = ktime_to_timespec64(skb_hwtstamps(skb)->hwtstamp);
+ else
+ tss->ts[2] = (struct timespec64) {0};
+}
+
#ifdef CONFIG_MMU
static const struct vm_operations_struct tcp_vm_ops = {
};
@@ -1842,13 +1854,13 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len,
struct scm_timestamping_internal *tss,
int *cmsg_flags);
static int receive_fallback_to_copy(struct sock *sk,
- struct tcp_zerocopy_receive *zc, int inq)
+ struct tcp_zerocopy_receive *zc, int inq,
+ struct scm_timestamping_internal *tss)
{
unsigned long copy_address = (unsigned long)zc->copybuf_address;
- struct scm_timestamping_internal tss_unused;
- int err, cmsg_flags_unused;
struct msghdr msg = {};
struct iovec iov;
+ int err;
zc->length = 0;
zc->recv_skip_hint = 0;
@@ -1862,7 +1874,7 @@ static int receive_fallback_to_copy(struct sock *sk,
return err;
err = tcp_recvmsg_locked(sk, &msg, inq, /*nonblock=*/1, /*flags=*/0,
- &tss_unused, &cmsg_flags_unused);
+ tss, &zc->msg_flags);
if (err < 0)
return err;
@@ -1903,21 +1915,27 @@ static int tcp_copy_straggler_data(struct tcp_zerocopy_receive *zc,
return (__s32)copylen;
}
-static int tcp_zerocopy_handle_leftover_data(struct tcp_zerocopy_receive *zc,
- struct sock *sk,
- struct sk_buff *skb,
- u32 *seq,
- s32 copybuf_len)
+static int tcp_zc_handle_leftover(struct tcp_zerocopy_receive *zc,
+ struct sock *sk,
+ struct sk_buff *skb,
+ u32 *seq,
+ s32 copybuf_len,
+ struct scm_timestamping_internal *tss)
{
u32 offset, copylen = min_t(u32, copybuf_len, zc->recv_skip_hint);
if (!copylen)
return 0;
/* skb is null if inq < PAGE_SIZE. */
- if (skb)
+ if (skb) {
offset = *seq - TCP_SKB_CB(skb)->seq;
- else
+ } else {
skb = tcp_recv_skb(sk, *seq, &offset);
+ if (TCP_SKB_CB(skb)->has_rxtstamp) {
+ tcp_update_recv_tstamps(skb, tss);
+ zc->msg_flags |= TCP_CMSG_TS;
+ }
+ }
zc->copybuf_len = tcp_copy_straggler_data(zc, skb, copylen, &offset,
seq);
@@ -2004,9 +2022,38 @@ static int tcp_zerocopy_vm_insert_batch(struct vm_area_struct *vma,
err);
}
+#define TCP_VALID_ZC_MSG_FLAGS (TCP_CMSG_TS)
+static void tcp_recv_timestamp(struct msghdr *msg, const struct sock *sk,
+ struct scm_timestamping_internal *tss);
+static void tcp_zc_finalize_rx_tstamp(struct sock *sk,
+ struct tcp_zerocopy_receive *zc,
+ struct scm_timestamping_internal *tss)
+{
+ unsigned long msg_control_addr;
+ struct msghdr cmsg_dummy;
+
+ msg_control_addr = (unsigned long)zc->msg_control;
+ cmsg_dummy.msg_control = (void *)msg_control_addr;
+ cmsg_dummy.msg_controllen =
+ (__kernel_size_t)zc->msg_controllen;
+ cmsg_dummy.msg_flags = in_compat_syscall()
+ ? MSG_CMSG_COMPAT : 0;
+ zc->msg_flags = 0;
+ if (zc->msg_control == msg_control_addr &&
+ zc->msg_controllen == cmsg_dummy.msg_controllen) {
+ tcp_recv_timestamp(&cmsg_dummy, sk, tss);
+ zc->msg_control = (__u64)
+ ((uintptr_t)cmsg_dummy.msg_control);
+ zc->msg_controllen =
+ (__u64)cmsg_dummy.msg_controllen;
+ zc->msg_flags = (__u32)cmsg_dummy.msg_flags;
+ }
+}
+
#define TCP_ZEROCOPY_PAGE_BATCH_SIZE 32
static int tcp_zerocopy_receive(struct sock *sk,
- struct tcp_zerocopy_receive *zc)
+ struct tcp_zerocopy_receive *zc,
+ struct scm_timestamping_internal *tss)
{
u32 length = 0, offset, vma_len, avail_len, copylen = 0;
unsigned long address = (unsigned long)zc->address;
@@ -2023,6 +2070,7 @@ static int tcp_zerocopy_receive(struct sock *sk,
int ret;
zc->copybuf_len = 0;
+ zc->msg_flags = 0;
if (address & (PAGE_SIZE - 1) || address != zc->address)
return -EINVAL;
@@ -2033,7 +2081,7 @@ static int tcp_zerocopy_receive(struct sock *sk,
sock_rps_record_flow(sk);
if (inq && inq <= copybuf_len)
- return receive_fallback_to_copy(sk, zc, inq);
+ return receive_fallback_to_copy(sk, zc, inq, tss);
if (inq < PAGE_SIZE) {
zc->length = 0;
@@ -2078,6 +2126,11 @@ static int tcp_zerocopy_receive(struct sock *sk,
} else {
skb = tcp_recv_skb(sk, seq, &offset);
}
+
+ if (TCP_SKB_CB(skb)->has_rxtstamp) {
+ tcp_update_recv_tstamps(skb, tss);
+ zc->msg_flags |= TCP_CMSG_TS;
+ }
zc->recv_skip_hint = skb->len - offset;
frags = skb_advance_to_frag(skb, offset, &offset_frag);
if (!frags || offset_frag)
@@ -2120,8 +2173,7 @@ out:
mmap_read_unlock(current->mm);
/* Try to copy straggler data. */
if (!ret)
- copylen = tcp_zerocopy_handle_leftover_data(zc, sk, skb, &seq,
- copybuf_len);
+ copylen = tcp_zc_handle_leftover(zc, sk, skb, &seq, copybuf_len, tss);
if (length + copylen) {
WRITE_ONCE(tp->copied_seq, seq);
@@ -2142,20 +2194,6 @@ out:
}
#endif
-static void tcp_update_recv_tstamps(struct sk_buff *skb,
- struct scm_timestamping_internal *tss)
-{
- if (skb->tstamp)
- tss->ts[0] = ktime_to_timespec64(skb->tstamp);
- else
- tss->ts[0] = (struct timespec64) {0};
-
- if (skb_hwtstamps(skb)->hwtstamp)
- tss->ts[2] = ktime_to_timespec64(skb_hwtstamps(skb)->hwtstamp);
- else
- tss->ts[2] = (struct timespec64) {0};
-}
-
/* Similar to __sock_recv_timestamp, but does not require an skb */
static void tcp_recv_timestamp(struct msghdr *msg, const struct sock *sk,
struct scm_timestamping_internal *tss)
@@ -2272,7 +2310,7 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len,
goto out;
if (tp->recvmsg_inq)
- *cmsg_flags = 1;
+ *cmsg_flags = TCP_CMSG_INQ;
timeo = sock_rcvtimeo(sk, nonblock);
/* Urgent data needs to be handled specially. */
@@ -2453,7 +2491,7 @@ skip_copy:
if (TCP_SKB_CB(skb)->has_rxtstamp) {
tcp_update_recv_tstamps(skb, tss);
- *cmsg_flags |= 2;
+ *cmsg_flags |= TCP_CMSG_TS;
}
if (used + offset < skb->len)
@@ -2513,9 +2551,9 @@ int tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int nonblock,
release_sock(sk);
if (cmsg_flags && ret >= 0) {
- if (cmsg_flags & 2)
+ if (cmsg_flags & TCP_CMSG_TS)
tcp_recv_timestamp(msg, sk, &tss);
- if (cmsg_flags & 1) {
+ if (cmsg_flags & TCP_CMSG_INQ) {
inq = tcp_inq_hint(sk);
put_cmsg(msg, SOL_TCP, TCP_CM_INQ, sizeof(inq), &inq);
}
@@ -3767,11 +3805,24 @@ static size_t tcp_opt_stats_get_size(void)
nla_total_size(sizeof(u16)) + /* TCP_NLA_TIMEOUT_REHASH */
nla_total_size(sizeof(u32)) + /* TCP_NLA_BYTES_NOTSENT */
nla_total_size_64bit(sizeof(u64)) + /* TCP_NLA_EDT */
+ nla_total_size(sizeof(u8)) + /* TCP_NLA_TTL */
0;
}
+/* Returns TTL or hop limit of an incoming packet from skb. */
+static u8 tcp_skb_ttl_or_hop_limit(const struct sk_buff *skb)
+{
+ if (skb->protocol == htons(ETH_P_IP))
+ return ip_hdr(skb)->ttl;
+ else if (skb->protocol == htons(ETH_P_IPV6))
+ return ipv6_hdr(skb)->hop_limit;
+ else
+ return 0;
+}
+
struct sk_buff *tcp_get_timestamping_opt_stats(const struct sock *sk,
- const struct sk_buff *orig_skb)
+ const struct sk_buff *orig_skb,
+ const struct sk_buff *ack_skb)
{
const struct tcp_sock *tp = tcp_sk(sk);
struct sk_buff *stats;
@@ -3827,6 +3878,9 @@ struct sk_buff *tcp_get_timestamping_opt_stats(const struct sock *sk,
max_t(int, 0, tp->write_seq - tp->snd_nxt));
nla_put_u64_64bit(stats, TCP_NLA_EDT, orig_skb->skb_mstamp_ns,
TCP_NLA_PAD);
+ if (ack_skb)
+ nla_put_u8(stats, TCP_NLA_TTL,
+ tcp_skb_ttl_or_hop_limit(ack_skb));
return stats;
}
@@ -4083,6 +4137,7 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
}
#ifdef CONFIG_MMU
case TCP_ZEROCOPY_RECEIVE: {
+ struct scm_timestamping_internal tss;
struct tcp_zerocopy_receive zc = {};
int err;
@@ -4090,19 +4145,36 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
return -EFAULT;
if (len < offsetofend(struct tcp_zerocopy_receive, length))
return -EINVAL;
- if (len > sizeof(zc)) {
+ if (unlikely(len > sizeof(zc))) {
+ err = check_zeroed_user(optval + sizeof(zc),
+ len - sizeof(zc));
+ if (err < 1)
+ return err == 0 ? -EINVAL : err;
len = sizeof(zc);
if (put_user(len, optlen))
return -EFAULT;
}
if (copy_from_user(&zc, optval, len))
return -EFAULT;
+ if (zc.reserved)
+ return -EINVAL;
+ if (zc.msg_flags & ~(TCP_VALID_ZC_MSG_FLAGS))
+ return -EINVAL;
lock_sock(sk);
- err = tcp_zerocopy_receive(sk, &zc);
+ err = tcp_zerocopy_receive(sk, &zc, &tss);
+ err = BPF_CGROUP_RUN_PROG_GETSOCKOPT_KERN(sk, level, optname,
+ &zc, &len, err);
release_sock(sk);
- if (len >= offsetofend(struct tcp_zerocopy_receive, err))
- goto zerocopy_rcv_sk_err;
+ if (len >= offsetofend(struct tcp_zerocopy_receive, msg_flags))
+ goto zerocopy_rcv_cmsg;
switch (len) {
+ case offsetofend(struct tcp_zerocopy_receive, msg_flags):
+ goto zerocopy_rcv_cmsg;
+ case offsetofend(struct tcp_zerocopy_receive, msg_controllen):
+ case offsetofend(struct tcp_zerocopy_receive, msg_control):
+ case offsetofend(struct tcp_zerocopy_receive, flags):
+ case offsetofend(struct tcp_zerocopy_receive, copybuf_len):
+ case offsetofend(struct tcp_zerocopy_receive, copybuf_address):
case offsetofend(struct tcp_zerocopy_receive, err):
goto zerocopy_rcv_sk_err;
case offsetofend(struct tcp_zerocopy_receive, inq):
@@ -4111,6 +4183,11 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
default:
goto zerocopy_rcv_out;
}
+zerocopy_rcv_cmsg:
+ if (zc.msg_flags & TCP_CMSG_TS)
+ tcp_zc_finalize_rx_tstamp(sk, &zc, &tss);
+ else
+ zc.msg_flags = 0;
zerocopy_rcv_sk_err:
if (!err)
zc.err = sock_error(sk);
@@ -4133,6 +4210,18 @@ zerocopy_rcv_out:
return 0;
}
+bool tcp_bpf_bypass_getsockopt(int level, int optname)
+{
+ /* TCP do_tcp_getsockopt has optimized getsockopt implementation
+ * to avoid extra socket lock for TCP_ZEROCOPY_RECEIVE.
+ */
+ if (level == SOL_TCP && optname == TCP_ZEROCOPY_RECEIVE)
+ return true;
+
+ return false;
+}
+EXPORT_SYMBOL(tcp_bpf_bypass_getsockopt);
+
int tcp_getsockopt(struct sock *sk, int level, int optname, char __user *optval,
int __user *optlen)
{