[PATCH v6 6/9] crypto: arm64/aes-xctr: Improve readability of XCTR and CTR modes

Ard Biesheuvel ardb at kernel.org
Wed May 4 23:49:44 PDT 2022


Hi Nathan,

Thanks for cleaning this up.

On Wed, 4 May 2022 at 02:18, Nathan Huckleberry <nhuck at google.com> wrote:
>
> Added some clarifying comments, changed the register allocations to make
> the code clearer, and added register aliases.
>
> Signed-off-by: Nathan Huckleberry <nhuck at google.com>

With one comment below addressed:

Reviewed-by: Ard Biesheuvel <ardb at kernel.org>

> ---
>  arch/arm64/crypto/aes-modes.S | 193 ++++++++++++++++++++++------------
>  1 file changed, 128 insertions(+), 65 deletions(-)
>
> diff --git a/arch/arm64/crypto/aes-modes.S b/arch/arm64/crypto/aes-modes.S
> index 55df157fce3a..da7c9f3380f8 100644
> --- a/arch/arm64/crypto/aes-modes.S
> +++ b/arch/arm64/crypto/aes-modes.S
> @@ -322,32 +322,60 @@ AES_FUNC_END(aes_cbc_cts_decrypt)
>          * This macro generates the code for CTR and XCTR mode.
>          */
>  .macro ctr_encrypt xctr
> +       // Arguments
> +       OUT             .req x0
> +       IN              .req x1
> +       KEY             .req x2
> +       ROUNDS_W        .req w3
> +       BYTES_W         .req w4
> +       IV              .req x5
> +       BYTE_CTR_W      .req w6         // XCTR only
> +       // Intermediate values
> +       CTR_W           .req w11        // XCTR only
> +       CTR             .req x11        // XCTR only
> +       IV_PART         .req x12
> +       BLOCKS          .req x13
> +       BLOCKS_W        .req w13
> +
>         stp             x29, x30, [sp, #-16]!
>         mov             x29, sp
>
> -       enc_prepare     w3, x2, x12
> -       ld1             {vctr.16b}, [x5]
> +       enc_prepare     ROUNDS_W, KEY, IV_PART
> +       ld1             {vctr.16b}, [IV]
>
> +       /*
> +        * Keep 64 bits of the IV in a register.  For CTR mode this lets us
> +        * easily increment the IV.  For XCTR mode this lets us efficiently XOR
> +        * the 64-bit counter with the IV.
> +        */
>         .if \xctr
> -               umov            x12, vctr.d[0]
> -               lsr             w11, w6, #4
> +               umov            IV_PART, vctr.d[0]
> +               lsr             CTR_W, BYTE_CTR_W, #4
>         .else
> -               umov            x12, vctr.d[1] /* keep swabbed ctr in reg */
> -               rev             x12, x12
> +               umov            IV_PART, vctr.d[1]
> +               rev             IV_PART, IV_PART
>         .endif
>
>  .LctrloopNx\xctr:
> -       add             w7, w4, #15
> -       sub             w4, w4, #MAX_STRIDE << 4
> -       lsr             w7, w7, #4
> +       add             BLOCKS_W, BYTES_W, #15
> +       sub             BYTES_W, BYTES_W, #MAX_STRIDE << 4
> +       lsr             BLOCKS_W, BLOCKS_W, #4
>         mov             w8, #MAX_STRIDE
> -       cmp             w7, w8
> -       csel            w7, w7, w8, lt
> +       cmp             BLOCKS_W, w8
> +       csel            BLOCKS_W, BLOCKS_W, w8, lt
>
> +       /*
> +        * Set up the counter values in v0-v4.
> +        *
> +        * If we are encrypting less than MAX_STRIDE blocks, the tail block
> +        * handling code expects the last keystream block to be in v4.  For
> +        * example: if encrypting two blocks with MAX_STRIDE=5, then v3 and v4
> +        * should have the next two counter blocks.
> +        */
>         .if \xctr
> -               add             x11, x11, x7
> +               add             CTR, CTR, BLOCKS
>         .else
> -               adds            x12, x12, x7
> +               adds            IV_PART, IV_PART, BLOCKS
>         .endif
>         mov             v0.16b, vctr.16b
>         mov             v1.16b, vctr.16b
> @@ -355,16 +383,16 @@ AES_FUNC_END(aes_cbc_cts_decrypt)
>         mov             v3.16b, vctr.16b
>  ST5(   mov             v4.16b, vctr.16b                )
>         .if \xctr
> -               sub             x6, x11, #MAX_STRIDE - 1
> -               sub             x7, x11, #MAX_STRIDE - 2
> -               sub             x8, x11, #MAX_STRIDE - 3
> -               sub             x9, x11, #MAX_STRIDE - 4
> -ST5(           sub             x10, x11, #MAX_STRIDE - 5       )
> -               eor             x6, x6, x12
> -               eor             x7, x7, x12
> -               eor             x8, x8, x12
> -               eor             x9, x9, x12
> -               eor             x10, x10, x12
> +               sub             x6, CTR, #MAX_STRIDE - 1
> +               sub             x7, CTR, #MAX_STRIDE - 2
> +               sub             x8, CTR, #MAX_STRIDE - 3
> +               sub             x9, CTR, #MAX_STRIDE - 4
> +ST5(           sub             x10, CTR, #MAX_STRIDE - 5       )
> +               eor             x6, x6, IV_PART
> +               eor             x7, x7, IV_PART
> +               eor             x8, x8, IV_PART
> +               eor             x9, x9, IV_PART
> +               eor             x10, x10, IV_PART
>                 mov             v0.d[0], x6
>                 mov             v1.d[0], x7
>                 mov             v2.d[0], x8
> @@ -381,9 +409,9 @@ ST5(                mov             v4.d[0], x10                    )
>                 ins             vctr.d[0], x8
>
>                 /* apply carry to N counter blocks for N := x12 */

Please update this comment as well. And while at it, it might make
sense to clarify that doing a conditional branch here is fine wrt time
invariance, given that the IV is not part of the key or the plaintext,
and this code rarely triggers in practice anyway.

> -               cbz             x12, 2f
> +               cbz             IV_PART, 2f
>                 adr             x16, 1f
> -               sub             x16, x16, x12, lsl #3
> +               sub             x16, x16, IV_PART, lsl #3
>                 br              x16
>                 bti             c
>                 mov             v0.d[0], vctr.d[0]
> @@ -398,71 +426,88 @@ ST5(              mov             v4.d[0], vctr.d[0]              )
>  1:             b               2f
>                 .previous
>
> -2:             rev             x7, x12
> +2:             rev             x7, IV_PART
>                 ins             vctr.d[1], x7
> -               sub             x7, x12, #MAX_STRIDE - 1
> -               sub             x8, x12, #MAX_STRIDE - 2
> -               sub             x9, x12, #MAX_STRIDE - 3
> +               sub             x7, IV_PART, #MAX_STRIDE - 1
> +               sub             x8, IV_PART, #MAX_STRIDE - 2
> +               sub             x9, IV_PART, #MAX_STRIDE - 3
>                 rev             x7, x7
>                 rev             x8, x8
>                 mov             v1.d[1], x7
>                 rev             x9, x9
> -ST5(           sub             x10, x12, #MAX_STRIDE - 4       )
> +ST5(           sub             x10, IV_PART, #MAX_STRIDE - 4   )
>                 mov             v2.d[1], x8
>  ST5(           rev             x10, x10                        )
>                 mov             v3.d[1], x9
>  ST5(           mov             v4.d[1], x10                    )
>         .endif
> -       tbnz            w4, #31, .Lctrtail\xctr
> -       ld1             {v5.16b-v7.16b}, [x1], #48
> +
> +       /*
> +        * If there are at least MAX_STRIDE blocks left, XOR the plaintext with
> +        * keystream and store.  Otherwise jump to tail handling.
> +        */
> +       tbnz            BYTES_W, #31, .Lctrtail\xctr
> +       ld1             {v5.16b-v7.16b}, [IN], #48
>  ST4(   bl              aes_encrypt_block4x             )
>  ST5(   bl              aes_encrypt_block5x             )
>         eor             v0.16b, v5.16b, v0.16b
> -ST4(   ld1             {v5.16b}, [x1], #16             )
> +ST4(   ld1             {v5.16b}, [IN], #16             )
>         eor             v1.16b, v6.16b, v1.16b
> -ST5(   ld1             {v5.16b-v6.16b}, [x1], #32      )
> +ST5(   ld1             {v5.16b-v6.16b}, [IN], #32      )
>         eor             v2.16b, v7.16b, v2.16b
>         eor             v3.16b, v5.16b, v3.16b
>  ST5(   eor             v4.16b, v6.16b, v4.16b          )
> -       st1             {v0.16b-v3.16b}, [x0], #64
> -ST5(   st1             {v4.16b}, [x0], #16             )
> -       cbz             w4, .Lctrout\xctr
> +       st1             {v0.16b-v3.16b}, [OUT], #64
> +ST5(   st1             {v4.16b}, [OUT], #16            )
> +       cbz             BYTES_W, .Lctrout\xctr
>         b               .LctrloopNx\xctr
>
>  .Lctrout\xctr:
>         .if !\xctr
> -               st1             {vctr.16b}, [x5] /* return next CTR value */
> +               st1             {vctr.16b}, [IV] /* return next CTR value */
>         .endif
>         ldp             x29, x30, [sp], #16
>         ret
>
>  .Lctrtail\xctr:
> +       /*
> +        * Handle up to MAX_STRIDE * 16 - 1 bytes of plaintext
> +        *
> +        * This code expects the last keystream block to be in v4.  For example:
> +        * if encrypting two blocks with MAX_STRIDE=5, then v3 and v4 should
> +        * have the next two counter blocks.
> +        *
> +        * This allows us to store the ciphertext by writing to overlapping
> +        * regions of memory.  Any invalid ciphertext blocks get overwritten by
> +        * correctly computed blocks.  This approach avoids extra conditional
> +        * branches.
> +        */

Nit: Without overlapping stores, we'd have to load and store smaller
quantities and use a loop here, or bounce it via a temp buffer and
memcpy() it from the C code. So it's not just some extra branches.

>         mov             x16, #16
> -       ands            x6, x4, #0xf
> -       csel            x13, x6, x16, ne
> +       ands            w7, BYTES_W, #0xf
> +       csel            x13, x7, x16, ne
>
> -ST5(   cmp             w4, #64 - (MAX_STRIDE << 4)     )
> +ST5(   cmp             BYTES_W, #64 - (MAX_STRIDE << 4))
>  ST5(   csel            x14, x16, xzr, gt               )
> -       cmp             w4, #48 - (MAX_STRIDE << 4)
> +       cmp             BYTES_W, #48 - (MAX_STRIDE << 4)
>         csel            x15, x16, xzr, gt
> -       cmp             w4, #32 - (MAX_STRIDE << 4)
> +       cmp             BYTES_W, #32 - (MAX_STRIDE << 4)
>         csel            x16, x16, xzr, gt
> -       cmp             w4, #16 - (MAX_STRIDE << 4)
> +       cmp             BYTES_W, #16 - (MAX_STRIDE << 4)
>
> -       adr_l           x12, .Lcts_permute_table
> -       add             x12, x12, x13
> +       adr_l           x9, .Lcts_permute_table
> +       add             x9, x9, x13
>         ble             .Lctrtail1x\xctr
>
> -ST5(   ld1             {v5.16b}, [x1], x14             )
> -       ld1             {v6.16b}, [x1], x15
> -       ld1             {v7.16b}, [x1], x16
> +ST5(   ld1             {v5.16b}, [IN], x14             )
> +       ld1             {v6.16b}, [IN], x15
> +       ld1             {v7.16b}, [IN], x16
>
>  ST4(   bl              aes_encrypt_block4x             )
>  ST5(   bl              aes_encrypt_block5x             )
>
> -       ld1             {v8.16b}, [x1], x13
> -       ld1             {v9.16b}, [x1]
> -       ld1             {v10.16b}, [x12]
> +       ld1             {v8.16b}, [IN], x13
> +       ld1             {v9.16b}, [IN]
> +       ld1             {v10.16b}, [x9]
>
>  ST4(   eor             v6.16b, v6.16b, v0.16b          )
>  ST4(   eor             v7.16b, v7.16b, v1.16b          )
> @@ -477,30 +522,48 @@ ST5(      eor             v7.16b, v7.16b, v2.16b          )
>  ST5(   eor             v8.16b, v8.16b, v3.16b          )
>  ST5(   eor             v9.16b, v9.16b, v4.16b          )
>
> -ST5(   st1             {v5.16b}, [x0], x14             )
> -       st1             {v6.16b}, [x0], x15
> -       st1             {v7.16b}, [x0], x16
> -       add             x13, x13, x0
> +ST5(   st1             {v5.16b}, [OUT], x14            )
> +       st1             {v6.16b}, [OUT], x15
> +       st1             {v7.16b}, [OUT], x16
> +       add             x13, x13, OUT
>         st1             {v9.16b}, [x13]         // overlapping stores
> -       st1             {v8.16b}, [x0]
> +       st1             {v8.16b}, [OUT]
>         b               .Lctrout\xctr
>
>  .Lctrtail1x\xctr:
> -       sub             x7, x6, #16
> -       csel            x6, x6, x7, eq
> -       add             x1, x1, x6
> -       add             x0, x0, x6
> -       ld1             {v5.16b}, [x1]
> -       ld1             {v6.16b}, [x0]
> +       /*
> +        * Handle <= 16 bytes of plaintext
> +        */
> +       sub             x8, x7, #16
> +       csel            x7, x7, x8, eq
> +       add             IN, IN, x7
> +       add             OUT, OUT, x7
> +       ld1             {v5.16b}, [IN]
> +       ld1             {v6.16b}, [OUT]
>  ST5(   mov             v3.16b, v4.16b                  )
>         encrypt_block   v3, w3, x2, x8, w7
> -       ld1             {v10.16b-v11.16b}, [x12]
> +       ld1             {v10.16b-v11.16b}, [x9]
>         tbl             v3.16b, {v3.16b}, v10.16b
>         sshr            v11.16b, v11.16b, #7
>         eor             v5.16b, v5.16b, v3.16b
>         bif             v5.16b, v6.16b, v11.16b
> -       st1             {v5.16b}, [x0]
> +       st1             {v5.16b}, [OUT]
>         b               .Lctrout\xctr
> +
> +       // Arguments
> +       .unreq OUT
> +       .unreq IN
> +       .unreq KEY
> +       .unreq ROUNDS_W
> +       .unreq BYTES_W
> +       .unreq IV
> +       .unreq BYTE_CTR_W       // XCTR only
> +       // Intermediate values
> +       .unreq CTR_W            // XCTR only
> +       .unreq CTR              // XCTR only
> +       .unreq IV_PART
> +       .unreq BLOCKS
> +       .unreq BLOCKS_W
>  .endm
>
>         /*
> --
> 2.36.0.464.gb9c8b46e94-goog
>



More information about the linux-arm-kernel mailing list