[PATCH bpf-next] bpf, arm64: Support up to 12 function arguments

Florent Revest revest at chromium.org
Wed Sep 20 15:17:54 PDT 2023


On Sun, Sep 17, 2023 at 5:09 PM Xu Kuohai <xukuohai at huaweicloud.com> wrote:
>
> From: Xu Kuohai <xukuohai at huawei.com>
>
> Currently arm64 bpf trampoline supports up to 8 function arguments.
> According to the statistics from commit
> 473e3150e30a ("bpf, x86: allow function arguments up to 12 for TRACING"),
> there are about 200 functions accept 9 to 12 arguments, so adding support
> for up to 12 function arguments.

Thank you Xu, this will be a nice addition! :)

> Due to bpf only supports function arguments up to 16 bytes, according to
> AAPCS64, starting from the first argument, each argument is first
> attempted to be loaded to 1 or 2 smallest registers from x0-x7, if there
> are no enough registers to hold the entire argument, then all remaining
> arguments starting from this one are pushed to the stack for passing.

If I read the section 6.8.2 of the AAPCS64 correctly, there is a
corner case which I believe isn't covered by this logic.

void f(u128 a, u128 b, u128, c, u64 d, u128 e, u64 f) {}
- a goes on x0 and x1
- b goes on x2 and x3
- c goes on x4 and x5
- d goes on x6
- e spills on the stack because it doesn't fit in the remaining regs
- f goes on x7

Maybe it would be good to add something pathological like this to the
selftests ?

Otherwise I only have minor nitpicks.

> Signed-off-by: Xu Kuohai <xukuohai at huawei.com>
> ---
>  arch/arm64/net/bpf_jit_comp.c                | 171 ++++++++++++++-----
>  tools/testing/selftests/bpf/DENYLIST.aarch64 |   2 -
>  2 files changed, 131 insertions(+), 42 deletions(-)
>
> diff --git a/arch/arm64/net/bpf_jit_comp.c b/arch/arm64/net/bpf_jit_comp.c
> index 7d4af64e3982..a0cf526b07ea 100644
> --- a/arch/arm64/net/bpf_jit_comp.c
> +++ b/arch/arm64/net/bpf_jit_comp.c
> @@ -1705,7 +1705,7 @@ bool bpf_jit_supports_subprog_tailcalls(void)
>  }
>
>  static void invoke_bpf_prog(struct jit_ctx *ctx, struct bpf_tramp_link *l,
> -                           int args_off, int retval_off, int run_ctx_off,
> +                           int bargs_off, int retval_off, int run_ctx_off,
>                             bool save_ret)
>  {
>         __le32 *branch;
> @@ -1747,7 +1747,7 @@ static void invoke_bpf_prog(struct jit_ctx *ctx, struct bpf_tramp_link *l,
>         /* save return value to callee saved register x20 */
>         emit(A64_MOV(1, A64_R(20), A64_R(0)), ctx);
>
> -       emit(A64_ADD_I(1, A64_R(0), A64_SP, args_off), ctx);
> +       emit(A64_ADD_I(1, A64_R(0), A64_SP, bargs_off), ctx);
>         if (!p->jited)
>                 emit_addr_mov_i64(A64_R(1), (const u64)p->insnsi, ctx);
>
> @@ -1772,7 +1772,7 @@ static void invoke_bpf_prog(struct jit_ctx *ctx, struct bpf_tramp_link *l,
>  }
>
>  static void invoke_bpf_mod_ret(struct jit_ctx *ctx, struct bpf_tramp_links *tl,
> -                              int args_off, int retval_off, int run_ctx_off,
> +                              int bargs_off, int retval_off, int run_ctx_off,
>                                __le32 **branches)
>  {
>         int i;
> @@ -1782,7 +1782,7 @@ static void invoke_bpf_mod_ret(struct jit_ctx *ctx, struct bpf_tramp_links *tl,
>          */
>         emit(A64_STR64I(A64_ZR, A64_SP, retval_off), ctx);
>         for (i = 0; i < tl->nr_links; i++) {
> -               invoke_bpf_prog(ctx, tl->links[i], args_off, retval_off,
> +               invoke_bpf_prog(ctx, tl->links[i], bargs_off, retval_off,
>                                 run_ctx_off, true);
>                 /* if (*(u64 *)(sp + retval_off) !=  0)
>                  *      goto do_fexit;
> @@ -1796,23 +1796,111 @@ static void invoke_bpf_mod_ret(struct jit_ctx *ctx, struct bpf_tramp_links *tl,
>         }
>  }
>
> -static void save_args(struct jit_ctx *ctx, int args_off, int nregs)
> +struct arg_aux {
> +       /* how many args are passed through registers, the rest args are

the rest of the* args

> +        * passed through stack
> +        */
> +       int args_in_reg;

Maybe args_in_regs ? since args can go in multiple regs

> +       /* how many registers used for passing arguments */

are* used

> +       int regs_for_arg;

And here regs_for_args ? Since It's the number of registers used for all args

> +       /* how many stack slots used for arguments, each slot is 8 bytes */

are* used

> +       int stack_slots_for_arg;

And here stack_slots_for_args, for the same reason as above?

> +};
> +
> +static void calc_arg_aux(const struct btf_func_model *m,
> +                        struct arg_aux *a)
>  {
>         int i;
> +       int nregs;
> +       int slots;
> +       int stack_slots;
> +
> +       /* verifier ensures m->nr_args <= MAX_BPF_FUNC_ARGS */
> +       for (i = 0, nregs = 0; i < m->nr_args; i++) {
> +               slots = (m->arg_size[i] + 7) / 8;
> +               if (nregs + slots <= 8) /* passed through register ? */
> +                       nregs += slots;
> +               else
> +                       break;
> +       }
> +
> +       a->args_in_reg = i;
> +       a->regs_for_arg = nregs;
>
> -       for (i = 0; i < nregs; i++) {
> -               emit(A64_STR64I(i, A64_SP, args_off), ctx);
> -               args_off += 8;
> +       /* the rest arguments are passed through stack */
> +       for (stack_slots = 0; i < m->nr_args; i++)
> +               stack_slots += (m->arg_size[i] + 7) / 8;
> +
> +       a->stack_slots_for_arg = stack_slots;
> +}
> +
> +static void clear_garbage(struct jit_ctx *ctx, int reg, int effective_bytes)
> +{
> +       if (effective_bytes) {
> +               int garbage_bits = 64 - 8 * effective_bytes;
> +#ifdef CONFIG_CPU_BIG_ENDIAN
> +               /* garbage bits are at the right end */
> +               emit(A64_LSR(1, reg, reg, garbage_bits), ctx);
> +               emit(A64_LSL(1, reg, reg, garbage_bits), ctx);
> +#else
> +               /* garbage bits are at the left end */
> +               emit(A64_LSL(1, reg, reg, garbage_bits), ctx);
> +               emit(A64_LSR(1, reg, reg, garbage_bits), ctx);
> +#endif
>         }
>  }
>
> -static void restore_args(struct jit_ctx *ctx, int args_off, int nregs)
> +static void save_args(struct jit_ctx *ctx, int bargs_off, int oargs_off,
> +                     const struct btf_func_model *m,
> +                     const struct arg_aux *a,
> +                     bool for_call_origin)
>  {
>         int i;
> +       int reg;
> +       int doff;
> +       int soff;
> +       int slots;
> +       u8 tmp = bpf2a64[TMP_REG_1];
> +
> +       /* store argument registers to stack for call bpf, or restore argument

to* call bpf or "for the bpf program"

> +        * registers from stack for the original function
> +        */
> +       for (reg = 0; reg < a->regs_for_arg; reg++) {
> +               emit(for_call_origin ?
> +                    A64_LDR64I(reg, A64_SP, bargs_off) :
> +                    A64_STR64I(reg, A64_SP, bargs_off),
> +                    ctx);
> +               bargs_off += 8;
> +       }
>
> -       for (i = 0; i < nregs; i++) {
> -               emit(A64_LDR64I(i, A64_SP, args_off), ctx);
> -               args_off += 8;
> +       soff = 32; /* on stack arguments start from FP + 32 */
> +       doff = (for_call_origin ? oargs_off : bargs_off);
> +
> +       /* save on stack arguments */
> +       for (i = a->args_in_reg; i < m->nr_args; i++) {
> +               slots = (m->arg_size[i] + 7) / 8;
> +               /* verifier ensures arg_size <= 16, so slots equals 1 or 2 */
> +               while (slots-- > 0) {
> +                       emit(A64_LDR64I(tmp, A64_FP, soff), ctx);
> +                       /* if there is unused space in the last slot, clear
> +                        * the garbage contained in the space.
> +                        */
> +                       if (slots == 0 && !for_call_origin)
> +                               clear_garbage(ctx, tmp, m->arg_size[i] % 8);
> +                       emit(A64_STR64I(tmp, A64_SP, doff), ctx);
> +                       soff += 8;
> +                       doff += 8;
> +               }
> +       }
> +}
> +
> +static void restore_args(struct jit_ctx *ctx, int bargs_off, int nregs)
> +{
> +       int reg;
> +
> +       for (reg = 0; reg < nregs; reg++) {
> +               emit(A64_LDR64I(reg, A64_SP, bargs_off), ctx);
> +               bargs_off += 8;
>         }
>  }
>
> @@ -1829,17 +1917,21 @@ static void restore_args(struct jit_ctx *ctx, int args_off, int nregs)
>   */
>  static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
>                               struct bpf_tramp_links *tlinks, void *orig_call,
> -                             int nregs, u32 flags)
> +                             const struct btf_func_model *m,
> +                             const struct arg_aux *a,
> +                             u32 flags)
>  {
>         int i;
>         int stack_size;
>         int retaddr_off;
>         int regs_off;
>         int retval_off;
> -       int args_off;
> +       int bargs_off;
>         int nregs_off;
>         int ip_off;
>         int run_ctx_off;
> +       int oargs_off;
> +       int nregs;
>         struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
>         struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
>         struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
> @@ -1859,19 +1951,26 @@ static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
>          *
>          * SP + retval_off  [ return value      ] BPF_TRAMP_F_CALL_ORIG or
>          *                                        BPF_TRAMP_F_RET_FENTRY_RET
> -        *
>          *                  [ arg reg N         ]
>          *                  [ ...               ]
> -        * SP + args_off    [ arg reg 1         ]
> +        * SP + bargs_off   [ arg reg 1         ] for bpf
>          *
>          * SP + nregs_off   [ arg regs count    ]
>          *
>          * SP + ip_off      [ traced function   ] BPF_TRAMP_F_IP_ARG flag
>          *
>          * SP + run_ctx_off [ bpf_tramp_run_ctx ]
> +        *
> +        *                  [ stack arg N       ]
> +        *                  [ ...               ]
> +        * SP + oargs_off   [ stack arg 1       ] for original func
>          */
>
>         stack_size = 0;
> +       oargs_off = stack_size;
> +       if (flags & BPF_TRAMP_F_CALL_ORIG)
> +               stack_size += 8 * a->stack_slots_for_arg;
> +
>         run_ctx_off = stack_size;
>         /* room for bpf_tramp_run_ctx */
>         stack_size += round_up(sizeof(struct bpf_tramp_run_ctx), 8);
> @@ -1885,9 +1984,10 @@ static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
>         /* room for args count */
>         stack_size += 8;
>
> -       args_off = stack_size;
> +       bargs_off = stack_size;
>         /* room for args */
> -       stack_size += nregs * 8;
> +       nregs = a->regs_for_arg + a->stack_slots_for_arg;

Maybe this name no longer makes sense ?



More information about the linux-arm-kernel mailing list