diff options
-rw-r--r-- | kernel/bpf/sockmap.c | 52 | ||||
-rw-r--r-- | net/core/filter.c | 52 |
2 files changed, 54 insertions, 50 deletions
diff --git a/kernel/bpf/sockmap.c b/kernel/bpf/sockmap.c index cf5195c7c331..ce63e5801746 100644 --- a/kernel/bpf/sockmap.c +++ b/kernel/bpf/sockmap.c @@ -236,7 +236,7 @@ static int bpf_tcp_init(struct sock *sk) } static void smap_release_sock(struct smap_psock *psock, struct sock *sock); -static int free_start_sg(struct sock *sk, struct sk_msg_buff *md); +static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge); static void bpf_tcp_release(struct sock *sk) { @@ -248,7 +248,7 @@ static void bpf_tcp_release(struct sock *sk) goto out; if (psock->cork) { - free_start_sg(psock->sock, psock->cork); + free_start_sg(psock->sock, psock->cork, true); kfree(psock->cork); psock->cork = NULL; } @@ -330,14 +330,14 @@ static void bpf_tcp_close(struct sock *sk, long timeout) close_fun = psock->save_close; if (psock->cork) { - free_start_sg(psock->sock, psock->cork); + free_start_sg(psock->sock, psock->cork, true); kfree(psock->cork); psock->cork = NULL; } list_for_each_entry_safe(md, mtmp, &psock->ingress, list) { list_del(&md->list); - free_start_sg(psock->sock, md); + free_start_sg(psock->sock, md, true); kfree(md); } @@ -369,7 +369,7 @@ static void bpf_tcp_close(struct sock *sk, long timeout) /* If another thread deleted this object skip deletion. * The refcnt on psock may or may not be zero. */ - if (l) { + if (l && l == link) { hlist_del_rcu(&link->hash_node); smap_release_sock(psock, link->sk); free_htab_elem(htab, link); @@ -570,14 +570,16 @@ static void free_bytes_sg(struct sock *sk, int bytes, md->sg_start = i; } -static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md) +static int free_sg(struct sock *sk, int start, + struct sk_msg_buff *md, bool charge) { struct scatterlist *sg = md->sg_data; int i = start, free = 0; while (sg[i].length) { free += sg[i].length; - sk_mem_uncharge(sk, sg[i].length); + if (charge) + sk_mem_uncharge(sk, sg[i].length); if (!md->skb) put_page(sg_page(&sg[i])); sg[i].length = 0; @@ -594,9 +596,9 @@ static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md) return free; } -static int free_start_sg(struct sock *sk, struct sk_msg_buff *md) +static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge) { - int free = free_sg(sk, md->sg_start, md); + int free = free_sg(sk, md->sg_start, md, charge); md->sg_start = md->sg_end; return free; @@ -604,7 +606,7 @@ static int free_start_sg(struct sock *sk, struct sk_msg_buff *md) static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md) { - return free_sg(sk, md->sg_curr, md); + return free_sg(sk, md->sg_curr, md, true); } static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md) @@ -718,7 +720,7 @@ static int bpf_tcp_ingress(struct sock *sk, int apply_bytes, list_add_tail(&r->list, &psock->ingress); sk->sk_data_ready(sk); } else { - free_start_sg(sk, r); + free_start_sg(sk, r, true); kfree(r); } @@ -752,14 +754,10 @@ static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send, release_sock(sk); } smap_release_sock(psock, sk); - if (unlikely(err)) - goto out; - return 0; + return err; out_rcu: rcu_read_unlock(); -out: - free_bytes_sg(NULL, send, md, false); - return err; + return 0; } static inline void bpf_md_init(struct smap_psock *psock) @@ -822,7 +820,7 @@ more_data: case __SK_PASS: err = bpf_tcp_push(sk, send, m, flags, true); if (unlikely(err)) { - *copied -= free_start_sg(sk, m); + *copied -= free_start_sg(sk, m, true); break; } @@ -845,16 +843,17 @@ more_data: lock_sock(sk); if (unlikely(err < 0)) { - free_start_sg(sk, m); + int free = free_start_sg(sk, m, false); + psock->sg_size = 0; if (!cork) - *copied -= send; + *copied -= free; } else { psock->sg_size -= send; } if (cork) { - free_start_sg(sk, m); + free_start_sg(sk, m, true); psock->sg_size = 0; kfree(m); m = NULL; @@ -912,6 +911,8 @@ static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, if (unlikely(flags & MSG_ERRQUEUE)) return inet_recv_error(sk, msg, len, addr_len); + if (!skb_queue_empty(&sk->sk_receive_queue)) + return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); rcu_read_lock(); psock = smap_psock_sk(sk); @@ -922,9 +923,6 @@ static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, goto out; rcu_read_unlock(); - if (!skb_queue_empty(&sk->sk_receive_queue)) - return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); - lock_sock(sk); bytes_ready: while (copied != len) { @@ -1122,7 +1120,7 @@ wait_for_memory: err = sk_stream_wait_memory(sk, &timeo); if (err) { if (m && m != psock->cork) - free_start_sg(sk, m); + free_start_sg(sk, m, true); goto out_err; } } @@ -1581,13 +1579,13 @@ static void smap_gc_work(struct work_struct *w) bpf_prog_put(psock->bpf_tx_msg); if (psock->cork) { - free_start_sg(psock->sock, psock->cork); + free_start_sg(psock->sock, psock->cork, true); kfree(psock->cork); } list_for_each_entry_safe(md, mtmp, &psock->ingress, list) { list_del(&md->list); - free_start_sg(psock->sock, md); + free_start_sg(psock->sock, md, true); kfree(md); } diff --git a/net/core/filter.c b/net/core/filter.c index c25eb36f1320..2c7801f6737a 100644 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -2282,14 +2282,21 @@ static const struct bpf_func_proto bpf_msg_cork_bytes_proto = { .arg2_type = ARG_ANYTHING, }; +#define sk_msg_iter_var(var) \ + do { \ + var++; \ + if (var == MAX_SKB_FRAGS) \ + var = 0; \ + } while (0) + BPF_CALL_4(bpf_msg_pull_data, struct sk_msg_buff *, msg, u32, start, u32, end, u64, flags) { unsigned int len = 0, offset = 0, copy = 0; + int bytes = end - start, bytes_sg_total; struct scatterlist *sg = msg->sg_data; int first_sg, last_sg, i, shift; unsigned char *p, *to, *from; - int bytes = end - start; struct page *page; if (unlikely(flags || end <= start)) @@ -2299,21 +2306,22 @@ BPF_CALL_4(bpf_msg_pull_data, i = msg->sg_start; do { len = sg[i].length; - offset += len; if (start < offset + len) break; - i++; - if (i == MAX_SKB_FRAGS) - i = 0; + offset += len; + sk_msg_iter_var(i); } while (i != msg->sg_end); if (unlikely(start >= offset + len)) return -EINVAL; - if (!msg->sg_copy[i] && bytes <= len) - goto out; - first_sg = i; + /* The start may point into the sg element so we need to also + * account for the headroom. + */ + bytes_sg_total = start - offset + bytes; + if (!msg->sg_copy[i] && bytes_sg_total <= len) + goto out; /* At this point we need to linearize multiple scatterlist * elements or a single shared page. Either way we need to @@ -2327,15 +2335,13 @@ BPF_CALL_4(bpf_msg_pull_data, */ do { copy += sg[i].length; - i++; - if (i == MAX_SKB_FRAGS) - i = 0; - if (bytes < copy) + sk_msg_iter_var(i); + if (bytes_sg_total <= copy) break; } while (i != msg->sg_end); last_sg = i; - if (unlikely(copy < end - start)) + if (unlikely(bytes_sg_total > copy)) return -EINVAL; page = alloc_pages(__GFP_NOWARN | GFP_ATOMIC, get_order(copy)); @@ -2355,9 +2361,7 @@ BPF_CALL_4(bpf_msg_pull_data, sg[i].length = 0; put_page(sg_page(&sg[i])); - i++; - if (i == MAX_SKB_FRAGS) - i = 0; + sk_msg_iter_var(i); } while (i != last_sg); sg[first_sg].length = copy; @@ -2367,11 +2371,15 @@ BPF_CALL_4(bpf_msg_pull_data, * had a single entry though we can just replace it and * be done. Otherwise walk the ring and shift the entries. */ - shift = last_sg - first_sg - 1; + WARN_ON_ONCE(last_sg == first_sg); + shift = last_sg > first_sg ? + last_sg - first_sg - 1 : + MAX_SKB_FRAGS - first_sg + last_sg - 1; if (!shift) goto out; - i = first_sg + 1; + i = first_sg; + sk_msg_iter_var(i); do { int move_from; @@ -2388,15 +2396,13 @@ BPF_CALL_4(bpf_msg_pull_data, sg[move_from].page_link = 0; sg[move_from].offset = 0; - i++; - if (i == MAX_SKB_FRAGS) - i = 0; + sk_msg_iter_var(i); } while (1); msg->sg_end -= shift; if (msg->sg_end < 0) msg->sg_end += MAX_SKB_FRAGS; out: - msg->data = sg_virt(&sg[i]) + start - offset; + msg->data = sg_virt(&sg[first_sg]) + start - offset; msg->data_end = msg->data + bytes; return 0; @@ -7281,7 +7287,7 @@ static u32 sk_reuseport_convert_ctx_access(enum bpf_access_type type, break; case offsetof(struct sk_reuseport_md, ip_protocol): - BUILD_BUG_ON(hweight_long(SK_FL_PROTO_MASK) != BITS_PER_BYTE); + BUILD_BUG_ON(HWEIGHT32(SK_FL_PROTO_MASK) != BITS_PER_BYTE); SK_REUSEPORT_LOAD_SK_FIELD_SIZE_OFF(__sk_flags_offset, BPF_W, 0); *insn++ = BPF_ALU32_IMM(BPF_AND, si->dst_reg, SK_FL_PROTO_MASK); |