[PATCH v7 01/15] genpt: Generic Page Table base API

Samiullah Khawaja skhawaja at google.com
Mon Oct 27 09:35:10 PDT 2025


On Thu, Oct 23, 2025 at 11:21 AM Jason Gunthorpe <jgg at nvidia.com> wrote:
>
> The generic API is intended to be separated from the implementation of
> page table algorithms. It contains only accessors for walking and
> manipulating the table and helpers that are useful for building an
> implementation. Memory management is not in the generic API, but part of
> the implementation.
>
> Using a multi-compilation approach the implementation module would include
> headers in this order:
>
>   common.h
>   defs_FMT.h
>   pt_defs.h
>   FMT.h
>   pt_common.h
>   IMPLEMENTATION.h
>
> Where each compilation unit would have a combination of FMT and
> IMPLEMENTATION to produce a per-format per-implementation module.
>
> The API is designed so that the format headers have minimal logic, and
> default implementations are provided if the format doesn't include one.
>
> Generally formats provide their code via an inline function using the
> pattern:
>
>   static inline FMTpt_XX(..) {}
>   #define pt_XX FMTpt_XX
>
> The common code then enforces a function signature so that there is no
> drift in function arguments, or accidental polymorphic functions (as has
> been slightly troublesome in mm). Use of function-like #defines are
> avoided in the format even though many of the functions are small enough.
>
> Provide kdocs for the API surface.
>
> This is enough to implement the 8 initial format variations with all of
> their features:
>  * Entries comprised of contiguous blocks of IO PTEs for larger page
>    sizes (AMDv1, ARMv8)
>  * Multi-level tables, up to 6 levels. Runtime selected top level
>  * The size of the top table level can be selected at runtime (ARM's
>    concatenated tables)
>  * The number of levels in the table can optionally increase dynamically
>    during map (AMDv1)
>  * Optional leaf entries at any level
>  * 32 bit/64 bit virtual and output addresses, using every bit
>  * Sign extended addressing (x86)
>  * Dirty tracking
>
> A basic simple format takes about 200 lines to declare the require inline
> functions.
>
> Tested-by: Alejandro Jimenez <alejandro.j.jimenez at oracle.com>
> Reviewed-by: Kevin Tian <kevin.tian at intel.com>
> Signed-off-by: Jason Gunthorpe <jgg at nvidia.com>
> ---
>  .clang-format                              |   1 +
>  drivers/iommu/Kconfig                      |   2 +
>  drivers/iommu/generic_pt/Kconfig           |  20 +
>  drivers/iommu/generic_pt/pt_common.h       | 358 ++++++++++++
>  drivers/iommu/generic_pt/pt_defs.h         | 329 +++++++++++
>  drivers/iommu/generic_pt/pt_fmt_defaults.h | 233 ++++++++
>  drivers/iommu/generic_pt/pt_iter.h         | 636 +++++++++++++++++++++
>  drivers/iommu/generic_pt/pt_log2.h         | 122 ++++
>  include/linux/generic_pt/common.h          | 135 +++++
>  9 files changed, 1836 insertions(+)
>  create mode 100644 drivers/iommu/generic_pt/Kconfig
>  create mode 100644 drivers/iommu/generic_pt/pt_common.h
>  create mode 100644 drivers/iommu/generic_pt/pt_defs.h
>  create mode 100644 drivers/iommu/generic_pt/pt_fmt_defaults.h
>  create mode 100644 drivers/iommu/generic_pt/pt_iter.h
>  create mode 100644 drivers/iommu/generic_pt/pt_log2.h
>  create mode 100644 include/linux/generic_pt/common.h
>
> diff --git a/.clang-format b/.clang-format
> index f371a13b4d192d..9e6a9177f8fb32 100644
> --- a/.clang-format
> +++ b/.clang-format
> @@ -415,6 +415,7 @@ ForEachMacros:
>    - 'for_each_prop_dlc_cpus'
>    - 'for_each_prop_dlc_platforms'
>    - 'for_each_property_of_node'
> +  - 'for_each_pt_level_entry'
>    - 'for_each_rdt_resource'
>    - 'for_each_reg'
>    - 'for_each_reg_filtered'
> diff --git a/drivers/iommu/Kconfig b/drivers/iommu/Kconfig
> index 70d29b14d85196..c9ae3221cd6f50 100644
> --- a/drivers/iommu/Kconfig
> +++ b/drivers/iommu/Kconfig
> @@ -384,3 +384,5 @@ config SPRD_IOMMU
>           Say Y here if you want to use the multimedia devices listed above.
>
>  endif # IOMMU_SUPPORT
> +
> +source "drivers/iommu/generic_pt/Kconfig"
> diff --git a/drivers/iommu/generic_pt/Kconfig b/drivers/iommu/generic_pt/Kconfig
> new file mode 100644
> index 00000000000000..fb0f431ddba0a8
> --- /dev/null
> +++ b/drivers/iommu/generic_pt/Kconfig
> @@ -0,0 +1,20 @@
> +# SPDX-License-Identifier: GPL-2.0-only
> +
> +menuconfig GENERIC_PT
> +       bool "Generic Radix Page Table"
> +       help
> +         Generic library for building radix tree page tables.
> +
> +         Generic PT provides a set of HW page table formats and a common
> +         set of APIs to work with them.
> +
> +if GENERIC_PT
> +config DEBUG_GENERIC_PT
> +       bool "Extra debugging checks for GENERIC_PT"
> +       help
> +         Enable extra run time debugging checks for GENERIC_PT code. This
> +         incurs a runtime cost and should not be enabled for production
> +         kernels.
> +
> +         The kunit tests require this to be enabled to get full coverage.
> +endif
> diff --git a/drivers/iommu/generic_pt/pt_common.h b/drivers/iommu/generic_pt/pt_common.h
> new file mode 100644
> index 00000000000000..f64f800725dbb7
> --- /dev/null
> +++ b/drivers/iommu/generic_pt/pt_common.h
> @@ -0,0 +1,358 @@
> +/* SPDX-License-Identifier: GPL-2.0-only */
> +/*
> + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES
> + *
> + * This header is included after the format. It contains definitions
> + * that build on the format definitions to create the basic format API.
> + *
> + * The format API is listed here, with kdocs. The functions without bodies are
> + * implemented in the format using the pattern:
> + *     static inline FMTpt_XXX(..) {..}
> + *     #define pt_XXX FMTpt_XXX
> + *
> + * If the format doesn't implement a function then pt_fmt_defaults.h can provide
> + * a generic version.
> + *
> + * The routines marked "@pts: Entry to query" operate on the entire contiguous
> + * entry and can be called with a pts->index pointing to any sub item that makes
> + * up that entry.
> + *
> + * The header order is:
> + *  pt_defs.h
> + *  FMT.h
> + *  pt_common.h
> + */
> +#ifndef __GENERIC_PT_PT_COMMON_H
> +#define __GENERIC_PT_PT_COMMON_H
> +
> +#include "pt_defs.h"
> +#include "pt_fmt_defaults.h"
> +
> +/**
> + * pt_attr_from_entry() - Convert the permission bits back to attrs
> + * @pts: Entry to convert from
> + * @attrs: Resulting attrs
> + *
> + * Fill in the attrs with the permission bits encoded in the current leaf entry.
> + * The attrs should be usable with pt_install_leaf_entry() to reconstruct the
> + * same entry.
> + */
> +static inline void pt_attr_from_entry(const struct pt_state *pts,
> +                                     struct pt_write_attrs *attrs);
> +
> +/**
> + * pt_can_have_leaf() - True if the current level can have an OA entry
> + * @pts: The current level
> + *
> + * True if the current level can support pt_install_leaf_entry(). A leaf
> + * entry produce an OA.
> + */
> +static inline bool pt_can_have_leaf(const struct pt_state *pts);
> +
> +/**
> + * pt_can_have_table() - True if the current level can have a lower table
> + * @pts: The current level
> + *
> + * Every level except 0 is allowed to have a lower table.
> + */
> +static inline bool pt_can_have_table(const struct pt_state *pts)
> +{
> +       /* No further tables at level 0 */
> +       return pts->level > 0;
> +}
> +
> +/**
> + * pt_clear_entries() - Make entries empty (non-present)
> + * @pts: Starting table index
> + * @num_contig_lg2: Number of contiguous items to clear
> + *
> + * Clear a run of entries. A cleared entry will load back as PT_ENTRY_EMPTY
> + * and does not have any effect on table walking. The starting index must be
> + * aligned to num_contig_lg2.
> + */
> +static inline void pt_clear_entries(struct pt_state *pts,
> +                                   unsigned int num_contig_lg2);
> +
> +/**
> + * pt_entry_make_write_dirty() - Make an entry dirty
> + * @pts: Table entry to change
> + *
> + * Make pt_entry_is_write_dirty() return true for this entry. This can be called
> + * asynchronously with any other table manipulation under a RCU lock and must
> + * not corrupt the table.
> + */
> +static inline bool pt_entry_make_write_dirty(struct pt_state *pts);
> +
> +/**
> + * pt_entry_make_write_clean() - Make the entry write clean
> + * @pts: Table entry to change
> + *
> + * Modify the entry so that pt_entry_is_write_dirty() == false. The HW will
> + * eventually be notified of this change via a TLB flush, which is the point
> + * that the HW must become synchronized. Any "write dirty" prior to the TLB
> + * flush can be lost, but once the TLB flush completes all writes must make
> + * their entries write dirty.
> + *
> + * The format should alter the entry in a way that is compatible with any
> + * concurrent update from HW. The entire contiguous entry is changed.
> + */
> +static inline void pt_entry_make_write_clean(struct pt_state *pts);
> +
> +/**
> + * pt_entry_is_write_dirty() - True if the entry has been written to
> + * @pts: Entry to query
> + *
> + * "write dirty" means that the HW has written to the OA translated
> + * by this entry. If the entry is contiguous then the consolidated
> + * "write dirty" for all the items must be returned.
> + */
> +static inline bool pt_entry_is_write_dirty(const struct pt_state *pts);
> +
> +/**
> + * pt_dirty_supported() - True if the page table supports dirty tracking
> + * @common: Page table to query
> + */
> +static inline bool pt_dirty_supported(struct pt_common *common);
> +
> +/**
> + * pt_entry_num_contig_lg2() - Number of contiguous items for this leaf entry
> + * @pts: Entry to query
> + *
> + * Return the number of contiguous items this leaf entry spans. If the entry
> + * is single item it returns ilog2(1).
> + */
> +static inline unsigned int pt_entry_num_contig_lg2(const struct pt_state *pts);
> +
> +/**
> + * pt_entry_oa() - Output Address for this leaf entry
> + * @pts: Entry to query
> + *
> + * Return the output address for the start of the entry. If the entry
> + * is contiguous this returns the same value for each sub-item. I.e.::
> + *
> + *    log2_mod(pt_entry_oa(), pt_entry_oa_lg2sz()) == 0
> + *
> + * See pt_item_oa(). The format should implement one of these two functions
> + * depending on how it stores the OAs in the table.
> + */
> +static inline pt_oaddr_t pt_entry_oa(const struct pt_state *pts);
> +
> +/**
> + * pt_entry_oa_lg2sz() - Return the size of an OA entry
> + * @pts: Entry to query
> + *
> + * If the entry is not contiguous this returns pt_table_item_lg2sz(), otherwise
> + * it returns the total VA/OA size of the entire contiguous entry.
> + */
> +static inline unsigned int pt_entry_oa_lg2sz(const struct pt_state *pts)
> +{
> +       return pt_entry_num_contig_lg2(pts) + pt_table_item_lg2sz(pts);
> +}
> +
> +/**
> + * pt_entry_oa_exact() - Return the complete OA for an entry
> + * @pts: Entry to query
> + *
> + * During iteration the first entry could have a VA with an offset from the
> + * natural start of the entry. Return the exact OA including the pts's VA
> + * offset.
> + */
> +static inline pt_oaddr_t pt_entry_oa_exact(const struct pt_state *pts)
> +{
> +       return _pt_entry_oa_fast(pts) |
> +              log2_mod(pts->range->va, pt_entry_oa_lg2sz(pts));
> +}
> +
> +/**
> + * pt_full_va_prefix() - The top bits of the VA
> + * @common: Page table to query
> + *
> + * This is usually 0, but some formats have their VA space going downward from
> + * PT_VADDR_MAX, and will return that instead. This value must always be
> + * adjusted by struct pt_common max_vasz_lg2.
> + */
> +static inline pt_vaddr_t pt_full_va_prefix(const struct pt_common *common);
> +
> +/**
> + * pt_has_system_page_size() - True if level 0 can install a PAGE_SHIFT entry
> + * @common: Page table to query
> + *
> + * If true the caller can use, at level 0, pt_install_leaf_entry(PAGE_SHIFT).
> + * This is useful to create optimized paths for common cases of PAGE_SIZE
> + * mappings.
> + */
> +static inline bool pt_has_system_page_size(const struct pt_common *common);
> +
> +/**
> + * pt_install_leaf_entry() - Write a leaf entry to the table
> + * @pts: Table index to change
> + * @oa: Output Address for this leaf
> + * @oasz_lg2: Size in VA/OA for this leaf
> + * @attrs: Attributes to modify the entry
> + *
> + * A leaf OA entry will return PT_ENTRY_OA from pt_load_entry(). It translates
> + * the VA indicated by pts to the given OA.
> + *
> + * For a single item non-contiguous entry oasz_lg2 is pt_table_item_lg2sz().
> + * For contiguous it is pt_table_item_lg2sz() + num_contig_lg2.
> + *
> + * This must not be called if pt_can_have_leaf() == false. Contiguous sizes
> + * not indicated by pt_possible_sizes() must not be specified.
> + */
> +static inline void pt_install_leaf_entry(struct pt_state *pts, pt_oaddr_t oa,
> +                                        unsigned int oasz_lg2,
> +                                        const struct pt_write_attrs *attrs);
> +
> +/**
> + * pt_install_table() - Write a table entry to the table
> + * @pts: Table index to change
> + * @table_pa: CPU physical address of the lower table's memory
> + * @attrs: Attributes to modify the table index
> + *
> + * A table entry will return PT_ENTRY_TABLE from pt_load_entry(). The table_pa
> + * is the table at pts->level - 1. This is done by cmpxchg so pts must have the
> + * current entry loaded. The pts is updated with the installed entry.
> + *
> + * This must not be called if pt_can_have_table() == false.
> + *
> + * Returns: true if the table was installed successfully.
> + */
> +static inline bool pt_install_table(struct pt_state *pts, pt_oaddr_t table_pa,
> +                                   const struct pt_write_attrs *attrs);
> +
> +/**
> + * pt_item_oa() - Output Address for this leaf item
> + * @pts: Item to query
> + *
> + * Return the output address for this item. If the item is part of a contiguous
> + * entry it returns the value of the OA for this individual sub item.
> + *
> + * See pt_entry_oa(). The format should implement one of these two functions
> + * depending on how it stores the OA's in the table.
> + */
> +static inline pt_oaddr_t pt_item_oa(const struct pt_state *pts);
> +
> +/**
> + * pt_load_entry_raw() - Read from the location pts points at into the pts
> + * @pts: Table index to load
> + *
> + * Return the type of entry that was loaded. pts->entry will be filled in with
> + * the entry's content. See pt_load_entry()
> + */
> +static inline enum pt_entry_type pt_load_entry_raw(struct pt_state *pts);
> +
> +/**
> + * pt_max_oa_lg2() - Return the maximum OA the table format can hold
> + * @common: Page table to query
> + *
> + * The value oalog2_to_max_int(pt_max_oa_lg2()) is the MAX for the
> + * OA. This is the absolute maximum address the table can hold. struct pt_common
> + * max_oasz_lg2 sets a lower dynamic maximum based on HW capability.
> + */
> +static inline unsigned int
> +pt_max_oa_lg2(const struct pt_common *common);
> +
> +/**
> + * pt_num_items_lg2() - Return the number of items in this table level
> + * @pts: The current level
> + *
> + * The number of items in a table level defines the number of bits this level
> + * decodes from the VA. This function is not called for the top level,
> + * so it does not need to compute a special value for the top case. The
> + * result for the top is based on pt_common max_vasz_lg2.
> + *
> + * The value is used as part of determining the table indexes via the
> + * equation::
> + *
> + *   log2_mod(log2_div(VA, pt_table_item_lg2sz()), pt_num_items_lg2())
> + */
> +static inline unsigned int pt_num_items_lg2(const struct pt_state *pts);
> +
> +/**
> + * pt_pgsz_lg2_to_level - Return the level that maps the page size
> + * @common: Page table to query
> + * @pgsize_lg2: Log2 page size
> + *
> + * Returns the table level that will map the given page size. The page
> + * size must be part of the pt_possible_sizes() for some level.
> + */
> +static inline unsigned int pt_pgsz_lg2_to_level(struct pt_common *common,
> +                                               unsigned int pgsize_lg2);
> +
> +/**
> + * pt_possible_sizes() - Return a bitmap of possible output sizes at this level
> + * @pts: The current level
> + *
> + * Each level has a list of possible output sizes that can be installed as
> + * leaf entries. If pt_can_have_leaf() is false returns zero.
> + *
> + * Otherwise the bit in position pt_table_item_lg2sz() should be set indicating
> + * that a non-contiguous single item leaf entry is supported. The following
> + * pt_num_items_lg2() number of bits can be set indicating contiguous entries
> + * are supported. Bit pt_table_item_lg2sz() + pt_num_items_lg2() must not be
> + * set, contiguous entries cannot span the entire table.
> + *
> + * The OR of pt_possible_sizes() of all levels is the typical bitmask of all
> + * supported sizes in the entire table.
> + */
> +static inline pt_vaddr_t pt_possible_sizes(const struct pt_state *pts);
> +
> +/**
> + * pt_table_item_lg2sz() - Size of a single item entry in this table level
> + * @pts: The current level
> + *
> + * The size of the item specifies how much VA and OA a single item occupies.
> + *
> + * See pt_entry_oa_lg2sz() for the same value including the effect of contiguous
> + * entries.
> + */
> +static inline unsigned int pt_table_item_lg2sz(const struct pt_state *pts);
> +
> +/**
> + * pt_table_oa_lg2sz() - Return the VA/OA size of the entire table
> + * @pts: The current level
> + *
> + * Return the size of VA decoded by the entire table level.
> + */
> +static inline unsigned int pt_table_oa_lg2sz(const struct pt_state *pts)
> +{
> +       if (pts->range->top_level == pts->level)
> +               return pts->range->max_vasz_lg2;
> +       return min_t(unsigned int, pts->range->common->max_vasz_lg2,
> +                    pt_num_items_lg2(pts) + pt_table_item_lg2sz(pts));
> +}
> +
> +/**
> + * pt_table_pa() - Return the CPU physical address of the table entry
> + * @pts: Entry to query
> + *
> + * This is only ever called on PT_ENTRY_TABLE entries. Must return the same
> + * value passed to pt_install_table().
> + */
> +static inline pt_oaddr_t pt_table_pa(const struct pt_state *pts);
> +
> +/**
> + * pt_table_ptr() - Return a CPU pointer for a table item
> + * @pts: Entry to query
> + *
> + * Same as pt_table_pa() but returns a CPU pointer.
> + */
> +static inline struct pt_table_p *pt_table_ptr(const struct pt_state *pts)
> +{
> +       return __va(pt_table_pa(pts));
> +}
> +
> +/**
> + * pt_load_entry() - Read from the location pts points at into the pts
> + * @pts: Table index to load
> + *
> + * Set the type of entry that was loaded. pts->entry and pts->table_lower
> + * will be filled in with the entry's content.
> + */
> +static inline void pt_load_entry(struct pt_state *pts)
> +{
> +       pts->type = pt_load_entry_raw(pts);
> +       if (pts->type == PT_ENTRY_TABLE)
> +               pts->table_lower = pt_table_ptr(pts);
> +}
> +#endif
> diff --git a/drivers/iommu/generic_pt/pt_defs.h b/drivers/iommu/generic_pt/pt_defs.h
> new file mode 100644
> index 00000000000000..819057de50d82c
> --- /dev/null
> +++ b/drivers/iommu/generic_pt/pt_defs.h
> @@ -0,0 +1,329 @@
> +/* SPDX-License-Identifier: GPL-2.0-only */
> +/*
> + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES
> + *
> + * This header is included before the format. It contains definitions
> + * that are required to compile the format. The header order is:
> + *  pt_defs.h
> + *  fmt_XX.h
> + *  pt_common.h
> + */
> +#ifndef __GENERIC_PT_DEFS_H
> +#define __GENERIC_PT_DEFS_H
> +
> +#include <linux/generic_pt/common.h>
> +
> +#include <linux/types.h>
> +#include <linux/atomic.h>
> +#include <linux/bits.h>
> +#include <linux/limits.h>
> +#include <linux/bug.h>
> +#include <linux/kconfig.h>
> +#include "pt_log2.h"
> +
> +/* Header self-compile default defines */
> +#ifndef pt_write_attrs
> +typedef u64 pt_vaddr_t;
> +typedef u64 pt_oaddr_t;
> +#endif
> +
> +struct pt_table_p;
> +
> +enum {
> +       PT_VADDR_MAX = sizeof(pt_vaddr_t) == 8 ? U64_MAX : U32_MAX,
> +       PT_VADDR_MAX_LG2 = sizeof(pt_vaddr_t) == 8 ? 64 : 32,
> +       PT_OADDR_MAX = sizeof(pt_oaddr_t) == 8 ? U64_MAX : U32_MAX,
> +       PT_OADDR_MAX_LG2 = sizeof(pt_oaddr_t) == 8 ? 64 : 32,
> +};
> +
> +/*
> + * The format instantiation can have features wired off or on to optimize the
> + * code gen. Supported features are just a reflection of what the current set of
> + * kernel users want to use.
> + */
> +#ifndef PT_SUPPORTED_FEATURES
> +#define PT_SUPPORTED_FEATURES 0
> +#endif
> +
> +/*
> + * When in debug mode we compile all formats with all features. This allows the
> + * kunit to test the full matrix. SIGN_EXTEND can't co-exist with DYNAMIC_TOP or
> + * FULL_VA.
> + */
> +#if IS_ENABLED(CONFIG_DEBUG_GENERIC_PT)
> +enum {
> +       PT_ORIG_SUPPORTED_FEATURES = PT_SUPPORTED_FEATURES,
> +       PT_DEBUG_SUPPORTED_FEATURES =
> +               UINT_MAX &
> +               ~((PT_ORIG_SUPPORTED_FEATURES & BIT(PT_FEAT_SIGN_EXTEND)) ?
> +                         BIT(PT_FEAT_DYNAMIC_TOP) | BIT(PT_FEAT_FULL_VA) :
> +                         BIT(PT_FEAT_SIGN_EXTEND)),
> +};
> +#undef PT_SUPPORTED_FEATURES
> +#define PT_SUPPORTED_FEATURES PT_DEBUG_SUPPORTED_FEATURES
> +#endif
> +
> +#ifndef PT_FORCE_ENABLED_FEATURES
> +#define PT_FORCE_ENABLED_FEATURES 0
> +#endif
> +
> +/**
> + * DOC: Generic Page Table Language
> + *
> + * Language used in Generic Page Table
> + *  VA
> + *     The input address to the page table, often the virtual address.
> + *  OA
> + *     The output address from the page table, often the physical address.
> + *  leaf
> + *     An entry that results in an output address.
> + *  start/end
> + *     An half-open range, e.g. [0,0) refers to no VA.
> + *  start/last
> + *     An inclusive closed range, e.g. [0,0] refers to the VA 0
> + *  common
> + *     The generic page table container struct pt_common
> + *  level
> + *     Level 0 is always a table of only leaves with no futher table pointers.
> + *     Increasing levels increase the size of the table items. The least
> + *     significant VA bits used to index page tables are used to index the Level
> + *     0 table. The various labels for table levels used by HW descriptions are
> + *     not used.
> + *  top_level
> + *     The inclusive highest level of the table. A two-level table
> + *     has a top level of 1.
> + *  table
> + *     A linear array of translation items for that level.
> + *  index
> + *     The position in a table of an element: item = table[index]
> + *  item
> + *     A single index in a table
> + *  entry
> + *     A single logical element in a table. If contiguous pages are not
> + *     supported then item and entry are the same thing, otherwise entry refers
> + *     to all the items that comprise a single contiguous translation.
> + *  item/entry_size
> + *     The number of bytes of VA the table index translates for.
> + *     If the item is a table entry then the next table covers
> + *     this size. If the entry translates to an output address then the
> + *     full OA is: OA | (VA % entry_size)
> + *  contig_count
> + *     The number of consecutive items fused into a single entry.
> + *     item_size * contig_count is the size of that entry's translation.
> + *  lg2
> + *     Indicates the value is encoded as log2, i.e. 1<<x is the actual value.
> + *     Normally the compiler is fine to optimize divide and mod with log2 values
> + *     automatically when inlining, however if the values are not constant
> + *     expressions it can't. So we do it by hand; we want to avoid 64-bit
> + *     divmod.
> + */
> +
> +/* Returned by pt_load_entry() and for_each_pt_level_entry() */
> +enum pt_entry_type {
> +       PT_ENTRY_EMPTY,
> +       /* Entry is valid and points to a lower table level */
> +       PT_ENTRY_TABLE,
> +       /* Entry is valid and returns an output address */
> +       PT_ENTRY_OA,
> +};
> +
> +struct pt_range {
> +       struct pt_common *common;
> +       struct pt_table_p *top_table;
> +       pt_vaddr_t va;
> +       pt_vaddr_t last_va;
> +       u8 top_level;
> +       u8 max_vasz_lg2;
> +};
> +
> +/*
> + * Similar to xa_state, this records information about an in-progress parse at a
> + * single level.
> + */
> +struct pt_state {
> +       struct pt_range *range;
> +       struct pt_table_p *table;
> +       struct pt_table_p *table_lower;
> +       u64 entry;
> +       enum pt_entry_type type;
> +       unsigned short index;
> +       unsigned short end_index;
> +       u8 level;
> +};
> +
> +#define pt_cur_table(pts, type) ((type *)((pts)->table))
> +
> +/*
> + * Try to install a new table pointer. The locking methodology requires this to
> + * be atomic (multiple threads can race to install a pointer). The losing
> + * threads will fail the atomic and return false. They should free any memory
> + * and reparse the table level again.
> + */
> +#if !IS_ENABLED(CONFIG_GENERIC_ATOMIC64)
> +static inline bool pt_table_install64(struct pt_state *pts, u64 table_entry)
> +{
> +       u64 *entryp = pt_cur_table(pts, u64) + pts->index;
> +       u64 old_entry = pts->entry;
> +       bool ret;
> +
> +       /*
> +        * Ensure the zero'd table content itself is visible before its PTE can
> +        * be. release is a NOP on !SMP, but the HW is still doing an acquire.
> +        */
> +       if (!IS_ENABLED(CONFIG_SMP))
> +               dma_wmb();
> +       ret = try_cmpxchg64_release(entryp, &old_entry, table_entry);
> +       if (ret)
> +               pts->entry = table_entry;
> +       return ret;
> +}
> +#endif
> +
> +static inline bool pt_table_install32(struct pt_state *pts, u32 table_entry)
> +{
> +       u32 *entryp = pt_cur_table(pts, u32) + pts->index;
> +       u32 old_entry = pts->entry;
> +       bool ret;
> +
> +       /*
> +        * Ensure the zero'd table content itself is visible before its PTE can
> +        * be. release is a NOP on !SMP, but the HW is still doing an acquire.
> +        */
> +       if (!IS_ENABLED(CONFIG_SMP))
> +               dma_wmb();
> +       ret = try_cmpxchg_release(entryp, &old_entry, table_entry);
> +       if (ret)
> +               pts->entry = table_entry;
> +       return ret;
> +}
> +
> +#define PT_SUPPORTED_FEATURE(feature_nr) (PT_SUPPORTED_FEATURES & BIT(feature_nr))
> +
> +static inline bool pt_feature(const struct pt_common *common,
> +                             unsigned int feature_nr)
> +{
> +       if (PT_FORCE_ENABLED_FEATURES & BIT(feature_nr))
> +               return true;
> +       if (!PT_SUPPORTED_FEATURE(feature_nr))
> +               return false;
> +       return common->features & BIT(feature_nr);
> +}
> +
> +static inline bool pts_feature(const struct pt_state *pts,
> +                              unsigned int feature_nr)
> +{
> +       return pt_feature(pts->range->common, feature_nr);
> +}
> +
> +/*
> + * PT_WARN_ON is used for invariants that the kunit should be checking can't
> + * happen.
> + */
> +#if IS_ENABLED(CONFIG_DEBUG_GENERIC_PT)
> +#define PT_WARN_ON WARN_ON
> +#else
> +static inline bool PT_WARN_ON(bool condition)
> +{
> +       return false;
> +}
> +#endif
> +
> +/* These all work on the VA type */
> +#define log2_to_int(a_lg2) log2_to_int_t(pt_vaddr_t, a_lg2)
> +#define log2_to_max_int(a_lg2) log2_to_max_int_t(pt_vaddr_t, a_lg2)
> +#define log2_div(a, b_lg2) log2_div_t(pt_vaddr_t, a, b_lg2)
> +#define log2_div_eq(a, b, c_lg2) log2_div_eq_t(pt_vaddr_t, a, b, c_lg2)
> +#define log2_mod(a, b_lg2) log2_mod_t(pt_vaddr_t, a, b_lg2)
> +#define log2_mod_eq_max(a, b_lg2) log2_mod_eq_max_t(pt_vaddr_t, a, b_lg2)
> +#define log2_set_mod(a, val, b_lg2) log2_set_mod_t(pt_vaddr_t, a, val, b_lg2)
> +#define log2_set_mod_max(a, b_lg2) log2_set_mod_max_t(pt_vaddr_t, a, b_lg2)
> +#define log2_mul(a, b_lg2) log2_mul_t(pt_vaddr_t, a, b_lg2)
> +#define vaffs(a) ffs_t(pt_vaddr_t, a)
> +#define vafls(a) fls_t(pt_vaddr_t, a)
> +#define vaffz(a) ffz_t(pt_vaddr_t, a)
> +
> +/*
> + * The full VA (fva) versions permit the lg2 value to be == PT_VADDR_MAX_LG2 and
> + * generate a useful defined result. The non-fva versions will malfunction at
> + * this extreme.
> + */
> +static inline pt_vaddr_t fvalog2_div(pt_vaddr_t a, unsigned int b_lg2)
> +{
> +       if (PT_SUPPORTED_FEATURE(PT_FEAT_FULL_VA) && b_lg2 == PT_VADDR_MAX_LG2)
> +               return 0;
> +       return log2_div_t(pt_vaddr_t, a, b_lg2);
> +}
> +
> +static inline pt_vaddr_t fvalog2_mod(pt_vaddr_t a, unsigned int b_lg2)
> +{
> +       if (PT_SUPPORTED_FEATURE(PT_FEAT_FULL_VA) && b_lg2 == PT_VADDR_MAX_LG2)
> +               return a;
> +       return log2_mod_t(pt_vaddr_t, a, b_lg2);
> +}
> +
> +static inline bool fvalog2_div_eq(pt_vaddr_t a, pt_vaddr_t b,
> +                                 unsigned int c_lg2)
> +{
> +       if (PT_SUPPORTED_FEATURE(PT_FEAT_FULL_VA) && c_lg2 == PT_VADDR_MAX_LG2)
> +               return true;
> +       return log2_div_eq_t(pt_vaddr_t, a, b, c_lg2);
> +}
> +
> +static inline pt_vaddr_t fvalog2_set_mod(pt_vaddr_t a, pt_vaddr_t val,
> +                                        unsigned int b_lg2)
> +{
> +       if (PT_SUPPORTED_FEATURE(PT_FEAT_FULL_VA) && b_lg2 == PT_VADDR_MAX_LG2)
> +               return val;
> +       return log2_set_mod_t(pt_vaddr_t, a, val, b_lg2);
> +}
> +
> +static inline pt_vaddr_t fvalog2_set_mod_max(pt_vaddr_t a, unsigned int b_lg2)
> +{
> +       if (PT_SUPPORTED_FEATURE(PT_FEAT_FULL_VA) && b_lg2 == PT_VADDR_MAX_LG2)
> +               return PT_VADDR_MAX;
> +       return log2_set_mod_max_t(pt_vaddr_t, a, b_lg2);
> +}
> +
> +/* These all work on the OA type */
> +#define oalog2_to_int(a_lg2) log2_to_int_t(pt_oaddr_t, a_lg2)
> +#define oalog2_to_max_int(a_lg2) log2_to_max_int_t(pt_oaddr_t, a_lg2)
> +#define oalog2_div(a, b_lg2) log2_div_t(pt_oaddr_t, a, b_lg2)
> +#define oalog2_div_eq(a, b, c_lg2) log2_div_eq_t(pt_oaddr_t, a, b, c_lg2)
> +#define oalog2_mod(a, b_lg2) log2_mod_t(pt_oaddr_t, a, b_lg2)
> +#define oalog2_mod_eq_max(a, b_lg2) log2_mod_eq_max_t(pt_oaddr_t, a, b_lg2)
> +#define oalog2_set_mod(a, val, b_lg2) log2_set_mod_t(pt_oaddr_t, a, val, b_lg2)
> +#define oalog2_set_mod_max(a, b_lg2) log2_set_mod_max_t(pt_oaddr_t, a, b_lg2)
> +#define oalog2_mul(a, b_lg2) log2_mul_t(pt_oaddr_t, a, b_lg2)
> +#define oaffs(a) ffs_t(pt_oaddr_t, a)
> +#define oafls(a) fls_t(pt_oaddr_t, a)
> +#define oaffz(a) ffz_t(pt_oaddr_t, a)
> +
> +static inline uintptr_t _pt_top_set(struct pt_table_p *table_mem,
> +                                   unsigned int top_level)
> +{
> +       return top_level | (uintptr_t)table_mem;
> +}
> +
> +static inline void pt_top_set(struct pt_common *common,
> +                             struct pt_table_p *table_mem,
> +                             unsigned int top_level)
> +{
> +       WRITE_ONCE(common->top_of_table, _pt_top_set(table_mem, top_level));
> +}
> +
> +static inline void pt_top_set_level(struct pt_common *common,
> +                                   unsigned int top_level)
> +{
> +       pt_top_set(common, NULL, top_level);
> +}
> +
> +static inline unsigned int pt_top_get_level(const struct pt_common *common)
> +{
> +       return READ_ONCE(common->top_of_table) % (1 << PT_TOP_LEVEL_BITS);
> +}
> +
> +static inline bool pt_check_install_leaf_args(struct pt_state *pts,
> +                                             pt_oaddr_t oa,
> +                                             unsigned int oasz_lg2);
> +
> +#endif
> diff --git a/drivers/iommu/generic_pt/pt_fmt_defaults.h b/drivers/iommu/generic_pt/pt_fmt_defaults.h
> new file mode 100644
> index 00000000000000..60d594bbb1063e
> --- /dev/null
> +++ b/drivers/iommu/generic_pt/pt_fmt_defaults.h
> @@ -0,0 +1,233 @@
> +/* SPDX-License-Identifier: GPL-2.0-only */
> +/*
> + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES
> + *
> + * Default definitions for formats that don't define these functions.
> + */
> +#ifndef __GENERIC_PT_PT_FMT_DEFAULTS_H
> +#define __GENERIC_PT_PT_FMT_DEFAULTS_H
> +
> +#include "pt_defs.h"
> +#include <linux/log2.h>
> +
> +/* Header self-compile default defines */
> +#ifndef pt_load_entry_raw
> +#include "fmt/amdv1.h"
> +#endif
> +
> +/*
> + * The format must provide PT_GRANULE_LG2SZ, PT_TABLEMEM_LG2SZ, and
> + * PT_ITEM_WORD_SIZE. They must be the same at every level excluding the top.
> + */
> +#ifndef pt_table_item_lg2sz
> +static inline unsigned int pt_table_item_lg2sz(const struct pt_state *pts)
> +{
> +       return PT_GRANULE_LG2SZ +
> +              (PT_TABLEMEM_LG2SZ - ilog2(PT_ITEM_WORD_SIZE)) * pts->level;
> +}
> +#endif
> +
> +#ifndef pt_pgsz_lg2_to_level
> +static inline unsigned int pt_pgsz_lg2_to_level(struct pt_common *common,
> +                                               unsigned int pgsize_lg2)
> +{
> +       return ((unsigned int)(pgsize_lg2 - PT_GRANULE_LG2SZ)) /
> +              (PT_TABLEMEM_LG2SZ - ilog2(PT_ITEM_WORD_SIZE));
> +}
> +#endif
> +
> +/*
> + * If not supplied by the format then contiguous pages are not supported.
> + *
> + * If contiguous pages are supported then the format must also provide
> + * pt_contig_count_lg2() if it supports a single contiguous size per level,
> + * or pt_possible_sizes() if it supports multiple sizes per level.
> + */
> +#ifndef pt_entry_num_contig_lg2
> +static inline unsigned int pt_entry_num_contig_lg2(const struct pt_state *pts)
> +{
> +       return ilog2(1);
> +}
> +
> +/*
> + * Return the number of contiguous OA items forming an entry at this table level
> + */
> +static inline unsigned short pt_contig_count_lg2(const struct pt_state *pts)
> +{
> +       return ilog2(1);
> +}
> +#endif
> +
> +/* If not supplied by the format then dirty tracking is not supported */
> +#ifndef pt_entry_is_write_dirty
> +static inline bool pt_entry_is_write_dirty(const struct pt_state *pts)
> +{
> +       return false;
> +}
> +
> +static inline void pt_entry_make_write_clean(struct pt_state *pts)
> +{
> +}
> +
> +static inline bool pt_dirty_supported(struct pt_common *common)
> +{
> +       return false;
> +}
> +#else
> +/* If not supplied then dirty tracking is always enabled */
> +#ifndef pt_dirty_supported
> +static inline bool pt_dirty_supported(struct pt_common *common)
> +{
> +       return true;
> +}
> +#endif
> +#endif
> +
> +#ifndef pt_entry_make_write_dirty
> +static inline bool pt_entry_make_write_dirty(struct pt_state *pts)
> +{
> +       return false;
> +}
> +#endif
> +
> +/*
> + * Format supplies either:
> + *   pt_entry_oa - OA is at the start of a contiguous entry
> + * or
> + *   pt_item_oa  - OA is adjusted for every item in a contiguous entry
> + *
> + * Build the missing one
> + *
> + * The internal helper _pt_entry_oa_fast() allows generating
> + * an efficient pt_entry_oa_exact(), it doesn't care which
> + * option is selected.
> + */
> +#ifdef pt_entry_oa
> +static inline pt_oaddr_t pt_item_oa(const struct pt_state *pts)
> +{
> +       return pt_entry_oa(pts) |
> +              log2_mul(pts->index, pt_table_item_lg2sz(pts));
> +}
> +#define _pt_entry_oa_fast pt_entry_oa
> +#endif
> +
> +#ifdef pt_item_oa
> +static inline pt_oaddr_t pt_entry_oa(const struct pt_state *pts)
> +{
> +       return log2_set_mod(pt_item_oa(pts), 0,
> +                           pt_entry_num_contig_lg2(pts) +
> +                                   pt_table_item_lg2sz(pts));
> +}
> +#define _pt_entry_oa_fast pt_item_oa
> +#endif
> +
> +/*
> + * If not supplied by the format then use the constant
> + * PT_MAX_OUTPUT_ADDRESS_LG2.
> + */
> +#ifndef pt_max_oa_lg2
> +static inline unsigned int
> +pt_max_oa_lg2(const struct pt_common *common)
> +{
> +       return PT_MAX_OUTPUT_ADDRESS_LG2;
> +}
> +#endif
> +
> +#ifndef pt_has_system_page_size
> +static inline bool pt_has_system_page_size(const struct pt_common *common)
> +{
> +       return PT_GRANULE_LG2SZ == PAGE_SHIFT;
> +}
> +#endif
> +
> +/*
> + * If not supplied by the format then assume only one contiguous size determined
> + * by pt_contig_count_lg2()
> + */
> +#ifndef pt_possible_sizes
> +static inline unsigned short pt_contig_count_lg2(const struct pt_state *pts);
> +
> +/* Return a bitmap of possible leaf page sizes at this level */
> +static inline pt_vaddr_t pt_possible_sizes(const struct pt_state *pts)
> +{
> +       unsigned int isz_lg2 = pt_table_item_lg2sz(pts);
> +
> +       if (!pt_can_have_leaf(pts))
> +               return 0;
> +       return log2_to_int(isz_lg2) |
> +              log2_to_int(pt_contig_count_lg2(pts) + isz_lg2);
> +}
> +#endif
> +
> +/* If not supplied by the format then use 0. */
> +#ifndef pt_full_va_prefix
> +static inline pt_vaddr_t pt_full_va_prefix(const struct pt_common *common)
> +{
> +       return 0;
> +}
> +#endif
> +
> +/* If not supplied by the format then zero fill using PT_ITEM_WORD_SIZE */
> +#ifndef pt_clear_entries
> +static inline void pt_clear_entries64(struct pt_state *pts,
> +                                     unsigned int num_contig_lg2)
> +{
> +       u64 *tablep = pt_cur_table(pts, u64) + pts->index;
> +       u64 *end = tablep + log2_to_int(num_contig_lg2);
> +
> +       PT_WARN_ON(log2_mod(pts->index, num_contig_lg2));
> +       for (; tablep != end; tablep++)
> +               WRITE_ONCE(*tablep, 0);
> +}
> +
> +static inline void pt_clear_entries32(struct pt_state *pts,
> +                                     unsigned int num_contig_lg2)
> +{
> +       u32 *tablep = pt_cur_table(pts, u32) + pts->index;
> +       u32 *end = tablep + log2_to_int(num_contig_lg2);
> +
> +       PT_WARN_ON(log2_mod(pts->index, num_contig_lg2));
> +       for (; tablep != end; tablep++)
> +               WRITE_ONCE(*tablep, 0);
> +}
> +
> +static inline void pt_clear_entries(struct pt_state *pts,
> +                                   unsigned int num_contig_lg2)
> +{
> +       if (PT_ITEM_WORD_SIZE == sizeof(u32))
> +               pt_clear_entries32(pts, num_contig_lg2);
> +       else
> +               pt_clear_entries64(pts, num_contig_lg2);
> +}
> +#define pt_clear_entries pt_clear_entries
> +#endif
> +
> +/*
> + * Format can call in the pt_install_leaf_entry() to check the arguments are all
> + * aligned correctly.
> + */
> +static inline bool pt_check_install_leaf_args(struct pt_state *pts,
> +                                             pt_oaddr_t oa,
> +                                             unsigned int oasz_lg2)
> +{
> +       unsigned int isz_lg2 = pt_table_item_lg2sz(pts);
> +
> +       if (PT_WARN_ON(oalog2_mod(oa, oasz_lg2)))
> +               return false;
> +
> +#ifdef pt_possible_sizes
> +       if (PT_WARN_ON(isz_lg2 > oasz_lg2 ||
> +                      oasz_lg2 > isz_lg2 + pt_num_items_lg2(pts)))
> +               return false;
> +#else
> +       if (PT_WARN_ON(oasz_lg2 != isz_lg2 &&
> +                      oasz_lg2 != isz_lg2 + pt_contig_count_lg2(pts)))
> +               return false;
> +#endif
> +
> +       if (PT_WARN_ON(oalog2_mod(pts->index, oasz_lg2 - isz_lg2)))
> +               return false;
> +       return true;
> +}
> +
> +#endif
> diff --git a/drivers/iommu/generic_pt/pt_iter.h b/drivers/iommu/generic_pt/pt_iter.h
> new file mode 100644
> index 00000000000000..87f4a26c1a417a
> --- /dev/null
> +++ b/drivers/iommu/generic_pt/pt_iter.h
> @@ -0,0 +1,636 @@
> +/* SPDX-License-Identifier: GPL-2.0-only */
> +/*
> + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES
> + *
> + * Iterators for Generic Page Table
> + */
> +#ifndef __GENERIC_PT_PT_ITER_H
> +#define __GENERIC_PT_PT_ITER_H
> +
> +#include "pt_common.h"
> +
> +#include <linux/errno.h>
> +
> +/*
> + * Use to mangle symbols so that backtraces and the symbol table are
> + * understandable. Any non-inlined function should get mangled like this.
> + */
> +#define NS(fn) CONCATENATE(PTPFX, fn)
> +
> +/**
> + * pt_check_range() - Validate the range can be iterated
> + * @range: Range to validate
> + *
> + * Check that VA and last_va fall within the permitted range of VAs. If the
> + * format is using PT_FEAT_SIGN_EXTEND then this also checks the sign extension
> + * is correct.
> + */
> +static inline int pt_check_range(struct pt_range *range)
> +{
> +       pt_vaddr_t prefix;
> +
> +       PT_WARN_ON(!range->max_vasz_lg2);
> +
> +       if (pt_feature(range->common, PT_FEAT_SIGN_EXTEND)) {
> +               PT_WARN_ON(range->common->max_vasz_lg2 != range->max_vasz_lg2);
> +               prefix = fvalog2_div(range->va, range->max_vasz_lg2 - 1) ?
> +                                PT_VADDR_MAX :
> +                                0;
> +       } else {
> +               prefix = pt_full_va_prefix(range->common);
> +       }
> +
> +       if (!fvalog2_div_eq(range->va, prefix, range->max_vasz_lg2) ||
> +           !fvalog2_div_eq(range->last_va, prefix, range->max_vasz_lg2))
> +               return -ERANGE;
> +       return 0;
> +}
> +
> +/**
> + * pt_index_to_va() - Update range->va to the current pts->index
> + * @pts: Iteration State
> + *
> + * Adjust range->va to match the current index. This is done in a lazy manner
> + * since computing the VA takes several instructions and is rarely required.
> + */
> +static inline void pt_index_to_va(struct pt_state *pts)
> +{
> +       pt_vaddr_t lower_va;
> +
> +       lower_va = log2_mul(pts->index, pt_table_item_lg2sz(pts));
> +       pts->range->va = fvalog2_set_mod(pts->range->va, lower_va,
> +                                        pt_table_oa_lg2sz(pts));
> +}
> +
> +/*
> + * Add index_count_lg2 number of entries to pts's VA and index. The VA will be
> + * adjusted to the end of the contiguous block if it is currently in the middle.
> + */
> +static inline void _pt_advance(struct pt_state *pts,
> +                              unsigned int index_count_lg2)
> +{
> +       pts->index = log2_set_mod(pts->index + log2_to_int(index_count_lg2), 0,
> +                                 index_count_lg2);
> +}
> +
> +/**
> + * pt_entry_fully_covered() - Check if the item or entry is entirely contained
> + *                            within pts->range
> + * @pts: Iteration State
> + * @oasz_lg2: The size of the item to check, pt_table_item_lg2sz() or
> + *            pt_entry_oa_lg2sz()
> + *
> + * Returns: true if the item is fully enclosed by the pts->range.
> + */
> +static inline bool pt_entry_fully_covered(const struct pt_state *pts,
> +                                         unsigned int oasz_lg2)
> +{
> +       struct pt_range *range = pts->range;
> +
> +       /* Range begins at the start of the entry */
> +       if (log2_mod(pts->range->va, oasz_lg2))
> +               return false;
> +
> +       /* Range ends past the end of the entry */
> +       if (!log2_div_eq(range->va, range->last_va, oasz_lg2))
> +               return true;
> +
> +       /* Range ends at the end of the entry */
> +       return log2_mod_eq_max(range->last_va, oasz_lg2);
> +}
> +
> +/**
> + * pt_range_to_index() - Starting index for an iteration
> + * @pts: Iteration State
> + *
> + * Return: the starting index for the iteration in pts.
> + */
> +static inline unsigned int pt_range_to_index(const struct pt_state *pts)
> +{
> +       unsigned int isz_lg2 = pt_table_item_lg2sz(pts);
> +
> +       PT_WARN_ON(pts->level > pts->range->top_level);
> +       if (pts->range->top_level == pts->level)
> +               return log2_div(fvalog2_mod(pts->range->va,
> +                                           pts->range->max_vasz_lg2),
> +                               isz_lg2);
> +       return log2_mod(log2_div(pts->range->va, isz_lg2),
> +                       pt_num_items_lg2(pts));
> +}
> +
> +/**
> + * pt_range_to_end_index() - Ending index iteration
> + * @pts: Iteration State
> + *
> + * Return: the last index for the iteration in pts.
> + */
> +static inline unsigned int pt_range_to_end_index(const struct pt_state *pts)
> +{
> +       unsigned int isz_lg2 = pt_table_item_lg2sz(pts);
> +       struct pt_range *range = pts->range;
> +       unsigned int num_entries_lg2;
> +
> +       if (range->va == range->last_va)
> +               return pts->index + 1;
> +
> +       if (pts->range->top_level == pts->level)
> +               return log2_div(fvalog2_mod(pts->range->last_va,
> +                                           pts->range->max_vasz_lg2),
> +                               isz_lg2) +
> +                      1;
> +
> +       num_entries_lg2 = pt_num_items_lg2(pts);
> +
> +       /* last_va falls within this table */
> +       if (log2_div_eq(range->va, range->last_va, num_entries_lg2 + isz_lg2))
> +               return log2_mod(log2_div(pts->range->last_va, isz_lg2),
> +                               num_entries_lg2) +
> +                      1;
> +
> +       return log2_to_int(num_entries_lg2);
> +}
> +
> +static inline void _pt_iter_first(struct pt_state *pts)
> +{
> +       pts->index = pt_range_to_index(pts);
> +       pts->end_index = pt_range_to_end_index(pts);
> +       PT_WARN_ON(pts->index > pts->end_index);
> +}
> +
> +static inline bool _pt_iter_load(struct pt_state *pts)
> +{
> +       if (pts->index >= pts->end_index)
> +               return false;
> +       pt_load_entry(pts);
> +       return true;
> +}
> +
> +/**
> + * pt_next_entry() - Advance pts to the next entry
> + * @pts: Iteration State
> + *
> + * Update pts to go to the next index at this level. If pts is pointing at a
> + * contiguous entry then the index may advance my more than one.
> + */
> +static inline void pt_next_entry(struct pt_state *pts)
> +{
> +       if (pts->type == PT_ENTRY_OA &&
> +           !__builtin_constant_p(pt_entry_num_contig_lg2(pts) == 0))
> +               _pt_advance(pts, pt_entry_num_contig_lg2(pts));
> +       else
> +               pts->index++;
> +       pt_index_to_va(pts);
> +}
> +
> +/**
> + * for_each_pt_level_entry() - For loop wrapper over entries in the range
> + * @pts: Iteration State
> + *
> + * This is the basic iteration primitive. It iterates over all the entries in
> + * pts->range that fall within the pts's current table level. Each step does
> + * pt_load_entry(pts).
> + */
> +#define for_each_pt_level_entry(pts) \
> +       for (_pt_iter_first(pts); _pt_iter_load(pts); pt_next_entry(pts))
> +
> +/**
> + * pt_load_single_entry() - Version of pt_load_entry() usable within a walker
> + * @pts: Iteration State
> + *
> + * Alternative to for_each_pt_level_entry() if the walker function uses only a
> + * single entry.
> + */
> +static inline enum pt_entry_type pt_load_single_entry(struct pt_state *pts)
> +{
> +       pts->index = pt_range_to_index(pts);
> +       pt_load_entry(pts);
> +       return pts->type;
> +}
> +
> +static __always_inline struct pt_range _pt_top_range(struct pt_common *common,
> +                                                    uintptr_t top_of_table)
> +{
> +       struct pt_range range = {
> +               .common = common,
> +               .top_table =
> +                       (struct pt_table_p *)(top_of_table &
> +                                             ~(uintptr_t)PT_TOP_LEVEL_MASK),
> +               .top_level = top_of_table % (1 << PT_TOP_LEVEL_BITS),
> +       };
> +       struct pt_state pts = { .range = &range, .level = range.top_level };
> +       unsigned int max_vasz_lg2;
> +
> +       max_vasz_lg2 = common->max_vasz_lg2;
> +       if (pt_feature(common, PT_FEAT_DYNAMIC_TOP) &&
> +           pts.level != PT_MAX_TOP_LEVEL)
> +               max_vasz_lg2 = min_t(unsigned int, common->max_vasz_lg2,
> +                                    pt_num_items_lg2(&pts) +
> +                                            pt_table_item_lg2sz(&pts));
> +
> +       /*
> +        * The top range will default to the lower region only with sign extend.
> +        */
> +       range.max_vasz_lg2 = max_vasz_lg2;
> +       if (pt_feature(common, PT_FEAT_SIGN_EXTEND))
> +               max_vasz_lg2--;
> +
> +       range.va = fvalog2_set_mod(pt_full_va_prefix(common), 0, max_vasz_lg2);
> +       range.last_va =
> +               fvalog2_set_mod_max(pt_full_va_prefix(common), max_vasz_lg2);
> +       return range;
> +}
> +
> +/**
> + * pt_top_range() - Return a range that spans part of the top level
> + * @common: Table
> + *
> + * For PT_FEAT_SIGN_EXTEND this will return the lower range, and cover half the
> + * total page table. Otherwise it returns the entire page table.
> + */
> +static __always_inline struct pt_range pt_top_range(struct pt_common *common)
> +{
> +       /*
> +        * The top pointer can change without locking. We capture the value and
> +        * it's level here and are safe to walk it so long as both values are
> +        * captured without tearing.
> +        */
> +       return _pt_top_range(common, READ_ONCE(common->top_of_table));
> +}
> +
> +/**
> + * pt_all_range() - Return a range that spans the entire page table
> + * @common: Table
> + *
> + * The returned range spans the whole page table. Due to how PT_FEAT_SIGN_EXTEND
> + * is supported range->va and range->last_va will be incorrect during the
> + * iteration and must not be accessed.
> + */
> +static inline struct pt_range pt_all_range(struct pt_common *common)
> +{
> +       struct pt_range range = pt_top_range(common);
> +
> +       if (!pt_feature(common, PT_FEAT_SIGN_EXTEND))
> +               return range;
> +
> +       /*
> +        * Pretend the table is linear from 0 without a sign extension. This
> +        * generates the correct indexes for iteration.
> +        */
> +       range.last_va = fvalog2_set_mod_max(0, range.max_vasz_lg2);
> +       return range;
> +}
> +
> +/**
> + * pt_upper_range() - Return a range that spans part of the top level
> + * @common: Table
> + *
> + * For PT_FEAT_SIGN_EXTEND this will return the upper range, and cover half the
> + * total page table. Otherwise it returns the entire page table.
> + */
> +static inline struct pt_range pt_upper_range(struct pt_common *common)
> +{
> +       struct pt_range range = pt_top_range(common);
> +
> +       if (!pt_feature(common, PT_FEAT_SIGN_EXTEND))
> +               return range;
> +
> +       range.va = fvalog2_set_mod(PT_VADDR_MAX, 0, range.max_vasz_lg2 - 1);
> +       range.last_va = PT_VADDR_MAX;
> +       return range;
> +}
> +
> +/**
> + * pt_make_range() - Return a range that spans part of the table
> + * @common: Table
> + * @va: Start address
> + * @last_va: Last address
> + *
> + * The caller must validate the range with pt_check_range() before using it.
> + */
> +static __always_inline struct pt_range
> +pt_make_range(struct pt_common *common, pt_vaddr_t va, pt_vaddr_t last_va)
> +{
> +       struct pt_range range =
> +               _pt_top_range(common, READ_ONCE(common->top_of_table));
> +
> +       range.va = va;
> +       range.last_va = last_va;
> +
> +       return range;
> +}
> +
> +/*
> + * Span a slice of the table starting at a lower table level from an active
> + * walk.
> + */
> +static __always_inline struct pt_range
> +pt_make_child_range(const struct pt_range *parent, pt_vaddr_t va,
> +                   pt_vaddr_t last_va)
> +{
> +       struct pt_range range = *parent;
> +
> +       range.va = va;
> +       range.last_va = last_va;
> +
> +       PT_WARN_ON(last_va < va);
> +       PT_WARN_ON(pt_check_range(&range));
> +
> +       return range;
> +}
> +
> +/**
> + * pt_init() - Initialize a pt_state on the stack
> + * @range: Range pointer to embed in the state
> + * @level: Table level for the state
> + * @table: Pointer to the table memory at level
> + *
> + * Helper to initialize the on-stack pt_state from walker arguments.
> + */
> +static __always_inline struct pt_state
> +pt_init(struct pt_range *range, unsigned int level, struct pt_table_p *table)
> +{
> +       struct pt_state pts = {
> +               .range = range,
> +               .table = table,
> +               .level = level,
> +       };
> +       return pts;
> +}
> +
> +/**
> + * pt_init_top() - Initialize a pt_state on the stack
> + * @range: Range pointer to embed in the state
> + *
> + * The pt_state points to the top most level.
> + */
> +static __always_inline struct pt_state pt_init_top(struct pt_range *range)
> +{
> +       return pt_init(range, range->top_level, range->top_table);
> +}
> +
> +typedef int (*pt_level_fn_t)(struct pt_range *range, void *arg,
> +                            unsigned int level, struct pt_table_p *table);
> +
> +/**
> + * pt_descend() - Recursively invoke the walker for the lower level
> + * @pts: Iteration State
> + * @arg: Value to pass to the function
> + * @fn: Walker function to call
> + *
> + * pts must point to a table item. Invoke fn as a walker on the table
> + * pts points to.
> + */
> +static __always_inline int pt_descend(struct pt_state *pts, void *arg,
> +                                     pt_level_fn_t fn)
> +{
> +       int ret;
> +
> +       if (PT_WARN_ON(!pts->table_lower))
> +               return -EINVAL;
> +
> +       ret = (*fn)(pts->range, arg, pts->level - 1, pts->table_lower);
> +       return ret;
> +}
> +
> +/**
> + * pt_walk_range() - Walk over a VA range
> + * @range: Range pointer
> + * @fn: Walker function to call
> + * @arg: Value to pass to the function
> + *
> + * Walk over a VA range. The caller should have done a validity check, at
> + * least calling pt_check_range(), when building range. The walk will
> + * start at the top most table.
> + */
> +static __always_inline int pt_walk_range(struct pt_range *range,
> +                                        pt_level_fn_t fn, void *arg)
> +{
> +       return fn(range, arg, range->top_level, range->top_table);
> +}
> +
> +/*
> + * pt_walk_descend() - Recursively invoke the walker for a slice of a lower
> + *                     level
> + * @pts: Iteration State
> + * @va: Start address
> + * @last_va: Last address
> + * @fn: Walker function to call
> + * @arg: Value to pass to the function
> + *
> + * With pts pointing at a table item this will descend and over a slice of the
> + * lower table. The caller must ensure that va/last_va are within the table
> + * item. This creates a new walk and does not alter pts or pts->range.
> + */
> +static __always_inline int pt_walk_descend(const struct pt_state *pts,
> +                                          pt_vaddr_t va, pt_vaddr_t last_va,
> +                                          pt_level_fn_t fn, void *arg)
> +{
> +       struct pt_range range = pt_make_child_range(pts->range, va, last_va);
> +
> +       if (PT_WARN_ON(!pt_can_have_table(pts)) ||
> +           PT_WARN_ON(!pts->table_lower))
> +               return -EINVAL;
> +
> +       return fn(&range, arg, pts->level - 1, pts->table_lower);
> +}
> +
> +/*
> + * pt_walk_descend_all() - Recursively invoke the walker for a table item
> + * @parent_pts: Iteration State
> + * @fn: Walker function to call
> + * @arg: Value to pass to the function
> + *
> + * With pts pointing at a table item this will descend and over the entire lower
> + * table. This creates a new walk and does not alter pts or pts->range.
> + */
> +static __always_inline int
> +pt_walk_descend_all(const struct pt_state *parent_pts, pt_level_fn_t fn,
> +                   void *arg)
> +{
> +       unsigned int isz_lg2 = pt_table_item_lg2sz(parent_pts);
> +
> +       return pt_walk_descend(parent_pts,
> +                              log2_set_mod(parent_pts->range->va, 0, isz_lg2),
> +                              log2_set_mod_max(parent_pts->range->va, isz_lg2),
> +                              fn, arg);
> +}
> +
> +/**
> + * pt_range_slice() - Return a range that spans indexes
> + * @pts: Iteration State
> + * @start_index: Starting index within pts
> + * @end_index: Ending index within pts
> + *
> + * Create a range than spans an index range of the current table level
> + * pt_state points at.
> + */
> +static inline struct pt_range pt_range_slice(const struct pt_state *pts,
> +                                            unsigned int start_index,
> +                                            unsigned int end_index)
> +{
> +       unsigned int table_lg2sz = pt_table_oa_lg2sz(pts);
> +       pt_vaddr_t last_va;
> +       pt_vaddr_t va;
> +
> +       va = fvalog2_set_mod(pts->range->va,
> +                            log2_mul(start_index, pt_table_item_lg2sz(pts)),
> +                            table_lg2sz);
> +       last_va = fvalog2_set_mod(
> +               pts->range->va,
> +               log2_mul(end_index, pt_table_item_lg2sz(pts)) - 1, table_lg2sz);
> +       return pt_make_child_range(pts->range, va, last_va);
> +}
> +
> +/**
> + * pt_top_memsize_lg2()
> + * @common: Table
> + * @top_of_table: Top of table value from _pt_top_set()
> + *
> + * Compute the allocation size of the top table. For PT_FEAT_DYNAMIC_TOP this
> + * will compute the top size assuming the table will grow.
> + */
> +static inline unsigned int pt_top_memsize_lg2(struct pt_common *common,
> +                                             uintptr_t top_of_table)
> +{
> +       struct pt_range range = _pt_top_range(common, top_of_table);
> +       struct pt_state pts = pt_init_top(&range);
> +       unsigned int num_items_lg2;
> +
> +       num_items_lg2 = common->max_vasz_lg2 - pt_table_item_lg2sz(&pts);
> +       if (range.top_level != PT_MAX_TOP_LEVEL &&
> +           pt_feature(common, PT_FEAT_DYNAMIC_TOP))
> +               num_items_lg2 = min(num_items_lg2, pt_num_items_lg2(&pts));
> +
> +       /* Round up the allocation size to the minimum alignment */
> +       return max(ffs_t(u64, PT_TOP_PHYS_MASK),
> +                  num_items_lg2 + ilog2(PT_ITEM_WORD_SIZE));
> +}
> +
> +/**
> + * pt_compute_best_pgsize() - Determine the best page size for leaf entries
> + * @pgsz_bitmap: Permitted page sizes
> + * @va: Starting virtual address for the leaf entry
> + * @last_va: Last virtual address for the leaf entry, sets the max page size
> + * @oa: Starting output address for the leaf entry
> + *
> + * Compute the largest page size for va, last_va, and oa together and return it
> + * in lg2. The largest page size depends on the format's supported page sizes at
> + * this level, and the relative alignment of the VA and OA addresses. 0 means
> + * the OA cannot be stored with the provided pgsz_bitmap.
> + */
> +static inline unsigned int pt_compute_best_pgsize(pt_vaddr_t pgsz_bitmap,
> +                                                 pt_vaddr_t va,
> +                                                 pt_vaddr_t last_va,
> +                                                 pt_oaddr_t oa)
> +{
> +       unsigned int best_pgsz_lg2;
> +       unsigned int pgsz_lg2;
> +       pt_vaddr_t len = last_va - va + 1;
> +       pt_vaddr_t mask;
> +
> +       if (PT_WARN_ON(va >= last_va))
> +               return 0;
> +
> +       /*
> +        * Given a VA/OA pair the best page size is the largest page size
> +        * where:
> +        *
> +        * 1) VA and OA start at the page. Bitwise this is the count of least
> +        *    significant 0 bits.
> +        *    This also implies that last_va/oa has the same prefix as va/oa.
> +        */
> +       mask = va | oa;
> +
> +       /*
> +        * 2) The page size is not larger than the last_va (length). Since page
> +        *    sizes are always power of two this can't be larger than the
> +        *    largest power of two factor of the length.
> +        */
> +       mask |= log2_to_int(vafls(len) - 1);
> +
> +       best_pgsz_lg2 = vaffs(mask);
> +
> +       /* Choose the highest bit <= best_pgsz_lg2 */
> +       if (best_pgsz_lg2 < PT_VADDR_MAX_LG2 - 1)
> +               pgsz_bitmap = log2_mod(pgsz_bitmap, best_pgsz_lg2 + 1);
> +
> +       pgsz_lg2 = vafls(pgsz_bitmap);
> +       if (!pgsz_lg2)
> +               return 0;
> +
> +       pgsz_lg2--;
> +
> +       PT_WARN_ON(log2_mod(va, pgsz_lg2) != 0);
> +       PT_WARN_ON(oalog2_mod(oa, pgsz_lg2) != 0);
> +       PT_WARN_ON(va + log2_to_int(pgsz_lg2) - 1 > last_va);
> +       PT_WARN_ON(!log2_div_eq(va, va + log2_to_int(pgsz_lg2) - 1, pgsz_lg2));
> +       PT_WARN_ON(
> +               !oalog2_div_eq(oa, oa + log2_to_int(pgsz_lg2) - 1, pgsz_lg2));
> +       return pgsz_lg2;
> +}
> +
> +#define _PT_MAKE_CALL_LEVEL(fn)                                          \
> +       static __always_inline int fn(struct pt_range *range, void *arg, \
> +                                     unsigned int level,                \
> +                                     struct pt_table_p *table)          \
> +       {                                                                \
> +               static_assert(PT_MAX_TOP_LEVEL <= 5);                    \
> +               if (level == 0)                                          \
> +                       return CONCATENATE(fn, 0)(range, arg, 0, table); \
> +               if (level == 1 || PT_MAX_TOP_LEVEL == 1)                 \
> +                       return CONCATENATE(fn, 1)(range, arg, 1, table); \
> +               if (level == 2 || PT_MAX_TOP_LEVEL == 2)                 \
> +                       return CONCATENATE(fn, 2)(range, arg, 2, table); \
> +               if (level == 3 || PT_MAX_TOP_LEVEL == 3)                 \
> +                       return CONCATENATE(fn, 3)(range, arg, 3, table); \
> +               if (level == 4 || PT_MAX_TOP_LEVEL == 4)                 \
> +                       return CONCATENATE(fn, 4)(range, arg, 4, table); \
> +               return CONCATENATE(fn, 5)(range, arg, 5, table);         \
> +       }
> +
> +static inline int __pt_make_level_fn_err(struct pt_range *range, void *arg,
> +                                        unsigned int unused_level,
> +                                        struct pt_table_p *table)
> +{
> +       static_assert(PT_MAX_TOP_LEVEL <= 5);
> +       return -EPROTOTYPE;
> +}
> +
> +#define __PT_MAKE_LEVEL_FN(fn, level, descend_fn, do_fn)            \
> +       static inline int fn(struct pt_range *range, void *arg,     \
> +                            unsigned int unused_level,             \
> +                            struct pt_table_p *table)              \
> +       {                                                           \
> +               return do_fn(range, arg, level, table, descend_fn); \
> +       }
> +
> +/**
> + * PT_MAKE_LEVELS() - Build an unwound walker
> + * @fn: Name of the walker function
> + * @do_fn: Function to call at each level
> + *
> + * This builds a function call tree that can be fully inlined.
> + * The caller must provide a function body in an __always_inline function::
> + *
> + *  static __always_inline int do(struct pt_range *range, void *arg,
> + *         unsigned int level, struct pt_table_p *table,
> + *         pt_level_fn_t descend_fn)
> + *
> + * An inline function will be created for each table level that calls do_fn with
> + * a compile time constant for level and a pointer to the next lower function.
> + * This generates an optimally inlined walk where each of the functions sees a
> + * constant level and can codegen the exact constants/etc for that level.
> + *
> + * Note this can produce a lot of code!
> + */
> +#define PT_MAKE_LEVELS(fn, do_fn)                                             \
> +       __PT_MAKE_LEVEL_FN(CONCATENATE(fn, 0), 0, __pt_make_level_fn_err,     \
> +                          do_fn);                                            \
> +       __PT_MAKE_LEVEL_FN(CONCATENATE(fn, 1), 1, CONCATENATE(fn, 0), do_fn); \
> +       __PT_MAKE_LEVEL_FN(CONCATENATE(fn, 2), 2, CONCATENATE(fn, 1), do_fn); \
> +       __PT_MAKE_LEVEL_FN(CONCATENATE(fn, 3), 3, CONCATENATE(fn, 2), do_fn); \
> +       __PT_MAKE_LEVEL_FN(CONCATENATE(fn, 4), 4, CONCATENATE(fn, 3), do_fn); \
> +       __PT_MAKE_LEVEL_FN(CONCATENATE(fn, 5), 5, CONCATENATE(fn, 4), do_fn); \
> +       _PT_MAKE_CALL_LEVEL(fn)
> +
> +#endif
> diff --git a/drivers/iommu/generic_pt/pt_log2.h b/drivers/iommu/generic_pt/pt_log2.h
> new file mode 100644
> index 00000000000000..6dbbed11923853
> --- /dev/null
> +++ b/drivers/iommu/generic_pt/pt_log2.h
> @@ -0,0 +1,122 @@
> +/* SPDX-License-Identifier: GPL-2.0-only */
> +/*
> + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES
> + *
> + * Helper macros for working with log2 values
> + *
> + */
> +#ifndef __GENERIC_PT_LOG2_H
> +#define __GENERIC_PT_LOG2_H
> +#include <linux/bitops.h>
> +#include <linux/limits.h>
> +
> +/* Compute a */
> +#define log2_to_int_t(type, a_lg2) ((type)(((type)1) << (a_lg2)))
> +static_assert(log2_to_int_t(unsigned int, 0) == 1);
> +
> +/* Compute a - 1 (aka all low bits set) */
> +#define log2_to_max_int_t(type, a_lg2) ((type)(log2_to_int_t(type, a_lg2) - 1))
> +
> +/* Compute a / b */
> +#define log2_div_t(type, a, b_lg2) ((type)(((type)a) >> (b_lg2)))
> +static_assert(log2_div_t(unsigned int, 4, 2) == 1);
> +
> +/*
> + * Compute:
> + *   a / c == b / c
> + * aka the high bits are equal
> + */
> +#define log2_div_eq_t(type, a, b, c_lg2) \
> +       (log2_div_t(type, (a) ^ (b), c_lg2) == 0)
> +static_assert(log2_div_eq_t(unsigned int, 1, 1, 2));
> +
> +/* Compute a % b */
> +#define log2_mod_t(type, a, b_lg2) \
> +       ((type)(((type)a) & log2_to_max_int_t(type, b_lg2)))
> +static_assert(log2_mod_t(unsigned int, 1, 2) == 1);
> +
> +/*
> + * Compute:
> + *   a % b == b - 1
> + * aka the low bits are all 1s
> + */
> +#define log2_mod_eq_max_t(type, a, b_lg2) \
> +       (log2_mod_t(type, a, b_lg2) == log2_to_max_int_t(type, b_lg2))
> +static_assert(log2_mod_eq_max_t(unsigned int, 3, 2));
> +
> +/*
> + * Return a value such that:
> + *    a / b == ret / b
> + *    ret % b == val
> + * aka set the low bits to val. val must be < b
> + */
> +#define log2_set_mod_t(type, a, val, b_lg2) \
> +       ((((type)(a)) & (~log2_to_max_int_t(type, b_lg2))) | ((type)(val)))
> +static_assert(log2_set_mod_t(unsigned int, 3, 1, 2) == 1);
> +
> +/* Return a value such that:
> + *    a / b == ret / b
> + *    ret % b == b - 1
> + * aka set the low bits to all 1s
> + */
> +#define log2_set_mod_max_t(type, a, b_lg2) \
> +       (((type)(a)) | log2_to_max_int_t(type, b_lg2))
> +static_assert(log2_set_mod_max_t(unsigned int, 2, 2) == 3);
> +
> +/* Compute a * b */
> +#define log2_mul_t(type, a, b_lg2) ((type)(((type)a) << (b_lg2)))
> +static_assert(log2_mul_t(unsigned int, 2, 2) == 8);
> +
> +#define _dispatch_sz(type, fn, a) \
> +       (sizeof(type) == 4 ? fn##32((u32)a) : fn##64(a))
> +
> +/*
> + * Return the highest value such that:
> + *    fls_t(u32, 0) == 0
> + *    fls_t(u3, 1) == 1
> + *    a >= log2_to_int(ret - 1)
> + * aka find last set bit
> + */
> +static inline unsigned int fls32(u32 a)
> +{
> +       return fls(a);
> +}
> +#define fls_t(type, a) _dispatch_sz(type, fls, a)
> +
> +/*
> + * Return the highest value such that:
> + *    ffs_t(u32, 0) == UNDEFINED
> + *    ffs_t(u32, 1) == 0
> + *    log_mod(a, ret) == 0
> + * aka find first set bit
> + */
> +static inline unsigned int __ffs32(u32 a)
> +{
> +       return __ffs(a);
> +}
> +#define ffs_t(type, a) _dispatch_sz(type, __ffs, a)
> +
> +/*
> + * Return the highest value such that:
> + *    ffz_t(u32, U32_MAX) == UNDEFINED
> + *    ffz_t(u32, 0) == 0
> + *    ffz_t(u32, 1) == 1
> + *    log_mod(a, ret) == log_to_max_int(ret)
> + * aka find first zero bit
> + */
> +static inline unsigned int ffz32(u32 a)
> +{
> +       return ffz(a);
> +}
> +static inline unsigned int ffz64(u64 a)
> +{
> +       if (sizeof(u64) == sizeof(unsigned long))
> +               return ffz(a);
> +
> +       if ((u32)a == U32_MAX)
> +               return ffz32(a >> 32) + 32;
> +       return ffz32(a);
> +}
> +#define ffz_t(type, a) _dispatch_sz(type, ffz, a)
> +
> +#endif
> diff --git a/include/linux/generic_pt/common.h b/include/linux/generic_pt/common.h
> new file mode 100644
> index 00000000000000..e69a75511313cb
> --- /dev/null
> +++ b/include/linux/generic_pt/common.h
> @@ -0,0 +1,135 @@
> +/* SPDX-License-Identifier: GPL-2.0-only */
> +/*
> + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES
> + */
> +#ifndef __GENERIC_PT_COMMON_H
> +#define __GENERIC_PT_COMMON_H
> +
> +#include <linux/types.h>
> +#include <linux/build_bug.h>
> +#include <linux/bits.h>
> +
> +/**
> + * DOC: Generic Radix Page Table
> + *
> + * Generic Radix Page Table is a set of functions and helpers to efficiently
> + * parse radix style page tables typically seen in HW implementations. The
> + * interface is built to deliver similar code generation as the mm's pte/pmd/etc
> + * system by fully inlining the exact code required to handle each table level.
> + *
> + * Like the mm subsystem each format contributes its parsing implementation
> + * under common names and the common code implements the required algorithms.
> + *
> + * The system is divided into three logical levels:
> + *
> + *  - The page table format and its manipulation functions
> + *  - Generic helpers to give a consistent API regardless of underlying format
> + *  - An algorithm implementation (e.g. IOMMU/DRM/KVM/MM)
> + *
> + * Multiple implementations are supported. The intention is to have the generic
> + * format code be re-usable for whatever specialized implementation is required.
> + * The generic code is solely about the format of the radix tree; it does not
> + * include memory allocation or higher level decisions that are left for the
> + * implementation.
> + *
> + * The generic framework supports a superset of functions across many HW
> + * implementations:
> + *
> + *  - Entries comprised of contiguous blocks of IO PTEs for larger page sizes
> + *  - Multi-level tables, up to 6 levels. Runtime selected top level
> + *  - Runtime variable table level size (ARM's concatenated tables)
> + *  - Expandable top level allowing dynamic sizing of table levels
> + *  - Optional leaf entries at any level
> + *  - 32-bit/64-bit virtual and output addresses, using every address bit
> + *  - Dirty tracking
> + *  - Sign extended addressing
> + */
> +
> +/**
> + * struct pt_common - struct for all page table implementations
> + */
> +struct pt_common {
> +       /**
> +        * @top_of_table: Encodes the table top pointer and the top level in a
> +        * single value. Must use READ_ONCE/WRITE_ONCE to access it. The lower
> +        * bits of the aligned table pointer are used for the level.
> +        */
> +       uintptr_t top_of_table;
> +       /**
> +        * @max_oasz_lg2: Maximum number of bits the OA can contain. Upper bits
> +        * must be zero. This may be less than what the page table format
> +        * supports, but must not be more.
> +        */
> +       u8 max_oasz_lg2;
> +       /**
> +        * @max_vasz_lg2: Maximum number of bits the VA can contain. Upper bits
> +        * are 0 or 1 depending on pt_full_va_prefix(). This may be less than
> +        * what the page table format supports, but must not be more. When
> +        * PT_FEAT_DYNAMIC_TOP is set this reflects the maximum VA capability.
> +        */
> +       u8 max_vasz_lg2;
> +       /**
> +        * @features: Bitmap of `enum pt_features`
> +        */
> +       unsigned int features;
> +};
> +
> +/* Encoding parameters for top_of_table */
> +enum {
> +       PT_TOP_LEVEL_BITS = 3,
> +       PT_TOP_LEVEL_MASK = GENMASK(PT_TOP_LEVEL_BITS - 1, 0),
> +};
> +
> +/**
> + * enum pt_features - Features turned on in the table. Each symbol is a bit
> + * position.
> + */
> +enum pt_features {
> +       /**
> +        * @PT_FEAT_FULL_VA: The table can span the full VA range from 0 to
> +        * PT_VADDR_MAX.
> +        */
> +       PT_FEAT_FULL_VA,
> +       /**
> +        * @PT_FEAT_DYNAMIC_TOP: The table's top level can be increased
> +        * dynamically during map. This requires HW support for atomically
> +        * setting both the table top pointer and the starting table level.
> +        */
> +       PT_FEAT_DYNAMIC_TOP,
> +       /**
> +        * @PT_FEAT_SIGN_EXTEND: The top most bit of the valid VA range sign
> +        * extends up to the full pt_vaddr_t. This divides the page table into
> +        * three VA ranges::
> +        *
> +        *   0         -> 2^N - 1             Lower
> +        *   2^N       -> (MAX - 2^N - 1)     Non-Canonical
> +        *   MAX - 2^N -> MAX                 Upper
> +        *
> +        * In this mode pt_common::max_vasz_lg2 includes the sign bit and the
> +        * upper bits that don't fall within the translation are just validated.
> +        *
> +        * If not set there is no sign extension and valid VA goes from 0 to 2^N
> +        * - 1.
> +        */
> +       PT_FEAT_SIGN_EXTEND,
> +       /**
> +        * @PT_FEAT_FLUSH_RANGE: IOTLB maintenance is done by flushing IOVA
> +        * ranges which will clean out any walk cache or any IOPTE fully
> +        * contained by the range. The optimization objective is to minimize the
> +        * number of flushes even if ranges include IOVA gaps that do not need
> +        * to be flushed.
> +        */
> +       PT_FEAT_FLUSH_RANGE,
> +       /**
> +        * @PT_FEAT_FLUSH_RANGE_NO_GAPS: Like PT_FEAT_FLUSH_RANGE except that
> +        * the optimization objective is to only flush IOVA that has been
> +        * changed. This mode is suitable for cases like hypervisor shadowing
> +        * where flushing unchanged ranges may cause the hypervisor to reparse
> +        * significant amount of page table.
> +        */
> +       PT_FEAT_FLUSH_RANGE_NO_GAPS,
> +       /* private: */
> +       PT_FEAT_FMT_START,
> +};
> +
> +#endif
> --
> 2.43.0
>
>

Reviewed-by: Samiullah Khawaja <skhawaja at google.com>



More information about the linux-riscv mailing list