aboutsummaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to 'net')
-rw-r--r--net/Kconfig11
-rw-r--r--net/core/Makefile2
-rw-r--r--net/core/filter.c270
-rw-r--r--net/core/skmsg.c763
-rw-r--r--net/core/sock_map.c1002
-rw-r--r--net/ipv4/Makefile1
-rw-r--r--net/ipv4/tcp_bpf.c655
-rw-r--r--net/strparser/Kconfig4
8 files changed, 2512 insertions, 196 deletions
diff --git a/net/Kconfig b/net/Kconfig
index 228dfa382eec..f235edb593ba 100644
--- a/net/Kconfig
+++ b/net/Kconfig
@@ -300,8 +300,11 @@ config BPF_JIT
config BPF_STREAM_PARSER
bool "enable BPF STREAM_PARSER"
+ depends on INET
depends on BPF_SYSCALL
+ depends on CGROUP_BPF
select STREAM_PARSER
+ select NET_SOCK_MSG
---help---
Enabling this allows a stream parser to be used with
BPF_MAP_TYPE_SOCKMAP.
@@ -413,6 +416,14 @@ config GRO_CELLS
config SOCK_VALIDATE_XMIT
bool
+config NET_SOCK_MSG
+ bool
+ default n
+ help
+ The NET_SOCK_MSG provides a framework for plain sockets (e.g. TCP) or
+ ULPs (upper layer modules, e.g. TLS) to process L7 application data
+ with the help of BPF programs.
+
config NET_DEVLINK
tristate "Network physical/parent device Netlink interface"
help
diff --git a/net/core/Makefile b/net/core/Makefile
index 80175e6a2eb8..fccd31e0e7f7 100644
--- a/net/core/Makefile
+++ b/net/core/Makefile
@@ -16,6 +16,7 @@ obj-y += dev.o ethtool.o dev_addr_lists.o dst.o netevent.o \
obj-y += net-sysfs.o
obj-$(CONFIG_PAGE_POOL) += page_pool.o
obj-$(CONFIG_PROC_FS) += net-procfs.o
+obj-$(CONFIG_NET_SOCK_MSG) += skmsg.o
obj-$(CONFIG_NET_PKTGEN) += pktgen.o
obj-$(CONFIG_NETPOLL) += netpoll.o
obj-$(CONFIG_FIB_RULES) += fib_rules.o
@@ -27,6 +28,7 @@ obj-$(CONFIG_CGROUP_NET_PRIO) += netprio_cgroup.o
obj-$(CONFIG_CGROUP_NET_CLASSID) += netclassid_cgroup.o
obj-$(CONFIG_LWTUNNEL) += lwtunnel.o
obj-$(CONFIG_LWTUNNEL_BPF) += lwt_bpf.o
+obj-$(CONFIG_BPF_STREAM_PARSER) += sock_map.o
obj-$(CONFIG_DST_CACHE) += dst_cache.o
obj-$(CONFIG_HWBM) += hwbm.o
obj-$(CONFIG_NET_DEVLINK) += devlink.o
diff --git a/net/core/filter.c b/net/core/filter.c
index b844761b5d4c..0f5260b04bfe 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -38,6 +38,7 @@
#include <net/protocol.h>
#include <net/netlink.h>
#include <linux/skbuff.h>
+#include <linux/skmsg.h>
#include <net/sock.h>
#include <net/flow_dissector.h>
#include <linux/errno.h>
@@ -2142,123 +2143,7 @@ static const struct bpf_func_proto bpf_redirect_proto = {
.arg2_type = ARG_ANYTHING,
};
-BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb,
- struct bpf_map *, map, void *, key, u64, flags)
-{
- struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
-
- /* If user passes invalid input drop the packet. */
- if (unlikely(flags & ~(BPF_F_INGRESS)))
- return SK_DROP;
-
- tcb->bpf.flags = flags;
- tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key);
- if (!tcb->bpf.sk_redir)
- return SK_DROP;
-
- return SK_PASS;
-}
-
-static const struct bpf_func_proto bpf_sk_redirect_hash_proto = {
- .func = bpf_sk_redirect_hash,
- .gpl_only = false,
- .ret_type = RET_INTEGER,
- .arg1_type = ARG_PTR_TO_CTX,
- .arg2_type = ARG_CONST_MAP_PTR,
- .arg3_type = ARG_PTR_TO_MAP_KEY,
- .arg4_type = ARG_ANYTHING,
-};
-
-BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
- struct bpf_map *, map, u32, key, u64, flags)
-{
- struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
-
- /* If user passes invalid input drop the packet. */
- if (unlikely(flags & ~(BPF_F_INGRESS)))
- return SK_DROP;
-
- tcb->bpf.flags = flags;
- tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key);
- if (!tcb->bpf.sk_redir)
- return SK_DROP;
-
- return SK_PASS;
-}
-
-struct sock *do_sk_redirect_map(struct sk_buff *skb)
-{
- struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
-
- return tcb->bpf.sk_redir;
-}
-
-static const struct bpf_func_proto bpf_sk_redirect_map_proto = {
- .func = bpf_sk_redirect_map,
- .gpl_only = false,
- .ret_type = RET_INTEGER,
- .arg1_type = ARG_PTR_TO_CTX,
- .arg2_type = ARG_CONST_MAP_PTR,
- .arg3_type = ARG_ANYTHING,
- .arg4_type = ARG_ANYTHING,
-};
-
-BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg_buff *, msg,
- struct bpf_map *, map, void *, key, u64, flags)
-{
- /* If user passes invalid input drop the packet. */
- if (unlikely(flags & ~(BPF_F_INGRESS)))
- return SK_DROP;
-
- msg->flags = flags;
- msg->sk_redir = __sock_hash_lookup_elem(map, key);
- if (!msg->sk_redir)
- return SK_DROP;
-
- return SK_PASS;
-}
-
-static const struct bpf_func_proto bpf_msg_redirect_hash_proto = {
- .func = bpf_msg_redirect_hash,
- .gpl_only = false,
- .ret_type = RET_INTEGER,
- .arg1_type = ARG_PTR_TO_CTX,
- .arg2_type = ARG_CONST_MAP_PTR,
- .arg3_type = ARG_PTR_TO_MAP_KEY,
- .arg4_type = ARG_ANYTHING,
-};
-
-BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg_buff *, msg,
- struct bpf_map *, map, u32, key, u64, flags)
-{
- /* If user passes invalid input drop the packet. */
- if (unlikely(flags & ~(BPF_F_INGRESS)))
- return SK_DROP;
-
- msg->flags = flags;
- msg->sk_redir = __sock_map_lookup_elem(map, key);
- if (!msg->sk_redir)
- return SK_DROP;
-
- return SK_PASS;
-}
-
-struct sock *do_msg_redirect_map(struct sk_msg_buff *msg)
-{
- return msg->sk_redir;
-}
-
-static const struct bpf_func_proto bpf_msg_redirect_map_proto = {
- .func = bpf_msg_redirect_map,
- .gpl_only = false,
- .ret_type = RET_INTEGER,
- .arg1_type = ARG_PTR_TO_CTX,
- .arg2_type = ARG_CONST_MAP_PTR,
- .arg3_type = ARG_ANYTHING,
- .arg4_type = ARG_ANYTHING,
-};
-
-BPF_CALL_2(bpf_msg_apply_bytes, struct sk_msg_buff *, msg, u32, bytes)
+BPF_CALL_2(bpf_msg_apply_bytes, struct sk_msg *, msg, u32, bytes)
{
msg->apply_bytes = bytes;
return 0;
@@ -2272,7 +2157,7 @@ static const struct bpf_func_proto bpf_msg_apply_bytes_proto = {
.arg2_type = ARG_ANYTHING,
};
-BPF_CALL_2(bpf_msg_cork_bytes, struct sk_msg_buff *, msg, u32, bytes)
+BPF_CALL_2(bpf_msg_cork_bytes, struct sk_msg *, msg, u32, bytes)
{
msg->cork_bytes = bytes;
return 0;
@@ -2286,45 +2171,37 @@ static const struct bpf_func_proto bpf_msg_cork_bytes_proto = {
.arg2_type = ARG_ANYTHING,
};
-#define sk_msg_iter_var(var) \
- do { \
- var++; \
- if (var == MAX_SKB_FRAGS) \
- var = 0; \
- } while (0)
-
-BPF_CALL_4(bpf_msg_pull_data,
- struct sk_msg_buff *, msg, u32, start, u32, end, u64, flags)
+BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start,
+ u32, end, u64, flags)
{
- unsigned int len = 0, offset = 0, copy = 0, poffset = 0;
- int bytes = end - start, bytes_sg_total;
- struct scatterlist *sg = msg->sg_data;
- int first_sg, last_sg, i, shift;
- unsigned char *p, *to, *from;
+ u32 len = 0, offset = 0, copy = 0, poffset = 0, bytes = end - start;
+ u32 first_sge, last_sge, i, shift, bytes_sg_total;
+ struct scatterlist *sge;
+ u8 *raw, *to, *from;
struct page *page;
if (unlikely(flags || end <= start))
return -EINVAL;
/* First find the starting scatterlist element */
- i = msg->sg_start;
+ i = msg->sg.start;
do {
- len = sg[i].length;
+ len = sk_msg_elem(msg, i)->length;
if (start < offset + len)
break;
offset += len;
- sk_msg_iter_var(i);
- } while (i != msg->sg_end);
+ sk_msg_iter_var_next(i);
+ } while (i != msg->sg.end);
if (unlikely(start >= offset + len))
return -EINVAL;
- first_sg = i;
+ first_sge = i;
/* The start may point into the sg element so we need to also
* account for the headroom.
*/
bytes_sg_total = start - offset + bytes;
- if (!msg->sg_copy[i] && bytes_sg_total <= len)
+ if (!msg->sg.copy[i] && bytes_sg_total <= len)
goto out;
/* At this point we need to linearize multiple scatterlist
@@ -2338,12 +2215,12 @@ BPF_CALL_4(bpf_msg_pull_data,
* will copy the entire sg entry.
*/
do {
- copy += sg[i].length;
- sk_msg_iter_var(i);
+ copy += sk_msg_elem(msg, i)->length;
+ sk_msg_iter_var_next(i);
if (bytes_sg_total <= copy)
break;
- } while (i != msg->sg_end);
- last_sg = i;
+ } while (i != msg->sg.end);
+ last_sge = i;
if (unlikely(bytes_sg_total > copy))
return -EINVAL;
@@ -2352,63 +2229,61 @@ BPF_CALL_4(bpf_msg_pull_data,
get_order(copy));
if (unlikely(!page))
return -ENOMEM;
- p = page_address(page);
- i = first_sg;
+ raw = page_address(page);
+ i = first_sge;
do {
- from = sg_virt(&sg[i]);
- len = sg[i].length;
- to = p + poffset;
+ sge = sk_msg_elem(msg, i);
+ from = sg_virt(sge);
+ len = sge->length;
+ to = raw + poffset;
memcpy(to, from, len);
poffset += len;
- sg[i].length = 0;
- put_page(sg_page(&sg[i]));
+ sge->length = 0;
+ put_page(sg_page(sge));
- sk_msg_iter_var(i);
- } while (i != last_sg);
+ sk_msg_iter_var_next(i);
+ } while (i != last_sge);
- sg[first_sg].length = copy;
- sg_set_page(&sg[first_sg], page, copy, 0);
+ sg_set_page(&msg->sg.data[first_sge], page, copy, 0);
/* To repair sg ring we need to shift entries. If we only
* had a single entry though we can just replace it and
* be done. Otherwise walk the ring and shift the entries.
*/
- WARN_ON_ONCE(last_sg == first_sg);
- shift = last_sg > first_sg ?
- last_sg - first_sg - 1 :
- MAX_SKB_FRAGS - first_sg + last_sg - 1;
+ WARN_ON_ONCE(last_sge == first_sge);
+ shift = last_sge > first_sge ?
+ last_sge - first_sge - 1 :
+ MAX_SKB_FRAGS - first_sge + last_sge - 1;
if (!shift)
goto out;
- i = first_sg;
- sk_msg_iter_var(i);
+ i = first_sge;
+ sk_msg_iter_var_next(i);
do {
- int move_from;
+ u32 move_from;
- if (i + shift >= MAX_SKB_FRAGS)
- move_from = i + shift - MAX_SKB_FRAGS;
+ if (i + shift >= MAX_MSG_FRAGS)
+ move_from = i + shift - MAX_MSG_FRAGS;
else
move_from = i + shift;
-
- if (move_from == msg->sg_end)
+ if (move_from == msg->sg.end)
break;
- sg[i] = sg[move_from];
- sg[move_from].length = 0;
- sg[move_from].page_link = 0;
- sg[move_from].offset = 0;
-
- sk_msg_iter_var(i);
+ msg->sg.data[i] = msg->sg.data[move_from];
+ msg->sg.data[move_from].length = 0;
+ msg->sg.data[move_from].page_link = 0;
+ msg->sg.data[move_from].offset = 0;
+ sk_msg_iter_var_next(i);
} while (1);
- msg->sg_end -= shift;
- if (msg->sg_end < 0)
- msg->sg_end += MAX_SKB_FRAGS;
+
+ msg->sg.end = msg->sg.end - shift > msg->sg.end ?
+ msg->sg.end - shift + MAX_MSG_FRAGS :
+ msg->sg.end - shift;
out:
- msg->data = sg_virt(&sg[first_sg]) + start - offset;
+ msg->data = sg_virt(&msg->sg.data[first_sge]) + start - offset;
msg->data_end = msg->data + bytes;
-
return 0;
}
@@ -5203,6 +5078,9 @@ xdp_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
}
}
+const struct bpf_func_proto bpf_sock_map_update_proto __weak;
+const struct bpf_func_proto bpf_sock_hash_update_proto __weak;
+
static const struct bpf_func_proto *
sock_ops_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
{
@@ -5226,6 +5104,9 @@ sock_ops_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
}
}
+const struct bpf_func_proto bpf_msg_redirect_map_proto __weak;
+const struct bpf_func_proto bpf_msg_redirect_hash_proto __weak;
+
static const struct bpf_func_proto *
sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
{
@@ -5247,6 +5128,9 @@ sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
}
}
+const struct bpf_func_proto bpf_sk_redirect_map_proto __weak;
+const struct bpf_func_proto bpf_sk_redirect_hash_proto __weak;
+
static const struct bpf_func_proto *
sk_skb_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
{
@@ -7001,22 +6885,22 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
switch (si->off) {
case offsetof(struct sk_msg_md, data):
- *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg_buff, data),
+ *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg, data),
si->dst_reg, si->src_reg,
- offsetof(struct sk_msg_buff, data));
+ offsetof(struct sk_msg, data));
break;
case offsetof(struct sk_msg_md, data_end):
- *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg_buff, data_end),
+ *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg, data_end),
si->dst_reg, si->src_reg,
- offsetof(struct sk_msg_buff, data_end));
+ offsetof(struct sk_msg, data_end));
break;
case offsetof(struct sk_msg_md, family):
BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_family) != 2);
*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
- struct sk_msg_buff, sk),
+ struct sk_msg, sk),
si->dst_reg, si->src_reg,
- offsetof(struct sk_msg_buff, sk));
+ offsetof(struct sk_msg, sk));
*insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg,
offsetof(struct sock_common, skc_family));
break;
@@ -7025,9 +6909,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_daddr) != 4);
*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
- struct sk_msg_buff, sk),
+ struct sk_msg, sk),
si->dst_reg, si->src_reg,
- offsetof(struct sk_msg_buff, sk));
+ offsetof(struct sk_msg, sk));
*insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
offsetof(struct sock_common, skc_daddr));
break;
@@ -7037,9 +6921,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
skc_rcv_saddr) != 4);
*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
- struct sk_msg_buff, sk),
+ struct sk_msg, sk),
si->dst_reg, si->src_reg,
- offsetof(struct sk_msg_buff, sk));
+ offsetof(struct sk_msg, sk));
*insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
offsetof(struct sock_common,
skc_rcv_saddr));
@@ -7054,9 +6938,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
off = si->off;
off -= offsetof(struct sk_msg_md, remote_ip6[0]);
*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
- struct sk_msg_buff, sk),
+ struct sk_msg, sk),
si->dst_reg, si->src_reg,
- offsetof(struct sk_msg_buff, sk));
+ offsetof(struct sk_msg, sk));
*insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
offsetof(struct sock_common,
skc_v6_daddr.s6_addr32[0]) +
@@ -7075,9 +6959,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
off = si->off;
off -= offsetof(struct sk_msg_md, local_ip6[0]);
*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
- struct sk_msg_buff, sk),
+ struct sk_msg, sk),
si->dst_reg, si->src_reg,
- offsetof(struct sk_msg_buff, sk));
+ offsetof(struct sk_msg, sk));
*insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
offsetof(struct sock_common,
skc_v6_rcv_saddr.s6_addr32[0]) +
@@ -7091,9 +6975,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_dport) != 2);
*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
- struct sk_msg_buff, sk),
+ struct sk_msg, sk),
si->dst_reg, si->src_reg,
- offsetof(struct sk_msg_buff, sk));
+ offsetof(struct sk_msg, sk));
*insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg,
offsetof(struct sock_common, skc_dport));
#ifndef __BIG_ENDIAN_BITFIELD
@@ -7105,9 +6989,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_num) != 2);
*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
- struct sk_msg_buff, sk),
+ struct sk_msg, sk),
si->dst_reg, si->src_reg,
- offsetof(struct sk_msg_buff, sk));
+ offsetof(struct sk_msg, sk));
*insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg,
offsetof(struct sock_common, skc_num));
break;
diff --git a/net/core/skmsg.c b/net/core/skmsg.c
new file mode 100644
index 000000000000..ae2b281c9c57
--- /dev/null
+++ b/net/core/skmsg.c
@@ -0,0 +1,763 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
+
+#include <linux/skmsg.h>
+#include <linux/skbuff.h>
+#include <linux/scatterlist.h>
+
+#include <net/sock.h>
+#include <net/tcp.h>
+
+static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce)
+{
+ if (msg->sg.end > msg->sg.start &&
+ elem_first_coalesce < msg->sg.end)
+ return true;
+
+ if (msg->sg.end < msg->sg.start &&
+ (elem_first_coalesce > msg->sg.start ||
+ elem_first_coalesce < msg->sg.end))
+ return true;
+
+ return false;
+}
+
+int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
+ int elem_first_coalesce)
+{
+ struct page_frag *pfrag = sk_page_frag(sk);
+ int ret = 0;
+
+ len -= msg->sg.size;
+ while (len > 0) {
+ struct scatterlist *sge;
+ u32 orig_offset;
+ int use, i;
+
+ if (!sk_page_frag_refill(sk, pfrag))
+ return -ENOMEM;
+
+ orig_offset = pfrag->offset;
+ use = min_t(int, len, pfrag->size - orig_offset);
+ if (!sk_wmem_schedule(sk, use))
+ return -ENOMEM;
+
+ i = msg->sg.end;
+ sk_msg_iter_var_prev(i);
+ sge = &msg->sg.data[i];
+
+ if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) &&
+ sg_page(sge) == pfrag->page &&
+ sge->offset + sge->length == orig_offset) {
+ sge->length += use;
+ } else {
+ if (sk_msg_full(msg)) {
+ ret = -ENOSPC;
+ break;
+ }
+
+ sge = &msg->sg.data[msg->sg.end];
+ sg_unmark_end(sge);
+ sg_set_page(sge, pfrag->page, use, orig_offset);
+ get_page(pfrag->page);
+ sk_msg_iter_next(msg, end);
+ }
+
+ sk_mem_charge(sk, use);
+ msg->sg.size += use;
+ pfrag->offset += use;
+ len -= use;
+ }
+
+ return ret;
+}
+EXPORT_SYMBOL_GPL(sk_msg_alloc);
+
+void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes)
+{
+ int i = msg->sg.start;
+
+ do {
+ struct scatterlist *sge = sk_msg_elem(msg, i);
+
+ if (bytes < sge->length) {
+ sge->length -= bytes;
+ sge->offset += bytes;
+ sk_mem_uncharge(sk, bytes);
+ break;
+ }
+
+ sk_mem_uncharge(sk, sge->length);
+ bytes -= sge->length;
+ sge->length = 0;
+ sge->offset = 0;
+ sk_msg_iter_var_next(i);
+ } while (bytes && i != msg->sg.end);
+ msg->sg.start = i;
+}
+EXPORT_SYMBOL_GPL(sk_msg_return_zero);
+
+void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes)
+{
+ int i = msg->sg.start;
+
+ do {
+ struct scatterlist *sge = &msg->sg.data[i];
+ int uncharge = (bytes < sge->length) ? bytes : sge->length;
+
+ sk_mem_uncharge(sk, uncharge);
+ bytes -= uncharge;
+ sk_msg_iter_var_next(i);
+ } while (i != msg->sg.end);
+}
+EXPORT_SYMBOL_GPL(sk_msg_return);
+
+static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i,
+ bool charge)
+{
+ struct scatterlist *sge = sk_msg_elem(msg, i);
+ u32 len = sge->length;
+
+ if (charge)
+ sk_mem_uncharge(sk, len);
+ if (!msg->skb)
+ put_page(sg_page(sge));
+ memset(sge, 0, sizeof(*sge));
+ return len;
+}
+
+static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i,
+ bool charge)
+{
+ struct scatterlist *sge = sk_msg_elem(msg, i);
+ int freed = 0;
+
+ while (msg->sg.size) {
+ msg->sg.size -= sge->length;
+ freed += sk_msg_free_elem(sk, msg, i, charge);
+ sk_msg_iter_var_next(i);
+ sk_msg_check_to_free(msg, i, msg->sg.size);
+ sge = sk_msg_elem(msg, i);
+ }
+ if (msg->skb)
+ consume_skb(msg->skb);
+ sk_msg_init(msg);
+ return freed;
+}
+
+int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg)
+{
+ return __sk_msg_free(sk, msg, msg->sg.start, false);
+}
+EXPORT_SYMBOL_GPL(sk_msg_free_nocharge);
+
+int sk_msg_free(struct sock *sk, struct sk_msg *msg)
+{
+ return __sk_msg_free(sk, msg, msg->sg.start, true);
+}
+EXPORT_SYMBOL_GPL(sk_msg_free);
+
+static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg,
+ u32 bytes, bool charge)
+{
+ struct scatterlist *sge;
+ u32 i = msg->sg.start;
+
+ while (bytes) {
+ sge = sk_msg_elem(msg, i);
+ if (!sge->length)
+ break;
+ if (bytes < sge->length) {
+ if (charge)
+ sk_mem_uncharge(sk, bytes);
+ sge->length -= bytes;
+ sge->offset += bytes;
+ msg->sg.size -= bytes;
+ break;
+ }
+
+ msg->sg.size -= sge->length;
+ bytes -= sge->length;
+ sk_msg_free_elem(sk, msg, i, charge);
+ sk_msg_iter_var_next(i);
+ sk_msg_check_to_free(msg, i, bytes);
+ }
+ msg->sg.start = i;
+}
+
+void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes)
+{
+ __sk_msg_free_partial(sk, msg, bytes, true);
+}
+EXPORT_SYMBOL_GPL(sk_msg_free_partial);
+
+void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
+ u32 bytes)
+{
+ __sk_msg_free_partial(sk, msg, bytes, false);
+}
+
+void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len)
+{
+ int trim = msg->sg.size - len;
+ u32 i = msg->sg.end;
+
+ if (trim <= 0) {
+ WARN_ON(trim < 0);
+ return;
+ }
+
+ sk_msg_iter_var_prev(i);
+ msg->sg.size = len;
+ while (msg->sg.data[i].length &&
+ trim >= msg->sg.data[i].length) {
+ trim -= msg->sg.data[i].length;
+ sk_msg_free_elem(sk, msg, i, true);
+ sk_msg_iter_var_prev(i);
+ if (!trim)
+ goto out;
+ }
+
+ msg->sg.data[i].length -= trim;
+ sk_mem_uncharge(sk, trim);
+out:
+ /* If we trim data before curr pointer update copybreak and current
+ * so that any future copy operations start at new copy location.
+ * However trimed data that has not yet been used in a copy op
+ * does not require an update.
+ */
+ if (msg->sg.curr >= i) {
+ msg->sg.curr = i;
+ msg->sg.copybreak = msg->sg.data[i].length;
+ }
+ sk_msg_iter_var_next(i);
+ msg->sg.end = i;
+}
+EXPORT_SYMBOL_GPL(sk_msg_trim);
+
+int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
+ struct sk_msg *msg, u32 bytes)
+{
+ int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg);
+ const int to_max_pages = MAX_MSG_FRAGS;
+ struct page *pages[MAX_MSG_FRAGS];
+ ssize_t orig, copied, use, offset;
+
+ orig = msg->sg.size;
+ while (bytes > 0) {
+ i = 0;
+ maxpages = to_max_pages - num_elems;
+ if (maxpages == 0) {
+ ret = -EFAULT;
+ goto out;
+ }
+
+ copied = iov_iter_get_pages(from, pages, bytes, maxpages,
+ &offset);
+ if (copied <= 0) {
+ ret = -EFAULT;
+ goto out;
+ }
+
+ iov_iter_advance(from, copied);
+ bytes -= copied;
+ msg->sg.size += copied;
+
+ while (copied) {
+ use = min_t(int, copied, PAGE_SIZE - offset);
+ sg_set_page(&msg->sg.data[msg->sg.end],
+ pages[i], use, offset);
+ sg_unmark_end(&msg->sg.data[msg->sg.end]);
+ sk_mem_charge(sk, use);
+
+ offset = 0;
+ copied -= use;
+ sk_msg_iter_next(msg, end);
+ num_elems++;
+ i++;
+ }
+ /* When zerocopy is mixed with sk_msg_*copy* operations we
+ * may have a copybreak set in this case clear and prefer
+ * zerocopy remainder when possible.
+ */
+ msg->sg.copybreak = 0;
+ msg->sg.curr = msg->sg.end;
+ }
+out:
+ /* Revert iov_iter updates, msg will need to use 'trim' later if it
+ * also needs to be cleared.
+ */
+ if (ret)
+ iov_iter_revert(from, msg->sg.size - orig);
+ return ret;
+}
+EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter);
+
+int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
+ struct sk_msg *msg, u32 bytes)
+{
+ int ret = -ENOSPC, i = msg->sg.curr;
+ struct scatterlist *sge;
+ u32 copy, buf_size;
+ void *to;
+
+ do {
+ sge = sk_msg_elem(msg, i);
+ /* This is possible if a trim operation shrunk the buffer */
+ if (msg->sg.copybreak >= sge->length) {
+ msg->sg.copybreak = 0;
+ sk_msg_iter_var_next(i);
+ if (i == msg->sg.end)
+ break;
+ sge = sk_msg_elem(msg, i);
+ }
+
+ buf_size = sge->length - msg->sg.copybreak;
+ copy = (buf_size > bytes) ? bytes : buf_size;
+ to = sg_virt(sge) + msg->sg.copybreak;
+ msg->sg.copybreak += copy;
+ if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
+ ret = copy_from_iter_nocache(to, copy, from);
+ else
+ ret = copy_from_iter(to, copy, from);
+ if (ret != copy) {
+ ret = -EFAULT;
+ goto out;
+ }
+ bytes -= copy;
+ if (!bytes)
+ break;
+ msg->sg.copybreak = 0;
+ sk_msg_iter_var_next(i);
+ } while (i != msg->sg.end);
+out:
+ msg->sg.curr = i;
+ return ret;
+}
+EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);
+
+static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb)
+{
+ struct sock *sk = psock->sk;
+ int copied = 0, num_sge;
+ struct sk_msg *msg;
+
+ msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC);
+ if (unlikely(!msg))
+ return -EAGAIN;
+ if (!sk_rmem_schedule(sk, skb, skb->len)) {
+ kfree(msg);
+ return -EAGAIN;
+ }
+
+ sk_msg_init(msg);
+ num_sge = skb_to_sgvec(skb, msg->sg.data, 0, skb->len);
+ if (unlikely(num_sge < 0)) {
+ kfree(msg);
+ return num_sge;
+ }
+
+ sk_mem_charge(sk, skb->len);
+ copied = skb->len;
+ msg->sg.start = 0;
+ msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge;
+ msg->skb = skb;
+
+ sk_psock_queue_msg(psock, msg);
+ sk->sk_data_ready(sk);
+ return copied;
+}
+
+static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb,
+ u32 off, u32 len, bool ingress)
+{
+ if (ingress)
+ return sk_psock_skb_ingress(psock, skb);
+ else
+ return skb_send_sock_locked(psock->sk, skb, off, len);
+}
+
+static void sk_psock_backlog(struct work_struct *work)
+{
+ struct sk_psock *psock = container_of(work, struct sk_psock, work);
+ struct sk_psock_work_state *state = &psock->work_state;
+ struct sk_buff *skb;
+ bool ingress;
+ u32 len, off;
+ int ret;
+
+ /* Lock sock to avoid losing sk_socket during loop. */
+ lock_sock(psock->sk);
+ if (state->skb) {
+ skb = state->skb;
+ len = state->len;
+ off = state->off;
+ state->skb = NULL;
+ goto start;
+ }
+
+ while ((skb = skb_dequeue(&psock->ingress_skb))) {
+ len = skb->len;
+ off = 0;
+start:
+ ingress = tcp_skb_bpf_ingress(skb);
+ do {
+ ret = -EIO;
+ if (likely(psock->sk->sk_socket))
+ ret = sk_psock_handle_skb(psock, skb, off,
+ len, ingress);
+ if (ret <= 0) {
+ if (ret == -EAGAIN) {
+ state->skb = skb;
+ state->len = len;
+ state->off = off;
+ goto end;
+ }
+ /* Hard errors break pipe and stop xmit. */
+ sk_psock_report_error(psock, ret ? -ret : EPIPE);
+ sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
+ kfree_skb(skb);
+ goto end;
+ }
+ off += ret;
+ len -= ret;
+ } while (len);
+
+ if (!ingress)
+ kfree_skb(skb);
+ }
+end:
+ release_sock(psock->sk);
+}
+
+struct sk_psock *sk_psock_init(struct sock *sk, int node)
+{
+ struct sk_psock *psock = kzalloc_node(sizeof(*psock),
+ GFP_ATOMIC | __GFP_NOWARN,
+ node);
+ if (!psock)
+ return NULL;
+
+ psock->sk = sk;
+ psock->eval = __SK_NONE;
+
+ INIT_LIST_HEAD(&psock->link);
+ spin_lock_init(&psock->link_lock);
+
+ INIT_WORK(&psock->work, sk_psock_backlog);
+ INIT_LIST_HEAD(&psock->ingress_msg);
+ skb_queue_head_init(&psock->ingress_skb);
+
+ sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED);
+ refcount_set(&psock->refcnt, 1);
+
+ rcu_assign_sk_user_data(sk, psock);
+ sock_hold(sk);
+
+ return psock;
+}
+EXPORT_SYMBOL_GPL(sk_psock_init);
+
+struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock)
+{
+ struct sk_psock_link *link;
+
+ spin_lock_bh(&psock->link_lock);
+ link = list_first_entry_or_null(&psock->link, struct sk_psock_link,
+ list);
+ if (link)
+ list_del(&link->list);
+ spin_unlock_bh(&psock->link_lock);
+ return link;
+}
+
+void __sk_psock_purge_ingress_msg(struct sk_psock *psock)
+{
+ struct sk_msg *msg, *tmp;
+
+ list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) {
+ list_del(&msg->list);
+ sk_msg_free(psock->sk, msg);
+ kfree(msg);
+ }
+}
+
+static void sk_psock_zap_ingress(struct sk_psock *psock)
+{
+ __skb_queue_purge(&psock->ingress_skb);
+ __sk_psock_purge_ingress_msg(psock);
+}
+
+static void sk_psock_link_destroy(struct sk_psock *psock)
+{
+ struct sk_psock_link *link, *tmp;
+
+ list_for_each_entry_safe(link, tmp, &psock->link, list) {
+ list_del(&link->list);
+ sk_psock_free_link(link);
+ }
+}
+
+static void sk_psock_destroy_deferred(struct work_struct *gc)
+{
+ struct sk_psock *psock = container_of(gc, struct sk_psock, gc);
+
+ /* No sk_callback_lock since already detached. */
+ if (psock->parser.enabled)
+ strp_done(&psock->parser.strp);
+
+ cancel_work_sync(&psock->work);
+
+ psock_progs_drop(&psock->progs);
+
+ sk_psock_link_destroy(psock);
+ sk_psock_cork_free(psock);
+ sk_psock_zap_ingress(psock);
+
+ if (psock->sk_redir)
+ sock_put(psock->sk_redir);
+ sock_put(psock->sk);
+ kfree(psock);
+}
+
+void sk_psock_destroy(struct rcu_head *rcu)
+{
+ struct sk_psock *psock = container_of(rcu, struct sk_psock, rcu);
+
+ INIT_WORK(&psock->gc, sk_psock_destroy_deferred);
+ schedule_work(&psock->gc);
+}
+EXPORT_SYMBOL_GPL(sk_psock_destroy);
+
+void sk_psock_drop(struct sock *sk, struct sk_psock *psock)
+{
+ rcu_assign_sk_user_data(sk, NULL);
+ sk_psock_cork_free(psock);
+ sk_psock_restore_proto(sk, psock);
+
+ write_lock_bh(&sk->sk_callback_lock);
+ if (psock->progs.skb_parser)
+ sk_psock_stop_strp(sk, psock);
+ write_unlock_bh(&sk->sk_callback_lock);
+ sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
+
+ call_rcu_sched(&psock->rcu, sk_psock_destroy);
+}
+EXPORT_SYMBOL_GPL(sk_psock_drop);
+
+static int sk_psock_map_verd(int verdict, bool redir)
+{
+ switch (verdict) {
+ case SK_PASS:
+ return redir ? __SK_REDIRECT : __SK_PASS;
+ case SK_DROP:
+ default:
+ break;
+ }
+
+ return __SK_DROP;
+}
+
+int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
+ struct sk_msg *msg)
+{
+ struct bpf_prog *prog;
+ int ret;
+
+ preempt_disable();
+ rcu_read_lock();
+ prog = READ_ONCE(psock->progs.msg_parser);
+ if (unlikely(!prog)) {
+ ret = __SK_PASS;
+ goto out;
+ }
+
+ sk_msg_compute_data_pointers(msg);
+ msg->sk = sk;
+ ret = BPF_PROG_RUN(prog, msg);
+ ret = sk_psock_map_verd(ret, msg->sk_redir);
+ psock->apply_bytes = msg->apply_bytes;
+ if (ret == __SK_REDIRECT) {
+ if (psock->sk_redir)
+ sock_put(psock->sk_redir);
+ psock->sk_redir = msg->sk_redir;
+ if (!psock->sk_redir) {
+ ret = __SK_DROP;
+ goto out;
+ }
+ sock_hold(psock->sk_redir);
+ }
+out:
+ rcu_read_unlock();
+ preempt_enable();
+ return ret;
+}
+EXPORT_SYMBOL_GPL(sk_psock_msg_verdict);
+
+static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog,
+ struct sk_buff *skb)
+{
+ int ret;
+
+ skb->sk = psock->sk;
+ bpf_compute_data_end_sk_skb(skb);
+ preempt_disable();
+ ret = BPF_PROG_RUN(prog, skb);
+ preempt_enable();
+ /* strparser clones the skb before handing it to a upper layer,
+ * meaning skb_orphan has been called. We NULL sk on the way out
+ * to ensure we don't trigger a BUG_ON() in skb/sk operations
+ * later and because we are not charging the memory of this skb
+ * to any socket yet.
+ */
+ skb->sk = NULL;
+ return ret;
+}
+
+static struct sk_psock *sk_psock_from_strp(struct strparser *strp)
+{
+ struct sk_psock_parser *parser;
+
+ parser = container_of(strp, struct sk_psock_parser, strp);
+ return container_of(parser, struct sk_psock, parser);
+}
+
+static void sk_psock_verdict_apply(struct sk_psock *psock,
+ struct sk_buff *skb, int verdict)
+{
+ struct sk_psock *psock_other;
+ struct sock *sk_other;
+ bool ingress;
+
+ switch (verdict) {
+ case __SK_REDIRECT:
+ sk_other = tcp_skb_bpf_redirect_fetch(skb);
+ if (unlikely(!sk_other))
+ goto out_free;
+ psock_other = sk_psock(sk_other);
+ if (!psock_other || sock_flag(sk_other, SOCK_DEAD) ||
+ !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED))
+ goto out_free;
+ ingress = tcp_skb_bpf_ingress(skb);
+ if ((!ingress && sock_writeable(sk_other)) ||
+ (ingress &&
+ atomic_read(&sk_other->sk_rmem_alloc) <=
+ sk_other->sk_rcvbuf)) {
+ if (!ingress)
+ skb_set_owner_w(skb, sk_other);
+ skb_queue_tail(&psock_other->ingress_skb, skb);
+ schedule_work(&psock_other->work);
+ break;
+ }
+ /* fall-through */
+ case __SK_DROP:
+ /* fall-through */
+ default:
+out_free:
+ kfree_skb(skb);
+ }
+}
+
+static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb)
+{
+ struct sk_psock *psock = sk_psock_from_strp(strp);
+ struct bpf_prog *prog;
+ int ret = __SK_DROP;
+
+ rcu_read_lock();
+ prog = READ_ONCE(psock->progs.skb_verdict);
+ if (likely(prog)) {
+ skb_orphan(skb);
+ tcp_skb_bpf_redirect_clear(skb);
+ ret = sk_psock_bpf_run(psock, prog, skb);
+ ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb));
+ }
+ rcu_read_unlock();
+ sk_psock_verdict_apply(psock, skb, ret);
+}
+
+static int sk_psock_strp_read_done(struct strparser *strp, int err)
+{
+ return err;
+}
+
+static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb)
+{
+ struct sk_psock *psock = sk_psock_from_strp(strp);
+ struct bpf_prog *prog;
+ int ret = skb->len;
+
+ rcu_read_lock();
+ prog = READ_ONCE(psock->progs.skb_parser);
+ if (likely(prog))
+ ret = sk_psock_bpf_run(psock, prog, skb);
+ rcu_read_unlock();
+ return ret;
+}
+
+/* Called with socket lock held. */
+static void sk_psock_data_ready(struct sock *sk)
+{
+ struct sk_psock *psock;
+
+ rcu_read_lock();
+ psock = sk_psock(sk);
+ if (likely(psock)) {
+ write_lock_bh(&sk->sk_callback_lock);
+ strp_data_ready(&psock->parser.strp);
+ write_unlock_bh(&sk->sk_callback_lock);
+ }
+ rcu_read_unlock();
+}
+
+static void sk_psock_write_space(struct sock *sk)
+{
+ struct sk_psock *psock;
+ void (*write_space)(struct sock *sk);
+
+ rcu_read_lock();
+ psock = sk_psock(sk);
+ if (likely(psock && sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)))
+ schedule_work(&psock->work);
+ write_space = psock->saved_write_space;
+ rcu_read_unlock();
+ write_space(sk);
+}
+
+int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock)
+{
+ static const struct strp_callbacks cb = {
+ .rcv_msg = sk_psock_strp_read,
+ .read_sock_done = sk_psock_strp_read_done,
+ .parse_msg = sk_psock_strp_parse,
+ };
+
+ psock->parser.enabled = false;
+ return strp_init(&psock->parser.strp, sk, &cb);
+}
+
+void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock)
+{
+ struct sk_psock_parser *parser = &psock->parser;
+
+ if (parser->enabled)
+ return;
+
+ parser->saved_data_ready = sk->sk_data_ready;
+ sk->sk_data_ready = sk_psock_data_ready;
+ sk->sk_write_space = sk_psock_write_space;
+ parser->enabled = true;
+}
+
+void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)
+{
+ struct sk_psock_parser *parser = &psock->parser;
+
+ if (!parser->enabled)
+ return;
+
+ sk->sk_data_ready = parser->saved_data_ready;
+ parser->saved_data_ready = NULL;
+ strp_stop(&parser->strp);
+ parser->enabled = false;
+}
diff --git a/net/core/sock_map.c b/net/core/sock_map.c
new file mode 100644
index 000000000000..3c0e44cb811a
--- /dev/null
+++ b/net/core/sock_map.c
@@ -0,0 +1,1002 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
+
+#include <linux/bpf.h>
+#include <linux/filter.h>
+#include <linux/errno.h>
+#include <linux/file.h>
+#include <linux/net.h>
+#include <linux/workqueue.h>
+#include <linux/skmsg.h>
+#include <linux/list.h>
+#include <linux/jhash.h>
+
+struct bpf_stab {
+ struct bpf_map map;
+ struct sock **sks;
+ struct sk_psock_progs progs;
+ raw_spinlock_t lock;
+};
+
+#define SOCK_CREATE_FLAG_MASK \
+ (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
+
+static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
+{
+ struct bpf_stab *stab;
+ u64 cost;
+ int err;
+
+ if (!capable(CAP_NET_ADMIN))
+ return ERR_PTR(-EPERM);
+ if (attr->max_entries == 0 ||
+ attr->key_size != 4 ||
+ attr->value_size != 4 ||
+ attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
+ return ERR_PTR(-EINVAL);
+
+ stab = kzalloc(sizeof(*stab), GFP_USER);
+ if (!stab)
+ return ERR_PTR(-ENOMEM);
+
+ bpf_map_init_from_attr(&stab->map, attr);
+ raw_spin_lock_init(&stab->lock);
+
+ /* Make sure page count doesn't overflow. */
+ cost = (u64) stab->map.max_entries * sizeof(struct sock *);
+ if (cost >= U32_MAX - PAGE_SIZE) {
+ err = -EINVAL;
+ goto free_stab;
+ }
+
+ stab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
+ err = bpf_map_precharge_memlock(stab->map.pages);
+ if (err)
+ goto free_stab;
+
+ stab->sks = bpf_map_area_alloc(stab->map.max_entries *
+ sizeof(struct sock *),
+ stab->map.numa_node);
+ if (stab->sks)
+ return &stab->map;
+ err = -ENOMEM;
+free_stab:
+ kfree(stab);
+ return ERR_PTR(err);
+}
+
+int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog)
+{
+ u32 ufd = attr->target_fd;
+ struct bpf_map *map;
+ struct fd f;
+ int ret;
+
+ f = fdget(ufd);
+ map = __bpf_map_get(f);
+ if (IS_ERR(map))
+ return PTR_ERR(map);
+ ret = sock_map_prog_update(map, prog, attr->attach_type);
+ fdput(f);
+ return ret;
+}
+
+static void sock_map_sk_acquire(struct sock *sk)
+ __acquires(&sk->sk_lock.slock)
+{
+ lock_sock(sk);
+ preempt_disable();
+ rcu_read_lock();
+}
+
+static void sock_map_sk_release(struct sock *sk)
+ __releases(&sk->sk_lock.slock)
+{
+ rcu_read_unlock();
+ preempt_enable();
+ release_sock(sk);
+}
+
+static void sock_map_add_link(struct sk_psock *psock,
+ struct sk_psock_link *link,
+ struct bpf_map *map, void *link_raw)
+{
+ link->link_raw = link_raw;
+ link->map = map;
+ spin_lock_bh(&psock->link_lock);
+ list_add_tail(&link->list, &psock->link);
+ spin_unlock_bh(&psock->link_lock);
+}
+
+static void sock_map_del_link(struct sock *sk,
+ struct sk_psock *psock, void *link_raw)
+{
+ struct sk_psock_link *link, *tmp;
+ bool strp_stop = false;
+
+ spin_lock_bh(&psock->link_lock);
+ list_for_each_entry_safe(link, tmp, &psock->link, list) {
+ if (link->link_raw == link_raw) {
+ struct bpf_map *map = link->map;
+ struct bpf_stab *stab = container_of(map, struct bpf_stab,
+ map);
+ if (psock->parser.enabled && stab->progs.skb_parser)
+ strp_stop = true;
+ list_del(&link->list);
+ sk_psock_free_link(link);
+ }
+ }
+ spin_unlock_bh(&psock->link_lock);
+ if (strp_stop) {
+ write_lock_bh(&sk->sk_callback_lock);
+ sk_psock_stop_strp(sk, psock);
+ write_unlock_bh(&sk->sk_callback_lock);
+ }
+}
+
+static void sock_map_unref(struct sock *sk, void *link_raw)
+{
+ struct sk_psock *psock = sk_psock(sk);
+
+ if (likely(psock)) {
+ sock_map_del_link(sk, psock, link_raw);
+ sk_psock_put(sk, psock);
+ }
+}
+
+static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
+ struct sock *sk)
+{
+ struct bpf_prog *msg_parser, *skb_parser, *skb_verdict;
+ bool skb_progs, sk_psock_is_new = false;
+ struct sk_psock *psock;
+ int ret;
+
+ skb_verdict = READ_ONCE(progs->skb_verdict);
+ skb_parser = READ_ONCE(progs->skb_parser);
+ skb_progs = skb_parser && skb_verdict;
+ if (skb_progs) {
+ skb_verdict = bpf_prog_inc_not_zero(skb_verdict);
+ if (IS_ERR(skb_verdict))
+ return PTR_ERR(skb_verdict);
+ skb_parser = bpf_prog_inc_not_zero(skb_parser);
+ if (IS_ERR(skb_parser)) {
+ bpf_prog_put(skb_verdict);
+ return PTR_ERR(skb_parser);
+ }
+ }
+
+ msg_parser = READ_ONCE(progs->msg_parser);
+ if (msg_parser) {
+ msg_parser = bpf_prog_inc_not_zero(msg_parser);
+ if (IS_ERR(msg_parser)) {
+ ret = PTR_ERR(msg_parser);
+ goto out;
+ }
+ }
+
+ psock = sk_psock_get(sk);
+ if (psock) {
+ if (!sk_has_psock(sk)) {
+ ret = -EBUSY;
+ goto out_progs;
+ }
+ if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) ||
+ (skb_progs && READ_ONCE(psock->progs.skb_parser))) {
+ sk_psock_put(sk, psock);
+ ret = -EBUSY;
+ goto out_progs;
+ }
+ } else {
+ psock = sk_psock_init(sk, map->numa_node);
+ if (!psock) {
+ ret = -ENOMEM;
+ goto out_progs;
+ }
+ sk_psock_is_new = true;
+ }
+
+ if (msg_parser)
+ psock_set_prog(&psock->progs.msg_parser, msg_parser);
+ if (sk_psock_is_new) {
+ ret = tcp_bpf_init(sk);
+ if (ret < 0)
+ goto out_drop;
+ } else {
+ tcp_bpf_reinit(sk);
+ }
+
+ write_lock_bh(&sk->sk_callback_lock);
+ if (skb_progs && !psock->parser.enabled) {
+ ret = sk_psock_init_strp(sk, psock);
+ if (ret) {
+ write_unlock_bh(&sk->sk_callback_lock);
+ goto out_drop;
+ }
+ psock_set_prog(&psock->progs.skb_verdict, skb_verdict);
+ psock_set_prog(&psock->progs.skb_parser, skb_parser);
+ sk_psock_start_strp(sk, psock);
+ }
+ write_unlock_bh(&sk->sk_callback_lock);
+ return 0;
+out_drop:
+ sk_psock_put(sk, psock);
+out_progs:
+ if (msg_parser)
+ bpf_prog_put(msg_parser);
+out:
+ if (skb_progs) {
+ bpf_prog_put(skb_verdict);
+ bpf_prog_put(skb_parser);
+ }
+ return ret;
+}
+
+static void sock_map_free(struct bpf_map *map)
+{
+ struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
+ int i;
+
+ synchronize_rcu();
+ rcu_read_lock();
+ raw_spin_lock_bh(&stab->lock);
+ for (i = 0; i < stab->map.max_entries; i++) {
+ struct sock **psk = &stab->sks[i];
+ struct sock *sk;
+
+ sk = xchg(psk, NULL);
+ if (sk)
+ sock_map_unref(sk, psk);
+ }
+ raw_spin_unlock_bh(&stab->lock);
+ rcu_read_unlock();
+
+ bpf_map_area_free(stab->sks);
+ kfree(stab);
+}
+
+static void sock_map_release_progs(struct bpf_map *map)
+{
+ psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs);
+}
+
+static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
+{
+ struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
+
+ WARN_ON_ONCE(!rcu_read_lock_held());
+
+ if (unlikely(key >= map->max_entries))
+ return NULL;
+ return READ_ONCE(stab->sks[key]);
+}
+
+static void *sock_map_lookup(struct bpf_map *map, void *key)
+{
+ return ERR_PTR(-EOPNOTSUPP);
+}
+
+static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
+ struct sock **psk)
+{
+ struct sock *sk;
+
+ raw_spin_lock_bh(&stab->lock);
+ sk = *psk;
+ if (!sk_test || sk_test == sk)
+ *psk = NULL;
+ raw_spin_unlock_bh(&stab->lock);
+ if (unlikely(!sk))
+ return -EINVAL;
+ sock_map_unref(sk, psk);
+ return 0;
+}
+
+static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk,
+ void *link_raw)
+{
+ struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
+
+ __sock_map_delete(stab, sk, link_raw);
+}
+
+static int sock_map_delete_elem(struct bpf_map *map, void *key)
+{
+ struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
+ u32 i = *(u32 *)key;
+ struct sock **psk;
+
+ if (unlikely(i >= map->max_entries))
+ return -EINVAL;
+
+ psk = &stab->sks[i];
+ return __sock_map_delete(stab, NULL, psk);
+}
+
+static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next)
+{
+ struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
+ u32 i = key ? *(u32 *)key : U32_MAX;
+ u32 *key_next = next;
+
+ if (i == stab->map.max_entries - 1)
+ return -ENOENT;
+ if (i >= stab->map.max_entries)
+ *key_next = 0;
+ else
+ *key_next = i + 1;
+ return 0;
+}
+
+static int sock_map_update_common(struct bpf_map *map, u32 idx,
+ struct sock *sk, u64 flags)
+{
+ struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
+ struct sk_psock_link *link;
+ struct sk_psock *psock;
+ struct sock *osk;
+ int ret;
+
+ WARN_ON_ONCE(!rcu_read_lock_held());
+ if (unlikely(flags > BPF_EXIST))
+ return -EINVAL;
+ if (unlikely(idx >= map->max_entries))
+ return -E2BIG;
+
+ link = sk_psock_init_link();
+ if (!link)
+ return -ENOMEM;
+
+ ret = sock_map_link(map, &stab->progs, sk);
+ if (ret < 0)
+ goto out_free;
+
+ psock = sk_psock(sk);
+ WARN_ON_ONCE(!psock);
+
+ raw_spin_lock_bh(&stab->lock);
+ osk = stab->sks[idx];
+ if (osk && flags == BPF_NOEXIST) {
+ ret = -EEXIST;
+ goto out_unlock;
+ } else if (!osk && flags == BPF_EXIST) {
+ ret = -ENOENT;
+ goto out_unlock;
+ }
+
+ sock_map_add_link(psock, link, map, &stab->sks[idx]);
+ stab->sks[idx] = sk;
+ if (osk)
+ sock_map_unref(osk, &stab->sks[idx]);
+ raw_spin_unlock_bh(&stab->lock);
+ return 0;
+out_unlock:
+ raw_spin_unlock_bh(&stab->lock);
+ if (psock)
+ sk_psock_put(sk, psock);
+out_free:
+ sk_psock_free_link(link);
+ return ret;
+}
+
+static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops)
+{
+ return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
+ ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB;
+}
+
+static bool sock_map_sk_is_suitable(const struct sock *sk)
+{
+ return sk->sk_type == SOCK_STREAM &&
+ sk->sk_protocol == IPPROTO_TCP;
+}
+
+static int sock_map_update_elem(struct bpf_map *map, void *key,
+ void *value, u64 flags)
+{
+ u32 ufd = *(u32 *)value;
+ u32 idx = *(u32 *)key;
+ struct socket *sock;
+ struct sock *sk;
+ int ret;
+
+ sock = sockfd_lookup(ufd, &ret);
+ if (!sock)
+ return ret;
+ sk = sock->sk;
+ if (!sk) {
+ ret = -EINVAL;
+ goto out;
+ }
+ if (!sock_map_sk_is_suitable(sk) ||
+ sk->sk_state != TCP_ESTABLISHED) {
+ ret = -EOPNOTSUPP;
+ goto out;
+ }
+
+ sock_map_sk_acquire(sk);
+ ret = sock_map_update_common(map, idx, sk, flags);
+ sock_map_sk_release(sk);
+out:
+ fput(sock->file);
+ return ret;
+}
+
+BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops,
+ struct bpf_map *, map, void *, key, u64, flags)
+{
+ WARN_ON_ONCE(!rcu_read_lock_held());
+
+ if (likely(sock_map_sk_is_suitable(sops->sk) &&
+ sock_map_op_okay(sops)))
+ return sock_map_update_common(map, *(u32 *)key, sops->sk,
+ flags);
+ return -EOPNOTSUPP;
+}
+
+const struct bpf_func_proto bpf_sock_map_update_proto = {
+ .func = bpf_sock_map_update,
+ .gpl_only = false,
+ .pkt_access = true,
+ .ret_type = RET_INTEGER,
+ .arg1_type = ARG_PTR_TO_CTX,
+ .arg2_type = ARG_CONST_MAP_PTR,
+ .arg3_type = ARG_PTR_TO_MAP_KEY,
+ .arg4_type = ARG_ANYTHING,
+};
+
+BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
+ struct bpf_map *, map, u32, key, u64, flags)
+{
+ struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
+
+ if (unlikely(flags & ~(BPF_F_INGRESS)))
+ return SK_DROP;
+ tcb->bpf.flags = flags;
+ tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key);
+ if (!tcb->bpf.sk_redir)
+ return SK_DROP;
+ return SK_PASS;
+}
+
+const struct bpf_func_proto bpf_sk_redirect_map_proto = {
+ .func = bpf_sk_redirect_map,
+ .gpl_only = false,
+ .ret_type = RET_INTEGER,
+ .arg1_type = ARG_PTR_TO_CTX,
+ .arg2_type = ARG_CONST_MAP_PTR,
+ .arg3_type = ARG_ANYTHING,
+ .arg4_type = ARG_ANYTHING,
+};
+
+BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg,
+ struct bpf_map *, map, u32, key, u64, flags)
+{
+ if (unlikely(flags & ~(BPF_F_INGRESS)))
+ return SK_DROP;
+ msg->flags = flags;
+ msg->sk_redir = __sock_map_lookup_elem(map, key);
+ if (!msg->sk_redir)
+ return SK_DROP;
+ return SK_PASS;
+}
+
+const struct bpf_func_proto bpf_msg_redirect_map_proto = {
+ .func = bpf_msg_redirect_map,
+ .gpl_only = false,
+ .ret_type = RET_INTEGER,
+ .arg1_type = ARG_PTR_TO_CTX,
+ .arg2_type = ARG_CONST_MAP_PTR,
+ .arg3_type = ARG_ANYTHING,
+ .arg4_type = ARG_ANYTHING,
+};
+
+const struct bpf_map_ops sock_map_ops = {
+ .map_alloc = sock_map_alloc,
+ .map_free = sock_map_free,
+ .map_get_next_key = sock_map_get_next_key,
+ .map_update_elem = sock_map_update_elem,
+ .map_delete_elem = sock_map_delete_elem,
+ .map_lookup_elem = sock_map_lookup,
+ .map_release_uref = sock_map_release_progs,
+ .map_check_btf = map_check_no_btf,
+};
+
+struct bpf_htab_elem {
+ struct rcu_head rcu;
+ u32 hash;
+ struct sock *sk;
+ struct hlist_node node;
+ u8 key[0];
+};
+
+struct bpf_htab_bucket {
+ struct hlist_head head;
+ raw_spinlock_t lock;
+};
+
+struct bpf_htab {
+ struct bpf_map map;
+ struct bpf_htab_bucket *buckets;
+ u32 buckets_num;
+ u32 elem_size;
+ struct sk_psock_progs progs;
+ atomic_t count;
+};
+
+static inline u32 sock_hash_bucket_hash(const void *key, u32 len)
+{
+ return jhash(key, len, 0);
+}
+
+static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab,
+ u32 hash)
+{
+ return &htab->buckets[hash & (htab->buckets_num - 1)];
+}
+
+static struct bpf_htab_elem *
+sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key,
+ u32 key_size)
+{
+ struct bpf_htab_elem *elem;
+
+ hlist_for_each_entry_rcu(elem, head, node) {
+ if (elem->hash == hash &&
+ !memcmp(&elem->key, key, key_size))
+ return elem;
+ }
+
+ return NULL;
+}
+
+static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
+{
+ struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
+ u32 key_size = map->key_size, hash;
+ struct bpf_htab_bucket *bucket;
+ struct bpf_htab_elem *elem;
+
+ WARN_ON_ONCE(!rcu_read_lock_held());
+
+ hash = sock_hash_bucket_hash(key, key_size);
+ bucket = sock_hash_select_bucket(htab, hash);
+ elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
+
+ return elem ? elem->sk : NULL;
+}
+
+static void sock_hash_free_elem(struct bpf_htab *htab,
+ struct bpf_htab_elem *elem)
+{
+ atomic_dec(&htab->count);
+ kfree_rcu(elem, rcu);
+}
+
+static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk,
+ void *link_raw)
+{
+ struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
+ struct bpf_htab_elem *elem_probe, *elem = link_raw;
+ struct bpf_htab_bucket *bucket;
+
+ WARN_ON_ONCE(!rcu_read_lock_held());
+ bucket = sock_hash_select_bucket(htab, elem->hash);
+
+ /* elem may be deleted in parallel from the map, but access here
+ * is okay since it's going away only after RCU grace period.
+ * However, we need to check whether it's still present.
+ */
+ raw_spin_lock_bh(&bucket->lock);
+ elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash,
+ elem->key, map->key_size);
+ if (elem_probe && elem_probe == elem) {
+ hlist_del_rcu(&elem->node);
+ sock_map_unref(elem->sk, elem);
+ sock_hash_free_elem(htab, elem);
+ }
+ raw_spin_unlock_bh(&bucket->lock);
+}
+
+static int sock_hash_delete_elem(struct bpf_map *map, void *key)
+{
+ struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
+ u32 hash, key_size = map->key_size;
+ struct bpf_htab_bucket *bucket;
+ struct bpf_htab_elem *elem;
+ int ret = -ENOENT;
+
+ hash = sock_hash_bucket_hash(key, key_size);
+ bucket = sock_hash_select_bucket(htab, hash);
+
+ raw_spin_lock_bh(&bucket->lock);
+ elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
+ if (elem) {
+ hlist_del_rcu(&elem->node);
+ sock_map_unref(elem->sk, elem);
+ sock_hash_free_elem(htab, elem);
+ ret = 0;
+ }
+ raw_spin_unlock_bh(&bucket->lock);
+ return ret;
+}
+
+static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab,
+ void *key, u32 key_size,
+ u32 hash, struct sock *sk,
+ struct bpf_htab_elem *old)
+{
+ struct bpf_htab_elem *new;
+
+ if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
+ if (!old) {
+ atomic_dec(&htab->count);
+ return ERR_PTR(-E2BIG);
+ }
+ }
+
+ new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,
+ htab->map.numa_node);
+ if (!new) {
+ atomic_dec(&htab->count);
+ return ERR_PTR(-ENOMEM);
+ }
+ memcpy(new->key, key, key_size);
+ new->sk = sk;
+ new->hash = hash;
+ return new;
+}
+
+static int sock_hash_update_common(struct bpf_map *map, void *key,
+ struct sock *sk, u64 flags)
+{
+ struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
+ u32 key_size = map->key_size, hash;
+ struct bpf_htab_elem *elem, *elem_new;
+ struct bpf_htab_bucket *bucket;
+ struct sk_psock_link *link;
+ struct sk_psock *psock;
+ int ret;
+
+ WARN_ON_ONCE(!rcu_read_lock_held());
+ if (unlikely(flags > BPF_EXIST))
+ return -EINVAL;
+
+ link = sk_psock_init_link();
+ if (!link)
+ return -ENOMEM;
+
+ ret = sock_map_link(map, &htab->progs, sk);
+ if (ret < 0)
+ goto out_free;
+
+ psock = sk_psock(sk);
+ WARN_ON_ONCE(!psock);
+
+ hash = sock_hash_bucket_hash(key, key_size);
+ bucket = sock_hash_select_bucket(htab, hash);
+
+ raw_spin_lock_bh(&bucket->lock);
+ elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
+ if (elem && flags == BPF_NOEXIST) {
+ ret = -EEXIST;
+ goto out_unlock;
+ } else if (!elem && flags == BPF_EXIST) {
+ ret = -ENOENT;
+ goto out_unlock;
+ }
+
+ elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem);
+ if (IS_ERR(elem_new)) {
+ ret = PTR_ERR(elem_new);
+ goto out_unlock;
+ }
+
+ sock_map_add_link(psock, link, map, elem_new);
+ /* Add new element to the head of the list, so that
+ * concurrent search will find it before old elem.
+ */
+ hlist_add_head_rcu(&elem_new->node, &bucket->head);
+ if (elem) {
+ hlist_del_rcu(&elem->node);
+ sock_map_unref(elem->sk, elem);
+ sock_hash_free_elem(htab, elem);
+ }
+ raw_spin_unlock_bh(&bucket->lock);
+ return 0;
+out_unlock:
+ raw_spin_unlock_bh(&bucket->lock);
+ sk_psock_put(sk, psock);
+out_free:
+ sk_psock_free_link(link);
+ return ret;
+}
+
+static int sock_hash_update_elem(struct bpf_map *map, void *key,
+ void *value, u64 flags)
+{
+ u32 ufd = *(u32 *)value;
+ struct socket *sock;
+ struct sock *sk;
+ int ret;
+
+ sock = sockfd_lookup(ufd, &ret);
+ if (!sock)
+ return ret;
+ sk = sock->sk;
+ if (!sk) {
+ ret = -EINVAL;
+ goto out;
+ }
+ if (!sock_map_sk_is_suitable(sk) ||
+ sk->sk_state != TCP_ESTABLISHED) {
+ ret = -EOPNOTSUPP;
+ goto out;
+ }
+
+ sock_map_sk_acquire(sk);
+ ret = sock_hash_update_common(map, key, sk, flags);
+ sock_map_sk_release(sk);
+out:
+ fput(sock->file);
+ return ret;
+}
+
+static int sock_hash_get_next_key(struct bpf_map *map, void *key,
+ void *key_next)
+{
+ struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
+ struct bpf_htab_elem *elem, *elem_next;
+ u32 hash, key_size = map->key_size;
+ struct hlist_head *head;
+ int i = 0;
+
+ if (!key)
+ goto find_first_elem;
+ hash = sock_hash_bucket_hash(key, key_size);
+ head = &sock_hash_select_bucket(htab, hash)->head;
+ elem = sock_hash_lookup_elem_raw(head, hash, key, key_size);
+ if (!elem)
+ goto find_first_elem;
+
+ elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)),
+ struct bpf_htab_elem, node);
+ if (elem_next) {
+ memcpy(key_next, elem_next->key, key_size);
+ return 0;
+ }
+
+ i = hash & (htab->buckets_num - 1);
+ i++;
+find_first_elem:
+ for (; i < htab->buckets_num; i++) {
+ head = &sock_hash_select_bucket(htab, i)->head;
+ elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)),
+ struct bpf_htab_elem, node);
+ if (elem_next) {
+ memcpy(key_next, elem_next->key, key_size);
+ return 0;
+ }
+ }
+
+ return -ENOENT;
+}
+
+static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
+{
+ struct bpf_htab *htab;
+ int i, err;
+ u64 cost;
+
+ if (!capable(CAP_NET_ADMIN))
+ return ERR_PTR(-EPERM);
+ if (attr->max_entries == 0 ||
+ attr->key_size == 0 ||
+ attr->value_size != 4 ||
+ attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
+ return ERR_PTR(-EINVAL);
+ if (attr->key_size > MAX_BPF_STACK)
+ return ERR_PTR(-E2BIG);
+
+ htab = kzalloc(sizeof(*htab), GFP_USER);
+ if (!htab)
+ return ERR_PTR(-ENOMEM);
+
+ bpf_map_init_from_attr(&htab->map, attr);
+
+ htab->buckets_num = roundup_pow_of_two(htab->map.max_entries);
+ htab->elem_size = sizeof(struct bpf_htab_elem) +
+ round_up(htab->map.key_size, 8);
+ if (htab->buckets_num == 0 ||
+ htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) {
+ err = -EINVAL;
+ goto free_htab;
+ }
+
+ cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) +
+ (u64) htab->elem_size * htab->map.max_entries;
+ if (cost >= U32_MAX - PAGE_SIZE) {
+ err = -EINVAL;
+ goto free_htab;
+ }
+
+ htab->buckets = bpf_map_area_alloc(htab->buckets_num *
+ sizeof(struct bpf_htab_bucket),
+ htab->map.numa_node);
+ if (!htab->buckets) {
+ err = -ENOMEM;
+ goto free_htab;
+ }
+
+ for (i = 0; i < htab->buckets_num; i++) {
+ INIT_HLIST_HEAD(&htab->buckets[i].head);
+ raw_spin_lock_init(&htab->buckets[i].lock);
+ }
+
+ return &htab->map;
+free_htab:
+ kfree(htab);
+ return ERR_PTR(err);
+}
+
+static void sock_hash_free(struct bpf_map *map)
+{
+ struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
+ struct bpf_htab_bucket *bucket;
+ struct bpf_htab_elem *elem;
+ struct hlist_node *node;
+ int i;
+
+ synchronize_rcu();
+ rcu_read_lock();
+ for (i = 0; i < htab->buckets_num; i++) {
+ bucket = sock_hash_select_bucket(htab, i);
+ raw_spin_lock_bh(&bucket->lock);
+ hlist_for_each_entry_safe(elem, node, &bucket->head, node) {
+ hlist_del_rcu(&elem->node);
+ sock_map_unref(elem->sk, elem);
+ }
+ raw_spin_unlock_bh(&bucket->lock);
+ }
+ rcu_read_unlock();
+
+ bpf_map_area_free(htab->buckets);
+ kfree(htab);
+}
+
+static void sock_hash_release_progs(struct bpf_map *map)
+{
+ psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs);
+}
+
+BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops,
+ struct bpf_map *, map, void *, key, u64, flags)
+{
+ WARN_ON_ONCE(!rcu_read_lock_held());
+
+ if (likely(sock_map_sk_is_suitable(sops->sk) &&
+ sock_map_op_okay(sops)))
+ return sock_hash_update_common(map, key, sops->sk, flags);
+ return -EOPNOTSUPP;
+}
+
+const struct bpf_func_proto bpf_sock_hash_update_proto = {
+ .func = bpf_sock_hash_update,
+ .gpl_only = false,
+ .pkt_access = true,
+ .ret_type = RET_INTEGER,
+ .arg1_type = ARG_PTR_TO_CTX,
+ .arg2_type = ARG_CONST_MAP_PTR,
+ .arg3_type = ARG_PTR_TO_MAP_KEY,
+ .arg4_type = ARG_ANYTHING,
+};
+
+BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb,
+ struct bpf_map *, map, void *, key, u64, flags)
+{
+ struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
+
+ if (unlikely(flags & ~(BPF_F_INGRESS)))
+ return SK_DROP;
+ tcb->bpf.flags = flags;
+ tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key);
+ if (!tcb->bpf.sk_redir)
+ return SK_DROP;
+ return SK_PASS;
+}
+
+const struct bpf_func_proto bpf_sk_redirect_hash_proto = {
+ .func = bpf_sk_redirect_hash,
+ .gpl_only = false,
+ .ret_type = RET_INTEGER,
+ .arg1_type = ARG_PTR_TO_CTX,
+ .arg2_type = ARG_CONST_MAP_PTR,
+ .arg3_type = ARG_PTR_TO_MAP_KEY,
+ .arg4_type = ARG_ANYTHING,
+};
+
+BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg,
+ struct bpf_map *, map, void *, key, u64, flags)
+{
+ if (unlikely(flags & ~(BPF_F_INGRESS)))
+ return SK_DROP;
+ msg->flags = flags;
+ msg->sk_redir = __sock_hash_lookup_elem(map, key);
+ if (!msg->sk_redir)
+ return SK_DROP;
+ return SK_PASS;
+}
+
+const struct bpf_func_proto bpf_msg_redirect_hash_proto = {
+ .func = bpf_msg_redirect_hash,
+ .gpl_only = false,
+ .ret_type = RET_INTEGER,
+ .arg1_type = ARG_PTR_TO_CTX,
+ .arg2_type = ARG_CONST_MAP_PTR,
+ .arg3_type = ARG_PTR_TO_MAP_KEY,
+ .arg4_type = ARG_ANYTHING,
+};
+
+const struct bpf_map_ops sock_hash_ops = {
+ .map_alloc = sock_hash_alloc,
+ .map_free = sock_hash_free,
+ .map_get_next_key = sock_hash_get_next_key,
+ .map_update_elem = sock_hash_update_elem,
+ .map_delete_elem = sock_hash_delete_elem,
+ .map_lookup_elem = sock_map_lookup,
+ .map_release_uref = sock_hash_release_progs,
+ .map_check_btf = map_check_no_btf,
+};
+
+static struct sk_psock_progs *sock_map_progs(struct bpf_map *map)
+{
+ switch (map->map_type) {
+ case BPF_MAP_TYPE_SOCKMAP:
+ return &container_of(map, struct bpf_stab, map)->progs;
+ case BPF_MAP_TYPE_SOCKHASH:
+ return &container_of(map, struct bpf_htab, map)->progs;
+ default:
+ break;
+ }
+
+ return NULL;
+}
+
+int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
+ u32 which)
+{
+ struct sk_psock_progs *progs = sock_map_progs(map);
+
+ if (!progs)
+ return -EOPNOTSUPP;
+
+ switch (which) {
+ case BPF_SK_MSG_VERDICT:
+ psock_set_prog(&progs->msg_parser, prog);
+ break;
+ case BPF_SK_SKB_STREAM_PARSER:
+ psock_set_prog(&progs->skb_parser, prog);
+ break;
+ case BPF_SK_SKB_STREAM_VERDICT:
+ psock_set_prog(&progs->skb_verdict, prog);
+ break;
+ default:
+ return -EOPNOTSUPP;
+ }
+
+ return 0;
+}
+
+void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link)
+{
+ switch (link->map->map_type) {
+ case BPF_MAP_TYPE_SOCKMAP:
+ return sock_map_delete_from_link(link->map, sk,
+ link->link_raw);
+ case BPF_MAP_TYPE_SOCKHASH:
+ return sock_hash_delete_from_link(link->map, sk,
+ link->link_raw);
+ default:
+ break;
+ }
+}
diff --git a/net/ipv4/Makefile b/net/ipv4/Makefile
index 7446b98661d8..58629314eae9 100644
--- a/net/ipv4/Makefile
+++ b/net/ipv4/Makefile
@@ -63,6 +63,7 @@ obj-$(CONFIG_TCP_CONG_SCALABLE) += tcp_scalable.o
obj-$(CONFIG_TCP_CONG_LP) += tcp_lp.o
obj-$(CONFIG_TCP_CONG_YEAH) += tcp_yeah.o
obj-$(CONFIG_TCP_CONG_ILLINOIS) += tcp_illinois.o
+obj-$(CONFIG_NET_SOCK_MSG) += tcp_bpf.o
obj-$(CONFIG_NETLABEL) += cipso_ipv4.o
obj-$(CONFIG_XFRM) += xfrm4_policy.o xfrm4_state.o xfrm4_input.o \
diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c
new file mode 100644
index 000000000000..80debb0daf37
--- /dev/null
+++ b/net/ipv4/tcp_bpf.c
@@ -0,0 +1,655 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
+
+#include <linux/skmsg.h>
+#include <linux/filter.h>
+#include <linux/bpf.h>
+#include <linux/init.h>
+#include <linux/wait.h>
+
+#include <net/inet_common.h>
+
+static bool tcp_bpf_stream_read(const struct sock *sk)
+{
+ struct sk_psock *psock;
+ bool empty = true;
+
+ rcu_read_lock();
+ psock = sk_psock(sk);
+ if (likely(psock))
+ empty = list_empty(&psock->ingress_msg);
+ rcu_read_unlock();
+ return !empty;
+}
+
+static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock,
+ int flags, long timeo, int *err)
+{
+ DEFINE_WAIT_FUNC(wait, woken_wake_function);
+ int ret;
+
+ add_wait_queue(sk_sleep(sk), &wait);
+ sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
+ ret = sk_wait_event(sk, &timeo,
+ !list_empty(&psock->ingress_msg) ||
+ !skb_queue_empty(&sk->sk_receive_queue), &wait);
+ sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
+ remove_wait_queue(sk_sleep(sk), &wait);
+ return ret;
+}
+
+int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
+ struct msghdr *msg, int len)
+{
+ struct iov_iter *iter = &msg->msg_iter;
+ int i, ret, copied = 0;
+
+ while (copied != len) {
+ struct scatterlist *sge;
+ struct sk_msg *msg_rx;
+
+ msg_rx = list_first_entry_or_null(&psock->ingress_msg,
+ struct sk_msg, list);
+ if (unlikely(!msg_rx))
+ break;
+
+ i = msg_rx->sg.start;
+ do {
+ struct page *page;
+ int copy;
+
+ sge = sk_msg_elem(msg_rx, i);
+ copy = sge->length;
+ page = sg_page(sge);
+ if (copied + copy > len)
+ copy = len - copied;
+ ret = copy_page_to_iter(page, sge->offset, copy, iter);
+ if (ret != copy) {
+ msg_rx->sg.start = i;
+ return -EFAULT;
+ }
+
+ copied += copy;
+ sge->offset += copy;
+ sge->length -= copy;
+ sk_mem_uncharge(sk, copy);
+ if (!sge->length) {
+ i++;
+ if (i == MAX_SKB_FRAGS)
+ i = 0;
+ if (!msg_rx->skb)
+ put_page(page);
+ }
+
+ if (copied == len)
+ break;
+ } while (i != msg_rx->sg.end);
+
+ msg_rx->sg.start = i;
+ if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
+ list_del(&msg_rx->list);
+ if (msg_rx->skb)
+ consume_skb(msg_rx->skb);
+ kfree(msg_rx);
+ }
+ }
+
+ return copied;
+}
+EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg);
+
+int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
+ int nonblock, int flags, int *addr_len)
+{
+ struct sk_psock *psock;
+ int copied, ret;
+
+ if (unlikely(flags & MSG_ERRQUEUE))
+ return inet_recv_error(sk, msg, len, addr_len);
+ if (!skb_queue_empty(&sk->sk_receive_queue))
+ return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
+
+ psock = sk_psock_get(sk);
+ if (unlikely(!psock))
+ return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
+ lock_sock(sk);
+msg_bytes_ready:
+ copied = __tcp_bpf_recvmsg(sk, psock, msg, len);
+ if (!copied) {
+ int data, err = 0;
+ long timeo;
+
+ timeo = sock_rcvtimeo(sk, nonblock);
+ data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err);
+ if (data) {
+ if (skb_queue_empty(&sk->sk_receive_queue))
+ goto msg_bytes_ready;
+ release_sock(sk);
+ sk_psock_put(sk, psock);
+ return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
+ }
+ if (err) {
+ ret = err;
+ goto out;
+ }
+ }
+ ret = copied;
+out:
+ release_sock(sk);
+ sk_psock_put(sk, psock);
+ return ret;
+}
+
+static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
+ struct sk_msg *msg, u32 apply_bytes, int flags)
+{
+ bool apply = apply_bytes;
+ struct scatterlist *sge;
+ u32 size, copied = 0;
+ struct sk_msg *tmp;
+ int i, ret = 0;
+
+ tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL);
+ if (unlikely(!tmp))
+ return -ENOMEM;
+
+ lock_sock(sk);
+ tmp->sg.start = msg->sg.start;
+ i = msg->sg.start;
+ do {
+ sge = sk_msg_elem(msg, i);
+ size = (apply && apply_bytes < sge->length) ?
+ apply_bytes : sge->length;
+ if (!sk_wmem_schedule(sk, size)) {
+ if (!copied)
+ ret = -ENOMEM;
+ break;
+ }
+
+ sk_mem_charge(sk, size);
+ sk_msg_xfer(tmp, msg, i, size);
+ copied += size;
+ if (sge->length)
+ get_page(sk_msg_page(tmp, i));
+ sk_msg_iter_var_next(i);
+ tmp->sg.end = i;
+ if (apply) {
+ apply_bytes -= size;
+ if (!apply_bytes)
+ break;
+ }
+ } while (i != msg->sg.end);
+
+ if (!ret) {
+ msg->sg.start = i;
+ msg->sg.size -= apply_bytes;
+ sk_psock_queue_msg(psock, tmp);
+ sk->sk_data_ready(sk);
+ } else {
+ sk_msg_free(sk, tmp);
+ kfree(tmp);
+ }
+
+ release_sock(sk);
+ return ret;
+}
+
+static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes,
+ int flags, bool uncharge)
+{
+ bool apply = apply_bytes;
+ struct scatterlist *sge;
+ struct page *page;
+ int size, ret = 0;
+ u32 off;
+
+ while (1) {
+ sge = sk_msg_elem(msg, msg->sg.start);
+ size = (apply && apply_bytes < sge->length) ?
+ apply_bytes : sge->length;
+ off = sge->offset;
+ page = sg_page(sge);
+
+ tcp_rate_check_app_limited(sk);
+retry:
+ ret = do_tcp_sendpages(sk, page, off, size, flags);
+ if (ret <= 0)
+ return ret;
+ if (apply)
+ apply_bytes -= ret;
+ msg->sg.size -= ret;
+ sge->offset += ret;
+ sge->length -= ret;
+ if (uncharge)
+ sk_mem_uncharge(sk, ret);
+ if (ret != size) {
+ size -= ret;
+ off += ret;
+ goto retry;
+ }
+ if (!sge->length) {
+ put_page(page);
+ sk_msg_iter_next(msg, start);
+ sg_init_table(sge, 1);
+ if (msg->sg.start == msg->sg.end)
+ break;
+ }
+ if (apply && !apply_bytes)
+ break;
+ }
+
+ return 0;
+}
+
+static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg,
+ u32 apply_bytes, int flags, bool uncharge)
+{
+ int ret;
+
+ lock_sock(sk);
+ ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge);
+ release_sock(sk);
+ return ret;
+}
+
+int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg,
+ u32 bytes, int flags)
+{
+ bool ingress = sk_msg_to_ingress(msg);
+ struct sk_psock *psock = sk_psock_get(sk);
+ int ret;
+
+ if (unlikely(!psock)) {
+ sk_msg_free(sk, msg);
+ return 0;
+ }
+ ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) :
+ tcp_bpf_push_locked(sk, msg, bytes, flags, false);
+ sk_psock_put(sk, psock);
+ return ret;
+}
+EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir);
+
+static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
+ struct sk_msg *msg, int *copied, int flags)
+{
+ bool cork = false, enospc = msg->sg.start == msg->sg.end;
+ struct sock *sk_redir;
+ u32 tosend;
+ int ret;
+
+more_data:
+ if (psock->eval == __SK_NONE)
+ psock->eval = sk_psock_msg_verdict(sk, psock, msg);
+
+ if (msg->cork_bytes &&
+ msg->cork_bytes > msg->sg.size && !enospc) {
+ psock->cork_bytes = msg->cork_bytes - msg->sg.size;
+ if (!psock->cork) {
+ psock->cork = kzalloc(sizeof(*psock->cork),
+ GFP_ATOMIC | __GFP_NOWARN);
+ if (!psock->cork)
+ return -ENOMEM;
+ }
+ memcpy(psock->cork, msg, sizeof(*msg));
+ return 0;
+ }
+
+ tosend = msg->sg.size;
+ if (psock->apply_bytes && psock->apply_bytes < tosend)
+ tosend = psock->apply_bytes;
+
+ switch (psock->eval) {
+ case __SK_PASS:
+ ret = tcp_bpf_push(sk, msg, tosend, flags, true);
+ if (unlikely(ret)) {
+ *copied -= sk_msg_free(sk, msg);
+ break;
+ }
+ sk_msg_apply_bytes(psock, tosend);
+ break;
+ case __SK_REDIRECT:
+ sk_redir = psock->sk_redir;
+ sk_msg_apply_bytes(psock, tosend);
+ if (psock->cork) {
+ cork = true;
+ psock->cork = NULL;
+ }
+ sk_msg_return(sk, msg, tosend);
+ release_sock(sk);
+ ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags);
+ lock_sock(sk);
+ if (unlikely(ret < 0)) {
+ int free = sk_msg_free_nocharge(sk, msg);
+
+ if (!cork)
+ *copied -= free;
+ }
+ if (cork) {
+ sk_msg_free(sk, msg);
+ kfree(msg);
+ msg = NULL;
+ ret = 0;
+ }
+ break;
+ case __SK_DROP:
+ default:
+ sk_msg_free_partial(sk, msg, tosend);
+ sk_msg_apply_bytes(psock, tosend);
+ *copied -= tosend;
+ return -EACCES;
+ }
+
+ if (likely(!ret)) {
+ if (!psock->apply_bytes) {
+ psock->eval = __SK_NONE;
+ if (psock->sk_redir) {
+ sock_put(psock->sk_redir);
+ psock->sk_redir = NULL;
+ }
+ }
+ if (msg &&
+ msg->sg.data[msg->sg.start].page_link &&
+ msg->sg.data[msg->sg.start].length)
+ goto more_data;
+ }
+ return ret;
+}
+
+static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
+{
+ struct sk_msg tmp, *msg_tx = NULL;
+ int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
+ int copied = 0, err = 0;
+ struct sk_psock *psock;
+ long timeo;
+
+ psock = sk_psock_get(sk);
+ if (unlikely(!psock))
+ return tcp_sendmsg(sk, msg, size);
+
+ lock_sock(sk);
+ timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
+ while (msg_data_left(msg)) {
+ bool enospc = false;
+ u32 copy, osize;
+
+ if (sk->sk_err) {
+ err = -sk->sk_err;
+ goto out_err;
+ }
+
+ copy = msg_data_left(msg);
+ if (!sk_stream_memory_free(sk))
+ goto wait_for_sndbuf;
+ if (psock->cork) {
+ msg_tx = psock->cork;
+ } else {
+ msg_tx = &tmp;
+ sk_msg_init(msg_tx);
+ }
+
+ osize = msg_tx->sg.size;
+ err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1);
+ if (err) {
+ if (err != -ENOSPC)
+ goto wait_for_memory;
+ enospc = true;
+ copy = msg_tx->sg.size - osize;
+ }
+
+ err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx,
+ copy);
+ if (err < 0) {
+ sk_msg_trim(sk, msg_tx, osize);
+ goto out_err;
+ }
+
+ copied += copy;
+ if (psock->cork_bytes) {
+ if (size > psock->cork_bytes)
+ psock->cork_bytes = 0;
+ else
+ psock->cork_bytes -= size;
+ if (psock->cork_bytes && !enospc)
+ goto out_err;
+ /* All cork bytes are accounted, rerun the prog. */
+ psock->eval = __SK_NONE;
+ psock->cork_bytes = 0;
+ }
+
+ err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags);
+ if (unlikely(err < 0))
+ goto out_err;
+ continue;
+wait_for_sndbuf:
+ set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
+wait_for_memory:
+ err = sk_stream_wait_memory(sk, &timeo);
+ if (err) {
+ if (msg_tx && msg_tx != psock->cork)
+ sk_msg_free(sk, msg_tx);
+ goto out_err;
+ }
+ }
+out_err:
+ if (err < 0)
+ err = sk_stream_error(sk, msg->msg_flags, err);
+ release_sock(sk);
+ sk_psock_put(sk, psock);
+ return copied ? copied : err;
+}
+
+static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset,
+ size_t size, int flags)
+{
+ struct sk_msg tmp, *msg = NULL;
+ int err = 0, copied = 0;
+ struct sk_psock *psock;
+ bool enospc = false;
+
+ psock = sk_psock_get(sk);
+ if (unlikely(!psock))
+ return tcp_sendpage(sk, page, offset, size, flags);
+
+ lock_sock(sk);
+ if (psock->cork) {
+ msg = psock->cork;
+ } else {
+ msg = &tmp;
+ sk_msg_init(msg);
+ }
+
+ /* Catch case where ring is full and sendpage is stalled. */
+ if (unlikely(sk_msg_full(msg)))
+ goto out_err;
+
+ sk_msg_page_add(msg, page, size, offset);
+ sk_mem_charge(sk, size);
+ copied = size;
+ if (sk_msg_full(msg))
+ enospc = true;
+ if (psock->cork_bytes) {
+ if (size > psock->cork_bytes)
+ psock->cork_bytes = 0;
+ else
+ psock->cork_bytes -= size;
+ if (psock->cork_bytes && !enospc)
+ goto out_err;
+ /* All cork bytes are accounted, rerun the prog. */
+ psock->eval = __SK_NONE;
+ psock->cork_bytes = 0;
+ }
+
+ err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags);
+out_err:
+ release_sock(sk);
+ sk_psock_put(sk, psock);
+ return copied ? copied : err;
+}
+
+static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock)
+{
+ struct sk_psock_link *link;
+
+ sk_psock_cork_free(psock);
+ __sk_psock_purge_ingress_msg(psock);
+ while ((link = sk_psock_link_pop(psock))) {
+ sk_psock_unlink(sk, link);
+ sk_psock_free_link(link);
+ }
+}
+
+static void tcp_bpf_unhash(struct sock *sk)
+{
+ void (*saved_unhash)(struct sock *sk);
+ struct sk_psock *psock;
+
+ rcu_read_lock();
+ psock = sk_psock(sk);
+ if (unlikely(!psock)) {
+ rcu_read_unlock();
+ if (sk->sk_prot->unhash)
+ sk->sk_prot->unhash(sk);
+ return;
+ }
+
+ saved_unhash = psock->saved_unhash;
+ tcp_bpf_remove(sk, psock);
+ rcu_read_unlock();
+ saved_unhash(sk);
+}
+
+static void tcp_bpf_close(struct sock *sk, long timeout)
+{
+ void (*saved_close)(struct sock *sk, long timeout);
+ struct sk_psock *psock;
+
+ lock_sock(sk);
+ rcu_read_lock();
+ psock = sk_psock(sk);
+ if (unlikely(!psock)) {
+ rcu_read_unlock();
+ release_sock(sk);
+ return sk->sk_prot->close(sk, timeout);
+ }
+
+ saved_close = psock->saved_close;
+ tcp_bpf_remove(sk, psock);
+ rcu_read_unlock();
+ release_sock(sk);
+ saved_close(sk, timeout);
+}
+
+enum {
+ TCP_BPF_IPV4,
+ TCP_BPF_IPV6,
+ TCP_BPF_NUM_PROTS,
+};
+
+enum {
+ TCP_BPF_BASE,
+ TCP_BPF_TX,
+ TCP_BPF_NUM_CFGS,
+};
+
+static struct proto *tcpv6_prot_saved __read_mostly;
+static DEFINE_SPINLOCK(tcpv6_prot_lock);
+static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS];
+
+static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
+ struct proto *base)
+{
+ prot[TCP_BPF_BASE] = *base;
+ prot[TCP_BPF_BASE].unhash = tcp_bpf_unhash;
+ prot[TCP_BPF_BASE].close = tcp_bpf_close;
+ prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg;
+ prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read;
+
+ prot[TCP_BPF_TX] = prot[TCP_BPF_BASE];
+ prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg;
+ prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage;
+}
+
+static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops)
+{
+ if (sk->sk_family == AF_INET6 &&
+ unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
+ spin_lock_bh(&tcpv6_prot_lock);
+ if (likely(ops != tcpv6_prot_saved)) {
+ tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops);
+ smp_store_release(&tcpv6_prot_saved, ops);
+ }
+ spin_unlock_bh(&tcpv6_prot_lock);
+ }
+}
+
+static int __init tcp_bpf_v4_build_proto(void)
+{
+ tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot);
+ return 0;
+}
+core_initcall(tcp_bpf_v4_build_proto);
+
+static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock)
+{
+ int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
+ int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
+
+ sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]);
+}
+
+static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock)
+{
+ int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
+ int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
+
+ /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed
+ * or added requiring sk_prot hook updates. We keep original saved
+ * hooks in this case.
+ */
+ sk->sk_prot = &tcp_bpf_prots[family][config];
+}
+
+static int tcp_bpf_assert_proto_ops(struct proto *ops)
+{
+ /* In order to avoid retpoline, we make assumptions when we call
+ * into ops if e.g. a psock is not present. Make sure they are
+ * indeed valid assumptions.
+ */
+ return ops->recvmsg == tcp_recvmsg &&
+ ops->sendmsg == tcp_sendmsg &&
+ ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
+}
+
+void tcp_bpf_reinit(struct sock *sk)
+{
+ struct sk_psock *psock;
+
+ sock_owned_by_me(sk);
+
+ rcu_read_lock();
+ psock = sk_psock(sk);
+ tcp_bpf_reinit_sk_prot(sk, psock);
+ rcu_read_unlock();
+}
+
+int tcp_bpf_init(struct sock *sk)
+{
+ struct proto *ops = READ_ONCE(sk->sk_prot);
+ struct sk_psock *psock;
+
+ sock_owned_by_me(sk);
+
+ rcu_read_lock();
+ psock = sk_psock(sk);
+ if (unlikely(!psock || psock->sk_proto ||
+ tcp_bpf_assert_proto_ops(ops))) {
+ rcu_read_unlock();
+ return -EINVAL;
+ }
+ tcp_bpf_check_v6_needs_rebuild(sk, ops);
+ tcp_bpf_update_sk_prot(sk, psock);
+ rcu_read_unlock();
+ return 0;
+}
diff --git a/net/strparser/Kconfig b/net/strparser/Kconfig
index 6cff3f6d0c3a..94da19a2a220 100644
--- a/net/strparser/Kconfig
+++ b/net/strparser/Kconfig
@@ -1,4 +1,2 @@
-
config STREAM_PARSER
- tristate
- default n
+ def_bool n