diff options
Diffstat (limited to 'net/core/filter.c')
-rw-r--r-- | net/core/filter.c | 819 |
1 files changed, 574 insertions, 245 deletions
diff --git a/net/core/filter.c b/net/core/filter.c index c25eb36f1320..35c6933c2622 100644 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -38,6 +38,7 @@ #include <net/protocol.h> #include <net/netlink.h> #include <linux/skbuff.h> +#include <linux/skmsg.h> #include <net/sock.h> #include <net/flow_dissector.h> #include <linux/errno.h> @@ -58,13 +59,17 @@ #include <net/busy_poll.h> #include <net/tcp.h> #include <net/xfrm.h> +#include <net/udp.h> #include <linux/bpf_trace.h> #include <net/xdp_sock.h> #include <linux/inetdevice.h> +#include <net/inet_hashtables.h> +#include <net/inet6_hashtables.h> #include <net/ip_fib.h> #include <net/flow.h> #include <net/arp.h> #include <net/ipv6.h> +#include <net/net_namespace.h> #include <linux/seg6_local.h> #include <net/seg6.h> #include <net/seg6_local.h> @@ -2138,123 +2143,7 @@ static const struct bpf_func_proto bpf_redirect_proto = { .arg2_type = ARG_ANYTHING, }; -BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, - struct bpf_map *, map, void *, key, u64, flags) -{ - struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); - - /* If user passes invalid input drop the packet. */ - if (unlikely(flags & ~(BPF_F_INGRESS))) - return SK_DROP; - - tcb->bpf.flags = flags; - tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key); - if (!tcb->bpf.sk_redir) - return SK_DROP; - - return SK_PASS; -} - -static const struct bpf_func_proto bpf_sk_redirect_hash_proto = { - .func = bpf_sk_redirect_hash, - .gpl_only = false, - .ret_type = RET_INTEGER, - .arg1_type = ARG_PTR_TO_CTX, - .arg2_type = ARG_CONST_MAP_PTR, - .arg3_type = ARG_PTR_TO_MAP_KEY, - .arg4_type = ARG_ANYTHING, -}; - -BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, - struct bpf_map *, map, u32, key, u64, flags) -{ - struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); - - /* If user passes invalid input drop the packet. */ - if (unlikely(flags & ~(BPF_F_INGRESS))) - return SK_DROP; - - tcb->bpf.flags = flags; - tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key); - if (!tcb->bpf.sk_redir) - return SK_DROP; - - return SK_PASS; -} - -struct sock *do_sk_redirect_map(struct sk_buff *skb) -{ - struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); - - return tcb->bpf.sk_redir; -} - -static const struct bpf_func_proto bpf_sk_redirect_map_proto = { - .func = bpf_sk_redirect_map, - .gpl_only = false, - .ret_type = RET_INTEGER, - .arg1_type = ARG_PTR_TO_CTX, - .arg2_type = ARG_CONST_MAP_PTR, - .arg3_type = ARG_ANYTHING, - .arg4_type = ARG_ANYTHING, -}; - -BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg_buff *, msg, - struct bpf_map *, map, void *, key, u64, flags) -{ - /* If user passes invalid input drop the packet. */ - if (unlikely(flags & ~(BPF_F_INGRESS))) - return SK_DROP; - - msg->flags = flags; - msg->sk_redir = __sock_hash_lookup_elem(map, key); - if (!msg->sk_redir) - return SK_DROP; - - return SK_PASS; -} - -static const struct bpf_func_proto bpf_msg_redirect_hash_proto = { - .func = bpf_msg_redirect_hash, - .gpl_only = false, - .ret_type = RET_INTEGER, - .arg1_type = ARG_PTR_TO_CTX, - .arg2_type = ARG_CONST_MAP_PTR, - .arg3_type = ARG_PTR_TO_MAP_KEY, - .arg4_type = ARG_ANYTHING, -}; - -BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg_buff *, msg, - struct bpf_map *, map, u32, key, u64, flags) -{ - /* If user passes invalid input drop the packet. */ - if (unlikely(flags & ~(BPF_F_INGRESS))) - return SK_DROP; - - msg->flags = flags; - msg->sk_redir = __sock_map_lookup_elem(map, key); - if (!msg->sk_redir) - return SK_DROP; - - return SK_PASS; -} - -struct sock *do_msg_redirect_map(struct sk_msg_buff *msg) -{ - return msg->sk_redir; -} - -static const struct bpf_func_proto bpf_msg_redirect_map_proto = { - .func = bpf_msg_redirect_map, - .gpl_only = false, - .ret_type = RET_INTEGER, - .arg1_type = ARG_PTR_TO_CTX, - .arg2_type = ARG_CONST_MAP_PTR, - .arg3_type = ARG_ANYTHING, - .arg4_type = ARG_ANYTHING, -}; - -BPF_CALL_2(bpf_msg_apply_bytes, struct sk_msg_buff *, msg, u32, bytes) +BPF_CALL_2(bpf_msg_apply_bytes, struct sk_msg *, msg, u32, bytes) { msg->apply_bytes = bytes; return 0; @@ -2268,7 +2157,7 @@ static const struct bpf_func_proto bpf_msg_apply_bytes_proto = { .arg2_type = ARG_ANYTHING, }; -BPF_CALL_2(bpf_msg_cork_bytes, struct sk_msg_buff *, msg, u32, bytes) +BPF_CALL_2(bpf_msg_cork_bytes, struct sk_msg *, msg, u32, bytes) { msg->cork_bytes = bytes; return 0; @@ -2282,39 +2171,39 @@ static const struct bpf_func_proto bpf_msg_cork_bytes_proto = { .arg2_type = ARG_ANYTHING, }; -BPF_CALL_4(bpf_msg_pull_data, - struct sk_msg_buff *, msg, u32, start, u32, end, u64, flags) +BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start, + u32, end, u64, flags) { - unsigned int len = 0, offset = 0, copy = 0; - struct scatterlist *sg = msg->sg_data; - int first_sg, last_sg, i, shift; - unsigned char *p, *to, *from; - int bytes = end - start; + u32 len = 0, offset = 0, copy = 0, poffset = 0, bytes = end - start; + u32 first_sge, last_sge, i, shift, bytes_sg_total; + struct scatterlist *sge; + u8 *raw, *to, *from; struct page *page; if (unlikely(flags || end <= start)) return -EINVAL; /* First find the starting scatterlist element */ - i = msg->sg_start; + i = msg->sg.start; do { - len = sg[i].length; - offset += len; + len = sk_msg_elem(msg, i)->length; if (start < offset + len) break; - i++; - if (i == MAX_SKB_FRAGS) - i = 0; - } while (i != msg->sg_end); + offset += len; + sk_msg_iter_var_next(i); + } while (i != msg->sg.end); if (unlikely(start >= offset + len)) return -EINVAL; - if (!msg->sg_copy[i] && bytes <= len) + first_sge = i; + /* The start may point into the sg element so we need to also + * account for the headroom. + */ + bytes_sg_total = start - offset + bytes; + if (!msg->sg.copy[i] && bytes_sg_total <= len) goto out; - first_sg = i; - /* At this point we need to linearize multiple scatterlist * elements or a single shared page. Either way we need to * copy into a linear buffer exclusively owned by BPF. Then @@ -2326,79 +2215,75 @@ BPF_CALL_4(bpf_msg_pull_data, * will copy the entire sg entry. */ do { - copy += sg[i].length; - i++; - if (i == MAX_SKB_FRAGS) - i = 0; - if (bytes < copy) + copy += sk_msg_elem(msg, i)->length; + sk_msg_iter_var_next(i); + if (bytes_sg_total <= copy) break; - } while (i != msg->sg_end); - last_sg = i; + } while (i != msg->sg.end); + last_sge = i; - if (unlikely(copy < end - start)) + if (unlikely(bytes_sg_total > copy)) return -EINVAL; - page = alloc_pages(__GFP_NOWARN | GFP_ATOMIC, get_order(copy)); + page = alloc_pages(__GFP_NOWARN | GFP_ATOMIC | __GFP_COMP, + get_order(copy)); if (unlikely(!page)) return -ENOMEM; - p = page_address(page); - offset = 0; - i = first_sg; + raw = page_address(page); + i = first_sge; do { - from = sg_virt(&sg[i]); - len = sg[i].length; - to = p + offset; + sge = sk_msg_elem(msg, i); + from = sg_virt(sge); + len = sge->length; + to = raw + poffset; memcpy(to, from, len); - offset += len; - sg[i].length = 0; - put_page(sg_page(&sg[i])); + poffset += len; + sge->length = 0; + put_page(sg_page(sge)); - i++; - if (i == MAX_SKB_FRAGS) - i = 0; - } while (i != last_sg); + sk_msg_iter_var_next(i); + } while (i != last_sge); - sg[first_sg].length = copy; - sg_set_page(&sg[first_sg], page, copy, 0); + sg_set_page(&msg->sg.data[first_sge], page, copy, 0); /* To repair sg ring we need to shift entries. If we only * had a single entry though we can just replace it and * be done. Otherwise walk the ring and shift the entries. */ - shift = last_sg - first_sg - 1; + WARN_ON_ONCE(last_sge == first_sge); + shift = last_sge > first_sge ? + last_sge - first_sge - 1 : + MAX_SKB_FRAGS - first_sge + last_sge - 1; if (!shift) goto out; - i = first_sg + 1; + i = first_sge; + sk_msg_iter_var_next(i); do { - int move_from; + u32 move_from; - if (i + shift >= MAX_SKB_FRAGS) - move_from = i + shift - MAX_SKB_FRAGS; + if (i + shift >= MAX_MSG_FRAGS) + move_from = i + shift - MAX_MSG_FRAGS; else move_from = i + shift; - - if (move_from == msg->sg_end) + if (move_from == msg->sg.end) break; - sg[i] = sg[move_from]; - sg[move_from].length = 0; - sg[move_from].page_link = 0; - sg[move_from].offset = 0; - - i++; - if (i == MAX_SKB_FRAGS) - i = 0; + msg->sg.data[i] = msg->sg.data[move_from]; + msg->sg.data[move_from].length = 0; + msg->sg.data[move_from].page_link = 0; + msg->sg.data[move_from].offset = 0; + sk_msg_iter_var_next(i); } while (1); - msg->sg_end -= shift; - if (msg->sg_end < 0) - msg->sg_end += MAX_SKB_FRAGS; + + msg->sg.end = msg->sg.end - shift > msg->sg.end ? + msg->sg.end - shift + MAX_MSG_FRAGS : + msg->sg.end - shift; out: - msg->data = sg_virt(&sg[i]) + start - offset; + msg->data = sg_virt(&msg->sg.data[first_sge]) + start - offset; msg->data_end = msg->data + bytes; - return 0; } @@ -2412,6 +2297,137 @@ static const struct bpf_func_proto bpf_msg_pull_data_proto = { .arg4_type = ARG_ANYTHING, }; +BPF_CALL_4(bpf_msg_push_data, struct sk_msg *, msg, u32, start, + u32, len, u64, flags) +{ + struct scatterlist sge, nsge, nnsge, rsge = {0}, *psge; + u32 new, i = 0, l, space, copy = 0, offset = 0; + u8 *raw, *to, *from; + struct page *page; + + if (unlikely(flags)) + return -EINVAL; + + /* First find the starting scatterlist element */ + i = msg->sg.start; + do { + l = sk_msg_elem(msg, i)->length; + + if (start < offset + l) + break; + offset += l; + sk_msg_iter_var_next(i); + } while (i != msg->sg.end); + + if (start >= offset + l) + return -EINVAL; + + space = MAX_MSG_FRAGS - sk_msg_elem_used(msg); + + /* If no space available will fallback to copy, we need at + * least one scatterlist elem available to push data into + * when start aligns to the beginning of an element or two + * when it falls inside an element. We handle the start equals + * offset case because its the common case for inserting a + * header. + */ + if (!space || (space == 1 && start != offset)) + copy = msg->sg.data[i].length; + + page = alloc_pages(__GFP_NOWARN | GFP_ATOMIC | __GFP_COMP, + get_order(copy + len)); + if (unlikely(!page)) + return -ENOMEM; + + if (copy) { + int front, back; + + raw = page_address(page); + + psge = sk_msg_elem(msg, i); + front = start - offset; + back = psge->length - front; + from = sg_virt(psge); + + if (front) + memcpy(raw, from, front); + + if (back) { + from += front; + to = raw + front + len; + + memcpy(to, from, back); + } + + put_page(sg_page(psge)); + } else if (start - offset) { + psge = sk_msg_elem(msg, i); + rsge = sk_msg_elem_cpy(msg, i); + + psge->length = start - offset; + rsge.length -= psge->length; + rsge.offset += start; + + sk_msg_iter_var_next(i); + sg_unmark_end(psge); + sk_msg_iter_next(msg, end); + } + + /* Slot(s) to place newly allocated data */ + new = i; + + /* Shift one or two slots as needed */ + if (!copy) { + sge = sk_msg_elem_cpy(msg, i); + + sk_msg_iter_var_next(i); + sg_unmark_end(&sge); + sk_msg_iter_next(msg, end); + + nsge = sk_msg_elem_cpy(msg, i); + if (rsge.length) { + sk_msg_iter_var_next(i); + nnsge = sk_msg_elem_cpy(msg, i); + } + + while (i != msg->sg.end) { + msg->sg.data[i] = sge; + sge = nsge; + sk_msg_iter_var_next(i); + if (rsge.length) { + nsge = nnsge; + nnsge = sk_msg_elem_cpy(msg, i); + } else { + nsge = sk_msg_elem_cpy(msg, i); + } + } + } + + /* Place newly allocated data buffer */ + sk_mem_charge(msg->sk, len); + msg->sg.size += len; + msg->sg.copy[new] = false; + sg_set_page(&msg->sg.data[new], page, len + copy, 0); + if (rsge.length) { + get_page(sg_page(&rsge)); + sk_msg_iter_var_next(new); + msg->sg.data[new] = rsge; + } + + sk_msg_compute_data_pointers(msg); + return 0; +} + +static const struct bpf_func_proto bpf_msg_push_data_proto = { + .func = bpf_msg_push_data, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_ANYTHING, +}; + BPF_CALL_1(bpf_get_cgroup_classid, const struct sk_buff *, skb) { return task_get_classid(skb); @@ -3170,6 +3186,32 @@ static int __bpf_tx_xdp(struct net_device *dev, return 0; } +static noinline int +xdp_do_redirect_slow(struct net_device *dev, struct xdp_buff *xdp, + struct bpf_prog *xdp_prog, struct bpf_redirect_info *ri) +{ + struct net_device *fwd; + u32 index = ri->ifindex; + int err; + + fwd = dev_get_by_index_rcu(dev_net(dev), index); + ri->ifindex = 0; + if (unlikely(!fwd)) { + err = -EINVAL; + goto err; + } + + err = __bpf_tx_xdp(fwd, NULL, xdp, 0); + if (unlikely(err)) + goto err; + + _trace_xdp_redirect(dev, xdp_prog, index); + return 0; +err: + _trace_xdp_redirect_err(dev, xdp_prog, index, err); + return err; +} + static int __bpf_tx_xdp_map(struct net_device *dev_rx, void *fwd, struct bpf_map *map, struct xdp_buff *xdp, @@ -3182,7 +3224,7 @@ static int __bpf_tx_xdp_map(struct net_device *dev_rx, void *fwd, struct bpf_dtab_netdev *dst = fwd; err = dev_map_enqueue(dst, xdp, dev_rx); - if (err) + if (unlikely(err)) return err; __dev_map_insert_ctx(map, index); break; @@ -3191,7 +3233,7 @@ static int __bpf_tx_xdp_map(struct net_device *dev_rx, void *fwd, struct bpf_cpu_map_entry *rcpu = fwd; err = cpu_map_enqueue(rcpu, xdp, dev_rx); - if (err) + if (unlikely(err)) return err; __cpu_map_insert_ctx(map, index); break; @@ -3232,7 +3274,7 @@ void xdp_do_flush_map(void) } EXPORT_SYMBOL_GPL(xdp_do_flush_map); -static void *__xdp_map_lookup_elem(struct bpf_map *map, u32 index) +static inline void *__xdp_map_lookup_elem(struct bpf_map *map, u32 index) { switch (map->map_type) { case BPF_MAP_TYPE_DEVMAP: @@ -3264,9 +3306,9 @@ void bpf_clear_redirect_map(struct bpf_map *map) } static int xdp_do_redirect_map(struct net_device *dev, struct xdp_buff *xdp, - struct bpf_prog *xdp_prog, struct bpf_map *map) + struct bpf_prog *xdp_prog, struct bpf_map *map, + struct bpf_redirect_info *ri) { - struct bpf_redirect_info *ri = this_cpu_ptr(&bpf_redirect_info); u32 index = ri->ifindex; void *fwd = NULL; int err; @@ -3275,11 +3317,11 @@ static int xdp_do_redirect_map(struct net_device *dev, struct xdp_buff *xdp, WRITE_ONCE(ri->map, NULL); fwd = __xdp_map_lookup_elem(map, index); - if (!fwd) { + if (unlikely(!fwd)) { err = -EINVAL; goto err; } - if (ri->map_to_flush && ri->map_to_flush != map) + if (ri->map_to_flush && unlikely(ri->map_to_flush != map)) xdp_do_flush_map(); err = __bpf_tx_xdp_map(dev, fwd, map, xdp, index); @@ -3299,29 +3341,11 @@ int xdp_do_redirect(struct net_device *dev, struct xdp_buff *xdp, { struct bpf_redirect_info *ri = this_cpu_ptr(&bpf_redirect_info); struct bpf_map *map = READ_ONCE(ri->map); - struct net_device *fwd; - u32 index = ri->ifindex; - int err; - if (map) - return xdp_do_redirect_map(dev, xdp, xdp_prog, map); + if (likely(map)) + return xdp_do_redirect_map(dev, xdp, xdp_prog, map, ri); - fwd = dev_get_by_index_rcu(dev_net(dev), index); - ri->ifindex = 0; - if (unlikely(!fwd)) { - err = -EINVAL; - goto err; - } - - err = __bpf_tx_xdp(fwd, NULL, xdp, 0); - if (unlikely(err)) - goto err; - - _trace_xdp_redirect(dev, xdp_prog, index); - return 0; -err: - _trace_xdp_redirect_err(dev, xdp_prog, index, err); - return err; + return xdp_do_redirect_slow(dev, xdp, xdp_prog, ri); } EXPORT_SYMBOL_GPL(xdp_do_redirect); @@ -3909,8 +3933,8 @@ BPF_CALL_5(bpf_setsockopt, struct bpf_sock_ops_kern *, bpf_sock, sk->sk_userlocks |= SOCK_SNDBUF_LOCK; sk->sk_sndbuf = max_t(int, val * 2, SOCK_MIN_SNDBUF); break; - case SO_MAX_PACING_RATE: - sk->sk_max_pacing_rate = val; + case SO_MAX_PACING_RATE: /* 32bit version */ + sk->sk_max_pacing_rate = (val == ~0U) ? ~0UL : val; sk->sk_pacing_rate = min(sk->sk_pacing_rate, sk->sk_max_pacing_rate); break; @@ -4007,6 +4031,12 @@ BPF_CALL_5(bpf_setsockopt, struct bpf_sock_ops_kern *, bpf_sock, tp->snd_ssthresh = val; } break; + case TCP_SAVE_SYN: + if (val < 0 || val > 1) + ret = -EINVAL; + else + tp->save_syn = val; + break; default: ret = -EINVAL; } @@ -4036,17 +4066,29 @@ BPF_CALL_5(bpf_getsockopt, struct bpf_sock_ops_kern *, bpf_sock, if (!sk_fullsock(sk)) goto err_clear; - #ifdef CONFIG_INET if (level == SOL_TCP && sk->sk_prot->getsockopt == tcp_getsockopt) { - if (optname == TCP_CONGESTION) { - struct inet_connection_sock *icsk = inet_csk(sk); + struct inet_connection_sock *icsk; + struct tcp_sock *tp; + + switch (optname) { + case TCP_CONGESTION: + icsk = inet_csk(sk); if (!icsk->icsk_ca_ops || optlen <= 1) goto err_clear; strncpy(optval, icsk->icsk_ca_ops->name, optlen); optval[optlen - 1] = 0; - } else { + break; + case TCP_SAVED_SYN: + tp = tcp_sk(sk); + + if (optlen <= 0 || !tp->saved_syn || + optlen > tp->saved_syn[0]) + goto err_clear; + memcpy(optval, tp->saved_syn + 1, optlen); + break; + default: goto err_clear; } } else if (level == SOL_IP) { @@ -4781,6 +4823,149 @@ static const struct bpf_func_proto bpf_lwt_seg6_adjust_srh_proto = { }; #endif /* CONFIG_IPV6_SEG6_BPF */ +#ifdef CONFIG_INET +static struct sock *sk_lookup(struct net *net, struct bpf_sock_tuple *tuple, + struct sk_buff *skb, u8 family, u8 proto) +{ + bool refcounted = false; + struct sock *sk = NULL; + int dif = 0; + + if (skb->dev) + dif = skb->dev->ifindex; + + if (family == AF_INET) { + __be32 src4 = tuple->ipv4.saddr; + __be32 dst4 = tuple->ipv4.daddr; + int sdif = inet_sdif(skb); + + if (proto == IPPROTO_TCP) + sk = __inet_lookup(net, &tcp_hashinfo, skb, 0, + src4, tuple->ipv4.sport, + dst4, tuple->ipv4.dport, + dif, sdif, &refcounted); + else + sk = __udp4_lib_lookup(net, src4, tuple->ipv4.sport, + dst4, tuple->ipv4.dport, + dif, sdif, &udp_table, skb); +#if IS_ENABLED(CONFIG_IPV6) + } else { + struct in6_addr *src6 = (struct in6_addr *)&tuple->ipv6.saddr; + struct in6_addr *dst6 = (struct in6_addr *)&tuple->ipv6.daddr; + u16 hnum = ntohs(tuple->ipv6.dport); + int sdif = inet6_sdif(skb); + + if (proto == IPPROTO_TCP) + sk = __inet6_lookup(net, &tcp_hashinfo, skb, 0, + src6, tuple->ipv6.sport, + dst6, hnum, + dif, sdif, &refcounted); + else if (likely(ipv6_bpf_stub)) + sk = ipv6_bpf_stub->udp6_lib_lookup(net, + src6, tuple->ipv6.sport, + dst6, hnum, + dif, sdif, + &udp_table, skb); +#endif + } + + if (unlikely(sk && !refcounted && !sock_flag(sk, SOCK_RCU_FREE))) { + WARN_ONCE(1, "Found non-RCU, unreferenced socket!"); + sk = NULL; + } + return sk; +} + +/* bpf_sk_lookup performs the core lookup for different types of sockets, + * taking a reference on the socket if it doesn't have the flag SOCK_RCU_FREE. + * Returns the socket as an 'unsigned long' to simplify the casting in the + * callers to satisfy BPF_CALL declarations. + */ +static unsigned long +bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len, + u8 proto, u64 netns_id, u64 flags) +{ + struct net *caller_net; + struct sock *sk = NULL; + u8 family = AF_UNSPEC; + struct net *net; + + family = len == sizeof(tuple->ipv4) ? AF_INET : AF_INET6; + if (unlikely(family == AF_UNSPEC || netns_id > U32_MAX || flags)) + goto out; + + if (skb->dev) + caller_net = dev_net(skb->dev); + else + caller_net = sock_net(skb->sk); + if (netns_id) { + net = get_net_ns_by_id(caller_net, netns_id); + if (unlikely(!net)) + goto out; + sk = sk_lookup(net, tuple, skb, family, proto); + put_net(net); + } else { + net = caller_net; + sk = sk_lookup(net, tuple, skb, family, proto); + } + + if (sk) + sk = sk_to_full_sk(sk); +out: + return (unsigned long) sk; +} + +BPF_CALL_5(bpf_sk_lookup_tcp, struct sk_buff *, skb, + struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags) +{ + return bpf_sk_lookup(skb, tuple, len, IPPROTO_TCP, netns_id, flags); +} + +static const struct bpf_func_proto bpf_sk_lookup_tcp_proto = { + .func = bpf_sk_lookup_tcp, + .gpl_only = false, + .pkt_access = true, + .ret_type = RET_PTR_TO_SOCKET_OR_NULL, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, + .arg5_type = ARG_ANYTHING, +}; + +BPF_CALL_5(bpf_sk_lookup_udp, struct sk_buff *, skb, + struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags) +{ + return bpf_sk_lookup(skb, tuple, len, IPPROTO_UDP, netns_id, flags); +} + +static const struct bpf_func_proto bpf_sk_lookup_udp_proto = { + .func = bpf_sk_lookup_udp, + .gpl_only = false, + .pkt_access = true, + .ret_type = RET_PTR_TO_SOCKET_OR_NULL, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, + .arg5_type = ARG_ANYTHING, +}; + +BPF_CALL_1(bpf_sk_release, struct sock *, sk) +{ + if (!sock_flag(sk, SOCK_RCU_FREE)) + sock_gen_put(sk); + return 0; +} + +static const struct bpf_func_proto bpf_sk_release_proto = { + .func = bpf_sk_release, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_SOCKET, +}; +#endif /* CONFIG_INET */ + bool bpf_helper_changes_pkt_data(void *func) { if (func == bpf_skb_vlan_push || @@ -4800,6 +4985,7 @@ bool bpf_helper_changes_pkt_data(void *func) func == bpf_xdp_adjust_head || func == bpf_xdp_adjust_meta || func == bpf_msg_pull_data || + func == bpf_msg_push_data || func == bpf_xdp_adjust_tail || #if IS_ENABLED(CONFIG_IPV6_SEG6_BPF) func == bpf_lwt_seg6_store_bytes || @@ -4822,6 +5008,12 @@ bpf_base_func_proto(enum bpf_func_id func_id) return &bpf_map_update_elem_proto; case BPF_FUNC_map_delete_elem: return &bpf_map_delete_elem_proto; + case BPF_FUNC_map_push_elem: + return &bpf_map_push_elem_proto; + case BPF_FUNC_map_pop_elem: + return &bpf_map_pop_elem_proto; + case BPF_FUNC_map_peek_elem: + return &bpf_map_peek_elem_proto; case BPF_FUNC_get_prandom_u32: return &bpf_get_prandom_u32_proto; case BPF_FUNC_get_smp_processor_id: @@ -4987,6 +5179,14 @@ tc_cls_act_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) case BPF_FUNC_skb_ancestor_cgroup_id: return &bpf_skb_ancestor_cgroup_id_proto; #endif +#ifdef CONFIG_INET + case BPF_FUNC_sk_lookup_tcp: + return &bpf_sk_lookup_tcp_proto; + case BPF_FUNC_sk_lookup_udp: + return &bpf_sk_lookup_udp_proto; + case BPF_FUNC_sk_release: + return &bpf_sk_release_proto; +#endif default: return bpf_base_func_proto(func_id); } @@ -5019,6 +5219,9 @@ xdp_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) } } +const struct bpf_func_proto bpf_sock_map_update_proto __weak; +const struct bpf_func_proto bpf_sock_hash_update_proto __weak; + static const struct bpf_func_proto * sock_ops_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) { @@ -5042,6 +5245,9 @@ sock_ops_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) } } +const struct bpf_func_proto bpf_msg_redirect_map_proto __weak; +const struct bpf_func_proto bpf_msg_redirect_hash_proto __weak; + static const struct bpf_func_proto * sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) { @@ -5056,6 +5262,8 @@ sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) return &bpf_msg_cork_bytes_proto; case BPF_FUNC_msg_pull_data: return &bpf_msg_pull_data_proto; + case BPF_FUNC_msg_push_data: + return &bpf_msg_push_data_proto; case BPF_FUNC_get_local_storage: return &bpf_get_local_storage_proto; default: @@ -5063,6 +5271,9 @@ sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) } } +const struct bpf_func_proto bpf_sk_redirect_map_proto __weak; +const struct bpf_func_proto bpf_sk_redirect_hash_proto __weak; + static const struct bpf_func_proto * sk_skb_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) { @@ -5087,6 +5298,25 @@ sk_skb_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) return &bpf_sk_redirect_hash_proto; case BPF_FUNC_get_local_storage: return &bpf_get_local_storage_proto; +#ifdef CONFIG_INET + case BPF_FUNC_sk_lookup_tcp: + return &bpf_sk_lookup_tcp_proto; + case BPF_FUNC_sk_lookup_udp: + return &bpf_sk_lookup_udp_proto; + case BPF_FUNC_sk_release: + return &bpf_sk_release_proto; +#endif + default: + return bpf_base_func_proto(func_id); + } +} + +static const struct bpf_func_proto * +flow_dissector_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) +{ + switch (func_id) { + case BPF_FUNC_skb_load_bytes: + return &bpf_skb_load_bytes_proto; default: return bpf_base_func_proto(func_id); } @@ -5210,6 +5440,10 @@ static bool bpf_skb_is_valid_access(int off, int size, enum bpf_access_type type if (size != size_default) return false; break; + case bpf_ctx_range(struct __sk_buff, flow_keys): + if (size != sizeof(struct bpf_flow_keys *)) + return false; + break; default: /* Only narrow read access allowed for now. */ if (type == BPF_WRITE) { @@ -5235,6 +5469,7 @@ static bool sk_filter_is_valid_access(int off, int size, case bpf_ctx_range(struct __sk_buff, data): case bpf_ctx_range(struct __sk_buff, data_meta): case bpf_ctx_range(struct __sk_buff, data_end): + case bpf_ctx_range(struct __sk_buff, flow_keys): case bpf_ctx_range_till(struct __sk_buff, family, local_port): return false; } @@ -5251,6 +5486,40 @@ static bool sk_filter_is_valid_access(int off, int size, return bpf_skb_is_valid_access(off, size, type, prog, info); } +static bool cg_skb_is_valid_access(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + switch (off) { + case bpf_ctx_range(struct __sk_buff, tc_classid): + case bpf_ctx_range(struct __sk_buff, data_meta): + case bpf_ctx_range(struct __sk_buff, flow_keys): + return false; + } + if (type == BPF_WRITE) { + switch (off) { + case bpf_ctx_range(struct __sk_buff, mark): + case bpf_ctx_range(struct __sk_buff, priority): + case bpf_ctx_range_till(struct __sk_buff, cb[0], cb[4]): + break; + default: + return false; + } + } + + switch (off) { + case bpf_ctx_range(struct __sk_buff, data): + info->reg_type = PTR_TO_PACKET; + break; + case bpf_ctx_range(struct __sk_buff, data_end): + info->reg_type = PTR_TO_PACKET_END; + break; + } + + return bpf_skb_is_valid_access(off, size, type, prog, info); +} + static bool lwt_is_valid_access(int off, int size, enum bpf_access_type type, const struct bpf_prog *prog, @@ -5260,6 +5529,7 @@ static bool lwt_is_valid_access(int off, int size, case bpf_ctx_range(struct __sk_buff, tc_classid): case bpf_ctx_range_till(struct __sk_buff, family, local_port): case bpf_ctx_range(struct __sk_buff, data_meta): + case bpf_ctx_range(struct __sk_buff, flow_keys): return false; } @@ -5345,23 +5615,29 @@ static bool __sock_filter_check_size(int off, int size, return size == size_default; } -static bool sock_filter_is_valid_access(int off, int size, - enum bpf_access_type type, - const struct bpf_prog *prog, - struct bpf_insn_access_aux *info) +bool bpf_sock_is_valid_access(int off, int size, enum bpf_access_type type, + struct bpf_insn_access_aux *info) { if (off < 0 || off >= sizeof(struct bpf_sock)) return false; if (off % size != 0) return false; - if (!__sock_filter_check_attach_type(off, type, - prog->expected_attach_type)) - return false; if (!__sock_filter_check_size(off, size, info)) return false; return true; } +static bool sock_filter_is_valid_access(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + if (!bpf_sock_is_valid_access(off, size, type, info)) + return false; + return __sock_filter_check_attach_type(off, type, + prog->expected_attach_type); +} + static int bpf_unclone_prologue(struct bpf_insn *insn_buf, bool direct_write, const struct bpf_prog *prog, int drop_verdict) { @@ -5470,6 +5746,7 @@ static bool tc_cls_act_is_valid_access(int off, int size, case bpf_ctx_range(struct __sk_buff, data_end): info->reg_type = PTR_TO_PACKET_END; break; + case bpf_ctx_range(struct __sk_buff, flow_keys): case bpf_ctx_range_till(struct __sk_buff, family, local_port): return false; } @@ -5671,6 +5948,7 @@ static bool sk_skb_is_valid_access(int off, int size, switch (off) { case bpf_ctx_range(struct __sk_buff, tc_classid): case bpf_ctx_range(struct __sk_buff, data_meta): + case bpf_ctx_range(struct __sk_buff, flow_keys): return false; } @@ -5730,6 +6008,39 @@ static bool sk_msg_is_valid_access(int off, int size, return true; } +static bool flow_dissector_is_valid_access(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + if (type == BPF_WRITE) { + switch (off) { + case bpf_ctx_range_till(struct __sk_buff, cb[0], cb[4]): + break; + default: + return false; + } + } + + switch (off) { + case bpf_ctx_range(struct __sk_buff, data): + info->reg_type = PTR_TO_PACKET; + break; + case bpf_ctx_range(struct __sk_buff, data_end): + info->reg_type = PTR_TO_PACKET_END; + break; + case bpf_ctx_range(struct __sk_buff, flow_keys): + info->reg_type = PTR_TO_FLOW_KEYS; + break; + case bpf_ctx_range(struct __sk_buff, tc_classid): + case bpf_ctx_range(struct __sk_buff, data_meta): + case bpf_ctx_range_till(struct __sk_buff, family, local_port): + return false; + } + + return bpf_skb_is_valid_access(off, size, type, prog, info); +} + static u32 bpf_convert_ctx_access(enum bpf_access_type type, const struct bpf_insn *si, struct bpf_insn *insn_buf, @@ -6024,15 +6335,24 @@ static u32 bpf_convert_ctx_access(enum bpf_access_type type, bpf_target_off(struct sock_common, skc_num, 2, target_size)); break; + + case offsetof(struct __sk_buff, flow_keys): + off = si->off; + off -= offsetof(struct __sk_buff, flow_keys); + off += offsetof(struct sk_buff, cb); + off += offsetof(struct qdisc_skb_cb, flow_keys); + *insn++ = BPF_LDX_MEM(BPF_SIZEOF(void *), si->dst_reg, + si->src_reg, off); + break; } return insn - insn_buf; } -static u32 sock_filter_convert_ctx_access(enum bpf_access_type type, - const struct bpf_insn *si, - struct bpf_insn *insn_buf, - struct bpf_prog *prog, u32 *target_size) +u32 bpf_sock_convert_ctx_access(enum bpf_access_type type, + const struct bpf_insn *si, + struct bpf_insn *insn_buf, + struct bpf_prog *prog, u32 *target_size) { struct bpf_insn *insn = insn_buf; int off; @@ -6742,22 +7062,22 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type, switch (si->off) { case offsetof(struct sk_msg_md, data): - *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg_buff, data), + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg, data), si->dst_reg, si->src_reg, - offsetof(struct sk_msg_buff, data)); + offsetof(struct sk_msg, data)); break; case offsetof(struct sk_msg_md, data_end): - *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg_buff, data_end), + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg, data_end), si->dst_reg, si->src_reg, - offsetof(struct sk_msg_buff, data_end)); + offsetof(struct sk_msg, data_end)); break; case offsetof(struct sk_msg_md, family): BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_family) != 2); *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( - struct sk_msg_buff, sk), + struct sk_msg, sk), si->dst_reg, si->src_reg, - offsetof(struct sk_msg_buff, sk)); + offsetof(struct sk_msg, sk)); *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg, offsetof(struct sock_common, skc_family)); break; @@ -6766,9 +7086,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type, BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_daddr) != 4); *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( - struct sk_msg_buff, sk), + struct sk_msg, sk), si->dst_reg, si->src_reg, - offsetof(struct sk_msg_buff, sk)); + offsetof(struct sk_msg, sk)); *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, offsetof(struct sock_common, skc_daddr)); break; @@ -6778,9 +7098,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type, skc_rcv_saddr) != 4); *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( - struct sk_msg_buff, sk), + struct sk_msg, sk), si->dst_reg, si->src_reg, - offsetof(struct sk_msg_buff, sk)); + offsetof(struct sk_msg, sk)); *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, offsetof(struct sock_common, skc_rcv_saddr)); @@ -6795,9 +7115,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type, off = si->off; off -= offsetof(struct sk_msg_md, remote_ip6[0]); *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( - struct sk_msg_buff, sk), + struct sk_msg, sk), si->dst_reg, si->src_reg, - offsetof(struct sk_msg_buff, sk)); + offsetof(struct sk_msg, sk)); *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, offsetof(struct sock_common, skc_v6_daddr.s6_addr32[0]) + @@ -6816,9 +7136,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type, off = si->off; off -= offsetof(struct sk_msg_md, local_ip6[0]); *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( - struct sk_msg_buff, sk), + struct sk_msg, sk), si->dst_reg, si->src_reg, - offsetof(struct sk_msg_buff, sk)); + offsetof(struct sk_msg, sk)); *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, offsetof(struct sock_common, skc_v6_rcv_saddr.s6_addr32[0]) + @@ -6832,9 +7152,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type, BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_dport) != 2); *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( - struct sk_msg_buff, sk), + struct sk_msg, sk), si->dst_reg, si->src_reg, - offsetof(struct sk_msg_buff, sk)); + offsetof(struct sk_msg, sk)); *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg, offsetof(struct sock_common, skc_dport)); #ifndef __BIG_ENDIAN_BITFIELD @@ -6846,9 +7166,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type, BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_num) != 2); *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( - struct sk_msg_buff, sk), + struct sk_msg, sk), si->dst_reg, si->src_reg, - offsetof(struct sk_msg_buff, sk)); + offsetof(struct sk_msg, sk)); *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg, offsetof(struct sock_common, skc_num)); break; @@ -6892,7 +7212,7 @@ const struct bpf_prog_ops xdp_prog_ops = { const struct bpf_verifier_ops cg_skb_verifier_ops = { .get_func_proto = cg_skb_func_proto, - .is_valid_access = sk_filter_is_valid_access, + .is_valid_access = cg_skb_is_valid_access, .convert_ctx_access = bpf_convert_ctx_access, }; @@ -6944,7 +7264,7 @@ const struct bpf_prog_ops lwt_seg6local_prog_ops = { const struct bpf_verifier_ops cg_sock_verifier_ops = { .get_func_proto = sock_filter_func_proto, .is_valid_access = sock_filter_is_valid_access, - .convert_ctx_access = sock_filter_convert_ctx_access, + .convert_ctx_access = bpf_sock_convert_ctx_access, }; const struct bpf_prog_ops cg_sock_prog_ops = { @@ -6987,6 +7307,15 @@ const struct bpf_verifier_ops sk_msg_verifier_ops = { const struct bpf_prog_ops sk_msg_prog_ops = { }; +const struct bpf_verifier_ops flow_dissector_verifier_ops = { + .get_func_proto = flow_dissector_func_proto, + .is_valid_access = flow_dissector_is_valid_access, + .convert_ctx_access = bpf_convert_ctx_access, +}; + +const struct bpf_prog_ops flow_dissector_prog_ops = { +}; + int sk_detach_filter(struct sock *sk) { int ret = -ENOENT; @@ -7281,7 +7610,7 @@ static u32 sk_reuseport_convert_ctx_access(enum bpf_access_type type, break; case offsetof(struct sk_reuseport_md, ip_protocol): - BUILD_BUG_ON(hweight_long(SK_FL_PROTO_MASK) != BITS_PER_BYTE); + BUILD_BUG_ON(HWEIGHT32(SK_FL_PROTO_MASK) != BITS_PER_BYTE); SK_REUSEPORT_LOAD_SK_FIELD_SIZE_OFF(__sk_flags_offset, BPF_W, 0); *insn++ = BPF_ALU32_IMM(BPF_AND, si->dst_reg, SK_FL_PROTO_MASK); |