aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/arch/riscv/net/bpf_jit_comp64.c
diff options
context:
space:
mode:
Diffstat (limited to 'arch/riscv/net/bpf_jit_comp64.c')
-rw-r--r--arch/riscv/net/bpf_jit_comp64.c113
1 files changed, 76 insertions, 37 deletions
diff --git a/arch/riscv/net/bpf_jit_comp64.c b/arch/riscv/net/bpf_jit_comp64.c
index cc1985d8750a..6cfd164cbe88 100644
--- a/arch/riscv/net/bpf_jit_comp64.c
+++ b/arch/riscv/net/bpf_jit_comp64.c
@@ -110,6 +110,16 @@ static bool is_32b_int(s64 val)
return -(1L << 31) <= val && val < (1L << 31);
}
+static bool in_auipc_jalr_range(s64 val)
+{
+ /*
+ * auipc+jalr can reach any signed PC-relative offset in the range
+ * [-2^31 - 2^11, 2^31 - 2^11).
+ */
+ return (-(1L << 31) - (1L << 11)) <= val &&
+ val < ((1L << 31) - (1L << 11));
+}
+
static void emit_imm(u8 rd, s64 val, struct rv_jit_context *ctx)
{
/* Note that the immediate from the add is sign-extended,
@@ -380,20 +390,24 @@ static void emit_sext_32_rd(u8 *rd, struct rv_jit_context *ctx)
*rd = RV_REG_T2;
}
-static void emit_jump_and_link(u8 rd, s64 rvoff, bool force_jalr,
- struct rv_jit_context *ctx)
+static int emit_jump_and_link(u8 rd, s64 rvoff, bool force_jalr,
+ struct rv_jit_context *ctx)
{
s64 upper, lower;
if (rvoff && is_21b_int(rvoff) && !force_jalr) {
emit(rv_jal(rd, rvoff >> 1), ctx);
- return;
+ return 0;
+ } else if (in_auipc_jalr_range(rvoff)) {
+ upper = (rvoff + (1 << 11)) >> 12;
+ lower = rvoff & 0xfff;
+ emit(rv_auipc(RV_REG_T1, upper), ctx);
+ emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
+ return 0;
}
- upper = (rvoff + (1 << 11)) >> 12;
- lower = rvoff & 0xfff;
- emit(rv_auipc(RV_REG_T1, upper), ctx);
- emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
+ pr_err("bpf-jit: target offset 0x%llx is out of range\n", rvoff);
+ return -ERANGE;
}
static bool is_signed_bpf_cond(u8 cond)
@@ -407,18 +421,16 @@ static int emit_call(bool fixed, u64 addr, struct rv_jit_context *ctx)
s64 off = 0;
u64 ip;
u8 rd;
+ int ret;
if (addr && ctx->insns) {
ip = (u64)(long)(ctx->insns + ctx->ninsns);
off = addr - ip;
- if (!is_32b_int(off)) {
- pr_err("bpf-jit: target call addr %pK is out of range\n",
- (void *)addr);
- return -ERANGE;
- }
}
- emit_jump_and_link(RV_REG_RA, off, !fixed, ctx);
+ ret = emit_jump_and_link(RV_REG_RA, off, !fixed, ctx);
+ if (ret)
+ return ret;
rd = bpf_to_rv_reg(BPF_REG_0, ctx);
emit(rv_addi(rd, RV_REG_A0, 0), ctx);
return 0;
@@ -429,7 +441,7 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
{
bool is64 = BPF_CLASS(insn->code) == BPF_ALU64 ||
BPF_CLASS(insn->code) == BPF_JMP;
- int s, e, rvoff, i = insn - ctx->prog->insnsi;
+ int s, e, rvoff, ret, i = insn - ctx->prog->insnsi;
struct bpf_prog_aux *aux = ctx->prog->aux;
u8 rd = -1, rs = -1, code = insn->code;
s16 off = insn->off;
@@ -503,7 +515,7 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
case BPF_ALU | BPF_LSH | BPF_X:
case BPF_ALU64 | BPF_LSH | BPF_X:
emit(is64 ? rv_sll(rd, rd, rs) : rv_sllw(rd, rd, rs), ctx);
- if (!is64)
+ if (!is64 && !aux->verifier_zext)
emit_zext_32(rd, ctx);
break;
case BPF_ALU | BPF_RSH | BPF_X:
@@ -530,13 +542,21 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
/* dst = BSWAP##imm(dst) */
case BPF_ALU | BPF_END | BPF_FROM_LE:
- {
- int shift = 64 - imm;
-
- emit(rv_slli(rd, rd, shift), ctx);
- emit(rv_srli(rd, rd, shift), ctx);
+ switch (imm) {
+ case 16:
+ emit(rv_slli(rd, rd, 48), ctx);
+ emit(rv_srli(rd, rd, 48), ctx);
+ break;
+ case 32:
+ if (!aux->verifier_zext)
+ emit_zext_32(rd, ctx);
+ break;
+ case 64:
+ /* Do nothing */
+ break;
+ }
break;
- }
+
case BPF_ALU | BPF_END | BPF_FROM_BE:
emit(rv_addi(RV_REG_T2, RV_REG_ZERO, 0), ctx);
@@ -680,26 +700,28 @@ out_be:
case BPF_ALU | BPF_LSH | BPF_K:
case BPF_ALU64 | BPF_LSH | BPF_K:
emit(is64 ? rv_slli(rd, rd, imm) : rv_slliw(rd, rd, imm), ctx);
- if (!is64)
+ if (!is64 && !aux->verifier_zext)
emit_zext_32(rd, ctx);
break;
case BPF_ALU | BPF_RSH | BPF_K:
case BPF_ALU64 | BPF_RSH | BPF_K:
emit(is64 ? rv_srli(rd, rd, imm) : rv_srliw(rd, rd, imm), ctx);
- if (!is64)
+ if (!is64 && !aux->verifier_zext)
emit_zext_32(rd, ctx);
break;
case BPF_ALU | BPF_ARSH | BPF_K:
case BPF_ALU64 | BPF_ARSH | BPF_K:
emit(is64 ? rv_srai(rd, rd, imm) : rv_sraiw(rd, rd, imm), ctx);
- if (!is64)
+ if (!is64 && !aux->verifier_zext)
emit_zext_32(rd, ctx);
break;
/* JUMP off */
case BPF_JMP | BPF_JA:
rvoff = rv_offset(i, off, ctx);
- emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
+ ret = emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
+ if (ret)
+ return ret;
break;
/* IF (dst COND src) JUMP off */
@@ -770,11 +792,15 @@ out_be:
case BPF_JMP32 | BPF_JSGE | BPF_K:
case BPF_JMP | BPF_JSLE | BPF_K:
case BPF_JMP32 | BPF_JSLE | BPF_K:
- case BPF_JMP | BPF_JSET | BPF_K:
- case BPF_JMP32 | BPF_JSET | BPF_K:
rvoff = rv_offset(i, off, ctx);
s = ctx->ninsns;
- emit_imm(RV_REG_T1, imm, ctx);
+ if (imm) {
+ emit_imm(RV_REG_T1, imm, ctx);
+ rs = RV_REG_T1;
+ } else {
+ /* If imm is 0, simply use zero register. */
+ rs = RV_REG_ZERO;
+ }
if (!is64) {
if (is_signed_bpf_cond(BPF_OP(code)))
emit_sext_32_rd(&rd, ctx);
@@ -785,23 +811,34 @@ out_be:
/* Adjust for extra insns */
rvoff -= (e - s) << 2;
+ emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
+ break;
- if (BPF_OP(code) == BPF_JSET) {
- /* Adjust for and */
- rvoff -= 4;
- emit(rv_and(RV_REG_T1, rd, RV_REG_T1), ctx);
- emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff,
- ctx);
+ case BPF_JMP | BPF_JSET | BPF_K:
+ case BPF_JMP32 | BPF_JSET | BPF_K:
+ rvoff = rv_offset(i, off, ctx);
+ s = ctx->ninsns;
+ if (is_12b_int(imm)) {
+ emit(rv_andi(RV_REG_T1, rd, imm), ctx);
} else {
- emit_branch(BPF_OP(code), rd, RV_REG_T1, rvoff, ctx);
+ emit_imm(RV_REG_T1, imm, ctx);
+ emit(rv_and(RV_REG_T1, rd, RV_REG_T1), ctx);
}
+ /* For jset32, we should clear the upper 32 bits of t1, but
+ * sign-extension is sufficient here and saves one instruction,
+ * as t1 is used only in comparison against zero.
+ */
+ if (!is64 && imm < 0)
+ emit(rv_addiw(RV_REG_T1, RV_REG_T1, 0), ctx);
+ e = ctx->ninsns;
+ rvoff -= (e - s) << 2;
+ emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, ctx);
break;
/* function call */
case BPF_JMP | BPF_CALL:
{
bool fixed;
- int ret;
u64 addr;
mark_call(ctx);
@@ -826,7 +863,9 @@ out_be:
break;
rvoff = epilogue_offset(ctx);
- emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
+ ret = emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
+ if (ret)
+ return ret;
break;
/* dst = imm64 */