[PATCH bpf-next 4/4] riscv, bpf: Mixing bpf2bpf and tailcalls

Pu Lehui pulehui at huawei.com
Mon Jan 29 19:26:22 PST 2024



On 2023/9/19 11:57, Pu Lehui wrote:
> From: Pu Lehui <pulehui at huawei.com>
> 
> In the current RV64 JIT, if we just don't initialize the TCC in subprog,
> the TCC can be propagated from the parent process to the subprocess, but
> the TCC of the parent process cannot be restored when the subprocess
> exits. Since the RV64 TCC is initialized before saving the callee saved
> registers into the stack, we cannot use the callee saved register to
> pass the TCC, otherwise the original value of the callee saved register
> will be destroyed. So we implemented mixing bpf2bpf and tailcalls
> similar to x86_64, i.e. using a non-callee saved register to transfer
> the TCC between functions, and saving that register to the stack to
> protect the TCC value. At the same time, we also consider the scenario
> of mixing trampoline.
> 
> Tests test_bpf.ko and test_verifier have passed, as well as the relative
> testcases of test_progs*.
> 
> Signed-off-by: Pu Lehui <pulehui at huawei.com>
> ---
>   arch/riscv/net/bpf_jit.h        |  1 +
>   arch/riscv/net/bpf_jit_comp64.c | 91 ++++++++++++++-------------------
>   2 files changed, 39 insertions(+), 53 deletions(-)
> 
> diff --git a/arch/riscv/net/bpf_jit.h b/arch/riscv/net/bpf_jit.h
> index d21c6c92a..ca518846c 100644
> --- a/arch/riscv/net/bpf_jit.h
> +++ b/arch/riscv/net/bpf_jit.h
> @@ -75,6 +75,7 @@ struct rv_jit_context {
>   	int nexentries;
>   	unsigned long flags;
>   	int stack_size;
> +	int tcc_offset;
>   };
>   
>   /* Convert from ninsns to bytes. */
> diff --git a/arch/riscv/net/bpf_jit_comp64.c b/arch/riscv/net/bpf_jit_comp64.c
> index f2ded1151..f37be4911 100644
> --- a/arch/riscv/net/bpf_jit_comp64.c
> +++ b/arch/riscv/net/bpf_jit_comp64.c
> @@ -13,13 +13,11 @@
>   #include <asm/patch.h>
>   #include "bpf_jit.h"
>   
> +#define RV_REG_TCC		RV_REG_A6
>   #define RV_FENTRY_NINSNS	2
>   /* fentry and TCC init insns will be skipped on tailcall */
>   #define RV_TAILCALL_OFFSET	((RV_FENTRY_NINSNS + 1) * 4)
>   
> -#define RV_REG_TCC RV_REG_A6
> -#define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
> -
>   static const int regmap[] = {
>   	[BPF_REG_0] =	RV_REG_A5,
>   	[BPF_REG_1] =	RV_REG_A0,
> @@ -51,14 +49,12 @@ static const int pt_regmap[] = {
>   };
>   
>   enum {
> -	RV_CTX_F_SEEN_TAIL_CALL =	0,
>   	RV_CTX_F_SEEN_CALL =		RV_REG_RA,
>   	RV_CTX_F_SEEN_S1 =		RV_REG_S1,
>   	RV_CTX_F_SEEN_S2 =		RV_REG_S2,
>   	RV_CTX_F_SEEN_S3 =		RV_REG_S3,
>   	RV_CTX_F_SEEN_S4 =		RV_REG_S4,
>   	RV_CTX_F_SEEN_S5 =		RV_REG_S5,
> -	RV_CTX_F_SEEN_S6 =		RV_REG_S6,
>   };
>   
>   static u8 bpf_to_rv_reg(int bpf_reg, struct rv_jit_context *ctx)
> @@ -71,7 +67,6 @@ static u8 bpf_to_rv_reg(int bpf_reg, struct rv_jit_context *ctx)
>   	case RV_CTX_F_SEEN_S3:
>   	case RV_CTX_F_SEEN_S4:
>   	case RV_CTX_F_SEEN_S5:
> -	case RV_CTX_F_SEEN_S6:
>   		__set_bit(reg, &ctx->flags);
>   	}
>   	return reg;
> @@ -86,7 +81,6 @@ static bool seen_reg(int reg, struct rv_jit_context *ctx)
>   	case RV_CTX_F_SEEN_S3:
>   	case RV_CTX_F_SEEN_S4:
>   	case RV_CTX_F_SEEN_S5:
> -	case RV_CTX_F_SEEN_S6:
>   		return test_bit(reg, &ctx->flags);
>   	}
>   	return false;
> @@ -102,32 +96,6 @@ static void mark_call(struct rv_jit_context *ctx)
>   	__set_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
>   }
>   
> -static bool seen_call(struct rv_jit_context *ctx)
> -{
> -	return test_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
> -}
> -
> -static void mark_tail_call(struct rv_jit_context *ctx)
> -{
> -	__set_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
> -}
> -
> -static bool seen_tail_call(struct rv_jit_context *ctx)
> -{
> -	return test_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
> -}
> -
> -static u8 rv_tail_call_reg(struct rv_jit_context *ctx)
> -{
> -	mark_tail_call(ctx);
> -
> -	if (seen_call(ctx)) {
> -		__set_bit(RV_CTX_F_SEEN_S6, &ctx->flags);
> -		return RV_REG_S6;
> -	}
> -	return RV_REG_A6;
> -}
> -
>   static bool is_32b_int(s64 val)
>   {
>   	return -(1L << 31) <= val && val < (1L << 31);
> @@ -235,10 +203,7 @@ static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
>   		emit_ld(RV_REG_S5, store_offset, RV_REG_SP, ctx);
>   		store_offset -= 8;
>   	}
> -	if (seen_reg(RV_REG_S6, ctx)) {
> -		emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
> -		store_offset -= 8;
> -	}
> +	emit_ld(RV_REG_TCC, store_offset, RV_REG_SP, ctx);
>   
>   	emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
>   	/* Set return value. */
> @@ -332,7 +297,6 @@ static void emit_zext_32(u8 reg, struct rv_jit_context *ctx)
>   static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
>   {
>   	int tc_ninsn, off, start_insn = ctx->ninsns;
> -	u8 tcc = rv_tail_call_reg(ctx);
>   
>   	/* a0: &ctx
>   	 * a1: &array
> @@ -355,9 +319,11 @@ static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
>   	/* if (--TCC < 0)
>   	 *     goto out;
>   	 */
> -	emit_addi(RV_REG_TCC, tcc, -1, ctx);
> +	emit_ld(RV_REG_TCC, ctx->tcc_offset, RV_REG_SP, ctx);
> +	emit_addi(RV_REG_TCC, RV_REG_TCC, -1, ctx);
>   	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
>   	emit_branch(BPF_JSLT, RV_REG_TCC, RV_REG_ZERO, off, ctx);
> +	emit_sd(RV_REG_SP, ctx->tcc_offset, RV_REG_TCC, ctx);
>   
>   	/* prog = array->ptrs[index];
>   	 * if (!prog)
> @@ -763,7 +729,7 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
>   	int i, ret, offset;
>   	int *branches_off = NULL;
>   	int stack_size = 0, nregs = m->nr_args;
> -	int retval_off, args_off, nregs_off, ip_off, run_ctx_off, sreg_off;
> +	int retval_off, args_off, nregs_off, ip_off, run_ctx_off, sreg_off, tcc_off;
>   	struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
>   	struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
>   	struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
> @@ -807,6 +773,8 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
>   	 *
>   	 * FP - sreg_off    [ callee saved reg	]
>   	 *
> +	 * FP - tcc_off     [ tail call count	] BPF_TRAMP_F_TAIL_CALL_CTX
> +	 *
>   	 *		    [ pads              ] pads for 16 bytes alignment
>   	 */
>   
> @@ -848,6 +816,11 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
>   	stack_size += 8;
>   	sreg_off = stack_size;
>   
> +	if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) {
> +		stack_size += 8;
> +		tcc_off = stack_size;
> +	}
> +
>   	stack_size = round_up(stack_size, 16);
>   
>   	if (func_addr) {
> @@ -874,6 +847,10 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
>   		emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
>   	}
>   
> +	/* store tail call count */
> +	if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
> +		emit_sd(RV_REG_FP, -tcc_off, RV_REG_TCC, ctx);
> +
>   	/* callee saved register S1 to pass start time */
>   	emit_sd(RV_REG_FP, -sreg_off, RV_REG_S1, ctx);
>   
> @@ -927,6 +904,9 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
>   
>   	if (flags & BPF_TRAMP_F_CALL_ORIG) {
>   		restore_args(nregs, args_off, ctx);
> +		/* restore TCC to RV_REG_TCC before calling the original function */
> +		if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
> +			emit_ld(RV_REG_TCC, -tcc_off, RV_REG_FP, ctx);
>   		ret = emit_call((const u64)orig_call, true, ctx);
>   		if (ret)
>   			goto out;
> @@ -967,6 +947,10 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
>   
>   	emit_ld(RV_REG_S1, -sreg_off, RV_REG_FP, ctx);
>   
> +	/* restore TCC to RV_REG_TCC before calling the original function */
> +	if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
> +		emit_ld(RV_REG_TCC, -tcc_off, RV_REG_FP, ctx);

Sorry guys. This will emit double times when `flags & 
BPF_TRAMP_F_CALL_ORIG`. Will fix it in next version.

> +
>   	if (func_addr) {
>   		/* trampoline called from function entry */
>   		emit_ld(RV_REG_T0, stack_size - 8, RV_REG_SP, ctx);
> @@ -1476,6 +1460,9 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
>   		if (ret < 0)
>   			return ret;
>   
> +		/* restore TCC from stack to RV_REG_TCC */
> +		emit_ld(RV_REG_TCC, ctx->tcc_offset, RV_REG_SP, ctx);
> +
>   		ret = emit_call(addr, fixed_addr, ctx);
>   		if (ret)
>   			return ret;
> @@ -1735,6 +1722,7 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
>   void bpf_jit_build_prologue(struct rv_jit_context *ctx)
>   {
>   	int i, stack_adjust = 0, store_offset, bpf_stack_adjust;
> +	bool is_main = ctx->prog->aux->func_idx == 0;
>   
>   	bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
>   	if (bpf_stack_adjust)
> @@ -1753,8 +1741,7 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx)
>   		stack_adjust += 8;
>   	if (seen_reg(RV_REG_S5, ctx))
>   		stack_adjust += 8;
> -	if (seen_reg(RV_REG_S6, ctx))
> -		stack_adjust += 8;
> +	stack_adjust += 8; /* RV_REG_TCC */
>   
>   	stack_adjust = round_up(stack_adjust, 16);
>   	stack_adjust += bpf_stack_adjust;
> @@ -1769,7 +1756,8 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx)
>   	 * (TCC) register. This instruction is skipped for tail calls.
>   	 * Force using a 4-byte (non-compressed) instruction.
>   	 */
> -	emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
> +	if (is_main)
> +		emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
>   
>   	emit_addi(RV_REG_SP, RV_REG_SP, -stack_adjust, ctx);
>   
> @@ -1799,22 +1787,14 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx)
>   		emit_sd(RV_REG_SP, store_offset, RV_REG_S5, ctx);
>   		store_offset -= 8;
>   	}
> -	if (seen_reg(RV_REG_S6, ctx)) {
> -		emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx);
> -		store_offset -= 8;
> -	}
> +	emit_sd(RV_REG_SP, store_offset, RV_REG_TCC, ctx);
> +	ctx->tcc_offset = store_offset;
>   
>   	emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx);
>   
>   	if (bpf_stack_adjust)
>   		emit_addi(RV_REG_S5, RV_REG_SP, bpf_stack_adjust, ctx);
>   
> -	/* Program contains calls and tail calls, so RV_REG_TCC need
> -	 * to be saved across calls.
> -	 */
> -	if (seen_tail_call(ctx) && seen_call(ctx))
> -		emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx);
> -
>   	ctx->stack_size = stack_adjust;
>   }
>   
> @@ -1827,3 +1807,8 @@ bool bpf_jit_supports_kfunc_call(void)
>   {
>   	return true;
>   }
> +
> +bool bpf_jit_supports_subprog_tailcalls(void)
> +{
> +	return true;
> +}



More information about the linux-riscv mailing list