aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--arch/x86/net/bpf_jit_comp.c237
-rw-r--r--include/linux/bpf.h3
-rw-r--r--include/linux/bpf_verifier.h1
-rw-r--r--kernel/bpf/arraymap.c40
-rw-r--r--kernel/bpf/core.c2
-rw-r--r--kernel/bpf/verifier.c16
6 files changed, 244 insertions, 55 deletions
diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
index 7b0ff169c9a0..26f43279b78b 100644
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -221,14 +221,48 @@ struct jit_context {
/* Number of bytes emit_patch() needs to generate instructions */
#define X86_PATCH_SIZE 5
+/* Number of bytes that will be skipped on tailcall */
+#define X86_TAIL_CALL_OFFSET 11
-#define PROLOGUE_SIZE 25
+static void push_callee_regs(u8 **pprog, bool *callee_regs_used)
+{
+ u8 *prog = *pprog;
+ int cnt = 0;
+
+ if (callee_regs_used[0])
+ EMIT1(0x53); /* push rbx */
+ if (callee_regs_used[1])
+ EMIT2(0x41, 0x55); /* push r13 */
+ if (callee_regs_used[2])
+ EMIT2(0x41, 0x56); /* push r14 */
+ if (callee_regs_used[3])
+ EMIT2(0x41, 0x57); /* push r15 */
+ *pprog = prog;
+}
+
+static void pop_callee_regs(u8 **pprog, bool *callee_regs_used)
+{
+ u8 *prog = *pprog;
+ int cnt = 0;
+
+ if (callee_regs_used[3])
+ EMIT2(0x41, 0x5F); /* pop r15 */
+ if (callee_regs_used[2])
+ EMIT2(0x41, 0x5E); /* pop r14 */
+ if (callee_regs_used[1])
+ EMIT2(0x41, 0x5D); /* pop r13 */
+ if (callee_regs_used[0])
+ EMIT1(0x5B); /* pop rbx */
+ *pprog = prog;
+}
/*
- * Emit x86-64 prologue code for BPF program and check its size.
- * bpf_tail_call helper will skip it while jumping into another program
+ * Emit x86-64 prologue code for BPF program.
+ * bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes
+ * while jumping to another program
*/
-static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf)
+static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf,
+ bool tail_call_reachable, bool is_subprog)
{
u8 *prog = *pprog;
int cnt = X86_PATCH_SIZE;
@@ -238,19 +272,18 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf)
*/
memcpy(prog, ideal_nops[NOP_ATOMIC5], cnt);
prog += cnt;
+ if (!ebpf_from_cbpf) {
+ if (tail_call_reachable && !is_subprog)
+ EMIT2(0x31, 0xC0); /* xor eax, eax */
+ else
+ EMIT2(0x66, 0x90); /* nop2 */
+ }
EMIT1(0x55); /* push rbp */
EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */
/* sub rsp, rounded_stack_depth */
EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8));
- EMIT1(0x53); /* push rbx */
- EMIT2(0x41, 0x55); /* push r13 */
- EMIT2(0x41, 0x56); /* push r14 */
- EMIT2(0x41, 0x57); /* push r15 */
- if (!ebpf_from_cbpf) {
- /* zero init tail_call_cnt */
- EMIT2(0x6a, 0x00);
- BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
- }
+ if (tail_call_reachable)
+ EMIT1(0x50); /* push rax */
*pprog = prog;
}
@@ -314,13 +347,14 @@ static int __bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
mutex_lock(&text_mutex);
if (memcmp(ip, old_insn, X86_PATCH_SIZE))
goto out;
+ ret = 1;
if (memcmp(ip, new_insn, X86_PATCH_SIZE)) {
if (text_live)
text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL);
else
memcpy(ip, new_insn, X86_PATCH_SIZE);
+ ret = 0;
}
- ret = 0;
out:
mutex_unlock(&text_mutex);
return ret;
@@ -337,6 +371,22 @@ int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
return __bpf_arch_text_poke(ip, t, old_addr, new_addr, true);
}
+static int get_pop_bytes(bool *callee_regs_used)
+{
+ int bytes = 0;
+
+ if (callee_regs_used[3])
+ bytes += 2;
+ if (callee_regs_used[2])
+ bytes += 2;
+ if (callee_regs_used[1])
+ bytes += 2;
+ if (callee_regs_used[0])
+ bytes += 1;
+
+ return bytes;
+}
+
/*
* Generate the following code:
*
@@ -351,12 +401,26 @@ int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
* goto *(prog->bpf_func + prologue_size);
* out:
*/
-static void emit_bpf_tail_call_indirect(u8 **pprog)
+static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
+ u32 stack_depth)
{
+ int tcc_off = -4 - round_up(stack_depth, 8);
u8 *prog = *pprog;
- int label1, label2, label3;
+ int pop_bytes = 0;
+ int off1 = 49;
+ int off2 = 38;
+ int off3 = 16;
int cnt = 0;
+ /* count the additional bytes used for popping callee regs from stack
+ * that need to be taken into account for each of the offsets that
+ * are used for bailing out of the tail call
+ */
+ pop_bytes = get_pop_bytes(callee_regs_used);
+ off1 += pop_bytes;
+ off2 += pop_bytes;
+ off3 += pop_bytes;
+
/*
* rdi - pointer to ctx
* rsi - pointer to bpf_array
@@ -370,21 +434,19 @@ static void emit_bpf_tail_call_indirect(u8 **pprog)
EMIT2(0x89, 0xD2); /* mov edx, edx */
EMIT3(0x39, 0x56, /* cmp dword ptr [rsi + 16], edx */
offsetof(struct bpf_array, map.max_entries));
-#define OFFSET1 (41 + RETPOLINE_RCX_BPF_JIT_SIZE) /* Number of bytes to jump */
+#define OFFSET1 (off1 + RETPOLINE_RCX_BPF_JIT_SIZE) /* Number of bytes to jump */
EMIT2(X86_JBE, OFFSET1); /* jbe out */
- label1 = cnt;
/*
* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
* goto out;
*/
- EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */
+ EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */
EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */
-#define OFFSET2 (30 + RETPOLINE_RCX_BPF_JIT_SIZE)
+#define OFFSET2 (off2 + RETPOLINE_RCX_BPF_JIT_SIZE)
EMIT2(X86_JA, OFFSET2); /* ja out */
- label2 = cnt;
EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */
- EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */
+ EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */
/* prog = array->ptrs[index]; */
EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6, /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */
@@ -394,48 +456,84 @@ static void emit_bpf_tail_call_indirect(u8 **pprog)
* if (prog == NULL)
* goto out;
*/
- EMIT3(0x48, 0x85, 0xC9); /* test rcx,rcx */
-#define OFFSET3 (8 + RETPOLINE_RCX_BPF_JIT_SIZE)
+ EMIT3(0x48, 0x85, 0xC9); /* test rcx,rcx */
+#define OFFSET3 (off3 + RETPOLINE_RCX_BPF_JIT_SIZE)
EMIT2(X86_JE, OFFSET3); /* je out */
- label3 = cnt;
- /* goto *(prog->bpf_func + prologue_size); */
+ *pprog = prog;
+ pop_callee_regs(pprog, callee_regs_used);
+ prog = *pprog;
+
+ EMIT1(0x58); /* pop rax */
+ EMIT3_off32(0x48, 0x81, 0xC4, /* add rsp, sd */
+ round_up(stack_depth, 8));
+
+ /* goto *(prog->bpf_func + X86_TAIL_CALL_OFFSET); */
EMIT4(0x48, 0x8B, 0x49, /* mov rcx, qword ptr [rcx + 32] */
offsetof(struct bpf_prog, bpf_func));
- EMIT4(0x48, 0x83, 0xC1, PROLOGUE_SIZE); /* add rcx, prologue_size */
-
+ EMIT4(0x48, 0x83, 0xC1, /* add rcx, X86_TAIL_CALL_OFFSET */
+ X86_TAIL_CALL_OFFSET);
/*
* Now we're ready to jump into next BPF program
* rdi == ctx (1st arg)
- * rcx == prog->bpf_func + prologue_size
+ * rcx == prog->bpf_func + X86_TAIL_CALL_OFFSET
*/
RETPOLINE_RCX_BPF_JIT();
/* out: */
- BUILD_BUG_ON(cnt - label1 != OFFSET1);
- BUILD_BUG_ON(cnt - label2 != OFFSET2);
- BUILD_BUG_ON(cnt - label3 != OFFSET3);
*pprog = prog;
}
static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke,
- u8 **pprog, int addr, u8 *image)
+ u8 **pprog, int addr, u8 *image,
+ bool *callee_regs_used, u32 stack_depth)
{
+ int tcc_off = -4 - round_up(stack_depth, 8);
u8 *prog = *pprog;
+ int pop_bytes = 0;
+ int off1 = 27;
+ int poke_off;
int cnt = 0;
+ /* count the additional bytes used for popping callee regs to stack
+ * that need to be taken into account for jump offset that is used for
+ * bailing out from of the tail call when limit is reached
+ */
+ pop_bytes = get_pop_bytes(callee_regs_used);
+ off1 += pop_bytes;
+
+ /*
+ * total bytes for:
+ * - nop5/ jmpq $off
+ * - pop callee regs
+ * - sub rsp, $val
+ * - pop rax
+ */
+ poke_off = X86_PATCH_SIZE + pop_bytes + 7 + 1;
+
/*
* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
* goto out;
*/
- EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */
+ EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */
EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */
- EMIT2(X86_JA, 14); /* ja out */
+ EMIT2(X86_JA, off1); /* ja out */
EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */
- EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */
+ EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */
+ poke->tailcall_bypass = image + (addr - poke_off - X86_PATCH_SIZE);
+ poke->adj_off = X86_TAIL_CALL_OFFSET;
poke->tailcall_target = image + (addr - X86_PATCH_SIZE);
- poke->adj_off = PROLOGUE_SIZE;
+ poke->bypass_addr = (u8 *)poke->tailcall_target + X86_PATCH_SIZE;
+
+ emit_jump(&prog, (u8 *)poke->tailcall_target + X86_PATCH_SIZE,
+ poke->tailcall_bypass);
+
+ *pprog = prog;
+ pop_callee_regs(pprog, callee_regs_used);
+ prog = *pprog;
+ EMIT1(0x58); /* pop rax */
+ EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8));
memcpy(prog, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE);
prog += X86_PATCH_SIZE;
@@ -476,6 +574,11 @@ static void bpf_tail_call_direct_fixup(struct bpf_prog *prog)
(u8 *)target->bpf_func +
poke->adj_off, false);
BUG_ON(ret < 0);
+ ret = __bpf_arch_text_poke(poke->tailcall_bypass,
+ BPF_MOD_JUMP,
+ (u8 *)poke->tailcall_target +
+ X86_PATCH_SIZE, NULL, false);
+ BUG_ON(ret < 0);
}
WRITE_ONCE(poke->tailcall_target_stable, true);
mutex_unlock(&array->aux->poke_mutex);
@@ -654,19 +757,49 @@ static bool ex_handler_bpf(const struct exception_table_entry *x,
return true;
}
+static void detect_reg_usage(struct bpf_insn *insn, int insn_cnt,
+ bool *regs_used, bool *tail_call_seen)
+{
+ int i;
+
+ for (i = 1; i <= insn_cnt; i++, insn++) {
+ if (insn->code == (BPF_JMP | BPF_TAIL_CALL))
+ *tail_call_seen = true;
+ if (insn->dst_reg == BPF_REG_6 || insn->src_reg == BPF_REG_6)
+ regs_used[0] = true;
+ if (insn->dst_reg == BPF_REG_7 || insn->src_reg == BPF_REG_7)
+ regs_used[1] = true;
+ if (insn->dst_reg == BPF_REG_8 || insn->src_reg == BPF_REG_8)
+ regs_used[2] = true;
+ if (insn->dst_reg == BPF_REG_9 || insn->src_reg == BPF_REG_9)
+ regs_used[3] = true;
+ }
+}
+
static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
int oldproglen, struct jit_context *ctx)
{
+ bool tail_call_reachable = bpf_prog->aux->tail_call_reachable;
struct bpf_insn *insn = bpf_prog->insnsi;
+ bool callee_regs_used[4] = {};
int insn_cnt = bpf_prog->len;
+ bool tail_call_seen = false;
bool seen_exit = false;
u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
int i, cnt = 0, excnt = 0;
int proglen = 0;
u8 *prog = temp;
+ detect_reg_usage(insn, insn_cnt, callee_regs_used,
+ &tail_call_seen);
+
+ /* tail call's presence in current prog implies it is reachable */
+ tail_call_reachable |= tail_call_seen;
+
emit_prologue(&prog, bpf_prog->aux->stack_depth,
- bpf_prog_was_classic(bpf_prog));
+ bpf_prog_was_classic(bpf_prog), tail_call_reachable,
+ bpf_prog->aux->func_idx != 0);
+ push_callee_regs(&prog, callee_regs_used);
addrs[0] = prog - temp;
for (i = 1; i <= insn_cnt; i++, insn++) {
@@ -1104,16 +1237,27 @@ xadd: if (is_imm8(insn->off))
/* call */
case BPF_JMP | BPF_CALL:
func = (u8 *) __bpf_call_base + imm32;
- if (!imm32 || emit_call(&prog, func, image + addrs[i - 1]))
- return -EINVAL;
+ if (tail_call_reachable) {
+ EMIT3_off32(0x48, 0x8B, 0x85,
+ -(bpf_prog->aux->stack_depth + 8));
+ if (!imm32 || emit_call(&prog, func, image + addrs[i - 1] + 7))
+ return -EINVAL;
+ } else {
+ if (!imm32 || emit_call(&prog, func, image + addrs[i - 1]))
+ return -EINVAL;
+ }
break;
case BPF_JMP | BPF_TAIL_CALL:
if (imm32)
emit_bpf_tail_call_direct(&bpf_prog->aux->poke_tab[imm32 - 1],
- &prog, addrs[i], image);
+ &prog, addrs[i], image,
+ callee_regs_used,
+ bpf_prog->aux->stack_depth);
else
- emit_bpf_tail_call_indirect(&prog);
+ emit_bpf_tail_call_indirect(&prog,
+ callee_regs_used,
+ bpf_prog->aux->stack_depth);
break;
/* cond jump */
@@ -1296,12 +1440,9 @@ emit_jmp:
seen_exit = true;
/* Update cleanup_addr */
ctx->cleanup_addr = proglen;
- if (!bpf_prog_was_classic(bpf_prog))
- EMIT1(0x5B); /* get rid of tail_call_cnt */
- EMIT2(0x41, 0x5F); /* pop r15 */
- EMIT2(0x41, 0x5E); /* pop r14 */
- EMIT2(0x41, 0x5D); /* pop r13 */
- EMIT1(0x5B); /* pop rbx */
+ pop_callee_regs(&prog, callee_regs_used);
+ if (tail_call_reachable)
+ EMIT1(0x59); /* pop rcx, get rid of tail_call_cnt */
EMIT1(0xC9); /* leave */
EMIT1(0xC3); /* ret */
break;
diff --git a/include/linux/bpf.h b/include/linux/bpf.h
index f3790c9cf542..d7c5a6ed87e3 100644
--- a/include/linux/bpf.h
+++ b/include/linux/bpf.h
@@ -698,6 +698,8 @@ enum bpf_jit_poke_reason {
/* Descriptor of pokes pointing /into/ the JITed image. */
struct bpf_jit_poke_descriptor {
void *tailcall_target;
+ void *tailcall_bypass;
+ void *bypass_addr;
union {
struct {
struct bpf_map *map;
@@ -738,6 +740,7 @@ struct bpf_prog_aux {
bool attach_btf_trace; /* true if attaching to BTF-enabled raw tp */
bool func_proto_unreliable;
bool sleepable;
+ bool tail_call_reachable;
enum bpf_tramp_prog_type trampoline_prog_type;
struct bpf_trampoline *trampoline;
struct hlist_node tramp_hlist;
diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
index 5026b75db972..fbc964526ba3 100644
--- a/include/linux/bpf_verifier.h
+++ b/include/linux/bpf_verifier.h
@@ -359,6 +359,7 @@ struct bpf_subprog_info {
u32 linfo_idx; /* The idx to the main_prog->aux->linfo */
u16 stack_depth; /* max. stack depth used by this function */
bool has_tail_call;
+ bool tail_call_reachable;
};
/* single container for all structs
diff --git a/kernel/bpf/arraymap.c b/kernel/bpf/arraymap.c
index 60abf7fe12de..e5fd31268ae0 100644
--- a/kernel/bpf/arraymap.c
+++ b/kernel/bpf/arraymap.c
@@ -898,6 +898,7 @@ static void prog_array_map_poke_run(struct bpf_map *map, u32 key,
struct bpf_prog *old,
struct bpf_prog *new)
{
+ u8 *old_addr, *new_addr, *old_bypass_addr;
struct prog_poke_elem *elem;
struct bpf_array_aux *aux;
@@ -949,12 +950,39 @@ static void prog_array_map_poke_run(struct bpf_map *map, u32 key,
poke->tail_call.key != key)
continue;
- ret = bpf_arch_text_poke(poke->tailcall_target, BPF_MOD_JUMP,
- old ? (u8 *)old->bpf_func +
- poke->adj_off : NULL,
- new ? (u8 *)new->bpf_func +
- poke->adj_off : NULL);
- BUG_ON(ret < 0 && ret != -EINVAL);
+ old_bypass_addr = old ? NULL : poke->bypass_addr;
+ old_addr = old ? (u8 *)old->bpf_func + poke->adj_off : NULL;
+ new_addr = new ? (u8 *)new->bpf_func + poke->adj_off : NULL;
+
+ if (new) {
+ ret = bpf_arch_text_poke(poke->tailcall_target,
+ BPF_MOD_JUMP,
+ old_addr, new_addr);
+ BUG_ON(ret < 0 && ret != -EINVAL);
+ if (!old) {
+ ret = bpf_arch_text_poke(poke->tailcall_bypass,
+ BPF_MOD_JUMP,
+ poke->bypass_addr,
+ NULL);
+ BUG_ON(ret < 0 && ret != -EINVAL);
+ }
+ } else {
+ ret = bpf_arch_text_poke(poke->tailcall_bypass,
+ BPF_MOD_JUMP,
+ old_bypass_addr,
+ poke->bypass_addr);
+ BUG_ON(ret < 0 && ret != -EINVAL);
+ /* let other CPUs finish the execution of program
+ * so that it will not possible to expose them
+ * to invalid nop, stack unwind, nop state
+ */
+ if (!ret)
+ synchronize_rcu();
+ ret = bpf_arch_text_poke(poke->tailcall_target,
+ BPF_MOD_JUMP,
+ old_addr, NULL);
+ BUG_ON(ret < 0 && ret != -EINVAL);
+ }
}
}
}
diff --git a/kernel/bpf/core.c b/kernel/bpf/core.c
index 2e00ac028d38..c4811b139caa 100644
--- a/kernel/bpf/core.c
+++ b/kernel/bpf/core.c
@@ -776,7 +776,7 @@ int bpf_jit_add_poke_descriptor(struct bpf_prog *prog,
if (size > poke_tab_max)
return -ENOSPC;
if (poke->tailcall_target || poke->tailcall_target_stable ||
- poke->adj_off)
+ poke->tailcall_bypass || poke->adj_off || poke->bypass_addr)
return -EINVAL;
switch (poke->reason) {
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 0958fba48d59..172e12df9eaa 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -2983,8 +2983,10 @@ static int check_max_stack_depth(struct bpf_verifier_env *env)
int depth = 0, frame = 0, idx = 0, i = 0, subprog_end;
struct bpf_subprog_info *subprog = env->subprog_info;
struct bpf_insn *insn = env->prog->insnsi;
+ bool tail_call_reachable = false;
int ret_insn[MAX_CALL_FRAMES];
int ret_prog[MAX_CALL_FRAMES];
+ int j;
process_func:
/* protect against potential stack overflow that might happen when
@@ -3040,6 +3042,10 @@ continue_func:
i);
return -EFAULT;
}
+
+ if (subprog[idx].has_tail_call)
+ tail_call_reachable = true;
+
frame++;
if (frame >= MAX_CALL_FRAMES) {
verbose(env, "the call stack of %d frames is too deep !\n",
@@ -3048,6 +3054,15 @@ continue_func:
}
goto process_func;
}
+ /* if tail call got detected across bpf2bpf calls then mark each of the
+ * currently present subprog frames as tail call reachable subprogs;
+ * this info will be utilized by JIT so that we will be preserving the
+ * tail call counter throughout bpf2bpf calls combined with tailcalls
+ */
+ if (tail_call_reachable)
+ for (j = 0; j < frame; j++)
+ subprog[ret_prog[j]].tail_call_reachable = true;
+
/* end of for() loop means the last insn of the 'subprog'
* was reached. Doesn't matter whether it was JA or EXIT
*/
@@ -10322,6 +10337,7 @@ static int jit_subprogs(struct bpf_verifier_env *env)
num_exentries++;
}
func[i]->aux->num_exentries = num_exentries;
+ func[i]->aux->tail_call_reachable = env->subprog_info[i].tail_call_reachable;
func[i] = bpf_int_jit_compile(func[i]);
if (!func[i]->jited) {
err = -ENOTSUPP;