diff options
Diffstat (limited to 'arch/x86/net/bpf_jit_comp.c')
-rw-r--r-- | arch/x86/net/bpf_jit_comp.c | 1564 |
1 files changed, 1156 insertions, 408 deletions
diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c index 9ba08e9abc09..00127abd89ee 100644 --- a/arch/x86/net/bpf_jit_comp.c +++ b/arch/x86/net/bpf_jit_comp.c @@ -1,9 +1,9 @@ // SPDX-License-Identifier: GPL-2.0-only /* - * bpf_jit_comp.c: BPF JIT compiler + * BPF JIT compiler * * Copyright (C) 2011-2013 Eric Dumazet (eric.dumazet@gmail.com) - * Internal BPF Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com + * Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com */ #include <linux/netdevice.h> #include <linux/filter.h> @@ -11,11 +11,11 @@ #include <linux/bpf.h> #include <linux/memory.h> #include <linux/sort.h> +#include <linux/init.h> #include <asm/extable.h> #include <asm/set_memory.h> #include <asm/nospec-branch.h> #include <asm/text-patching.h> -#include <asm/asm-prototypes.h> static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len) { @@ -31,7 +31,7 @@ static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len) } #define EMIT(bytes, len) \ - do { prog = emit_code(prog, bytes, len); cnt += len; } while (0) + do { prog = emit_code(prog, bytes, len); } while (0) #define EMIT1(b1) EMIT(b1, 1) #define EMIT2(b1, b2) EMIT((b1) + ((b2) << 8), 2) @@ -47,6 +47,12 @@ static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len) #define EMIT4_off32(b1, b2, b3, b4, off) \ do { EMIT4(b1, b2, b3, b4); EMIT(off, 4); } while (0) +#ifdef CONFIG_X86_KERNEL_IBT +#define EMIT_ENDBR() EMIT(gen_endbr(), 4) +#else +#define EMIT_ENDBR() +#endif + static bool is_imm8(int value) { return value <= 127 && value >= -128; @@ -158,6 +164,19 @@ static bool is_ereg(u32 reg) BIT(BPF_REG_AX)); } +/* + * is_ereg_8l() == true if BPF register 'reg' is mapped to access x86-64 + * lower 8-bit registers dil,sil,bpl,spl,r8b..r15b, which need extra byte + * of encoding. al,cl,dl,bl have simpler encoding. + */ +static bool is_ereg_8l(u32 reg) +{ + return is_ereg(reg) || + (1 << reg) & (BIT(BPF_REG_1) | + BIT(BPF_REG_2) | + BIT(BPF_REG_FP)); +} + static bool is_axreg(u32 reg) { return reg == BPF_REG_0; @@ -192,14 +211,39 @@ static u8 add_2reg(u8 byte, u32 dst_reg, u32 src_reg) return byte + reg2hex[dst_reg] + (reg2hex[src_reg] << 3); } +/* Some 1-byte opcodes for binary ALU operations */ +static u8 simple_alu_opcodes[] = { + [BPF_ADD] = 0x01, + [BPF_SUB] = 0x29, + [BPF_AND] = 0x21, + [BPF_OR] = 0x09, + [BPF_XOR] = 0x31, + [BPF_LSH] = 0xE0, + [BPF_RSH] = 0xE8, + [BPF_ARSH] = 0xF8, +}; + static void jit_fill_hole(void *area, unsigned int size) { /* Fill whole space with INT3 instructions */ memset(area, 0xcc, size); } +int bpf_arch_text_invalidate(void *dst, size_t len) +{ + return IS_ERR_OR_NULL(text_poke_set(dst, 0xcc, len)); +} + struct jit_context { int cleanup_addr; /* Epilogue code offset */ + + /* + * Program specific offsets of labels in the code; these rely on the + * JIT doing at least 2 passes, recording the position on the first + * pass, only to generate the correct offset on the second pass. + */ + int tail_call_direct_label; + int tail_call_indirect_label; }; /* Maximum number of bytes emitted while JITing one eBPF insn */ @@ -208,43 +252,78 @@ struct jit_context { /* Number of bytes emit_patch() needs to generate instructions */ #define X86_PATCH_SIZE 5 +/* Number of bytes that will be skipped on tailcall */ +#define X86_TAIL_CALL_OFFSET (11 + ENDBR_INSN_SIZE) -#define PROLOGUE_SIZE 25 +static void push_callee_regs(u8 **pprog, bool *callee_regs_used) +{ + u8 *prog = *pprog; + + if (callee_regs_used[0]) + EMIT1(0x53); /* push rbx */ + if (callee_regs_used[1]) + EMIT2(0x41, 0x55); /* push r13 */ + if (callee_regs_used[2]) + EMIT2(0x41, 0x56); /* push r14 */ + if (callee_regs_used[3]) + EMIT2(0x41, 0x57); /* push r15 */ + *pprog = prog; +} + +static void pop_callee_regs(u8 **pprog, bool *callee_regs_used) +{ + u8 *prog = *pprog; + + if (callee_regs_used[3]) + EMIT2(0x41, 0x5F); /* pop r15 */ + if (callee_regs_used[2]) + EMIT2(0x41, 0x5E); /* pop r14 */ + if (callee_regs_used[1]) + EMIT2(0x41, 0x5D); /* pop r13 */ + if (callee_regs_used[0]) + EMIT1(0x5B); /* pop rbx */ + *pprog = prog; +} /* - * Emit x86-64 prologue code for BPF program and check its size. - * bpf_tail_call helper will skip it while jumping into another program + * Emit x86-64 prologue code for BPF program. + * bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes + * while jumping to another program */ -static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf) +static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, + bool tail_call_reachable, bool is_subprog) { u8 *prog = *pprog; - int cnt = X86_PATCH_SIZE; /* BPF trampoline can be made to work without these nops, * but let's waste 5 bytes for now and optimize later */ - memcpy(prog, ideal_nops[NOP_ATOMIC5], cnt); - prog += cnt; + EMIT_ENDBR(); + memcpy(prog, x86_nops[5], X86_PATCH_SIZE); + prog += X86_PATCH_SIZE; + if (!ebpf_from_cbpf) { + if (tail_call_reachable && !is_subprog) + EMIT2(0x31, 0xC0); /* xor eax, eax */ + else + EMIT2(0x66, 0x90); /* nop2 */ + } EMIT1(0x55); /* push rbp */ EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */ + + /* X86_TAIL_CALL_OFFSET is here */ + EMIT_ENDBR(); + /* sub rsp, rounded_stack_depth */ - EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8)); - EMIT1(0x53); /* push rbx */ - EMIT2(0x41, 0x55); /* push r13 */ - EMIT2(0x41, 0x56); /* push r14 */ - EMIT2(0x41, 0x57); /* push r15 */ - if (!ebpf_from_cbpf) { - /* zero init tail_call_cnt */ - EMIT2(0x6a, 0x00); - BUILD_BUG_ON(cnt != PROLOGUE_SIZE); - } + if (stack_depth) + EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8)); + if (tail_call_reachable) + EMIT1(0x50); /* push rax */ *pprog = prog; } static int emit_patch(u8 **pprog, void *func, void *ip, u8 opcode) { u8 *prog = *pprog; - int cnt = 0; s64 offset; offset = func - (ip + X86_PATCH_SIZE); @@ -268,10 +347,9 @@ static int emit_jump(u8 **pprog, void *func, void *ip) } static int __bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, - void *old_addr, void *new_addr, - const bool text_live) + void *old_addr, void *new_addr) { - const u8 *nop_insn = ideal_nops[NOP_ATOMIC5]; + const u8 *nop_insn = x86_nops[5]; u8 old_insn[X86_PATCH_SIZE]; u8 new_insn[X86_PATCH_SIZE]; u8 *prog; @@ -301,18 +379,28 @@ static int __bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, mutex_lock(&text_mutex); if (memcmp(ip, old_insn, X86_PATCH_SIZE)) goto out; + ret = 1; if (memcmp(ip, new_insn, X86_PATCH_SIZE)) { - if (text_live) - text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL); - else - memcpy(ip, new_insn, X86_PATCH_SIZE); + text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL); + ret = 0; } - ret = 0; out: mutex_unlock(&text_mutex); return ret; } +int __init bpf_arch_init_dispatcher_early(void *ip) +{ + const u8 *nop_insn = x86_nops[5]; + + if (is_endbr(*(u32 *)ip)) + ip += ENDBR_INSN_SIZE; + + if (memcmp(ip, nop_insn, X86_PATCH_SIZE)) + text_poke_early(ip, nop_insn, X86_PATCH_SIZE); + return 0; +} + int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, void *old_addr, void *new_addr) { @@ -321,7 +409,50 @@ int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, /* BPF poking in modules is not supported */ return -EINVAL; - return __bpf_arch_text_poke(ip, t, old_addr, new_addr, true); + /* + * See emit_prologue(), for IBT builds the trampoline hook is preceded + * with an ENDBR instruction. + */ + if (is_endbr(*(u32 *)ip)) + ip += ENDBR_INSN_SIZE; + + return __bpf_arch_text_poke(ip, t, old_addr, new_addr); +} + +#define EMIT_LFENCE() EMIT3(0x0F, 0xAE, 0xE8) + +static void emit_indirect_jump(u8 **pprog, int reg, u8 *ip) +{ + u8 *prog = *pprog; + + if (cpu_feature_enabled(X86_FEATURE_RETPOLINE_LFENCE)) { + EMIT_LFENCE(); + EMIT2(0xFF, 0xE0 + reg); + } else if (cpu_feature_enabled(X86_FEATURE_RETPOLINE)) { + OPTIMIZER_HIDE_VAR(reg); + emit_jump(&prog, &__x86_indirect_thunk_array[reg], ip); + } else { + EMIT2(0xFF, 0xE0 + reg); /* jmp *%\reg */ + if (IS_ENABLED(CONFIG_RETPOLINE) || IS_ENABLED(CONFIG_SLS)) + EMIT1(0xCC); /* int3 */ + } + + *pprog = prog; +} + +static void emit_return(u8 **pprog, u8 *ip) +{ + u8 *prog = *pprog; + + if (cpu_feature_enabled(X86_FEATURE_RETHUNK)) { + emit_jump(&prog, &__x86_return_thunk, ip); + } else { + EMIT1(0xC3); /* ret */ + if (IS_ENABLED(CONFIG_SLS)) + EMIT1(0xCC); /* int3 */ + } + + *pprog = prog; } /* @@ -330,7 +461,7 @@ int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ... * if (index >= array->map.max_entries) * goto out; - * if (++tail_call_cnt > MAX_TAIL_CALL_CNT) + * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) * goto out; * prog = array->ptrs[index]; * if (prog == NULL) @@ -338,11 +469,13 @@ int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, * goto *(prog->bpf_func + prologue_size); * out: */ -static void emit_bpf_tail_call_indirect(u8 **pprog) +static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used, + u32 stack_depth, u8 *ip, + struct jit_context *ctx) { - u8 *prog = *pprog; - int label1, label2, label3; - int cnt = 0; + int tcc_off = -4 - round_up(stack_depth, 8); + u8 *prog = *pprog, *start = *pprog; + int offset; /* * rdi - pointer to ctx @@ -357,76 +490,98 @@ static void emit_bpf_tail_call_indirect(u8 **pprog) EMIT2(0x89, 0xD2); /* mov edx, edx */ EMIT3(0x39, 0x56, /* cmp dword ptr [rsi + 16], edx */ offsetof(struct bpf_array, map.max_entries)); -#define OFFSET1 (41 + RETPOLINE_RAX_BPF_JIT_SIZE) /* Number of bytes to jump */ - EMIT2(X86_JBE, OFFSET1); /* jbe out */ - label1 = cnt; + + offset = ctx->tail_call_indirect_label - (prog + 2 - start); + EMIT2(X86_JBE, offset); /* jbe out */ /* - * if (tail_call_cnt > MAX_TAIL_CALL_CNT) + * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) * goto out; */ - EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */ + EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ -#define OFFSET2 (30 + RETPOLINE_RAX_BPF_JIT_SIZE) - EMIT2(X86_JA, OFFSET2); /* ja out */ - label2 = cnt; + + offset = ctx->tail_call_indirect_label - (prog + 2 - start); + EMIT2(X86_JAE, offset); /* jae out */ EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ - EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */ + EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */ /* prog = array->ptrs[index]; */ - EMIT4_off32(0x48, 0x8B, 0x84, 0xD6, /* mov rax, [rsi + rdx * 8 + offsetof(...)] */ + EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6, /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */ offsetof(struct bpf_array, ptrs)); /* * if (prog == NULL) * goto out; */ - EMIT3(0x48, 0x85, 0xC0); /* test rax,rax */ -#define OFFSET3 (8 + RETPOLINE_RAX_BPF_JIT_SIZE) - EMIT2(X86_JE, OFFSET3); /* je out */ - label3 = cnt; + EMIT3(0x48, 0x85, 0xC9); /* test rcx,rcx */ - /* goto *(prog->bpf_func + prologue_size); */ - EMIT4(0x48, 0x8B, 0x40, /* mov rax, qword ptr [rax + 32] */ - offsetof(struct bpf_prog, bpf_func)); - EMIT4(0x48, 0x83, 0xC0, PROLOGUE_SIZE); /* add rax, prologue_size */ + offset = ctx->tail_call_indirect_label - (prog + 2 - start); + EMIT2(X86_JE, offset); /* je out */ + + pop_callee_regs(&prog, callee_regs_used); + + EMIT1(0x58); /* pop rax */ + if (stack_depth) + EMIT3_off32(0x48, 0x81, 0xC4, /* add rsp, sd */ + round_up(stack_depth, 8)); + /* goto *(prog->bpf_func + X86_TAIL_CALL_OFFSET); */ + EMIT4(0x48, 0x8B, 0x49, /* mov rcx, qword ptr [rcx + 32] */ + offsetof(struct bpf_prog, bpf_func)); + EMIT4(0x48, 0x83, 0xC1, /* add rcx, X86_TAIL_CALL_OFFSET */ + X86_TAIL_CALL_OFFSET); /* - * Wow we're ready to jump into next BPF program + * Now we're ready to jump into next BPF program * rdi == ctx (1st arg) - * rax == prog->bpf_func + prologue_size + * rcx == prog->bpf_func + X86_TAIL_CALL_OFFSET */ - RETPOLINE_RAX_BPF_JIT(); + emit_indirect_jump(&prog, 1 /* rcx */, ip + (prog - start)); /* out: */ - BUILD_BUG_ON(cnt - label1 != OFFSET1); - BUILD_BUG_ON(cnt - label2 != OFFSET2); - BUILD_BUG_ON(cnt - label3 != OFFSET3); + ctx->tail_call_indirect_label = prog - start; *pprog = prog; } static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke, - u8 **pprog, int addr, u8 *image) + u8 **pprog, u8 *ip, + bool *callee_regs_used, u32 stack_depth, + struct jit_context *ctx) { - u8 *prog = *pprog; - int cnt = 0; + int tcc_off = -4 - round_up(stack_depth, 8); + u8 *prog = *pprog, *start = *pprog; + int offset; /* - * if (tail_call_cnt > MAX_TAIL_CALL_CNT) + * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) * goto out; */ - EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */ + EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ - EMIT2(X86_JA, 14); /* ja out */ + + offset = ctx->tail_call_direct_label - (prog + 2 - start); + EMIT2(X86_JAE, offset); /* jae out */ EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ - EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */ + EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */ - poke->ip = image + (addr - X86_PATCH_SIZE); - poke->adj_off = PROLOGUE_SIZE; + poke->tailcall_bypass = ip + (prog - start); + poke->adj_off = X86_TAIL_CALL_OFFSET; + poke->tailcall_target = ip + ctx->tail_call_direct_label - X86_PATCH_SIZE; + poke->bypass_addr = (u8 *)poke->tailcall_target + X86_PATCH_SIZE; - memcpy(prog, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE); + emit_jump(&prog, (u8 *)poke->tailcall_target + X86_PATCH_SIZE, + poke->tailcall_bypass); + + pop_callee_regs(&prog, callee_regs_used); + EMIT1(0x58); /* pop rax */ + if (stack_depth) + EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8)); + + memcpy(prog, x86_nops[5], X86_PATCH_SIZE); prog += X86_PATCH_SIZE; + /* out: */ + ctx->tail_call_direct_label = prog - start; *pprog = prog; } @@ -440,7 +595,10 @@ static void bpf_tail_call_direct_fixup(struct bpf_prog *prog) for (i = 0; i < prog->aux->size_poke_tab; i++) { poke = &prog->aux->poke_tab[i]; - WARN_ON_ONCE(READ_ONCE(poke->ip_stable)); + if (poke->aux && poke->aux != prog->aux) + continue; + + WARN_ON_ONCE(READ_ONCE(poke->tailcall_target_stable)); if (poke->reason != BPF_POKE_REASON_TAIL_CALL) continue; @@ -449,20 +607,18 @@ static void bpf_tail_call_direct_fixup(struct bpf_prog *prog) mutex_lock(&array->aux->poke_mutex); target = array->ptrs[poke->tail_call.key]; if (target) { - /* Plain memcpy is used when image is not live yet - * and still not locked as read-only. Once poke - * location is active (poke->ip_stable), any parallel - * bpf_arch_text_poke() might occur still on the - * read-write image until we finally locked it as - * read-only. Both modifications on the given image - * are under text_mutex to avoid interference. - */ - ret = __bpf_arch_text_poke(poke->ip, BPF_MOD_JUMP, NULL, + ret = __bpf_arch_text_poke(poke->tailcall_target, + BPF_MOD_JUMP, NULL, (u8 *)target->bpf_func + - poke->adj_off, false); + poke->adj_off); + BUG_ON(ret < 0); + ret = __bpf_arch_text_poke(poke->tailcall_bypass, + BPF_MOD_JUMP, + (u8 *)poke->tailcall_target + + X86_PATCH_SIZE, NULL); BUG_ON(ret < 0); } - WRITE_ONCE(poke->ip_stable, true); + WRITE_ONCE(poke->tailcall_target_stable, true); mutex_unlock(&array->aux->poke_mutex); } } @@ -472,7 +628,6 @@ static void emit_mov_imm32(u8 **pprog, bool sign_propagate, { u8 *prog = *pprog; u8 b1, b2, b3; - int cnt = 0; /* * Optimization: if imm32 is positive, use 'mov %eax, imm32' @@ -512,7 +667,6 @@ static void emit_mov_imm64(u8 **pprog, u32 dst_reg, const u32 imm32_hi, const u32 imm32_lo) { u8 *prog = *pprog; - int cnt = 0; if (is_uimm32(((u64)imm32_hi << 32) | (u32)imm32_lo)) { /* @@ -523,7 +677,7 @@ static void emit_mov_imm64(u8 **pprog, u32 dst_reg, */ emit_mov_imm32(&prog, false, dst_reg, imm32_lo); } else { - /* movabsq %rax, imm64 */ + /* movabsq rax, imm64 */ EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg)); EMIT(imm32_lo, 4); EMIT(imm32_hi, 4); @@ -535,7 +689,6 @@ static void emit_mov_imm64(u8 **pprog, u32 dst_reg, static void emit_mov_reg(u8 **pprog, bool is64, u32 dst_reg, u32 src_reg) { u8 *prog = *pprog; - int cnt = 0; if (is64) { /* mov dst, src */ @@ -550,11 +703,58 @@ static void emit_mov_reg(u8 **pprog, bool is64, u32 dst_reg, u32 src_reg) *pprog = prog; } +/* Emit the suffix (ModR/M etc) for addressing *(ptr_reg + off) and val_reg */ +static void emit_insn_suffix(u8 **pprog, u32 ptr_reg, u32 val_reg, int off) +{ + u8 *prog = *pprog; + + if (is_imm8(off)) { + /* 1-byte signed displacement. + * + * If off == 0 we could skip this and save one extra byte, but + * special case of x86 R13 which always needs an offset is not + * worth the hassle + */ + EMIT2(add_2reg(0x40, ptr_reg, val_reg), off); + } else { + /* 4-byte signed displacement */ + EMIT1_off32(add_2reg(0x80, ptr_reg, val_reg), off); + } + *pprog = prog; +} + +/* + * Emit a REX byte if it will be necessary to address these registers + */ +static void maybe_emit_mod(u8 **pprog, u32 dst_reg, u32 src_reg, bool is64) +{ + u8 *prog = *pprog; + + if (is64) + EMIT1(add_2mod(0x48, dst_reg, src_reg)); + else if (is_ereg(dst_reg) || is_ereg(src_reg)) + EMIT1(add_2mod(0x40, dst_reg, src_reg)); + *pprog = prog; +} + +/* + * Similar version of maybe_emit_mod() for a single register + */ +static void maybe_emit_1mod(u8 **pprog, u32 reg, bool is64) +{ + u8 *prog = *pprog; + + if (is64) + EMIT1(add_1mod(0x48, reg)); + else if (is_ereg(reg)) + EMIT1(add_1mod(0x40, reg)); + *pprog = prog; +} + /* LDX: dst_reg = *(u8*)(src_reg + off) */ static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off) { u8 *prog = *pprog; - int cnt = 0; switch (size) { case BPF_B: @@ -577,15 +777,7 @@ static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off) EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x8B); break; } - /* - * If insn->off == 0 we can save one extra byte, but - * special case of x86 R13 which always needs an offset - * is not worth the hassle - */ - if (is_imm8(off)) - EMIT2(add_2reg(0x40, src_reg, dst_reg), off); - else - EMIT1_off32(add_2reg(0x80, src_reg, dst_reg), off); + emit_insn_suffix(&prog, src_reg, dst_reg, off); *pprog = prog; } @@ -593,14 +785,12 @@ static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off) static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off) { u8 *prog = *pprog; - int cnt = 0; switch (size) { case BPF_B: /* Emit 'mov byte ptr [rax + off], al' */ - if (is_ereg(dst_reg) || is_ereg(src_reg) || - /* We have to add extra byte for x86 SIL, DIL regs */ - src_reg == BPF_REG_1 || src_reg == BPF_REG_2) + if (is_ereg(dst_reg) || is_ereg_8l(src_reg)) + /* Add extra byte for eregs or SIL,DIL,BPL in src_reg */ EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x88); else EMIT1(0x88); @@ -621,16 +811,52 @@ static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off) EMIT2(add_2mod(0x48, dst_reg, src_reg), 0x89); break; } - if (is_imm8(off)) - EMIT2(add_2reg(0x40, dst_reg, src_reg), off); - else - EMIT1_off32(add_2reg(0x80, dst_reg, src_reg), off); + emit_insn_suffix(&prog, dst_reg, src_reg, off); + *pprog = prog; +} + +static int emit_atomic(u8 **pprog, u8 atomic_op, + u32 dst_reg, u32 src_reg, s16 off, u8 bpf_size) +{ + u8 *prog = *pprog; + + EMIT1(0xF0); /* lock prefix */ + + maybe_emit_mod(&prog, dst_reg, src_reg, bpf_size == BPF_DW); + + /* emit opcode */ + switch (atomic_op) { + case BPF_ADD: + case BPF_AND: + case BPF_OR: + case BPF_XOR: + /* lock *(u32/u64*)(dst_reg + off) <op>= src_reg */ + EMIT1(simple_alu_opcodes[atomic_op]); + break; + case BPF_ADD | BPF_FETCH: + /* src_reg = atomic_fetch_add(dst_reg + off, src_reg); */ + EMIT2(0x0F, 0xC1); + break; + case BPF_XCHG: + /* src_reg = atomic_xchg(dst_reg + off, src_reg); */ + EMIT1(0x87); + break; + case BPF_CMPXCHG: + /* r0 = atomic_cmpxchg(dst_reg + off, r0, src_reg); */ + EMIT2(0x0F, 0xB1); + break; + default: + pr_err("bpf_jit: unknown atomic opcode %02x\n", atomic_op); + return -EFAULT; + } + + emit_insn_suffix(&prog, dst_reg, src_reg, off); + *pprog = prog; + return 0; } -static bool ex_handler_bpf(const struct exception_table_entry *x, - struct pt_regs *regs, int trapnr, - unsigned long error_code, unsigned long fault_addr) +bool ex_handler_bpf(const struct exception_table_entry *x, struct pt_regs *regs) { u32 reg = x->fixup >> 8; @@ -640,30 +866,89 @@ static bool ex_handler_bpf(const struct exception_table_entry *x, return true; } -static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, - int oldproglen, struct jit_context *ctx) +static void detect_reg_usage(struct bpf_insn *insn, int insn_cnt, + bool *regs_used, bool *tail_call_seen) +{ + int i; + + for (i = 1; i <= insn_cnt; i++, insn++) { + if (insn->code == (BPF_JMP | BPF_TAIL_CALL)) + *tail_call_seen = true; + if (insn->dst_reg == BPF_REG_6 || insn->src_reg == BPF_REG_6) + regs_used[0] = true; + if (insn->dst_reg == BPF_REG_7 || insn->src_reg == BPF_REG_7) + regs_used[1] = true; + if (insn->dst_reg == BPF_REG_8 || insn->src_reg == BPF_REG_8) + regs_used[2] = true; + if (insn->dst_reg == BPF_REG_9 || insn->src_reg == BPF_REG_9) + regs_used[3] = true; + } +} + +static void emit_nops(u8 **pprog, int len) +{ + u8 *prog = *pprog; + int i, noplen; + + while (len > 0) { + noplen = len; + + if (noplen > ASM_NOP_MAX) + noplen = ASM_NOP_MAX; + + for (i = 0; i < noplen; i++) + EMIT1(x86_nops[noplen][i]); + len -= noplen; + } + + *pprog = prog; +} + +#define INSN_SZ_DIFF (((addrs[i] - addrs[i - 1]) - (prog - temp))) + +static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image, + int oldproglen, struct jit_context *ctx, bool jmp_padding) { + bool tail_call_reachable = bpf_prog->aux->tail_call_reachable; struct bpf_insn *insn = bpf_prog->insnsi; + bool callee_regs_used[4] = {}; int insn_cnt = bpf_prog->len; + bool tail_call_seen = false; bool seen_exit = false; u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY]; - int i, cnt = 0, excnt = 0; - int proglen = 0; + int i, excnt = 0; + int ilen, proglen = 0; u8 *prog = temp; + int err; + + detect_reg_usage(insn, insn_cnt, callee_regs_used, + &tail_call_seen); + + /* tail call's presence in current prog implies it is reachable */ + tail_call_reachable |= tail_call_seen; emit_prologue(&prog, bpf_prog->aux->stack_depth, - bpf_prog_was_classic(bpf_prog)); - addrs[0] = prog - temp; + bpf_prog_was_classic(bpf_prog), tail_call_reachable, + bpf_prog->aux->func_idx != 0); + push_callee_regs(&prog, callee_regs_used); + + ilen = prog - temp; + if (rw_image) + memcpy(rw_image + proglen, temp, ilen); + proglen += ilen; + addrs[0] = proglen; + prog = temp; for (i = 1; i <= insn_cnt; i++, insn++) { const s32 imm32 = insn->imm; u32 dst_reg = insn->dst_reg; u32 src_reg = insn->src_reg; u8 b2 = 0, b3 = 0; + u8 *start_of_ldx; s64 jmp_offset; u8 jmp_cond; - int ilen; u8 *func; + int nops; switch (insn->code) { /* ALU */ @@ -677,17 +962,9 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, case BPF_ALU64 | BPF_AND | BPF_X: case BPF_ALU64 | BPF_OR | BPF_X: case BPF_ALU64 | BPF_XOR | BPF_X: - switch (BPF_OP(insn->code)) { - case BPF_ADD: b2 = 0x01; break; - case BPF_SUB: b2 = 0x29; break; - case BPF_AND: b2 = 0x21; break; - case BPF_OR: b2 = 0x09; break; - case BPF_XOR: b2 = 0x31; break; - } - if (BPF_CLASS(insn->code) == BPF_ALU64) - EMIT1(add_2mod(0x48, dst_reg, src_reg)); - else if (is_ereg(dst_reg) || is_ereg(src_reg)) - EMIT1(add_2mod(0x40, dst_reg, src_reg)); + maybe_emit_mod(&prog, dst_reg, src_reg, + BPF_CLASS(insn->code) == BPF_ALU64); + b2 = simple_alu_opcodes[BPF_OP(insn->code)]; EMIT2(b2, add_2reg(0xC0, dst_reg, src_reg)); break; @@ -701,10 +978,8 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, /* neg dst */ case BPF_ALU | BPF_NEG: case BPF_ALU64 | BPF_NEG: - if (BPF_CLASS(insn->code) == BPF_ALU64) - EMIT1(add_1mod(0x48, dst_reg)); - else if (is_ereg(dst_reg)) - EMIT1(add_1mod(0x40, dst_reg)); + maybe_emit_1mod(&prog, dst_reg, + BPF_CLASS(insn->code) == BPF_ALU64); EMIT2(0xF7, add_1reg(0xD8, dst_reg)); break; @@ -718,10 +993,8 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, case BPF_ALU64 | BPF_AND | BPF_K: case BPF_ALU64 | BPF_OR | BPF_K: case BPF_ALU64 | BPF_XOR | BPF_K: - if (BPF_CLASS(insn->code) == BPF_ALU64) - EMIT1(add_1mod(0x48, dst_reg)); - else if (is_ereg(dst_reg)) - EMIT1(add_1mod(0x40, dst_reg)); + maybe_emit_1mod(&prog, dst_reg, + BPF_CLASS(insn->code) == BPF_ALU64); /* * b3 holds 'normal' opcode, b2 short form only valid @@ -778,19 +1051,30 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, case BPF_ALU64 | BPF_MOD | BPF_X: case BPF_ALU64 | BPF_DIV | BPF_X: case BPF_ALU64 | BPF_MOD | BPF_K: - case BPF_ALU64 | BPF_DIV | BPF_K: - EMIT1(0x50); /* push rax */ - EMIT1(0x52); /* push rdx */ + case BPF_ALU64 | BPF_DIV | BPF_K: { + bool is64 = BPF_CLASS(insn->code) == BPF_ALU64; - if (BPF_SRC(insn->code) == BPF_X) - /* mov r11, src_reg */ - EMIT_mov(AUX_REG, src_reg); - else + if (dst_reg != BPF_REG_0) + EMIT1(0x50); /* push rax */ + if (dst_reg != BPF_REG_3) + EMIT1(0x52); /* push rdx */ + + if (BPF_SRC(insn->code) == BPF_X) { + if (src_reg == BPF_REG_0 || + src_reg == BPF_REG_3) { + /* mov r11, src_reg */ + EMIT_mov(AUX_REG, src_reg); + src_reg = AUX_REG; + } + } else { /* mov r11, imm32 */ EMIT3_off32(0x49, 0xC7, 0xC3, imm32); + src_reg = AUX_REG; + } - /* mov rax, dst_reg */ - EMIT_mov(BPF_REG_0, dst_reg); + if (dst_reg != BPF_REG_0) + /* mov rax, dst_reg */ + emit_mov_reg(&prog, is64, BPF_REG_0, dst_reg); /* * xor edx, edx @@ -798,63 +1082,51 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, */ EMIT2(0x31, 0xd2); - if (BPF_CLASS(insn->code) == BPF_ALU64) - /* div r11 */ - EMIT3(0x49, 0xF7, 0xF3); - else - /* div r11d */ - EMIT3(0x41, 0xF7, 0xF3); - - if (BPF_OP(insn->code) == BPF_MOD) - /* mov r11, rdx */ - EMIT3(0x49, 0x89, 0xD3); - else - /* mov r11, rax */ - EMIT3(0x49, 0x89, 0xC3); + /* div src_reg */ + maybe_emit_1mod(&prog, src_reg, is64); + EMIT2(0xF7, add_1reg(0xF0, src_reg)); - EMIT1(0x5A); /* pop rdx */ - EMIT1(0x58); /* pop rax */ + if (BPF_OP(insn->code) == BPF_MOD && + dst_reg != BPF_REG_3) + /* mov dst_reg, rdx */ + emit_mov_reg(&prog, is64, dst_reg, BPF_REG_3); + else if (BPF_OP(insn->code) == BPF_DIV && + dst_reg != BPF_REG_0) + /* mov dst_reg, rax */ + emit_mov_reg(&prog, is64, dst_reg, BPF_REG_0); - /* mov dst_reg, r11 */ - EMIT_mov(dst_reg, AUX_REG); + if (dst_reg != BPF_REG_3) + EMIT1(0x5A); /* pop rdx */ + if (dst_reg != BPF_REG_0) + EMIT1(0x58); /* pop rax */ break; + } case BPF_ALU | BPF_MUL | BPF_K: - case BPF_ALU | BPF_MUL | BPF_X: case BPF_ALU64 | BPF_MUL | BPF_K: - case BPF_ALU64 | BPF_MUL | BPF_X: - { - bool is64 = BPF_CLASS(insn->code) == BPF_ALU64; - - if (dst_reg != BPF_REG_0) - EMIT1(0x50); /* push rax */ - if (dst_reg != BPF_REG_3) - EMIT1(0x52); /* push rdx */ - - /* mov r11, dst_reg */ - EMIT_mov(AUX_REG, dst_reg); + maybe_emit_mod(&prog, dst_reg, dst_reg, + BPF_CLASS(insn->code) == BPF_ALU64); - if (BPF_SRC(insn->code) == BPF_X) - emit_mov_reg(&prog, is64, BPF_REG_0, src_reg); + if (is_imm8(imm32)) + /* imul dst_reg, dst_reg, imm8 */ + EMIT3(0x6B, add_2reg(0xC0, dst_reg, dst_reg), + imm32); else - emit_mov_imm32(&prog, is64, BPF_REG_0, imm32); + /* imul dst_reg, dst_reg, imm32 */ + EMIT2_off32(0x69, + add_2reg(0xC0, dst_reg, dst_reg), + imm32); + break; - if (is64) - EMIT1(add_1mod(0x48, AUX_REG)); - else if (is_ereg(AUX_REG)) - EMIT1(add_1mod(0x40, AUX_REG)); - /* mul(q) r11 */ - EMIT2(0xF7, add_1reg(0xE0, AUX_REG)); + case BPF_ALU | BPF_MUL | BPF_X: + case BPF_ALU64 | BPF_MUL | BPF_X: + maybe_emit_mod(&prog, src_reg, dst_reg, + BPF_CLASS(insn->code) == BPF_ALU64); - if (dst_reg != BPF_REG_3) - EMIT1(0x5A); /* pop rdx */ - if (dst_reg != BPF_REG_0) { - /* mov dst_reg, rax */ - EMIT_mov(dst_reg, BPF_REG_0); - EMIT1(0x58); /* pop rax */ - } + /* imul dst_reg, src_reg */ + EMIT3(0x0F, 0xAF, add_2reg(0xC0, src_reg, dst_reg)); break; - } + /* Shifts */ case BPF_ALU | BPF_LSH | BPF_K: case BPF_ALU | BPF_RSH | BPF_K: @@ -862,17 +1134,10 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, case BPF_ALU64 | BPF_LSH | BPF_K: case BPF_ALU64 | BPF_RSH | BPF_K: case BPF_ALU64 | BPF_ARSH | BPF_K: - if (BPF_CLASS(insn->code) == BPF_ALU64) - EMIT1(add_1mod(0x48, dst_reg)); - else if (is_ereg(dst_reg)) - EMIT1(add_1mod(0x40, dst_reg)); - - switch (BPF_OP(insn->code)) { - case BPF_LSH: b3 = 0xE0; break; - case BPF_RSH: b3 = 0xE8; break; - case BPF_ARSH: b3 = 0xF8; break; - } + maybe_emit_1mod(&prog, dst_reg, + BPF_CLASS(insn->code) == BPF_ALU64); + b3 = simple_alu_opcodes[BPF_OP(insn->code)]; if (imm32 == 1) EMIT2(0xD1, add_1reg(b3, dst_reg)); else @@ -901,16 +1166,10 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, } /* shl %rax, %cl | shr %rax, %cl | sar %rax, %cl */ - if (BPF_CLASS(insn->code) == BPF_ALU64) - EMIT1(add_1mod(0x48, dst_reg)); - else if (is_ereg(dst_reg)) - EMIT1(add_1mod(0x40, dst_reg)); + maybe_emit_1mod(&prog, dst_reg, + BPF_CLASS(insn->code) == BPF_ALU64); - switch (BPF_OP(insn->code)) { - case BPF_LSH: b3 = 0xE0; break; - case BPF_RSH: b3 = 0xE8; break; - case BPF_ARSH: b3 = 0xF8; break; - } + b3 = simple_alu_opcodes[BPF_OP(insn->code)]; EMIT2(0xD3, add_1reg(b3, dst_reg)); if (src_reg != BPF_REG_4) @@ -978,6 +1237,12 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, } break; + /* speculation barrier */ + case BPF_ST | BPF_NOSPEC: + if (boot_cpu_has(X86_FEATURE_XMM2)) + EMIT_LFENCE(); + break; + /* ST: *(u8*)(dst_reg + off) = imm */ case BPF_ST | BPF_MEM | BPF_B: if (is_ereg(dst_reg)) @@ -1025,12 +1290,65 @@ st: if (is_imm8(insn->off)) case BPF_LDX | BPF_PROBE_MEM | BPF_W: case BPF_LDX | BPF_MEM | BPF_DW: case BPF_LDX | BPF_PROBE_MEM | BPF_DW: + if (BPF_MODE(insn->code) == BPF_PROBE_MEM) { + /* Though the verifier prevents negative insn->off in BPF_PROBE_MEM + * add abs(insn->off) to the limit to make sure that negative + * offset won't be an issue. + * insn->off is s16, so it won't affect valid pointers. + */ + u64 limit = TASK_SIZE_MAX + PAGE_SIZE + abs(insn->off); + u8 *end_of_jmp1, *end_of_jmp2; + + /* Conservatively check that src_reg + insn->off is a kernel address: + * 1. src_reg + insn->off >= limit + * 2. src_reg + insn->off doesn't become small positive. + * Cannot do src_reg + insn->off >= limit in one branch, + * since it needs two spare registers, but JIT has only one. + */ + + /* movabsq r11, limit */ + EMIT2(add_1mod(0x48, AUX_REG), add_1reg(0xB8, AUX_REG)); + EMIT((u32)limit, 4); + EMIT(limit >> 32, 4); + /* cmp src_reg, r11 */ + maybe_emit_mod(&prog, src_reg, AUX_REG, true); + EMIT2(0x39, add_2reg(0xC0, src_reg, AUX_REG)); + /* if unsigned '<' goto end_of_jmp2 */ + EMIT2(X86_JB, 0); + end_of_jmp1 = prog; + + /* mov r11, src_reg */ + emit_mov_reg(&prog, true, AUX_REG, src_reg); + /* add r11, insn->off */ + maybe_emit_1mod(&prog, AUX_REG, true); + EMIT2_off32(0x81, add_1reg(0xC0, AUX_REG), insn->off); + /* jmp if not carry to start_of_ldx + * Otherwise ERR_PTR(-EINVAL) + 128 will be the user addr + * that has to be rejected. + */ + EMIT2(0x73 /* JNC */, 0); + end_of_jmp2 = prog; + + /* xor dst_reg, dst_reg */ + emit_mov_imm32(&prog, false, dst_reg, 0); + /* jmp byte_after_ldx */ + EMIT2(0xEB, 0); + + /* populate jmp_offset for JB above to jump to xor dst_reg */ + end_of_jmp1[-1] = end_of_jmp2 - end_of_jmp1; + /* populate jmp_offset for JNC above to jump to start_of_ldx */ + start_of_ldx = prog; + end_of_jmp2[-1] = start_of_ldx - end_of_jmp2; + } emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn->off); if (BPF_MODE(insn->code) == BPF_PROBE_MEM) { struct exception_table_entry *ex; - u8 *_insn = image + proglen; + u8 *_insn = image + proglen + (start_of_ldx - temp); s64 delta; + /* populate jmp_offset for JMP above */ + start_of_ldx[-1] = prog - start_of_ldx; + if (!bpf_prog->aux->extable) break; @@ -1045,14 +1363,12 @@ st: if (is_imm8(insn->off)) pr_err("extable->insn doesn't fit into 32-bit\n"); return -EFAULT; } + /* switch ex to rw buffer for writes */ + ex = (void *)rw_image + ((void *)ex - (void *)image); + ex->insn = delta; - delta = (u8 *)ex_handler_bpf - (u8 *)&ex->handler; - if (!is_simm32(delta)) { - pr_err("extable->handler doesn't fit into 32-bit\n"); - return -EFAULT; - } - ex->handler = delta; + ex->data = EX_TYPE_BPF; if (dst_reg > BPF_REG_9) { pr_err("verifier error\n"); @@ -1066,40 +1382,97 @@ st: if (is_imm8(insn->off)) * End result: x86 insn "mov rbx, qword ptr [rax+0x14]" * of 4 bytes will be ignored and rbx will be zero inited. */ - ex->fixup = (prog - temp) | (reg2pt_regs[dst_reg] << 8); + ex->fixup = (prog - start_of_ldx) | (reg2pt_regs[dst_reg] << 8); } break; - /* STX XADD: lock *(u32*)(dst_reg + off) += src_reg */ - case BPF_STX | BPF_XADD | BPF_W: - /* Emit 'lock add dword ptr [rax + off], eax' */ - if (is_ereg(dst_reg) || is_ereg(src_reg)) - EMIT3(0xF0, add_2mod(0x40, dst_reg, src_reg), 0x01); - else - EMIT2(0xF0, 0x01); - goto xadd; - case BPF_STX | BPF_XADD | BPF_DW: - EMIT3(0xF0, add_2mod(0x48, dst_reg, src_reg), 0x01); -xadd: if (is_imm8(insn->off)) - EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off); - else - EMIT1_off32(add_2reg(0x80, dst_reg, src_reg), - insn->off); + case BPF_STX | BPF_ATOMIC | BPF_W: + case BPF_STX | BPF_ATOMIC | BPF_DW: + if (insn->imm == (BPF_AND | BPF_FETCH) || + insn->imm == (BPF_OR | BPF_FETCH) || + insn->imm == (BPF_XOR | BPF_FETCH)) { + bool is64 = BPF_SIZE(insn->code) == BPF_DW; + u32 real_src_reg = src_reg; + u32 real_dst_reg = dst_reg; + u8 *branch_target; + + /* + * Can't be implemented with a single x86 insn. + * Need to do a CMPXCHG loop. + */ + + /* Will need RAX as a CMPXCHG operand so save R0 */ + emit_mov_reg(&prog, true, BPF_REG_AX, BPF_REG_0); + if (src_reg == BPF_REG_0) + real_src_reg = BPF_REG_AX; + if (dst_reg == BPF_REG_0) + real_dst_reg = BPF_REG_AX; + + branch_target = prog; + /* Load old value */ + emit_ldx(&prog, BPF_SIZE(insn->code), + BPF_REG_0, real_dst_reg, insn->off); + /* + * Perform the (commutative) operation locally, + * put the result in the AUX_REG. + */ + emit_mov_reg(&prog, is64, AUX_REG, BPF_REG_0); + maybe_emit_mod(&prog, AUX_REG, real_src_reg, is64); + EMIT2(simple_alu_opcodes[BPF_OP(insn->imm)], + add_2reg(0xC0, AUX_REG, real_src_reg)); + /* Attempt to swap in new value */ + err = emit_atomic(&prog, BPF_CMPXCHG, + real_dst_reg, AUX_REG, + insn->off, + BPF_SIZE(insn->code)); + if (WARN_ON(err)) + return err; + /* + * ZF tells us whether we won the race. If it's + * cleared we need to try again. + */ + EMIT2(X86_JNE, -(prog - branch_target) - 2); + /* Return the pre-modification value */ + emit_mov_reg(&prog, is64, real_src_reg, BPF_REG_0); + /* Restore R0 after clobbering RAX */ + emit_mov_reg(&prog, true, BPF_REG_0, BPF_REG_AX); + break; + } + + err = emit_atomic(&prog, insn->imm, dst_reg, src_reg, + insn->off, BPF_SIZE(insn->code)); + if (err) + return err; break; /* call */ case BPF_JMP | BPF_CALL: func = (u8 *) __bpf_call_base + imm32; - if (!imm32 || emit_call(&prog, func, image + addrs[i - 1])) - return -EINVAL; + if (tail_call_reachable) { + /* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */ + EMIT3_off32(0x48, 0x8B, 0x85, + -round_up(bpf_prog->aux->stack_depth, 8) - 8); + if (!imm32 || emit_call(&prog, func, image + addrs[i - 1] + 7)) + return -EINVAL; + } else { + if (!imm32 || emit_call(&prog, func, image + addrs[i - 1])) + return -EINVAL; + } break; case BPF_JMP | BPF_TAIL_CALL: if (imm32) emit_bpf_tail_call_direct(&bpf_prog->aux->poke_tab[imm32 - 1], - &prog, addrs[i], image); + &prog, image + addrs[i - 1], + callee_regs_used, + bpf_prog->aux->stack_depth, + ctx); else - emit_bpf_tail_call_indirect(&prog); + emit_bpf_tail_call_indirect(&prog, + callee_regs_used, + bpf_prog->aux->stack_depth, + image + addrs[i - 1], + ctx); break; /* cond jump */ @@ -1124,30 +1497,24 @@ xadd: if (is_imm8(insn->off)) case BPF_JMP32 | BPF_JSGE | BPF_X: case BPF_JMP32 | BPF_JSLE | BPF_X: /* cmp dst_reg, src_reg */ - if (BPF_CLASS(insn->code) == BPF_JMP) - EMIT1(add_2mod(0x48, dst_reg, src_reg)); - else if (is_ereg(dst_reg) || is_ereg(src_reg)) - EMIT1(add_2mod(0x40, dst_reg, src_reg)); + maybe_emit_mod(&prog, dst_reg, src_reg, + BPF_CLASS(insn->code) == BPF_JMP); EMIT2(0x39, add_2reg(0xC0, dst_reg, src_reg)); goto emit_cond_jmp; case BPF_JMP | BPF_JSET | BPF_X: case BPF_JMP32 | BPF_JSET | BPF_X: /* test dst_reg, src_reg */ - if (BPF_CLASS(insn->code) == BPF_JMP) - EMIT1(add_2mod(0x48, dst_reg, src_reg)); - else if (is_ereg(dst_reg) || is_ereg(src_reg)) - EMIT1(add_2mod(0x40, dst_reg, src_reg)); + maybe_emit_mod(&prog, dst_reg, src_reg, + BPF_CLASS(insn->code) == BPF_JMP); EMIT2(0x85, add_2reg(0xC0, dst_reg, src_reg)); goto emit_cond_jmp; case BPF_JMP | BPF_JSET | BPF_K: case BPF_JMP32 | BPF_JSET | BPF_K: /* test dst_reg, imm32 */ - if (BPF_CLASS(insn->code) == BPF_JMP) - EMIT1(add_1mod(0x48, dst_reg)); - else if (is_ereg(dst_reg)) - EMIT1(add_1mod(0x40, dst_reg)); + maybe_emit_1mod(&prog, dst_reg, + BPF_CLASS(insn->code) == BPF_JMP); EMIT2_off32(0xF7, add_1reg(0xC0, dst_reg), imm32); goto emit_cond_jmp; @@ -1173,19 +1540,15 @@ xadd: if (is_imm8(insn->off)) case BPF_JMP32 | BPF_JSLE | BPF_K: /* test dst_reg, dst_reg to save one extra byte */ if (imm32 == 0) { - if (BPF_CLASS(insn->code) == BPF_JMP) - EMIT1(add_2mod(0x48, dst_reg, dst_reg)); - else if (is_ereg(dst_reg)) - EMIT1(add_2mod(0x40, dst_reg, dst_reg)); + maybe_emit_mod(&prog, dst_reg, dst_reg, + BPF_CLASS(insn->code) == BPF_JMP); EMIT2(0x85, add_2reg(0xC0, dst_reg, dst_reg)); goto emit_cond_jmp; } /* cmp dst_reg, imm8/32 */ - if (BPF_CLASS(insn->code) == BPF_JMP) - EMIT1(add_1mod(0x48, dst_reg)); - else if (is_ereg(dst_reg)) - EMIT1(add_1mod(0x40, dst_reg)); + maybe_emit_1mod(&prog, dst_reg, + BPF_CLASS(insn->code) == BPF_JMP); if (is_imm8(imm32)) EMIT3(0x83, add_1reg(0xF8, dst_reg), imm32); @@ -1238,6 +1601,30 @@ emit_cond_jmp: /* Convert BPF opcode to x86 */ } jmp_offset = addrs[i + insn->off] - addrs[i]; if (is_imm8(jmp_offset)) { + if (jmp_padding) { + /* To keep the jmp_offset valid, the extra bytes are + * padded before the jump insn, so we subtract the + * 2 bytes of jmp_cond insn from INSN_SZ_DIFF. + * + * If the previous pass already emits an imm8 + * jmp_cond, then this BPF insn won't shrink, so + * "nops" is 0. + * + * On the other hand, if the previous pass emits an + * imm32 jmp_cond, the extra 4 bytes(*) is padded to + * keep the image from shrinking further. + * + * (*) imm32 jmp_cond is 6 bytes, and imm8 jmp_cond + * is 2 bytes, so the size difference is 4 bytes. + */ + nops = INSN_SZ_DIFF - 2; + if (nops != 0 && nops != 4) { + pr_err("unexpected jmp_cond padding: %d bytes\n", + nops); + return -EFAULT; + } + emit_nops(&prog, nops); + } EMIT2(jmp_cond, jmp_offset); } else if (is_simm32(jmp_offset)) { EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset); @@ -1260,11 +1647,55 @@ emit_cond_jmp: /* Convert BPF opcode to x86 */ else jmp_offset = addrs[i + insn->off] - addrs[i]; - if (!jmp_offset) - /* Optimize out nop jumps */ + if (!jmp_offset) { + /* + * If jmp_padding is enabled, the extra nops will + * be inserted. Otherwise, optimize out nop jumps. + */ + if (jmp_padding) { + /* There are 3 possible conditions. + * (1) This BPF_JA is already optimized out in + * the previous run, so there is no need + * to pad any extra byte (0 byte). + * (2) The previous pass emits an imm8 jmp, + * so we pad 2 bytes to match the previous + * insn size. + * (3) Similarly, the previous pass emits an + * imm32 jmp, and 5 bytes is padded. + */ + nops = INSN_SZ_DIFF; + if (nops != 0 && nops != 2 && nops != 5) { + pr_err("unexpected nop jump padding: %d bytes\n", + nops); + return -EFAULT; + } + emit_nops(&prog, nops); + } break; + } emit_jmp: if (is_imm8(jmp_offset)) { + if (jmp_padding) { + /* To avoid breaking jmp_offset, the extra bytes + * are padded before the actual jmp insn, so + * 2 bytes is subtracted from INSN_SZ_DIFF. + * + * If the previous pass already emits an imm8 + * jmp, there is nothing to pad (0 byte). + * + * If it emits an imm32 jmp (5 bytes) previously + * and now an imm8 jmp (2 bytes), then we pad + * (5 - 2 = 3) bytes to stop the image from + * shrinking further. + */ + nops = INSN_SZ_DIFF - 2; + if (nops != 0 && nops != 3) { + pr_err("unexpected jump padding: %d bytes\n", + nops); + return -EFAULT; + } + emit_nops(&prog, INSN_SZ_DIFF - 2); + } EMIT2(0xEB, jmp_offset); } else if (is_simm32(jmp_offset)) { EMIT1_off32(0xE9, jmp_offset); @@ -1282,14 +1713,9 @@ emit_jmp: seen_exit = true; /* Update cleanup_addr */ ctx->cleanup_addr = proglen; - if (!bpf_prog_was_classic(bpf_prog)) - EMIT1(0x5B); /* get rid of tail_call_cnt */ - EMIT2(0x41, 0x5F); /* pop r15 */ - EMIT2(0x41, 0x5E); /* pop r14 */ - EMIT2(0x41, 0x5D); /* pop r13 */ - EMIT1(0x5B); /* pop rbx */ + pop_callee_regs(&prog, callee_regs_used); EMIT1(0xC9); /* leave */ - EMIT1(0xC3); /* ret */ + emit_return(&prog, image + addrs[i - 1] + (prog - temp)); break; default: @@ -1310,11 +1736,20 @@ emit_jmp: } if (image) { - if (unlikely(proglen + ilen > oldproglen)) { + /* + * When populating the image, assert that: + * + * i) We do not write beyond the allocated space, and + * ii) addrs[i] did not change from the prior run, in order + * to validate assumptions made for computing branch + * displacements. + */ + if (unlikely(proglen + ilen > oldproglen || + proglen + ilen != addrs[i])) { pr_err("bpf_jit: fatal error\n"); return -EFAULT; } - memcpy(image + proglen, temp, ilen); + memcpy(rw_image + proglen, temp, ilen); } proglen += ilen; addrs[i] = proglen; @@ -1331,67 +1766,230 @@ emit_jmp: static void save_regs(const struct btf_func_model *m, u8 **prog, int nr_args, int stack_size) { - int i; + int i, j, arg_size, nr_regs; /* Store function arguments to stack. * For a function that accepts two pointers the sequence will be: * mov QWORD PTR [rbp-0x10],rdi * mov QWORD PTR [rbp-0x8],rsi */ - for (i = 0; i < min(nr_args, 6); i++) - emit_stx(prog, bytes_to_bpf_size(m->arg_size[i]), - BPF_REG_FP, - i == 5 ? X86_REG_R9 : BPF_REG_1 + i, - -(stack_size - i * 8)); + for (i = 0, j = 0; i < min(nr_args, 6); i++) { + if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG) { + nr_regs = (m->arg_size[i] + 7) / 8; + arg_size = 8; + } else { + nr_regs = 1; + arg_size = m->arg_size[i]; + } + + while (nr_regs) { + emit_stx(prog, bytes_to_bpf_size(arg_size), + BPF_REG_FP, + j == 5 ? X86_REG_R9 : BPF_REG_1 + j, + -(stack_size - j * 8)); + nr_regs--; + j++; + } + } } static void restore_regs(const struct btf_func_model *m, u8 **prog, int nr_args, int stack_size) { - int i; + int i, j, arg_size, nr_regs; /* Restore function arguments from stack. * For a function that accepts two pointers the sequence will be: * EMIT4(0x48, 0x8B, 0x7D, 0xF0); mov rdi,QWORD PTR [rbp-0x10] * EMIT4(0x48, 0x8B, 0x75, 0xF8); mov rsi,QWORD PTR [rbp-0x8] */ - for (i = 0; i < min(nr_args, 6); i++) - emit_ldx(prog, bytes_to_bpf_size(m->arg_size[i]), - i == 5 ? X86_REG_R9 : BPF_REG_1 + i, - BPF_REG_FP, - -(stack_size - i * 8)); + for (i = 0, j = 0; i < min(nr_args, 6); i++) { + if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG) { + nr_regs = (m->arg_size[i] + 7) / 8; + arg_size = 8; + } else { + nr_regs = 1; + arg_size = m->arg_size[i]; + } + + while (nr_regs) { + emit_ldx(prog, bytes_to_bpf_size(arg_size), + j == 5 ? X86_REG_R9 : BPF_REG_1 + j, + BPF_REG_FP, + -(stack_size - j * 8)); + nr_regs--; + j++; + } + } +} + +static int invoke_bpf_prog(const struct btf_func_model *m, u8 **pprog, + struct bpf_tramp_link *l, int stack_size, + int run_ctx_off, bool save_ret) +{ + void (*exit)(struct bpf_prog *prog, u64 start, + struct bpf_tramp_run_ctx *run_ctx) = __bpf_prog_exit; + u64 (*enter)(struct bpf_prog *prog, + struct bpf_tramp_run_ctx *run_ctx) = __bpf_prog_enter; + u8 *prog = *pprog; + u8 *jmp_insn; + int ctx_cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie); + struct bpf_prog *p = l->link.prog; + u64 cookie = l->cookie; + + /* mov rdi, cookie */ + emit_mov_imm64(&prog, BPF_REG_1, (long) cookie >> 32, (u32) (long) cookie); + + /* Prepare struct bpf_tramp_run_ctx. + * + * bpf_tramp_run_ctx is already preserved by + * arch_prepare_bpf_trampoline(). + * + * mov QWORD PTR [rbp - run_ctx_off + ctx_cookie_off], rdi + */ + emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_1, -run_ctx_off + ctx_cookie_off); + + if (p->aux->sleepable) { + enter = __bpf_prog_enter_sleepable; + exit = __bpf_prog_exit_sleepable; + } else if (p->type == BPF_PROG_TYPE_STRUCT_OPS) { + enter = __bpf_prog_enter_struct_ops; + exit = __bpf_prog_exit_struct_ops; + } else if (p->expected_attach_type == BPF_LSM_CGROUP) { + enter = __bpf_prog_enter_lsm_cgroup; + exit = __bpf_prog_exit_lsm_cgroup; + } + + /* arg1: mov rdi, progs[i] */ + emit_mov_imm64(&prog, BPF_REG_1, (long) p >> 32, (u32) (long) p); + /* arg2: lea rsi, [rbp - ctx_cookie_off] */ + EMIT4(0x48, 0x8D, 0x75, -run_ctx_off); + + if (emit_call(&prog, enter, prog)) + return -EINVAL; + /* remember prog start time returned by __bpf_prog_enter */ + emit_mov_reg(&prog, true, BPF_REG_6, BPF_REG_0); + + /* if (__bpf_prog_enter*(prog) == 0) + * goto skip_exec_of_prog; + */ + EMIT3(0x48, 0x85, 0xC0); /* test rax,rax */ + /* emit 2 nops that will be replaced with JE insn */ + jmp_insn = prog; + emit_nops(&prog, 2); + + /* arg1: lea rdi, [rbp - stack_size] */ + EMIT4(0x48, 0x8D, 0x7D, -stack_size); + /* arg2: progs[i]->insnsi for interpreter */ + if (!p->jited) + emit_mov_imm64(&prog, BPF_REG_2, + (long) p->insnsi >> 32, + (u32) (long) p->insnsi); + /* call JITed bpf program or interpreter */ + if (emit_call(&prog, p->bpf_func, prog)) + return -EINVAL; + + /* + * BPF_TRAMP_MODIFY_RETURN trampolines can modify the return + * of the previous call which is then passed on the stack to + * the next BPF program. + * + * BPF_TRAMP_FENTRY trampoline may need to return the return + * value of BPF_PROG_TYPE_STRUCT_OPS prog. + */ + if (save_ret) + emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8); + + /* replace 2 nops with JE insn, since jmp target is known */ + jmp_insn[0] = X86_JE; + jmp_insn[1] = prog - jmp_insn - 2; + + /* arg1: mov rdi, progs[i] */ + emit_mov_imm64(&prog, BPF_REG_1, (long) p >> 32, (u32) (long) p); + /* arg2: mov rsi, rbx <- start time in nsec */ + emit_mov_reg(&prog, true, BPF_REG_2, BPF_REG_6); + /* arg3: lea rdx, [rbp - run_ctx_off] */ + EMIT4(0x48, 0x8D, 0x55, -run_ctx_off); + if (emit_call(&prog, exit, prog)) + return -EINVAL; + + *pprog = prog; + return 0; +} + +static void emit_align(u8 **pprog, u32 align) +{ + u8 *target, *prog = *pprog; + + target = PTR_ALIGN(prog, align); + if (target != prog) + emit_nops(&prog, target - prog); + + *pprog = prog; +} + +static int emit_cond_near_jump(u8 **pprog, void *func, void *ip, u8 jmp_cond) +{ + u8 *prog = *pprog; + s64 offset; + + offset = func - (ip + 2 + 4); + if (!is_simm32(offset)) { + pr_err("Target %p is out of range\n", func); + return -EINVAL; + } + EMIT2_off32(0x0F, jmp_cond + 0x10, offset); + *pprog = prog; + return 0; } static int invoke_bpf(const struct btf_func_model *m, u8 **pprog, - struct bpf_prog **progs, int prog_cnt, int stack_size) + struct bpf_tramp_links *tl, int stack_size, + int run_ctx_off, bool save_ret) { + int i; u8 *prog = *pprog; - int cnt = 0, i; - for (i = 0; i < prog_cnt; i++) { - if (emit_call(&prog, __bpf_prog_enter, prog)) - return -EINVAL; - /* remember prog start time returned by __bpf_prog_enter */ - emit_mov_reg(&prog, true, BPF_REG_6, BPF_REG_0); - - /* arg1: lea rdi, [rbp - stack_size] */ - EMIT4(0x48, 0x8D, 0x7D, -stack_size); - /* arg2: progs[i]->insnsi for interpreter */ - if (!progs[i]->jited) - emit_mov_imm64(&prog, BPF_REG_2, - (long) progs[i]->insnsi >> 32, - (u32) (long) progs[i]->insnsi); - /* call JITed bpf program or interpreter */ - if (emit_call(&prog, progs[i]->bpf_func, prog)) + for (i = 0; i < tl->nr_links; i++) { + if (invoke_bpf_prog(m, &prog, tl->links[i], stack_size, + run_ctx_off, save_ret)) return -EINVAL; + } + *pprog = prog; + return 0; +} + +static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog, + struct bpf_tramp_links *tl, int stack_size, + int run_ctx_off, u8 **branches) +{ + u8 *prog = *pprog; + int i; - /* arg1: mov rdi, progs[i] */ - emit_mov_imm64(&prog, BPF_REG_1, (long) progs[i] >> 32, - (u32) (long) progs[i]); - /* arg2: mov rsi, rbx <- start time in nsec */ - emit_mov_reg(&prog, true, BPF_REG_2, BPF_REG_6); - if (emit_call(&prog, __bpf_prog_exit, prog)) + /* The first fmod_ret program will receive a garbage return value. + * Set this to 0 to avoid confusing the program. + */ + emit_mov_imm32(&prog, false, BPF_REG_0, 0); + emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8); + for (i = 0; i < tl->nr_links; i++) { + if (invoke_bpf_prog(m, &prog, tl->links[i], stack_size, run_ctx_off, true)) return -EINVAL; + + /* mod_ret prog stored return value into [rbp - 8]. Emit: + * if (*(u64 *)(rbp - 8) != 0) + * goto do_fexit; + */ + /* cmp QWORD PTR [rbp - 0x8], 0x0 */ + EMIT4(0x48, 0x83, 0x7d, 0xf8); EMIT1(0x00); + + /* Save the location of the branch and Generate 6 nops + * (4 bytes for an offset and 2 bytes for the jump) These nops + * are replaced with a conditional jump once do_fexit (i.e. the + * start of the fexit invocation) is finalized. + */ + branches[i] = prog; + emit_nops(&prog, 4 + 2); } + *pprog = prog; return 0; } @@ -1456,66 +2054,192 @@ static int invoke_bpf(const struct btf_func_model *m, u8 **pprog, * add rsp, 8 // skip eth_type_trans's frame * ret // return to its caller */ -int arch_prepare_bpf_trampoline(void *image, void *image_end, +int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *image_end, const struct btf_func_model *m, u32 flags, - struct bpf_prog **fentry_progs, int fentry_cnt, - struct bpf_prog **fexit_progs, int fexit_cnt, - void *orig_call) + struct bpf_tramp_links *tlinks, + void *func_addr) { - int cnt = 0, nr_args = m->nr_args; - int stack_size = nr_args * 8; + int ret, i, nr_args = m->nr_args, extra_nregs = 0; + int regs_off, ip_off, args_off, stack_size = nr_args * 8, run_ctx_off; + struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY]; + struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT]; + struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN]; + void *orig_call = func_addr; + u8 **branches = NULL; u8 *prog; + bool save_ret; /* x86-64 supports up to 6 arguments. 7+ can be added in the future */ if (nr_args > 6) return -ENOTSUPP; - if ((flags & BPF_TRAMP_F_RESTORE_REGS) && - (flags & BPF_TRAMP_F_SKIP_FRAME)) - return -EINVAL; + for (i = 0; i < MAX_BPF_FUNC_ARGS; i++) { + if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG) + extra_nregs += (m->arg_size[i] + 7) / 8 - 1; + } + if (nr_args + extra_nregs > 6) + return -ENOTSUPP; + stack_size += extra_nregs * 8; + + /* Generated trampoline stack layout: + * + * RBP + 8 [ return address ] + * RBP + 0 [ RBP ] + * + * RBP - 8 [ return value ] BPF_TRAMP_F_CALL_ORIG or + * BPF_TRAMP_F_RET_FENTRY_RET flags + * + * [ reg_argN ] always + * [ ... ] + * RBP - regs_off [ reg_arg1 ] program's ctx pointer + * + * RBP - args_off [ arg regs count ] always + * + * RBP - ip_off [ traced function ] BPF_TRAMP_F_IP_ARG flag + * + * RBP - run_ctx_off [ bpf_tramp_run_ctx ] + */ - if (flags & BPF_TRAMP_F_CALL_ORIG) - stack_size += 8; /* room for return value of orig_call */ + /* room for return value of orig_call or fentry prog */ + save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET); + if (save_ret) + stack_size += 8; - if (flags & BPF_TRAMP_F_SKIP_FRAME) + regs_off = stack_size; + + /* args count */ + stack_size += 8; + args_off = stack_size; + + if (flags & BPF_TRAMP_F_IP_ARG) + stack_size += 8; /* room for IP address argument */ + + ip_off = stack_size; + + stack_size += (sizeof(struct bpf_tramp_run_ctx) + 7) & ~0x7; + run_ctx_off = stack_size; + + if (flags & BPF_TRAMP_F_SKIP_FRAME) { /* skip patched call instruction and point orig_call to actual * body of the kernel function. */ + if (is_endbr(*(u32 *)orig_call)) + orig_call += ENDBR_INSN_SIZE; orig_call += X86_PATCH_SIZE; + } prog = image; + EMIT_ENDBR(); EMIT1(0x55); /* push rbp */ EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */ EMIT4(0x48, 0x83, 0xEC, stack_size); /* sub rsp, stack_size */ EMIT1(0x53); /* push rbx */ - save_regs(m, &prog, nr_args, stack_size); + /* Store number of argument registers of the traced function: + * mov rax, nr_args + extra_nregs + * mov QWORD PTR [rbp - args_off], rax + */ + emit_mov_imm64(&prog, BPF_REG_0, 0, (u32) nr_args + extra_nregs); + emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -args_off); - if (fentry_cnt) - if (invoke_bpf(m, &prog, fentry_progs, fentry_cnt, stack_size)) - return -EINVAL; + if (flags & BPF_TRAMP_F_IP_ARG) { + /* Store IP address of the traced function: + * movabsq rax, func_addr + * mov QWORD PTR [rbp - ip_off], rax + */ + emit_mov_imm64(&prog, BPF_REG_0, (long) func_addr >> 32, (u32) (long) func_addr); + emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -ip_off); + } + + save_regs(m, &prog, nr_args, regs_off); if (flags & BPF_TRAMP_F_CALL_ORIG) { - if (fentry_cnt) - restore_regs(m, &prog, nr_args, stack_size); + /* arg1: mov rdi, im */ + emit_mov_imm64(&prog, BPF_REG_1, (long) im >> 32, (u32) (long) im); + if (emit_call(&prog, __bpf_tramp_enter, prog)) { + ret = -EINVAL; + goto cleanup; + } + } - /* call original function */ - if (emit_call(&prog, orig_call, prog)) + if (fentry->nr_links) + if (invoke_bpf(m, &prog, fentry, regs_off, run_ctx_off, + flags & BPF_TRAMP_F_RET_FENTRY_RET)) return -EINVAL; + + if (fmod_ret->nr_links) { + branches = kcalloc(fmod_ret->nr_links, sizeof(u8 *), + GFP_KERNEL); + if (!branches) + return -ENOMEM; + + if (invoke_bpf_mod_ret(m, &prog, fmod_ret, regs_off, + run_ctx_off, branches)) { + ret = -EINVAL; + goto cleanup; + } + } + + if (flags & BPF_TRAMP_F_CALL_ORIG) { + restore_regs(m, &prog, nr_args, regs_off); + + if (flags & BPF_TRAMP_F_ORIG_STACK) { + emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, 8); + EMIT2(0xff, 0xd0); /* call *rax */ + } else { + /* call original function */ + if (emit_call(&prog, orig_call, prog)) { + ret = -EINVAL; + goto cleanup; + } + } /* remember return value in a stack for bpf prog to access */ emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8); + im->ip_after_call = prog; + memcpy(prog, x86_nops[5], X86_PATCH_SIZE); + prog += X86_PATCH_SIZE; } - if (fexit_cnt) - if (invoke_bpf(m, &prog, fexit_progs, fexit_cnt, stack_size)) - return -EINVAL; + if (fmod_ret->nr_links) { + /* From Intel 64 and IA-32 Architectures Optimization + * Reference Manual, 3.4.1.4 Code Alignment, Assembly/Compiler + * Coding Rule 11: All branch targets should be 16-byte + * aligned. + */ + emit_align(&prog, 16); + /* Update the branches saved in invoke_bpf_mod_ret with the + * aligned address of do_fexit. + */ + for (i = 0; i < fmod_ret->nr_links; i++) + emit_cond_near_jump(&branches[i], prog, branches[i], + X86_JNE); + } + + if (fexit->nr_links) + if (invoke_bpf(m, &prog, fexit, regs_off, run_ctx_off, false)) { + ret = -EINVAL; + goto cleanup; + } if (flags & BPF_TRAMP_F_RESTORE_REGS) - restore_regs(m, &prog, nr_args, stack_size); + restore_regs(m, &prog, nr_args, regs_off); - if (flags & BPF_TRAMP_F_CALL_ORIG) - /* restore original return value back into RAX */ + /* This needs to be done regardless. If there were fmod_ret programs, + * the return value is only updated on the stack and still needs to be + * restored to R0. + */ + if (flags & BPF_TRAMP_F_CALL_ORIG) { + im->ip_epilogue = prog; + /* arg1: mov rdi, im */ + emit_mov_imm64(&prog, BPF_REG_1, (long) im >> 32, (u32) (long) im); + if (emit_call(&prog, __bpf_tramp_exit, prog)) { + ret = -EINVAL; + goto cleanup; + } + } + /* restore return value of orig_call or fentry prog back into RAX */ + if (save_ret) emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, -8); EMIT1(0x5B); /* pop rbx */ @@ -1523,73 +2247,23 @@ int arch_prepare_bpf_trampoline(void *image, void *image_end, if (flags & BPF_TRAMP_F_SKIP_FRAME) /* skip our return address and return to parent */ EMIT4(0x48, 0x83, 0xC4, 8); /* add rsp, 8 */ - EMIT1(0xC3); /* ret */ + emit_return(&prog, prog); /* Make sure the trampoline generation logic doesn't overflow */ - if (WARN_ON_ONCE(prog > (u8 *)image_end - BPF_INSN_SAFETY)) - return -EFAULT; - return prog - (u8 *)image; -} - -static int emit_cond_near_jump(u8 **pprog, void *func, void *ip, u8 jmp_cond) -{ - u8 *prog = *pprog; - int cnt = 0; - s64 offset; - - offset = func - (ip + 2 + 4); - if (!is_simm32(offset)) { - pr_err("Target %p is out of range\n", func); - return -EINVAL; - } - EMIT2_off32(0x0F, jmp_cond + 0x10, offset); - *pprog = prog; - return 0; -} - -static void emit_nops(u8 **pprog, unsigned int len) -{ - unsigned int i, noplen; - u8 *prog = *pprog; - int cnt = 0; - - while (len > 0) { - noplen = len; - - if (noplen > ASM_NOP_MAX) - noplen = ASM_NOP_MAX; - - for (i = 0; i < noplen; i++) - EMIT1(ideal_nops[noplen][i]); - len -= noplen; + if (WARN_ON_ONCE(prog > (u8 *)image_end - BPF_INSN_SAFETY)) { + ret = -EFAULT; + goto cleanup; } + ret = prog - (u8 *)image; - *pprog = prog; -} - -static int emit_fallback_jump(u8 **pprog) -{ - u8 *prog = *pprog; - int err = 0; - -#ifdef CONFIG_RETPOLINE - /* Note that this assumes the the compiler uses external - * thunks for indirect calls. Both clang and GCC use the same - * naming convention for external thunks. - */ - err = emit_jump(&prog, __x86_indirect_thunk_rdx, prog); -#else - int cnt = 0; - - EMIT2(0xFF, 0xE2); /* jmp rdx */ -#endif - *pprog = prog; - return err; +cleanup: + kfree(branches); + return ret; } -static int emit_bpf_dispatcher(u8 **pprog, int a, int b, s64 *progs) +static int emit_bpf_dispatcher(u8 **pprog, int a, int b, s64 *progs, u8 *image, u8 *buf) { - u8 *jg_reloc, *jg_target, *prog = *pprog; - int pivot, err, jg_bytes = 1, cnt = 0; + u8 *jg_reloc, *prog = *pprog; + int pivot, err, jg_bytes = 1; s64 jg_offset; if (a == b) { @@ -1602,14 +2276,12 @@ static int emit_bpf_dispatcher(u8 **pprog, int a, int b, s64 *progs) EMIT2_off32(0x81, add_1reg(0xF8, BPF_REG_3), progs[a]); err = emit_cond_near_jump(&prog, /* je func */ - (void *)progs[a], prog, + (void *)progs[a], image + (prog - buf), X86_JE); if (err) return err; - err = emit_fallback_jump(&prog); /* jmp thunk/indirect */ - if (err) - return err; + emit_indirect_jump(&prog, 2 /* rdx */, image + (prog - buf)); *pprog = prog; return 0; @@ -1634,7 +2306,7 @@ static int emit_bpf_dispatcher(u8 **pprog, int a, int b, s64 *progs) jg_reloc = prog; err = emit_bpf_dispatcher(&prog, a, a + pivot, /* emit lower_part */ - progs); + progs, image, buf); if (err) return err; @@ -1643,14 +2315,12 @@ static int emit_bpf_dispatcher(u8 **pprog, int a, int b, s64 *progs) * Coding Rule 11: All branch targets should be 16-byte * aligned. */ - jg_target = PTR_ALIGN(prog, 16); - if (jg_target != prog) - emit_nops(&prog, jg_target - prog); + emit_align(&prog, 16); jg_offset = prog - jg_reloc; emit_code(jg_reloc - jg_bytes, jg_offset, jg_bytes); err = emit_bpf_dispatcher(&prog, a + pivot + 1, /* emit upper_part */ - b, progs); + b, progs, image, buf); if (err) return err; @@ -1670,15 +2340,16 @@ static int cmp_ips(const void *a, const void *b) return 0; } -int arch_prepare_bpf_dispatcher(void *image, s64 *funcs, int num_funcs) +int arch_prepare_bpf_dispatcher(void *image, void *buf, s64 *funcs, int num_funcs) { - u8 *prog = image; + u8 *prog = buf; sort(funcs, num_funcs, sizeof(funcs[0]), cmp_ips, NULL); - return emit_bpf_dispatcher(&prog, 0, num_funcs - 1, funcs); + return emit_bpf_dispatcher(&prog, 0, num_funcs - 1, funcs, image, buf); } struct x64_jit_data { + struct bpf_binary_header *rw_header; struct bpf_binary_header *header; int *addrs; u8 *image; @@ -1686,8 +2357,12 @@ struct x64_jit_data { struct jit_context ctx; }; +#define MAX_PASSES 20 +#define PADDING_PASSES (MAX_PASSES - 5) + struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) { + struct bpf_binary_header *rw_header = NULL; struct bpf_binary_header *header = NULL; struct bpf_prog *tmp, *orig_prog = prog; struct x64_jit_data *jit_data; @@ -1695,6 +2370,8 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) struct jit_context ctx = {}; bool tmp_blinded = false; bool extra_pass = false; + bool padding = false; + u8 *rw_image = NULL; u8 *image = NULL; int *addrs; int pass; @@ -1730,10 +2407,13 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) oldproglen = jit_data->proglen; image = jit_data->image; header = jit_data->header; + rw_header = jit_data->rw_header; + rw_image = (void *)rw_header + ((void *)image - (void *)header); extra_pass = true; + padding = true; goto skip_init_addrs; } - addrs = kmalloc_array(prog->len + 1, sizeof(*addrs), GFP_KERNEL); + addrs = kvmalloc_array(prog->len + 1, sizeof(*addrs), GFP_KERNEL); if (!addrs) { prog = orig_prog; goto out_addrs; @@ -1756,14 +2436,25 @@ skip_init_addrs: * may converge on the last pass. In such case do one more * pass to emit the final image. */ - for (pass = 0; pass < 20 || image; pass++) { - proglen = do_jit(prog, addrs, image, oldproglen, &ctx); + for (pass = 0; pass < MAX_PASSES || image; pass++) { + if (!padding && pass >= PADDING_PASSES) + padding = true; + proglen = do_jit(prog, addrs, image, rw_image, oldproglen, &ctx, padding); if (proglen <= 0) { out_image: image = NULL; - if (header) - bpf_jit_binary_free(header); + if (header) { + bpf_arch_text_copy(&header->size, &rw_header->size, + sizeof(rw_header->size)); + bpf_jit_binary_pack_free(header, rw_header); + } + /* Fall back to interpreter mode */ prog = orig_prog; + if (extra_pass) { + prog->bpf_func = NULL; + prog->jited = 0; + prog->jited_len = 0; + } goto out_addrs; } if (image) { @@ -1786,8 +2477,9 @@ out_image: sizeof(struct exception_table_entry); /* allocate module memory for x86 insns and extable */ - header = bpf_jit_binary_alloc(roundup(proglen, align) + extable_size, - &image, align, jit_fill_hole); + header = bpf_jit_binary_pack_alloc(roundup(proglen, align) + extable_size, + &image, align, &rw_header, &rw_image, + jit_fill_hole); if (!header) { prog = orig_prog; goto out_addrs; @@ -1803,14 +2495,27 @@ out_image: if (image) { if (!prog->is_func || extra_pass) { + /* + * bpf_jit_binary_pack_finalize fails in two scenarios: + * 1) header is not pointing to proper module memory; + * 2) the arch doesn't support bpf_arch_text_copy(). + * + * Both cases are serious bugs and justify WARN_ON. + */ + if (WARN_ON(bpf_jit_binary_pack_finalize(prog, header, rw_header))) { + /* header has been freed */ + header = NULL; + goto out_image; + } + bpf_tail_call_direct_fixup(prog); - bpf_jit_binary_lock_ro(header); } else { jit_data->addrs = addrs; jit_data->ctx = ctx; jit_data->proglen = proglen; jit_data->image = image; jit_data->header = header; + jit_data->rw_header = rw_header; } prog->bpf_func = (void *)image; prog->jited = 1; @@ -1823,7 +2528,7 @@ out_image: if (image) bpf_prog_fill_jited_linfo(prog, addrs + 1); out_addrs: - kfree(addrs); + kvfree(addrs); kfree(jit_data); prog->aux->jit_data = NULL; } @@ -1833,3 +2538,46 @@ out: tmp : orig_prog); return prog; } + +bool bpf_jit_supports_kfunc_call(void) +{ + return true; +} + +void *bpf_arch_text_copy(void *dst, void *src, size_t len) +{ + if (text_poke_copy(dst, src, len) == NULL) + return ERR_PTR(-EINVAL); + return dst; +} + +/* Indicate the JIT backend supports mixing bpf2bpf and tailcalls. */ +bool bpf_jit_supports_subprog_tailcalls(void) +{ + return true; +} + +void bpf_jit_free(struct bpf_prog *prog) +{ + if (prog->jited) { + struct x64_jit_data *jit_data = prog->aux->jit_data; + struct bpf_binary_header *hdr; + + /* + * If we fail the final pass of JIT (from jit_subprogs), + * the program may not be finalized yet. Call finalize here + * before freeing it. + */ + if (jit_data) { + bpf_jit_binary_pack_finalize(prog, jit_data->header, + jit_data->rw_header); + kvfree(jit_data->addrs); + kfree(jit_data); + } + hdr = bpf_jit_binary_pack_hdr(prog); + bpf_jit_binary_pack_free(hdr, NULL); + WARN_ON_ONCE(!bpf_prog_kallsyms_verify_off(prog)); + } + + bpf_prog_unlock_free(prog); +} |