[PATCH riscv/for-next] crypto: riscv - add vector crypto accelerated AES-CBC-CTS
Ard Biesheuvel
ardb at kernel.org
Wed Feb 14 08:34:03 PST 2024
On Tue, 13 Feb 2024 at 06:57, Eric Biggers <ebiggers at kernel.org> wrote:
>
> From: Eric Biggers <ebiggers at google.com>
>
> Add an implementation of cts(cbc(aes)) accelerated using the Zvkned
> RISC-V vector crypto extension. This is mainly useful for fscrypt,
> where cts(cbc(aes)) is the "default" filenames encryption algorithm. In
> that use case, typically most messages are short and are block-aligned.
Does this mean the storage space for filenames is rounded up to AES block size?
> The CBC-CTS variant implemented is CS3; this is the variant Linux uses.
>
> To perform well on short messages, the new implementation processes the
> full message in one call to the assembly function if the data is
> contiguous. Otherwise it falls back to CBC operations followed by CTS
> at the end. For decryption, to further improve performance on short
> messages, especially block-aligned messages, the CBC-CTS assembly
> function parallelizes the AES decryption of all full blocks.
Nice!
> This
> improves on the arm64 implementation of cts(cbc(aes)), which always
> splits the CBC part(s) from the CTS part, doing the AES decryptions for
> the last two blocks serially and usually loading the round keys twice.
>
So is the overhead of this sub-optimal approach mostly in the
redundant loading of the round keys? Or are there other significant
benefits?
If there are, I suppose we might port this improvement to x86 too, but
otherwise, I guess it'll only make sense for arm64.
> Tested in QEMU with CONFIG_CRYPTO_MANAGER_EXTRA_TESTS=y.
>
> Signed-off-by: Eric Biggers <ebiggers at google.com>
> ---
> arch/riscv/crypto/Kconfig | 4 +-
> arch/riscv/crypto/aes-riscv64-glue.c | 93 ++++++++++++++-
> arch/riscv/crypto/aes-riscv64-zvkned.S | 153 +++++++++++++++++++++++++
> 3 files changed, 245 insertions(+), 5 deletions(-)
>
> diff --git a/arch/riscv/crypto/Kconfig b/arch/riscv/crypto/Kconfig
> index 2ad44e1d464a..ad58dad9a580 100644
> --- a/arch/riscv/crypto/Kconfig
> +++ b/arch/riscv/crypto/Kconfig
> @@ -1,23 +1,23 @@
> # SPDX-License-Identifier: GPL-2.0
>
> menu "Accelerated Cryptographic Algorithms for CPU (riscv)"
>
> config CRYPTO_AES_RISCV64
> - tristate "Ciphers: AES, modes: ECB, CBC, CTR, XTS"
> + tristate "Ciphers: AES, modes: ECB, CBC, CTS, CTR, XTS"
> depends on 64BIT && RISCV_ISA_V && TOOLCHAIN_HAS_VECTOR_CRYPTO
> select CRYPTO_ALGAPI
> select CRYPTO_LIB_AES
> select CRYPTO_SKCIPHER
> help
> Block cipher: AES cipher algorithms
> - Length-preserving ciphers: AES with ECB, CBC, CTR, XTS
> + Length-preserving ciphers: AES with ECB, CBC, CTS, CTR, XTS
>
> Architecture: riscv64 using:
> - Zvkned vector crypto extension
> - Zvbb vector extension (XTS)
> - Zvkb vector crypto extension (CTR)
> - Zvkg vector crypto extension (XTS)
>
> config CRYPTO_CHACHA_RISCV64
> tristate "Ciphers: ChaCha"
> depends on 64BIT && RISCV_ISA_V && TOOLCHAIN_HAS_VECTOR_CRYPTO
> diff --git a/arch/riscv/crypto/aes-riscv64-glue.c b/arch/riscv/crypto/aes-riscv64-glue.c
> index 37bc6ef0be40..f814ee048555 100644
> --- a/arch/riscv/crypto/aes-riscv64-glue.c
> +++ b/arch/riscv/crypto/aes-riscv64-glue.c
> @@ -1,20 +1,22 @@
> // SPDX-License-Identifier: GPL-2.0-only
> /*
> * AES using the RISC-V vector crypto extensions. Includes the bare block
> - * cipher and the ECB, CBC, CTR, and XTS modes.
> + * cipher and the ECB, CBC, CBC-CTS, CTR, and XTS modes.
> *
> * Copyright (C) 2023 VRULL GmbH
> * Author: Heiko Stuebner <heiko.stuebner at vrull.eu>
> *
> * Copyright (C) 2023 SiFive, Inc.
> * Author: Jerry Shih <jerry.shih at sifive.com>
> + *
> + * Copyright 2024 Google LLC
> */
>
> #include <asm/simd.h>
> #include <asm/vector.h>
> #include <crypto/aes.h>
> #include <crypto/internal/cipher.h>
> #include <crypto/internal/simd.h>
> #include <crypto/internal/skcipher.h>
> #include <crypto/scatterwalk.h>
> #include <crypto/xts.h>
> @@ -33,20 +35,24 @@ asmlinkage void aes_ecb_encrypt_zvkned(const struct crypto_aes_ctx *key,
> asmlinkage void aes_ecb_decrypt_zvkned(const struct crypto_aes_ctx *key,
> const u8 *in, u8 *out, size_t len);
>
> asmlinkage void aes_cbc_encrypt_zvkned(const struct crypto_aes_ctx *key,
> const u8 *in, u8 *out, size_t len,
> u8 iv[AES_BLOCK_SIZE]);
> asmlinkage void aes_cbc_decrypt_zvkned(const struct crypto_aes_ctx *key,
> const u8 *in, u8 *out, size_t len,
> u8 iv[AES_BLOCK_SIZE]);
>
> +asmlinkage void aes_cbc_cts_crypt_zvkned(const struct crypto_aes_ctx *key,
> + const u8 *in, u8 *out, size_t len,
> + const u8 iv[AES_BLOCK_SIZE], bool enc);
> +
> asmlinkage void aes_ctr32_crypt_zvkned_zvkb(const struct crypto_aes_ctx *key,
> const u8 *in, u8 *out, size_t len,
> u8 iv[AES_BLOCK_SIZE]);
>
> asmlinkage void aes_xts_encrypt_zvkned_zvbb_zvkg(
> const struct crypto_aes_ctx *key,
> const u8 *in, u8 *out, size_t len,
> u8 tweak[AES_BLOCK_SIZE]);
>
> asmlinkage void aes_xts_decrypt_zvkned_zvbb_zvkg(
> @@ -157,21 +163,21 @@ static int riscv64_aes_ecb_encrypt(struct skcipher_request *req)
> return riscv64_aes_ecb_crypt(req, true);
> }
>
> static int riscv64_aes_ecb_decrypt(struct skcipher_request *req)
> {
> return riscv64_aes_ecb_crypt(req, false);
> }
>
> /* AES-CBC */
>
> -static inline int riscv64_aes_cbc_crypt(struct skcipher_request *req, bool enc)
> +static int riscv64_aes_cbc_crypt(struct skcipher_request *req, bool enc)
> {
> struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
> const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
> struct skcipher_walk walk;
> unsigned int nbytes;
> int err;
>
> err = skcipher_walk_virt(&walk, req, false);
> while ((nbytes = walk.nbytes) != 0) {
> kernel_vector_begin();
> @@ -195,20 +201,84 @@ static inline int riscv64_aes_cbc_crypt(struct skcipher_request *req, bool enc)
> static int riscv64_aes_cbc_encrypt(struct skcipher_request *req)
> {
> return riscv64_aes_cbc_crypt(req, true);
> }
>
> static int riscv64_aes_cbc_decrypt(struct skcipher_request *req)
> {
> return riscv64_aes_cbc_crypt(req, false);
> }
>
> +/* AES-CBC-CTS */
> +
> +static int riscv64_aes_cbc_cts_crypt(struct skcipher_request *req, bool enc)
> +{
> + struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
> + const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
> + struct scatterlist sg_src[2], sg_dst[2];
> + struct skcipher_request subreq;
> + struct scatterlist *src, *dst;
> + struct skcipher_walk walk;
> + unsigned int cbc_len;
> + int err;
> +
> + if (req->cryptlen < AES_BLOCK_SIZE)
> + return -EINVAL;
> +
> + err = skcipher_walk_virt(&walk, req, false);
> + if (err)
> + return err;
> + /*
> + * If the full message is available in one step, decrypt it in one call
> + * to the CBC-CTS assembly function. This reduces overhead, especially
> + * on short messages. Otherwise, fall back to doing CBC up to the last
> + * two blocks, then invoke CTS just for the ciphertext stealing.
> + */
> + if (unlikely(walk.nbytes != req->cryptlen)) {
> + cbc_len = round_down(req->cryptlen - AES_BLOCK_SIZE - 1,
> + AES_BLOCK_SIZE);
> + skcipher_walk_abort(&walk);
> + skcipher_request_set_tfm(&subreq, tfm);
> + skcipher_request_set_callback(&subreq,
> + skcipher_request_flags(req),
> + NULL, NULL);
> + skcipher_request_set_crypt(&subreq, req->src, req->dst,
> + cbc_len, req->iv);
> + err = riscv64_aes_cbc_crypt(&subreq, enc);
> + if (err)
> + return err;
> + dst = src = scatterwalk_ffwd(sg_src, req->src, cbc_len);
> + if (req->dst != req->src)
> + dst = scatterwalk_ffwd(sg_dst, req->dst, cbc_len);
> + skcipher_request_set_crypt(&subreq, src, dst,
> + req->cryptlen - cbc_len, req->iv);
> + err = skcipher_walk_virt(&walk, &subreq, false);
> + if (err)
> + return err;
> + }
> + kernel_vector_begin();
> + aes_cbc_cts_crypt_zvkned(ctx, walk.src.virt.addr, walk.dst.virt.addr,
> + walk.nbytes, req->iv, enc);
> + kernel_vector_end();
> + return skcipher_walk_done(&walk, 0);
> +}
> +
> +static int riscv64_aes_cbc_cts_encrypt(struct skcipher_request *req)
> +{
> + return riscv64_aes_cbc_cts_crypt(req, true);
> +}
> +
> +static int riscv64_aes_cbc_cts_decrypt(struct skcipher_request *req)
> +{
> + return riscv64_aes_cbc_cts_crypt(req, false);
> +}
> +
> /* AES-CTR */
>
> static int riscv64_aes_ctr_crypt(struct skcipher_request *req)
> {
> struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
> const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
> unsigned int nbytes, p1_nbytes;
> struct skcipher_walk walk;
> u32 ctr32, nblocks;
> int err;
> @@ -427,20 +497,36 @@ static struct skcipher_alg riscv64_zvkned_aes_skcipher_algs[] = {
> .max_keysize = AES_MAX_KEY_SIZE,
> .ivsize = AES_BLOCK_SIZE,
> .base = {
> .cra_blocksize = AES_BLOCK_SIZE,
> .cra_ctxsize = sizeof(struct crypto_aes_ctx),
> .cra_priority = 300,
> .cra_name = "cbc(aes)",
> .cra_driver_name = "cbc-aes-riscv64-zvkned",
> .cra_module = THIS_MODULE,
> },
> + }, {
> + .setkey = riscv64_aes_setkey_skcipher,
> + .encrypt = riscv64_aes_cbc_cts_encrypt,
> + .decrypt = riscv64_aes_cbc_cts_decrypt,
> + .min_keysize = AES_MIN_KEY_SIZE,
> + .max_keysize = AES_MAX_KEY_SIZE,
> + .ivsize = AES_BLOCK_SIZE,
> + .walksize = 4 * AES_BLOCK_SIZE, /* matches LMUL=4 */
> + .base = {
> + .cra_blocksize = AES_BLOCK_SIZE,
> + .cra_ctxsize = sizeof(struct crypto_aes_ctx),
> + .cra_priority = 300,
> + .cra_name = "cts(cbc(aes))",
> + .cra_driver_name = "cts-cbc-aes-riscv64-zvkned",
> + .cra_module = THIS_MODULE,
> + },
> }
> };
>
> static struct skcipher_alg riscv64_zvkned_zvkb_aes_skcipher_alg = {
> .setkey = riscv64_aes_setkey_skcipher,
> .encrypt = riscv64_aes_ctr_crypt,
> .decrypt = riscv64_aes_ctr_crypt,
> .min_keysize = AES_MIN_KEY_SIZE,
> .max_keysize = AES_MAX_KEY_SIZE,
> .ivsize = AES_BLOCK_SIZE,
> @@ -533,18 +619,19 @@ static void __exit riscv64_aes_mod_exit(void)
> if (riscv_isa_extension_available(NULL, ZVKB))
> crypto_unregister_skcipher(&riscv64_zvkned_zvkb_aes_skcipher_alg);
> crypto_unregister_skciphers(riscv64_zvkned_aes_skcipher_algs,
> ARRAY_SIZE(riscv64_zvkned_aes_skcipher_algs));
> crypto_unregister_alg(&riscv64_zvkned_aes_cipher_alg);
> }
>
> module_init(riscv64_aes_mod_init);
> module_exit(riscv64_aes_mod_exit);
>
> -MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS (RISC-V accelerated)");
> +MODULE_DESCRIPTION("AES-ECB/CBC/CTS/CTR/XTS (RISC-V accelerated)");
> MODULE_AUTHOR("Jerry Shih <jerry.shih at sifive.com>");
> MODULE_LICENSE("GPL");
> MODULE_ALIAS_CRYPTO("aes");
> MODULE_ALIAS_CRYPTO("ecb(aes)");
> MODULE_ALIAS_CRYPTO("cbc(aes)");
> +MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
> MODULE_ALIAS_CRYPTO("ctr(aes)");
> MODULE_ALIAS_CRYPTO("xts(aes)");
> diff --git a/arch/riscv/crypto/aes-riscv64-zvkned.S b/arch/riscv/crypto/aes-riscv64-zvkned.S
> index 43541aad6386..23d063f94ce6 100644
> --- a/arch/riscv/crypto/aes-riscv64-zvkned.S
> +++ b/arch/riscv/crypto/aes-riscv64-zvkned.S
> @@ -177,10 +177,163 @@ SYM_FUNC_END(aes_cbc_encrypt_zvkned)
>
> // Same prototype and calling convention as the encryption function
> SYM_FUNC_START(aes_cbc_decrypt_zvkned)
> aes_begin KEYP, 128f, 192f
> aes_cbc_decrypt 256
> 128:
> aes_cbc_decrypt 128
> 192:
> aes_cbc_decrypt 192
> SYM_FUNC_END(aes_cbc_decrypt_zvkned)
> +
> +.macro aes_cbc_cts_encrypt keylen
> +
> + // CBC-encrypt all blocks except the last. But don't store the
> + // second-to-last block to the output buffer yet, since it will be
> + // handled specially in the ciphertext stealing step. Exception: if the
> + // message is single-block, still encrypt the last (and only) block.
> + li t0, 16
> + j 2f
> +1:
> + vse32.v v16, (OUTP) // Store ciphertext block
> + addi OUTP, OUTP, 16
> +2:
> + vle32.v v17, (INP) // Load plaintext block
> + vxor.vv v16, v16, v17 // XOR with IV or prev ciphertext block
> + aes_encrypt v16, \keylen // Encrypt
> + addi INP, INP, 16
> + addi LEN, LEN, -16
> + bgt LEN, t0, 1b // Repeat if more than one block remains
> +
> + // Special case: if the message is a single block, just do CBC.
> + beqz LEN, .Lcts_encrypt_done\@
> +
> + // Encrypt the last two blocks using ciphertext stealing as follows:
> + // C[n-1] = Encrypt(Encrypt(P[n-1] ^ C[n-2]) ^ P[n])
> + // C[n] = Encrypt(P[n-1] ^ C[n-2])[0..LEN]
> + //
> + // C[i] denotes the i'th ciphertext block, and likewise P[i] the i'th
> + // plaintext block. Block n, the last block, may be partial; its length
> + // is 1 <= LEN <= 16. If there are only 2 blocks, C[n-2] means the IV.
> + //
> + // v16 already contains Encrypt(P[n-1] ^ C[n-2]).
> + // INP points to P[n]. OUTP points to where C[n-1] should go.
> + // To support in-place encryption, load P[n] before storing C[n].
> + addi t0, OUTP, 16 // Get pointer to where C[n] should go
> + vsetvli zero, LEN, e8, m1, tu, ma
> + vle8.v v17, (INP) // Load P[n]
> + vse8.v v16, (t0) // Store C[n]
> + vxor.vv v16, v16, v17 // v16 = Encrypt(P[n-1] ^ C[n-2]) ^ P[n]
> + vsetivli zero, 4, e32, m1, ta, ma
> + aes_encrypt v16, \keylen
> +.Lcts_encrypt_done\@:
> + vse32.v v16, (OUTP) // Store C[n-1] (or C[n] in single-block case)
> + ret
> +.endm
> +
> +#define LEN32 t4 // Length of remaining full blocks in 32-bit words
> +#define LEN_MOD16 t5 // Length of message in bytes mod 16
> +
> +.macro aes_cbc_cts_decrypt keylen
> + andi LEN32, LEN, ~15
> + srli LEN32, LEN32, 2
> + andi LEN_MOD16, LEN, 15
> +
> + // Save C[n-2] in v28 so that it's available later during the ciphertext
> + // stealing step. If there are fewer than three blocks, C[n-2] means
> + // the IV, otherwise it means the third-to-last ciphertext block.
> + vmv.v.v v28, v16 // IV
> + add t0, LEN, -33
> + bltz t0, .Lcts_decrypt_loop\@
> + andi t0, t0, ~15
> + add t0, t0, INP
> + vle32.v v28, (t0)
> +
> + // CBC-decrypt all full blocks. For the last full block, or the last 2
> + // full blocks if the message is block-aligned, this doesn't write the
> + // correct output blocks (unless the message is only a single block),
> + // because it XORs the wrong values with the raw AES plaintexts. But we
> + // fix this after this loop without redoing the AES decryptions. This
> + // approach allows more of the AES decryptions to be parallelized.
> +.Lcts_decrypt_loop\@:
> + vsetvli t0, LEN32, e32, m4, ta, ma
> + addi t1, t0, -4
> + vle32.v v20, (INP) // Load next set of ciphertext blocks
> + vmv.v.v v24, v16 // Get IV or last ciphertext block of prev set
> + vslideup.vi v24, v20, 4 // Setup prev ciphertext blocks
> + vslidedown.vx v16, v20, t1 // Save last ciphertext block of this set
> + aes_decrypt v20, \keylen // Decrypt this set of blocks
> + vxor.vv v24, v24, v20 // XOR prev ciphertext blocks with decrypted blocks
> + vse32.v v24, (OUTP) // Store this set of plaintext blocks
> + sub LEN32, LEN32, t0
> + slli t0, t0, 2 // Words to bytes
> + add INP, INP, t0
> + add OUTP, OUTP, t0
> + bnez LEN32, .Lcts_decrypt_loop\@
> +
> + vsetivli zero, 4, e32, m4, ta, ma
> + vslidedown.vx v20, v20, t1 // Extract raw plaintext of last full block
> + addi t0, OUTP, -16 // Get pointer to last full plaintext block
> + bnez LEN_MOD16, .Lcts_decrypt_non_block_aligned\@
> +
> + // Special case: if the message is a single block, just do CBC.
> + li t1, 16
> + beq LEN, t1, .Lcts_decrypt_done\@
> +
> + // Block-aligned message. Just fix up the last 2 blocks. We need:
> + //
> + // P[n-1] = Decrypt(C[n]) ^ C[n-2]
> + // P[n] = Decrypt(C[n-1]) ^ C[n]
> + //
> + // We have C[n] in v16, Decrypt(C[n]) in v20, and C[n-2] in v28.
> + // Together with Decrypt(C[n-1]) ^ C[n-2] from the output buffer, this
> + // is everything needed to fix the output without re-decrypting blocks.
> + addi t1, OUTP, -32 // Get pointer to where P[n-1] should go
> + vxor.vv v20, v20, v28 // Decrypt(C[n]) ^ C[n-2] == P[n-1]
> + vle32.v v24, (t1) // Decrypt(C[n-1]) ^ C[n-2]
> + vse32.v v20, (t1) // Store P[n-1]
> + vxor.vv v20, v24, v16 // Decrypt(C[n-1]) ^ C[n-2] ^ C[n] == P[n] ^ C[n-2]
> + j .Lcts_decrypt_finish\@
> +
> +.Lcts_decrypt_non_block_aligned\@:
> + // Decrypt the last two blocks using ciphertext stealing as follows:
> + //
> + // P[n-1] = Decrypt(C[n] || Decrypt(C[n-1])[LEN_MOD16..16]) ^ C[n-2]
> + // P[n] = (Decrypt(C[n-1]) ^ C[n])[0..LEN_MOD16]
> + //
> + // We already have Decrypt(C[n-1]) in v20 and C[n-2] in v28.
> + vmv.v.v v16, v20 // v16 = Decrypt(C[n-1])
> + vsetvli zero, LEN_MOD16, e8, m1, tu, ma
> + vle8.v v20, (INP) // v20 = C[n] || Decrypt(C[n-1])[LEN_MOD16..16]
> + vxor.vv v16, v16, v20 // v16 = Decrypt(C[n-1]) ^ C[n]
> + vse8.v v16, (OUTP) // Store P[n]
> + vsetivli zero, 4, e32, m1, ta, ma
> + aes_decrypt v20, \keylen // v20 = Decrypt(C[n] || Decrypt(C[n-1])[LEN_MOD16..16])
> +.Lcts_decrypt_finish\@:
> + vxor.vv v20, v20, v28 // XOR with C[n-2]
> + vse32.v v20, (t0) // Store last full plaintext block
> +.Lcts_decrypt_done\@:
> + ret
> +.endm
> +
> +.macro aes_cbc_cts_crypt keylen
> + vle32.v v16, (IVP) // Load IV
> + beqz a5, .Lcts_decrypt\@
> + aes_cbc_cts_encrypt \keylen
> +.Lcts_decrypt\@:
> + aes_cbc_cts_decrypt \keylen
> +.endm
> +
> +// void aes_cbc_cts_crypt_zvkned(const struct crypto_aes_ctx *key,
> +// const u8 *in, u8 *out, size_t len,
> +// const u8 iv[16], bool enc);
> +//
> +// Encrypts or decrypts a message with the CS3 variant of AES-CBC-CTS.
> +// This is the variant that unconditionally swaps the last two blocks.
> +SYM_FUNC_START(aes_cbc_cts_crypt_zvkned)
> + aes_begin KEYP, 128f, 192f
> + aes_cbc_cts_crypt 256
> +128:
> + aes_cbc_cts_crypt 128
> +192:
> + aes_cbc_cts_crypt 192
> +SYM_FUNC_END(aes_cbc_cts_crypt_zvkned)
>
> base-commit: cb4ede926134a65bc3bf90ed58dace8451d7e759
> prerequisite-patch-id: 2a69e1270be0fa567cc43269826171d6e46d65de
> --
> 2.43.0
>
More information about the linux-riscv
mailing list