diff options
Diffstat (limited to 'net/core/sock_map.c')
-rw-r--r-- | net/core/sock_map.c | 141 |
1 files changed, 95 insertions, 46 deletions
diff --git a/net/core/sock_map.c b/net/core/sock_map.c index 4059f94e9bb5..119f52a99dc1 100644 --- a/net/core/sock_map.c +++ b/net/core/sock_map.c @@ -70,11 +70,49 @@ int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog) struct fd f; int ret; + if (attr->attach_flags || attr->replace_bpf_fd) + return -EINVAL; + f = fdget(ufd); map = __bpf_map_get(f); if (IS_ERR(map)) return PTR_ERR(map); - ret = sock_map_prog_update(map, prog, attr->attach_type); + ret = sock_map_prog_update(map, prog, NULL, attr->attach_type); + fdput(f); + return ret; +} + +int sock_map_prog_detach(const union bpf_attr *attr, enum bpf_prog_type ptype) +{ + u32 ufd = attr->target_fd; + struct bpf_prog *prog; + struct bpf_map *map; + struct fd f; + int ret; + + if (attr->attach_flags || attr->replace_bpf_fd) + return -EINVAL; + + f = fdget(ufd); + map = __bpf_map_get(f); + if (IS_ERR(map)) + return PTR_ERR(map); + + prog = bpf_prog_get(attr->attach_bpf_fd); + if (IS_ERR(prog)) { + ret = PTR_ERR(prog); + goto put_map; + } + + if (prog->type != ptype) { + ret = -EINVAL; + goto put_prog; + } + + ret = sock_map_prog_update(map, NULL, prog, attr->attach_type); +put_prog: + bpf_prog_put(prog); +put_map: fdput(f); return ret; } @@ -643,6 +681,7 @@ const struct bpf_func_proto bpf_msg_redirect_map_proto = { .arg4_type = ARG_ANYTHING, }; +static int sock_map_btf_id; const struct bpf_map_ops sock_map_ops = { .map_alloc = sock_map_alloc, .map_free = sock_map_free, @@ -653,9 +692,11 @@ const struct bpf_map_ops sock_map_ops = { .map_lookup_elem = sock_map_lookup, .map_release_uref = sock_map_release_progs, .map_check_btf = map_check_no_btf, + .map_btf_name = "bpf_stab", + .map_btf_id = &sock_map_btf_id, }; -struct bpf_htab_elem { +struct bpf_shtab_elem { struct rcu_head rcu; u32 hash; struct sock *sk; @@ -663,14 +704,14 @@ struct bpf_htab_elem { u8 key[]; }; -struct bpf_htab_bucket { +struct bpf_shtab_bucket { struct hlist_head head; raw_spinlock_t lock; }; -struct bpf_htab { +struct bpf_shtab { struct bpf_map map; - struct bpf_htab_bucket *buckets; + struct bpf_shtab_bucket *buckets; u32 buckets_num; u32 elem_size; struct sk_psock_progs progs; @@ -682,17 +723,17 @@ static inline u32 sock_hash_bucket_hash(const void *key, u32 len) return jhash(key, len, 0); } -static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab, - u32 hash) +static struct bpf_shtab_bucket *sock_hash_select_bucket(struct bpf_shtab *htab, + u32 hash) { return &htab->buckets[hash & (htab->buckets_num - 1)]; } -static struct bpf_htab_elem * +static struct bpf_shtab_elem * sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key, u32 key_size) { - struct bpf_htab_elem *elem; + struct bpf_shtab_elem *elem; hlist_for_each_entry_rcu(elem, head, node) { if (elem->hash == hash && @@ -705,10 +746,10 @@ sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key, static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) { - struct bpf_htab *htab = container_of(map, struct bpf_htab, map); + struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); u32 key_size = map->key_size, hash; - struct bpf_htab_bucket *bucket; - struct bpf_htab_elem *elem; + struct bpf_shtab_bucket *bucket; + struct bpf_shtab_elem *elem; WARN_ON_ONCE(!rcu_read_lock_held()); @@ -719,8 +760,8 @@ static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) return elem ? elem->sk : NULL; } -static void sock_hash_free_elem(struct bpf_htab *htab, - struct bpf_htab_elem *elem) +static void sock_hash_free_elem(struct bpf_shtab *htab, + struct bpf_shtab_elem *elem) { atomic_dec(&htab->count); kfree_rcu(elem, rcu); @@ -729,9 +770,9 @@ static void sock_hash_free_elem(struct bpf_htab *htab, static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, void *link_raw) { - struct bpf_htab *htab = container_of(map, struct bpf_htab, map); - struct bpf_htab_elem *elem_probe, *elem = link_raw; - struct bpf_htab_bucket *bucket; + struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); + struct bpf_shtab_elem *elem_probe, *elem = link_raw; + struct bpf_shtab_bucket *bucket; WARN_ON_ONCE(!rcu_read_lock_held()); bucket = sock_hash_select_bucket(htab, elem->hash); @@ -753,10 +794,10 @@ static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, static int sock_hash_delete_elem(struct bpf_map *map, void *key) { - struct bpf_htab *htab = container_of(map, struct bpf_htab, map); + struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); u32 hash, key_size = map->key_size; - struct bpf_htab_bucket *bucket; - struct bpf_htab_elem *elem; + struct bpf_shtab_bucket *bucket; + struct bpf_shtab_elem *elem; int ret = -ENOENT; hash = sock_hash_bucket_hash(key, key_size); @@ -774,12 +815,12 @@ static int sock_hash_delete_elem(struct bpf_map *map, void *key) return ret; } -static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab, - void *key, u32 key_size, - u32 hash, struct sock *sk, - struct bpf_htab_elem *old) +static struct bpf_shtab_elem *sock_hash_alloc_elem(struct bpf_shtab *htab, + void *key, u32 key_size, + u32 hash, struct sock *sk, + struct bpf_shtab_elem *old) { - struct bpf_htab_elem *new; + struct bpf_shtab_elem *new; if (atomic_inc_return(&htab->count) > htab->map.max_entries) { if (!old) { @@ -803,10 +844,10 @@ static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab, static int sock_hash_update_common(struct bpf_map *map, void *key, struct sock *sk, u64 flags) { - struct bpf_htab *htab = container_of(map, struct bpf_htab, map); + struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); u32 key_size = map->key_size, hash; - struct bpf_htab_elem *elem, *elem_new; - struct bpf_htab_bucket *bucket; + struct bpf_shtab_elem *elem, *elem_new; + struct bpf_shtab_bucket *bucket; struct sk_psock_link *link; struct sk_psock *psock; int ret; @@ -916,8 +957,8 @@ out: static int sock_hash_get_next_key(struct bpf_map *map, void *key, void *key_next) { - struct bpf_htab *htab = container_of(map, struct bpf_htab, map); - struct bpf_htab_elem *elem, *elem_next; + struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); + struct bpf_shtab_elem *elem, *elem_next; u32 hash, key_size = map->key_size; struct hlist_head *head; int i = 0; @@ -931,7 +972,7 @@ static int sock_hash_get_next_key(struct bpf_map *map, void *key, goto find_first_elem; elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)), - struct bpf_htab_elem, node); + struct bpf_shtab_elem, node); if (elem_next) { memcpy(key_next, elem_next->key, key_size); return 0; @@ -943,7 +984,7 @@ find_first_elem: for (; i < htab->buckets_num; i++) { head = &sock_hash_select_bucket(htab, i)->head; elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)), - struct bpf_htab_elem, node); + struct bpf_shtab_elem, node); if (elem_next) { memcpy(key_next, elem_next->key, key_size); return 0; @@ -955,7 +996,7 @@ find_first_elem: static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) { - struct bpf_htab *htab; + struct bpf_shtab *htab; int i, err; u64 cost; @@ -977,15 +1018,15 @@ static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) bpf_map_init_from_attr(&htab->map, attr); htab->buckets_num = roundup_pow_of_two(htab->map.max_entries); - htab->elem_size = sizeof(struct bpf_htab_elem) + + htab->elem_size = sizeof(struct bpf_shtab_elem) + round_up(htab->map.key_size, 8); if (htab->buckets_num == 0 || - htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) { + htab->buckets_num > U32_MAX / sizeof(struct bpf_shtab_bucket)) { err = -EINVAL; goto free_htab; } - cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) + + cost = (u64) htab->buckets_num * sizeof(struct bpf_shtab_bucket) + (u64) htab->elem_size * htab->map.max_entries; if (cost >= U32_MAX - PAGE_SIZE) { err = -EINVAL; @@ -996,7 +1037,7 @@ static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) goto free_htab; htab->buckets = bpf_map_area_alloc(htab->buckets_num * - sizeof(struct bpf_htab_bucket), + sizeof(struct bpf_shtab_bucket), htab->map.numa_node); if (!htab->buckets) { bpf_map_charge_finish(&htab->map.memory); @@ -1017,10 +1058,10 @@ free_htab: static void sock_hash_free(struct bpf_map *map) { - struct bpf_htab *htab = container_of(map, struct bpf_htab, map); - struct bpf_htab_bucket *bucket; + struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); + struct bpf_shtab_bucket *bucket; struct hlist_head unlink_list; - struct bpf_htab_elem *elem; + struct bpf_shtab_elem *elem; struct hlist_node *node; int i; @@ -1096,7 +1137,7 @@ static void *sock_hash_lookup(struct bpf_map *map, void *key) static void sock_hash_release_progs(struct bpf_map *map) { - psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs); + psock_progs_drop(&container_of(map, struct bpf_shtab, map)->progs); } BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops, @@ -1176,6 +1217,7 @@ const struct bpf_func_proto bpf_msg_redirect_hash_proto = { .arg4_type = ARG_ANYTHING, }; +static int sock_hash_map_btf_id; const struct bpf_map_ops sock_hash_ops = { .map_alloc = sock_hash_alloc, .map_free = sock_hash_free, @@ -1186,6 +1228,8 @@ const struct bpf_map_ops sock_hash_ops = { .map_lookup_elem_sys_only = sock_hash_lookup_sys, .map_release_uref = sock_hash_release_progs, .map_check_btf = map_check_no_btf, + .map_btf_name = "bpf_shtab", + .map_btf_id = &sock_hash_map_btf_id, }; static struct sk_psock_progs *sock_map_progs(struct bpf_map *map) @@ -1194,7 +1238,7 @@ static struct sk_psock_progs *sock_map_progs(struct bpf_map *map) case BPF_MAP_TYPE_SOCKMAP: return &container_of(map, struct bpf_stab, map)->progs; case BPF_MAP_TYPE_SOCKHASH: - return &container_of(map, struct bpf_htab, map)->progs; + return &container_of(map, struct bpf_shtab, map)->progs; default: break; } @@ -1203,27 +1247,32 @@ static struct sk_psock_progs *sock_map_progs(struct bpf_map *map) } int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, - u32 which) + struct bpf_prog *old, u32 which) { struct sk_psock_progs *progs = sock_map_progs(map); + struct bpf_prog **pprog; if (!progs) return -EOPNOTSUPP; switch (which) { case BPF_SK_MSG_VERDICT: - psock_set_prog(&progs->msg_parser, prog); + pprog = &progs->msg_parser; break; case BPF_SK_SKB_STREAM_PARSER: - psock_set_prog(&progs->skb_parser, prog); + pprog = &progs->skb_parser; break; case BPF_SK_SKB_STREAM_VERDICT: - psock_set_prog(&progs->skb_verdict, prog); + pprog = &progs->skb_verdict; break; default: return -EOPNOTSUPP; } + if (old) + return psock_replace_prog(pprog, prog, old); + + psock_set_prog(pprog, prog); return 0; } |