diff options
Diffstat (limited to 'net/ipv4/tcp_bpf.c')
-rw-r--r-- | net/ipv4/tcp_bpf.c | 66 |
1 files changed, 47 insertions, 19 deletions
diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c index f70aa0932bd6..cf9c3e8f7ccb 100644 --- a/net/ipv4/tcp_bpf.c +++ b/net/ipv4/tcp_bpf.c @@ -138,10 +138,9 @@ int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, struct sk_psock *psock = sk_psock_get(sk); int ret; - if (unlikely(!psock)) { - sk_msg_free(sk, msg); - return 0; - } + if (unlikely(!psock)) + return -EPIPE; + ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) : tcp_bpf_push_locked(sk, msg, bytes, flags, false); sk_psock_put(sk, psock); @@ -175,7 +174,6 @@ static int tcp_msg_wait_data(struct sock *sk, struct sk_psock *psock, static int tcp_bpf_recvmsg_parser(struct sock *sk, struct msghdr *msg, size_t len, - int nonblock, int flags, int *addr_len) { @@ -187,7 +185,7 @@ static int tcp_bpf_recvmsg_parser(struct sock *sk, psock = sk_psock_get(sk); if (unlikely(!psock)) - return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); + return tcp_recvmsg(sk, msg, len, flags, addr_len); lock_sock(sk); msg_bytes_ready: @@ -196,19 +194,46 @@ msg_bytes_ready: long timeo; int data; - timeo = sock_rcvtimeo(sk, nonblock); + if (sock_flag(sk, SOCK_DONE)) + goto out; + + if (sk->sk_err) { + copied = sock_error(sk); + goto out; + } + + if (sk->sk_shutdown & RCV_SHUTDOWN) + goto out; + + if (sk->sk_state == TCP_CLOSE) { + copied = -ENOTCONN; + goto out; + } + + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + if (!timeo) { + copied = -EAGAIN; + goto out; + } + + if (signal_pending(current)) { + copied = sock_intr_errno(timeo); + goto out; + } + data = tcp_msg_wait_data(sk, psock, timeo); if (data && !sk_psock_queue_empty(psock)) goto msg_bytes_ready; copied = -EAGAIN; } +out: release_sock(sk); sk_psock_put(sk, psock); return copied; } static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, - int nonblock, int flags, int *addr_len) + int flags, int *addr_len) { struct sk_psock *psock; int copied, ret; @@ -218,11 +243,11 @@ static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, psock = sk_psock_get(sk); if (unlikely(!psock)) - return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); + return tcp_recvmsg(sk, msg, len, flags, addr_len); if (!skb_queue_empty(&sk->sk_receive_queue) && sk_psock_queue_empty(psock)) { sk_psock_put(sk, psock); - return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); + return tcp_recvmsg(sk, msg, len, flags, addr_len); } lock_sock(sk); msg_bytes_ready: @@ -231,14 +256,14 @@ msg_bytes_ready: long timeo; int data; - timeo = sock_rcvtimeo(sk, nonblock); + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); data = tcp_msg_wait_data(sk, psock, timeo); if (data) { if (!sk_psock_queue_empty(psock)) goto msg_bytes_ready; release_sock(sk); sk_psock_put(sk, psock); - return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); + return tcp_recvmsg(sk, msg, len, flags, addr_len); } copied = -EAGAIN; } @@ -253,7 +278,7 @@ static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, { bool cork = false, enospc = sk_msg_full(msg); struct sock *sk_redir; - u32 tosend, delta = 0; + u32 tosend, origsize, sent, delta = 0; u32 eval = __SK_NONE; int ret; @@ -311,7 +336,9 @@ more_data: sk_msg_return(sk, msg, tosend); release_sock(sk); + origsize = msg->sg.size; ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags); + sent = origsize - msg->sg.size; if (eval == __SK_REDIRECT) sock_put(sk_redir); @@ -348,8 +375,11 @@ more_data: } if (msg && msg->sg.data[msg->sg.start].page_link && - msg->sg.data[msg->sg.start].length) + msg->sg.data[msg->sg.start].length) { + if (eval == __SK_REDIRECT) + sk_mem_charge(sk, tosend - sent); goto more_data; + } } return ret; } @@ -512,6 +542,7 @@ static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], struct proto *base) { prot[TCP_BPF_BASE] = *base; + prot[TCP_BPF_BASE].destroy = sock_map_destroy; prot[TCP_BPF_BASE].close = sock_map_close; prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; prot[TCP_BPF_BASE].sock_is_readable = sk_msg_is_readable; @@ -578,14 +609,11 @@ int tcp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) } else { sk->sk_write_space = psock->saved_write_space; /* Pairs with lockless read in sk_clone_lock() */ - WRITE_ONCE(sk->sk_prot, psock->sk_proto); + sock_replace_proto(sk, psock->sk_proto); } return 0; } - if (inet_csk_has_ulp(sk)) - return -EINVAL; - if (sk->sk_family == AF_INET6) { if (tcp_bpf_assert_proto_ops(psock->sk_proto)) return -EINVAL; @@ -594,7 +622,7 @@ int tcp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) } /* Pairs with lockless read in sk_clone_lock() */ - WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]); + sock_replace_proto(sk, &tcp_bpf_prots[family][config]); return 0; } EXPORT_SYMBOL_GPL(tcp_bpf_update_proto); |