aboutsummaryrefslogtreecommitdiffstats
path: root/kernel/bpf/sockmap.c
diff options
context:
space:
mode:
Diffstat (limited to 'kernel/bpf/sockmap.c')
-rw-r--r--kernel/bpf/sockmap.c733
1 files changed, 707 insertions, 26 deletions
diff --git a/kernel/bpf/sockmap.c b/kernel/bpf/sockmap.c
index a927e89dad6e..69c5bccabd22 100644
--- a/kernel/bpf/sockmap.c
+++ b/kernel/bpf/sockmap.c
@@ -38,6 +38,7 @@
#include <linux/skbuff.h>
#include <linux/workqueue.h>
#include <linux/list.h>
+#include <linux/mm.h>
#include <net/strparser.h>
#include <net/tcp.h>
@@ -47,6 +48,7 @@
struct bpf_stab {
struct bpf_map map;
struct sock **sock_map;
+ struct bpf_prog *bpf_tx_msg;
struct bpf_prog *bpf_parse;
struct bpf_prog *bpf_verdict;
};
@@ -62,8 +64,7 @@ struct smap_psock_map_entry {
struct smap_psock {
struct rcu_head rcu;
- /* refcnt is used inside sk_callback_lock */
- u32 refcnt;
+ refcount_t refcnt;
/* datapath variables */
struct sk_buff_head rxqueue;
@@ -74,7 +75,16 @@ struct smap_psock {
int save_off;
struct sk_buff *save_skb;
+ /* datapath variables for tx_msg ULP */
+ struct sock *sk_redir;
+ int apply_bytes;
+ int cork_bytes;
+ int sg_size;
+ int eval;
+ struct sk_msg_buff *cork;
+
struct strparser strp;
+ struct bpf_prog *bpf_tx_msg;
struct bpf_prog *bpf_parse;
struct bpf_prog *bpf_verdict;
struct list_head maps;
@@ -92,6 +102,11 @@ struct smap_psock {
void (*save_write_space)(struct sock *sk);
};
+static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
+static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
+static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
+ int offset, size_t size, int flags);
+
static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
{
return rcu_dereference_sk_user_data(sk);
@@ -116,27 +131,41 @@ static int bpf_tcp_init(struct sock *sk)
psock->save_close = sk->sk_prot->close;
psock->sk_proto = sk->sk_prot;
+
+ if (psock->bpf_tx_msg) {
+ tcp_bpf_proto.sendmsg = bpf_tcp_sendmsg;
+ tcp_bpf_proto.sendpage = bpf_tcp_sendpage;
+ }
+
sk->sk_prot = &tcp_bpf_proto;
rcu_read_unlock();
return 0;
}
+static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
+static int free_start_sg(struct sock *sk, struct sk_msg_buff *md);
+
static void bpf_tcp_release(struct sock *sk)
{
struct smap_psock *psock;
rcu_read_lock();
psock = smap_psock_sk(sk);
+ if (unlikely(!psock))
+ goto out;
- if (likely(psock)) {
- sk->sk_prot = psock->sk_proto;
- psock->sk_proto = NULL;
+ if (psock->cork) {
+ free_start_sg(psock->sock, psock->cork);
+ kfree(psock->cork);
+ psock->cork = NULL;
}
+
+ sk->sk_prot = psock->sk_proto;
+ psock->sk_proto = NULL;
+out:
rcu_read_unlock();
}
-static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
-
static void bpf_tcp_close(struct sock *sk, long timeout)
{
void (*close_fun)(struct sock *sk, long timeout);
@@ -175,6 +204,7 @@ enum __sk_action {
__SK_DROP = 0,
__SK_PASS,
__SK_REDIRECT,
+ __SK_NONE,
};
static struct tcp_ulp_ops bpf_tcp_ulp_ops __read_mostly = {
@@ -186,10 +216,621 @@ static struct tcp_ulp_ops bpf_tcp_ulp_ops __read_mostly = {
.release = bpf_tcp_release,
};
+static int memcopy_from_iter(struct sock *sk,
+ struct sk_msg_buff *md,
+ struct iov_iter *from, int bytes)
+{
+ struct scatterlist *sg = md->sg_data;
+ int i = md->sg_curr, rc = -ENOSPC;
+
+ do {
+ int copy;
+ char *to;
+
+ if (md->sg_copybreak >= sg[i].length) {
+ md->sg_copybreak = 0;
+
+ if (++i == MAX_SKB_FRAGS)
+ i = 0;
+
+ if (i == md->sg_end)
+ break;
+ }
+
+ copy = sg[i].length - md->sg_copybreak;
+ to = sg_virt(&sg[i]) + md->sg_copybreak;
+ md->sg_copybreak += copy;
+
+ if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
+ rc = copy_from_iter_nocache(to, copy, from);
+ else
+ rc = copy_from_iter(to, copy, from);
+
+ if (rc != copy) {
+ rc = -EFAULT;
+ goto out;
+ }
+
+ bytes -= copy;
+ if (!bytes)
+ break;
+
+ md->sg_copybreak = 0;
+ if (++i == MAX_SKB_FRAGS)
+ i = 0;
+ } while (i != md->sg_end);
+out:
+ md->sg_curr = i;
+ return rc;
+}
+
+static int bpf_tcp_push(struct sock *sk, int apply_bytes,
+ struct sk_msg_buff *md,
+ int flags, bool uncharge)
+{
+ bool apply = apply_bytes;
+ struct scatterlist *sg;
+ int offset, ret = 0;
+ struct page *p;
+ size_t size;
+
+ while (1) {
+ sg = md->sg_data + md->sg_start;
+ size = (apply && apply_bytes < sg->length) ?
+ apply_bytes : sg->length;
+ offset = sg->offset;
+
+ tcp_rate_check_app_limited(sk);
+ p = sg_page(sg);
+retry:
+ ret = do_tcp_sendpages(sk, p, offset, size, flags);
+ if (ret != size) {
+ if (ret > 0) {
+ if (apply)
+ apply_bytes -= ret;
+ size -= ret;
+ offset += ret;
+ if (uncharge)
+ sk_mem_uncharge(sk, ret);
+ goto retry;
+ }
+
+ sg->length = size;
+ sg->offset = offset;
+ return ret;
+ }
+
+ if (apply)
+ apply_bytes -= ret;
+ sg->offset += ret;
+ sg->length -= ret;
+ if (uncharge)
+ sk_mem_uncharge(sk, ret);
+
+ if (!sg->length) {
+ put_page(p);
+ md->sg_start++;
+ if (md->sg_start == MAX_SKB_FRAGS)
+ md->sg_start = 0;
+ memset(sg, 0, sizeof(*sg));
+
+ if (md->sg_start == md->sg_end)
+ break;
+ }
+
+ if (apply && !apply_bytes)
+ break;
+ }
+ return 0;
+}
+
+static inline void bpf_compute_data_pointers_sg(struct sk_msg_buff *md)
+{
+ struct scatterlist *sg = md->sg_data + md->sg_start;
+
+ if (md->sg_copy[md->sg_start]) {
+ md->data = md->data_end = 0;
+ } else {
+ md->data = sg_virt(sg);
+ md->data_end = md->data + sg->length;
+ }
+}
+
+static void return_mem_sg(struct sock *sk, int bytes, struct sk_msg_buff *md)
+{
+ struct scatterlist *sg = md->sg_data;
+ int i = md->sg_start;
+
+ do {
+ int uncharge = (bytes < sg[i].length) ? bytes : sg[i].length;
+
+ sk_mem_uncharge(sk, uncharge);
+ bytes -= uncharge;
+ if (!bytes)
+ break;
+ i++;
+ if (i == MAX_SKB_FRAGS)
+ i = 0;
+ } while (i != md->sg_end);
+}
+
+static void free_bytes_sg(struct sock *sk, int bytes, struct sk_msg_buff *md)
+{
+ struct scatterlist *sg = md->sg_data;
+ int i = md->sg_start, free;
+
+ while (bytes && sg[i].length) {
+ free = sg[i].length;
+ if (bytes < free) {
+ sg[i].length -= bytes;
+ sg[i].offset += bytes;
+ sk_mem_uncharge(sk, bytes);
+ break;
+ }
+
+ sk_mem_uncharge(sk, sg[i].length);
+ put_page(sg_page(&sg[i]));
+ bytes -= sg[i].length;
+ sg[i].length = 0;
+ sg[i].page_link = 0;
+ sg[i].offset = 0;
+ i++;
+
+ if (i == MAX_SKB_FRAGS)
+ i = 0;
+ }
+}
+
+static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md)
+{
+ struct scatterlist *sg = md->sg_data;
+ int i = start, free = 0;
+
+ while (sg[i].length) {
+ free += sg[i].length;
+ sk_mem_uncharge(sk, sg[i].length);
+ put_page(sg_page(&sg[i]));
+ sg[i].length = 0;
+ sg[i].page_link = 0;
+ sg[i].offset = 0;
+ i++;
+
+ if (i == MAX_SKB_FRAGS)
+ i = 0;
+ }
+
+ return free;
+}
+
+static int free_start_sg(struct sock *sk, struct sk_msg_buff *md)
+{
+ int free = free_sg(sk, md->sg_start, md);
+
+ md->sg_start = md->sg_end;
+ return free;
+}
+
+static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md)
+{
+ return free_sg(sk, md->sg_curr, md);
+}
+
+static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md)
+{
+ return ((_rc == SK_PASS) ?
+ (md->map ? __SK_REDIRECT : __SK_PASS) :
+ __SK_DROP);
+}
+
+static unsigned int smap_do_tx_msg(struct sock *sk,
+ struct smap_psock *psock,
+ struct sk_msg_buff *md)
+{
+ struct bpf_prog *prog;
+ unsigned int rc, _rc;
+
+ preempt_disable();
+ rcu_read_lock();
+
+ /* If the policy was removed mid-send then default to 'accept' */
+ prog = READ_ONCE(psock->bpf_tx_msg);
+ if (unlikely(!prog)) {
+ _rc = SK_PASS;
+ goto verdict;
+ }
+
+ bpf_compute_data_pointers_sg(md);
+ rc = (*prog->bpf_func)(md, prog->insnsi);
+ psock->apply_bytes = md->apply_bytes;
+
+ /* Moving return codes from UAPI namespace into internal namespace */
+ _rc = bpf_map_msg_verdict(rc, md);
+
+ /* The psock has a refcount on the sock but not on the map and because
+ * we need to drop rcu read lock here its possible the map could be
+ * removed between here and when we need it to execute the sock
+ * redirect. So do the map lookup now for future use.
+ */
+ if (_rc == __SK_REDIRECT) {
+ if (psock->sk_redir)
+ sock_put(psock->sk_redir);
+ psock->sk_redir = do_msg_redirect_map(md);
+ if (!psock->sk_redir) {
+ _rc = __SK_DROP;
+ goto verdict;
+ }
+ sock_hold(psock->sk_redir);
+ }
+verdict:
+ rcu_read_unlock();
+ preempt_enable();
+
+ return _rc;
+}
+
+static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
+ struct sk_msg_buff *md,
+ int flags)
+{
+ struct smap_psock *psock;
+ struct scatterlist *sg;
+ int i, err, free = 0;
+
+ sg = md->sg_data;
+
+ rcu_read_lock();
+ psock = smap_psock_sk(sk);
+ if (unlikely(!psock))
+ goto out_rcu;
+
+ if (!refcount_inc_not_zero(&psock->refcnt))
+ goto out_rcu;
+
+ rcu_read_unlock();
+ lock_sock(sk);
+ err = bpf_tcp_push(sk, send, md, flags, false);
+ release_sock(sk);
+ smap_release_sock(psock, sk);
+ if (unlikely(err))
+ goto out;
+ return 0;
+out_rcu:
+ rcu_read_unlock();
+out:
+ i = md->sg_start;
+ while (sg[i].length) {
+ free += sg[i].length;
+ put_page(sg_page(&sg[i]));
+ sg[i].length = 0;
+ i++;
+ if (i == MAX_SKB_FRAGS)
+ i = 0;
+ }
+ return free;
+}
+
+static inline void bpf_md_init(struct smap_psock *psock)
+{
+ if (!psock->apply_bytes) {
+ psock->eval = __SK_NONE;
+ if (psock->sk_redir) {
+ sock_put(psock->sk_redir);
+ psock->sk_redir = NULL;
+ }
+ }
+}
+
+static void apply_bytes_dec(struct smap_psock *psock, int i)
+{
+ if (psock->apply_bytes) {
+ if (psock->apply_bytes < i)
+ psock->apply_bytes = 0;
+ else
+ psock->apply_bytes -= i;
+ }
+}
+
+static int bpf_exec_tx_verdict(struct smap_psock *psock,
+ struct sk_msg_buff *m,
+ struct sock *sk,
+ int *copied, int flags)
+{
+ bool cork = false, enospc = (m->sg_start == m->sg_end);
+ struct sock *redir;
+ int err = 0;
+ int send;
+
+more_data:
+ if (psock->eval == __SK_NONE)
+ psock->eval = smap_do_tx_msg(sk, psock, m);
+
+ if (m->cork_bytes &&
+ m->cork_bytes > psock->sg_size && !enospc) {
+ psock->cork_bytes = m->cork_bytes - psock->sg_size;
+ if (!psock->cork) {
+ psock->cork = kcalloc(1,
+ sizeof(struct sk_msg_buff),
+ GFP_ATOMIC | __GFP_NOWARN);
+
+ if (!psock->cork) {
+ err = -ENOMEM;
+ goto out_err;
+ }
+ }
+ memcpy(psock->cork, m, sizeof(*m));
+ goto out_err;
+ }
+
+ send = psock->sg_size;
+ if (psock->apply_bytes && psock->apply_bytes < send)
+ send = psock->apply_bytes;
+
+ switch (psock->eval) {
+ case __SK_PASS:
+ err = bpf_tcp_push(sk, send, m, flags, true);
+ if (unlikely(err)) {
+ *copied -= free_start_sg(sk, m);
+ break;
+ }
+
+ apply_bytes_dec(psock, send);
+ psock->sg_size -= send;
+ break;
+ case __SK_REDIRECT:
+ redir = psock->sk_redir;
+ apply_bytes_dec(psock, send);
+
+ if (psock->cork) {
+ cork = true;
+ psock->cork = NULL;
+ }
+
+ return_mem_sg(sk, send, m);
+ release_sock(sk);
+
+ err = bpf_tcp_sendmsg_do_redirect(redir, send, m, flags);
+ lock_sock(sk);
+
+ if (cork) {
+ free_start_sg(sk, m);
+ kfree(m);
+ m = NULL;
+ }
+ if (unlikely(err))
+ *copied -= err;
+ else
+ psock->sg_size -= send;
+ break;
+ case __SK_DROP:
+ default:
+ free_bytes_sg(sk, send, m);
+ apply_bytes_dec(psock, send);
+ *copied -= send;
+ psock->sg_size -= send;
+ err = -EACCES;
+ break;
+ }
+
+ if (likely(!err)) {
+ bpf_md_init(psock);
+ if (m &&
+ m->sg_data[m->sg_start].page_link &&
+ m->sg_data[m->sg_start].length)
+ goto more_data;
+ }
+
+out_err:
+ return err;
+}
+
+static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
+{
+ int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
+ struct sk_msg_buff md = {0};
+ unsigned int sg_copy = 0;
+ struct smap_psock *psock;
+ int copied = 0, err = 0;
+ struct scatterlist *sg;
+ long timeo;
+
+ /* Its possible a sock event or user removed the psock _but_ the ops
+ * have not been reprogrammed yet so we get here. In this case fallback
+ * to tcp_sendmsg. Note this only works because we _only_ ever allow
+ * a single ULP there is no hierarchy here.
+ */
+ rcu_read_lock();
+ psock = smap_psock_sk(sk);
+ if (unlikely(!psock)) {
+ rcu_read_unlock();
+ return tcp_sendmsg(sk, msg, size);
+ }
+
+ /* Increment the psock refcnt to ensure its not released while sending a
+ * message. Required because sk lookup and bpf programs are used in
+ * separate rcu critical sections. Its OK if we lose the map entry
+ * but we can't lose the sock reference.
+ */
+ if (!refcount_inc_not_zero(&psock->refcnt)) {
+ rcu_read_unlock();
+ return tcp_sendmsg(sk, msg, size);
+ }
+
+ sg = md.sg_data;
+ sg_init_table(sg, MAX_SKB_FRAGS);
+ rcu_read_unlock();
+
+ lock_sock(sk);
+ timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
+
+ while (msg_data_left(msg)) {
+ struct sk_msg_buff *m;
+ bool enospc = false;
+ int copy;
+
+ if (sk->sk_err) {
+ err = sk->sk_err;
+ goto out_err;
+ }
+
+ copy = msg_data_left(msg);
+ if (!sk_stream_memory_free(sk))
+ goto wait_for_sndbuf;
+
+ m = psock->cork_bytes ? psock->cork : &md;
+ m->sg_curr = m->sg_copybreak ? m->sg_curr : m->sg_end;
+ err = sk_alloc_sg(sk, copy, m->sg_data,
+ m->sg_start, &m->sg_end, &sg_copy,
+ m->sg_end - 1);
+ if (err) {
+ if (err != -ENOSPC)
+ goto wait_for_memory;
+ enospc = true;
+ copy = sg_copy;
+ }
+
+ err = memcopy_from_iter(sk, m, &msg->msg_iter, copy);
+ if (err < 0) {
+ free_curr_sg(sk, m);
+ goto out_err;
+ }
+
+ psock->sg_size += copy;
+ copied += copy;
+ sg_copy = 0;
+
+ /* When bytes are being corked skip running BPF program and
+ * applying verdict unless there is no more buffer space. In
+ * the ENOSPC case simply run BPF prorgram with currently
+ * accumulated data. We don't have much choice at this point
+ * we could try extending the page frags or chaining complex
+ * frags but even in these cases _eventually_ we will hit an
+ * OOM scenario. More complex recovery schemes may be
+ * implemented in the future, but BPF programs must handle
+ * the case where apply_cork requests are not honored. The
+ * canonical method to verify this is to check data length.
+ */
+ if (psock->cork_bytes) {
+ if (copy > psock->cork_bytes)
+ psock->cork_bytes = 0;
+ else
+ psock->cork_bytes -= copy;
+
+ if (psock->cork_bytes && !enospc)
+ goto out_cork;
+
+ /* All cork bytes accounted for re-run filter */
+ psock->eval = __SK_NONE;
+ psock->cork_bytes = 0;
+ }
+
+ err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
+ if (unlikely(err < 0))
+ goto out_err;
+ continue;
+wait_for_sndbuf:
+ set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
+wait_for_memory:
+ err = sk_stream_wait_memory(sk, &timeo);
+ if (err)
+ goto out_err;
+ }
+out_err:
+ if (err < 0)
+ err = sk_stream_error(sk, msg->msg_flags, err);
+out_cork:
+ release_sock(sk);
+ smap_release_sock(psock, sk);
+ return copied ? copied : err;
+}
+
+static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
+ int offset, size_t size, int flags)
+{
+ struct sk_msg_buff md = {0}, *m = NULL;
+ int err = 0, copied = 0;
+ struct smap_psock *psock;
+ struct scatterlist *sg;
+ bool enospc = false;
+
+ rcu_read_lock();
+ psock = smap_psock_sk(sk);
+ if (unlikely(!psock))
+ goto accept;
+
+ if (!refcount_inc_not_zero(&psock->refcnt))
+ goto accept;
+ rcu_read_unlock();
+
+ lock_sock(sk);
+
+ if (psock->cork_bytes)
+ m = psock->cork;
+ else
+ m = &md;
+
+ /* Catch case where ring is full and sendpage is stalled. */
+ if (unlikely(m->sg_end == m->sg_start &&
+ m->sg_data[m->sg_end].length))
+ goto out_err;
+
+ psock->sg_size += size;
+ sg = &m->sg_data[m->sg_end];
+ sg_set_page(sg, page, size, offset);
+ get_page(page);
+ m->sg_copy[m->sg_end] = true;
+ sk_mem_charge(sk, size);
+ m->sg_end++;
+ copied = size;
+
+ if (m->sg_end == MAX_SKB_FRAGS)
+ m->sg_end = 0;
+
+ if (m->sg_end == m->sg_start)
+ enospc = true;
+
+ if (psock->cork_bytes) {
+ if (size > psock->cork_bytes)
+ psock->cork_bytes = 0;
+ else
+ psock->cork_bytes -= size;
+
+ if (psock->cork_bytes && !enospc)
+ goto out_err;
+
+ /* All cork bytes accounted for re-run filter */
+ psock->eval = __SK_NONE;
+ psock->cork_bytes = 0;
+ }
+
+ err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
+out_err:
+ release_sock(sk);
+ smap_release_sock(psock, sk);
+ return copied ? copied : err;
+accept:
+ rcu_read_unlock();
+ return tcp_sendpage(sk, page, offset, size, flags);
+}
+
+static void bpf_tcp_msg_add(struct smap_psock *psock,
+ struct sock *sk,
+ struct bpf_prog *tx_msg)
+{
+ struct bpf_prog *orig_tx_msg;
+
+ orig_tx_msg = xchg(&psock->bpf_tx_msg, tx_msg);
+ if (orig_tx_msg)
+ bpf_prog_put(orig_tx_msg);
+}
+
static int bpf_tcp_ulp_register(void)
{
tcp_bpf_proto = tcp_prot;
tcp_bpf_proto.close = bpf_tcp_close;
+ /* Once BPF TX ULP is registered it is never unregistered. It
+ * will be in the ULP list for the lifetime of the system. Doing
+ * duplicate registers is not a problem.
+ */
return tcp_register_ulp(&bpf_tcp_ulp_ops);
}
@@ -373,15 +1014,13 @@ static void smap_destroy_psock(struct rcu_head *rcu)
static void smap_release_sock(struct smap_psock *psock, struct sock *sock)
{
- psock->refcnt--;
- if (psock->refcnt)
- return;
-
- tcp_cleanup_ulp(sock);
- smap_stop_sock(psock, sock);
- clear_bit(SMAP_TX_RUNNING, &psock->state);
- rcu_assign_sk_user_data(sock, NULL);
- call_rcu_sched(&psock->rcu, smap_destroy_psock);
+ if (refcount_dec_and_test(&psock->refcnt)) {
+ tcp_cleanup_ulp(sock);
+ smap_stop_sock(psock, sock);
+ clear_bit(SMAP_TX_RUNNING, &psock->state);
+ rcu_assign_sk_user_data(sock, NULL);
+ call_rcu_sched(&psock->rcu, smap_destroy_psock);
+ }
}
static int smap_parse_func_strparser(struct strparser *strp,
@@ -415,7 +1054,6 @@ static int smap_parse_func_strparser(struct strparser *strp,
return rc;
}
-
static int smap_read_sock_done(struct strparser *strp, int err)
{
return err;
@@ -485,12 +1123,22 @@ static void smap_gc_work(struct work_struct *w)
bpf_prog_put(psock->bpf_parse);
if (psock->bpf_verdict)
bpf_prog_put(psock->bpf_verdict);
+ if (psock->bpf_tx_msg)
+ bpf_prog_put(psock->bpf_tx_msg);
+
+ if (psock->cork) {
+ free_start_sg(psock->sock, psock->cork);
+ kfree(psock->cork);
+ }
list_for_each_entry_safe(e, tmp, &psock->maps, list) {
list_del(&e->list);
kfree(e);
}
+ if (psock->sk_redir)
+ sock_put(psock->sk_redir);
+
sock_put(psock->sock);
kfree(psock);
}
@@ -506,12 +1154,13 @@ static struct smap_psock *smap_init_psock(struct sock *sock,
if (!psock)
return ERR_PTR(-ENOMEM);
+ psock->eval = __SK_NONE;
psock->sock = sock;
skb_queue_head_init(&psock->rxqueue);
INIT_WORK(&psock->tx_work, smap_tx_work);
INIT_WORK(&psock->gc_work, smap_gc_work);
INIT_LIST_HEAD(&psock->maps);
- psock->refcnt = 1;
+ refcount_set(&psock->refcnt, 1);
rcu_assign_sk_user_data(sock, psock);
sock_hold(sock);
@@ -714,10 +1363,11 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
{
struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
struct smap_psock_map_entry *e = NULL;
- struct bpf_prog *verdict, *parse;
+ struct bpf_prog *verdict, *parse, *tx_msg;
struct sock *osock, *sock;
struct smap_psock *psock;
u32 i = *(u32 *)key;
+ bool new = false;
int err;
if (unlikely(flags > BPF_EXIST))
@@ -740,6 +1390,7 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
*/
verdict = READ_ONCE(stab->bpf_verdict);
parse = READ_ONCE(stab->bpf_parse);
+ tx_msg = READ_ONCE(stab->bpf_tx_msg);
if (parse && verdict) {
/* bpf prog refcnt may be zero if a concurrent attach operation
@@ -758,6 +1409,17 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
}
}
+ if (tx_msg) {
+ tx_msg = bpf_prog_inc_not_zero(stab->bpf_tx_msg);
+ if (IS_ERR(tx_msg)) {
+ if (verdict)
+ bpf_prog_put(verdict);
+ if (parse)
+ bpf_prog_put(parse);
+ return PTR_ERR(tx_msg);
+ }
+ }
+
write_lock_bh(&sock->sk_callback_lock);
psock = smap_psock_sk(sock);
@@ -772,7 +1434,14 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
err = -EBUSY;
goto out_progs;
}
- psock->refcnt++;
+ if (READ_ONCE(psock->bpf_tx_msg) && tx_msg) {
+ err = -EBUSY;
+ goto out_progs;
+ }
+ if (!refcount_inc_not_zero(&psock->refcnt)) {
+ err = -EAGAIN;
+ goto out_progs;
+ }
} else {
psock = smap_init_psock(sock, stab);
if (IS_ERR(psock)) {
@@ -780,11 +1449,8 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
goto out_progs;
}
- err = tcp_set_ulp_id(sock, TCP_ULP_BPF);
- if (err)
- goto out_progs;
-
set_bit(SMAP_TX_RUNNING, &psock->state);
+ new = true;
}
e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
@@ -797,6 +1463,14 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
/* 3. At this point we have a reference to a valid psock that is
* running. Attach any BPF programs needed.
*/
+ if (tx_msg)
+ bpf_tcp_msg_add(psock, sock, tx_msg);
+ if (new) {
+ err = tcp_set_ulp_id(sock, TCP_ULP_BPF);
+ if (err)
+ goto out_free;
+ }
+
if (parse && verdict && !psock->strp_enabled) {
err = smap_init_sock(psock, sock);
if (err)
@@ -818,8 +1492,6 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
struct smap_psock *opsock = smap_psock_sk(osock);
write_lock_bh(&osock->sk_callback_lock);
- if (osock != sock && parse)
- smap_stop_sock(opsock, osock);
smap_list_remove(opsock, &stab->sock_map[i]);
smap_release_sock(opsock, osock);
write_unlock_bh(&osock->sk_callback_lock);
@@ -832,6 +1504,8 @@ out_progs:
bpf_prog_put(verdict);
if (parse)
bpf_prog_put(parse);
+ if (tx_msg)
+ bpf_prog_put(tx_msg);
write_unlock_bh(&sock->sk_callback_lock);
kfree(e);
return err;
@@ -846,6 +1520,9 @@ int sock_map_prog(struct bpf_map *map, struct bpf_prog *prog, u32 type)
return -EINVAL;
switch (type) {
+ case BPF_SK_MSG_VERDICT:
+ orig = xchg(&stab->bpf_tx_msg, prog);
+ break;
case BPF_SK_SKB_STREAM_PARSER:
orig = xchg(&stab->bpf_parse, prog);
break;
@@ -907,6 +1584,10 @@ static void sock_map_release(struct bpf_map *map, struct file *map_file)
orig = xchg(&stab->bpf_verdict, NULL);
if (orig)
bpf_prog_put(orig);
+
+ orig = xchg(&stab->bpf_tx_msg, NULL);
+ if (orig)
+ bpf_prog_put(orig);
}
const struct bpf_map_ops sock_map_ops = {