[PATCH bpf-next v2 4/4] riscv, bpf: Mixing bpf2bpf and tailcalls
Pu Lehui
pulehui at huaweicloud.com
Thu Feb 1 00:22:16 PST 2024
On 2024/1/31 1:30, Björn Töpel wrote:
> Pu Lehui <pulehui at huaweicloud.com> writes:
>
>> From: Pu Lehui <pulehui at huawei.com>
>>
>> In the current RV64 JIT, if we just don't initialize the TCC in subprog,
>> the TCC can be propagated from the parent process to the subprocess, but
>> the TCC of the parent process cannot be restored when the subprocess
>> exits. Since the RV64 TCC is initialized before saving the callee saved
>> registers into the stack, we cannot use the callee saved register to
>> pass the TCC, otherwise the original value of the callee saved register
>> will be destroyed. So we implemented mixing bpf2bpf and tailcalls
>> similar to x86_64, i.e. using a non-callee saved register to transfer
>> the TCC between functions, and saving that register to the stack to
>> protect the TCC value. At the same time, we also consider the scenario
>> of mixing trampoline.
>>
>> Tests test_bpf.ko and test_verifier have passed, as well as the relative
>> testcases of test_progs*.
>
> Ok, I'll summarize, so that I know that I get it. ;-)
>
> All BPF progs (except the main), get the current TCC passed in a6. TCC
> is stored in each BPF stack frame.
>
> During tail calls, the TCC from the stack is loaded, decremented, and
> stored to the stack again.
>
> Mixing bpf2bpf/tailcalls means that each *BPF stackframe* can perform up
> to "current TCC to max_tailscalls" number of calls.
>
> main_prog() calls subprog1(). subprog1() can perform max_tailscalls.
> subprog1() returns, and main_prog() calls subprog2(). subprog2() can
> also perform max_tailscalls.
>
> Correct?
Your summarize is the same as what I thought, A6 is a carrier. I write a
use case to verify this:
diff --git a/tools/testing/selftests/bpf/prog_tests/tailcalls.c
b/tools/testing/selftests/bpf/prog_tests/tailcalls.c
index 59993fc9c0d7..65550e24c843 100644
--- a/tools/testing/selftests/bpf/prog_tests/tailcalls.c
+++ b/tools/testing/selftests/bpf/prog_tests/tailcalls.c
@@ -975,6 +975,80 @@ static void test_tailcall_bpf2bpf_6(void)
tailcall_bpf2bpf6__destroy(obj);
}
+#include "tailcall_bpf2bpf7.skel.h"
+
+static void test_tailcall_bpf2bpf_7(void)
+{
+ int err, map_fd, prog_fd, main_fd, data_fd, i;
+ struct tailcall_bpf2bpf7__bss val;
+ struct bpf_map *prog_array, *data_map;
+ struct bpf_program *prog;
+ struct bpf_object *obj;
+ char prog_name[32];
+ LIBBPF_OPTS(bpf_test_run_opts, topts,
+ .data_in = &pkt_v4,
+ .data_size_in = sizeof(pkt_v4),
+ .repeat = 1,
+ );
+
+ err = bpf_prog_test_load("tailcall_bpf2bpf7.bpf.o",
BPF_PROG_TYPE_SCHED_CLS,
+ &obj, &prog_fd);
+ if (CHECK_FAIL(err))
+ return;
+
+ prog = bpf_object__find_program_by_name(obj, "entry");
+ if (CHECK_FAIL(!prog))
+ goto out;
+
+ main_fd = bpf_program__fd(prog);
+ if (CHECK_FAIL(main_fd < 0))
+ goto out;
+
+ prog_array = bpf_object__find_map_by_name(obj, "jmp_table");
+ if (CHECK_FAIL(!prog_array))
+ goto out;
+
+ map_fd = bpf_map__fd(prog_array);
+ if (CHECK_FAIL(map_fd < 0))
+ goto out;
+
+ for (i = 0; i < bpf_map__max_entries(prog_array); i++) {
+ snprintf(prog_name, sizeof(prog_name), "classifier_%d", i);
+
+ prog = bpf_object__find_program_by_name(obj, prog_name);
+ if (CHECK_FAIL(!prog))
+ goto out;
+
+ prog_fd = bpf_program__fd(prog);
+ if (CHECK_FAIL(prog_fd < 0))
+ goto out;
+
+ err = bpf_map_update_elem(map_fd, &i, &prog_fd, BPF_ANY);
+ if (CHECK_FAIL(err))
+ goto out;
+ }
+
+ data_map = bpf_object__find_map_by_name(obj, "tailcall.bss");
+ if (CHECK_FAIL(!data_map || !bpf_map__is_internal(data_map)))
+ goto out;
+
+ data_fd = bpf_map__fd(data_map);
+ if (CHECK_FAIL(data_fd < 0))
+ goto out;
+
+ err = bpf_prog_test_run_opts(main_fd, &topts);
+ ASSERT_OK(err, "tailcall");
+
+ i = 0;
+ err = bpf_map_lookup_elem(data_fd, &i, &val);
+ ASSERT_OK(err, "tailcall count");
+ ASSERT_EQ(val.count0, 33, "tailcall count0");
+ ASSERT_EQ(val.count1, 33, "tailcall count1");
+
+out:
+ bpf_object__close(obj);
+}
+
/* test_tailcall_bpf2bpf_fentry checks that the count value of the
tail call
* limit enforcement matches with expectations when tailcall is
preceded with
* bpf2bpf call, and the bpf2bpf call is traced by fentry.
@@ -1213,6 +1287,8 @@ void test_tailcalls(void)
test_tailcall_bpf2bpf_4(true);
if (test__start_subtest("tailcall_bpf2bpf_6"))
test_tailcall_bpf2bpf_6();
+ if (test__start_subtest("tailcall_bpf2bpf_7"))
+ test_tailcall_bpf2bpf_7();
if (test__start_subtest("tailcall_bpf2bpf_fentry"))
test_tailcall_bpf2bpf_fentry();
if (test__start_subtest("tailcall_bpf2bpf_fexit"))
diff --git a/tools/testing/selftests/bpf/progs/tailcall_bpf2bpf7.c
b/tools/testing/selftests/bpf/progs/tailcall_bpf2bpf7.c
new file mode 100644
index 000000000000..9818f4056283
--- /dev/null
+++ b/tools/testing/selftests/bpf/progs/tailcall_bpf2bpf7.c
@@ -0,0 +1,52 @@
+#include <linux/bpf.h>
+#include <bpf/bpf_helpers.h>
+
+struct {
+ __uint(type, BPF_MAP_TYPE_PROG_ARRAY);
+ __uint(max_entries, 2);
+ __uint(key_size, sizeof(__u32));
+ __uint(value_size, sizeof(__u32));
+} jmp_table SEC(".maps");
+
+int count0;
+int count1;
+
+static __noinline
+int subprog_1(struct __sk_buff *skb)
+{
+ bpf_tail_call_static(skb, &jmp_table, 1);
+ return 0;
+}
+
+static __noinline
+int subprog_0(struct __sk_buff *skb)
+{
+ bpf_tail_call_static(skb, &jmp_table, 0);
+ return 0;
+}
+
+SEC("tc")
+int classifier_1(struct __sk_buff *skb)
+{
+ count1++;
+ subprog_1(skb);
+ return 0;
+}
+
+SEC("tc")
+int classifier_0(struct __sk_buff *skb)
+{
+ count0++;
+ subprog_0(skb);
+ return 0;
+}
+
+SEC("tc")
+int entry(struct __sk_buff *skb)
+{
+ subprog_0(skb);
+ subprog_1(skb);
+ return 0;
+}
+
+char _license[] SEC("license") = "GPL";
>
> Some comments below as well.
>
>> Signed-off-by: Pu Lehui <pulehui at huawei.com>
>> ---
>> arch/riscv/net/bpf_jit.h | 1 +
>> arch/riscv/net/bpf_jit_comp64.c | 89 +++++++++++++--------------------
>> 2 files changed, 37 insertions(+), 53 deletions(-)
>>
>> diff --git a/arch/riscv/net/bpf_jit.h b/arch/riscv/net/bpf_jit.h
>> index 8b35f12a4452..d8be89dadf18 100644
>> --- a/arch/riscv/net/bpf_jit.h
>> +++ b/arch/riscv/net/bpf_jit.h
>> @@ -81,6 +81,7 @@ struct rv_jit_context {
>> int nexentries;
>> unsigned long flags;
>> int stack_size;
>> + int tcc_offset;
>> };
>>
>> /* Convert from ninsns to bytes. */
>> diff --git a/arch/riscv/net/bpf_jit_comp64.c b/arch/riscv/net/bpf_jit_comp64.c
>> index 3516d425c5eb..64e0c86d60c4 100644
>> --- a/arch/riscv/net/bpf_jit_comp64.c
>> +++ b/arch/riscv/net/bpf_jit_comp64.c
>> @@ -13,13 +13,11 @@
>> #include <asm/patch.h>
>> #include "bpf_jit.h"
>>
>> +#define RV_REG_TCC RV_REG_A6
>> #define RV_FENTRY_NINSNS 2
>> /* fentry and TCC init insns will be skipped on tailcall */
>> #define RV_TAILCALL_OFFSET ((RV_FENTRY_NINSNS + 1) * 4)
>>
>> -#define RV_REG_TCC RV_REG_A6
>> -#define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
>> -
>> static const int regmap[] = {
>> [BPF_REG_0] = RV_REG_A5,
>> [BPF_REG_1] = RV_REG_A0,
>> @@ -51,14 +49,12 @@ static const int pt_regmap[] = {
>> };
>>
>> enum {
>> - RV_CTX_F_SEEN_TAIL_CALL = 0,
>> RV_CTX_F_SEEN_CALL = RV_REG_RA,
>> RV_CTX_F_SEEN_S1 = RV_REG_S1,
>> RV_CTX_F_SEEN_S2 = RV_REG_S2,
>> RV_CTX_F_SEEN_S3 = RV_REG_S3,
>> RV_CTX_F_SEEN_S4 = RV_REG_S4,
>> RV_CTX_F_SEEN_S5 = RV_REG_S5,
>> - RV_CTX_F_SEEN_S6 = RV_REG_S6,
>> };
>>
>> static u8 bpf_to_rv_reg(int bpf_reg, struct rv_jit_context *ctx)
>> @@ -71,7 +67,6 @@ static u8 bpf_to_rv_reg(int bpf_reg, struct rv_jit_context *ctx)
>> case RV_CTX_F_SEEN_S3:
>> case RV_CTX_F_SEEN_S4:
>> case RV_CTX_F_SEEN_S5:
>> - case RV_CTX_F_SEEN_S6:
>> __set_bit(reg, &ctx->flags);
>> }
>> return reg;
>> @@ -86,7 +81,6 @@ static bool seen_reg(int reg, struct rv_jit_context *ctx)
>> case RV_CTX_F_SEEN_S3:
>> case RV_CTX_F_SEEN_S4:
>> case RV_CTX_F_SEEN_S5:
>> - case RV_CTX_F_SEEN_S6:
>> return test_bit(reg, &ctx->flags);
>> }
>> return false;
>> @@ -102,32 +96,6 @@ static void mark_call(struct rv_jit_context *ctx)
>> __set_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
>> }
>>
>> -static bool seen_call(struct rv_jit_context *ctx)
>> -{
>> - return test_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
>> -}
>> -
>> -static void mark_tail_call(struct rv_jit_context *ctx)
>> -{
>> - __set_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
>> -}
>> -
>> -static bool seen_tail_call(struct rv_jit_context *ctx)
>> -{
>> - return test_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
>> -}
>> -
>> -static u8 rv_tail_call_reg(struct rv_jit_context *ctx)
>> -{
>> - mark_tail_call(ctx);
>> -
>> - if (seen_call(ctx)) {
>> - __set_bit(RV_CTX_F_SEEN_S6, &ctx->flags);
>> - return RV_REG_S6;
>> - }
>> - return RV_REG_A6;
>> -}
>> -
>> static bool is_32b_int(s64 val)
>> {
>> return -(1L << 31) <= val && val < (1L << 31);
>> @@ -252,10 +220,7 @@ static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
>> emit_ld(RV_REG_S5, store_offset, RV_REG_SP, ctx);
>> store_offset -= 8;
>> }
>> - if (seen_reg(RV_REG_S6, ctx)) {
>> - emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
>> - store_offset -= 8;
>> - }
>> + emit_ld(RV_REG_TCC, store_offset, RV_REG_SP, ctx);
>
> Why do you need to restore RV_REG_TCC? We're passing RV_REG_TCC (a6) as
> an argument at all call-sites, and for tailcalls we're loading from the
> stack.
>
> Is this to fake the a6 argument for the tail-call? If so, it's better to
> move it to emit_bpf_tail_call(), instead of letting all programs pay for
> it.
Yes, we can remove this duplicate load. will do that at next version.
>
>>
>> emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
>> /* Set return value. */
>> @@ -343,7 +308,6 @@ static void emit_branch(u8 cond, u8 rd, u8 rs, int rvoff,
>> static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
>> {
>> int tc_ninsn, off, start_insn = ctx->ninsns;
>> - u8 tcc = rv_tail_call_reg(ctx);
>>
>> /* a0: &ctx
>> * a1: &array
>> @@ -366,9 +330,11 @@ static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
>> /* if (--TCC < 0)
>> * goto out;
>> */
>> - emit_addi(RV_REG_TCC, tcc, -1, ctx);
>> + emit_ld(RV_REG_TCC, ctx->tcc_offset, RV_REG_SP, ctx);
>> + emit_addi(RV_REG_TCC, RV_REG_TCC, -1, ctx);
>> off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
>> emit_branch(BPF_JSLT, RV_REG_TCC, RV_REG_ZERO, off, ctx);
>> + emit_sd(RV_REG_SP, ctx->tcc_offset, RV_REG_TCC, ctx);
>>
>> /* prog = array->ptrs[index];
>> * if (!prog)
>> @@ -767,7 +733,7 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
>> int i, ret, offset;
>> int *branches_off = NULL;
>> int stack_size = 0, nregs = m->nr_args;
>> - int retval_off, args_off, nregs_off, ip_off, run_ctx_off, sreg_off;
>> + int retval_off, args_off, nregs_off, ip_off, run_ctx_off, sreg_off, tcc_off;
>> struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
>> struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
>> struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
>> @@ -812,6 +778,8 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
>> *
>> * FP - sreg_off [ callee saved reg ]
>> *
>> + * FP - tcc_off [ tail call count ] BPF_TRAMP_F_TAIL_CALL_CTX
>> + *
>> * [ pads ] pads for 16 bytes alignment
>> */
>>
>> @@ -853,6 +821,11 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
>> stack_size += 8;
>> sreg_off = stack_size;
>>
>> + if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) {
>> + stack_size += 8;
>> + tcc_off = stack_size;
>> + }
>> +
>> stack_size = round_up(stack_size, 16);
>>
>> if (!is_struct_ops) {
>> @@ -879,6 +852,10 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
>> emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
>> }
>>
>> + /* store tail call count */
>> + if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
>> + emit_sd(RV_REG_FP, -tcc_off, RV_REG_TCC, ctx);
>> +
>> /* callee saved register S1 to pass start time */
>> emit_sd(RV_REG_FP, -sreg_off, RV_REG_S1, ctx);
>>
>> @@ -932,6 +909,9 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
>>
>> if (flags & BPF_TRAMP_F_CALL_ORIG) {
>> restore_args(nregs, args_off, ctx);
>> + /* restore TCC to RV_REG_TCC before calling the original function */
>> + if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
>> + emit_ld(RV_REG_TCC, -tcc_off, RV_REG_FP, ctx);
>> ret = emit_call((const u64)orig_call, true, ctx);
>> if (ret)
>> goto out;
>> @@ -963,6 +943,9 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
>> ret = emit_call((const u64)__bpf_tramp_exit, true, ctx);
>> if (ret)
>> goto out;
>> + } else if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) {
>> + /* restore TCC to RV_REG_TCC before calling the original function */
>> + emit_ld(RV_REG_TCC, -tcc_off, RV_REG_FP, ctx);
>> }
>>
>> if (flags & BPF_TRAMP_F_RESTORE_REGS)
>> @@ -1455,6 +1438,9 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
>> if (ret < 0)
>> return ret;
>>
>> + /* restore TCC from stack to RV_REG_TCC */
>> + emit_ld(RV_REG_TCC, ctx->tcc_offset, RV_REG_SP, ctx);
>> +
>> ret = emit_call(addr, fixed_addr, ctx);
>> if (ret)
>> return ret;
>> @@ -1733,8 +1719,7 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx)
>> stack_adjust += 8;
>> if (seen_reg(RV_REG_S5, ctx))
>> stack_adjust += 8;
>> - if (seen_reg(RV_REG_S6, ctx))
>> - stack_adjust += 8;
>> + stack_adjust += 8; /* RV_REG_TCC */
>>
>> stack_adjust = round_up(stack_adjust, 16);
>> stack_adjust += bpf_stack_adjust;
>> @@ -1749,7 +1734,8 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx)
>> * (TCC) register. This instruction is skipped for tail calls.
>> * Force using a 4-byte (non-compressed) instruction.
>> */
>> - emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
>> + if (!bpf_is_subprog(ctx->prog))
>> + emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
>
> You're conditionally emitting the instruction. Doesn't this break
> RV_TAILCALL_OFFSET?
>
This does not break RV_TAILCALL_OFFSET, because The target of tailcall
is always `main` prog, but not subprog.
>
> Björn
More information about the linux-riscv
mailing list