[PATCH 09/11] RISC-V: drivers/iommu/riscv: Add SVA with PASID/ATS/PRI support.

Zong Li zong.li at sifive.com
Mon Jul 31 02:04:12 PDT 2023


On Thu, Jul 20, 2023 at 3:35 AM Tomasz Jeznach <tjeznach at rivosinc.com> wrote:
>
> Introduces SVA (Shared Virtual Address) for RISC-V IOMMU, with
> ATS/PRI services for capable devices.
>
> Co-developed-by: Sebastien Boeuf <seb at rivosinc.com>
> Signed-off-by: Sebastien Boeuf <seb at rivosinc.com>
> Signed-off-by: Tomasz Jeznach <tjeznach at rivosinc.com>
> ---
>  drivers/iommu/riscv/iommu.c | 601 +++++++++++++++++++++++++++++++++++-
>  drivers/iommu/riscv/iommu.h |  14 +
>  2 files changed, 610 insertions(+), 5 deletions(-)
>
> diff --git a/drivers/iommu/riscv/iommu.c b/drivers/iommu/riscv/iommu.c
> index 2ef6952a2109..6042c35be3ca 100644
> --- a/drivers/iommu/riscv/iommu.c
> +++ b/drivers/iommu/riscv/iommu.c
> @@ -384,6 +384,89 @@ static inline void riscv_iommu_cmd_iodir_set_did(struct riscv_iommu_command *cmd
>             FIELD_PREP(RISCV_IOMMU_CMD_IODIR_DID, devid) | RISCV_IOMMU_CMD_IODIR_DV;
>  }
>
> +static inline void riscv_iommu_cmd_iodir_set_pid(struct riscv_iommu_command *cmd,
> +                                                unsigned pasid)
> +{
> +       cmd->dword0 |= FIELD_PREP(RISCV_IOMMU_CMD_IODIR_PID, pasid);
> +}
> +
> +static void riscv_iommu_cmd_ats_inval(struct riscv_iommu_command *cmd)
> +{
> +       cmd->dword0 = FIELD_PREP(RISCV_IOMMU_CMD_OPCODE, RISCV_IOMMU_CMD_ATS_OPCODE) |
> +           FIELD_PREP(RISCV_IOMMU_CMD_FUNC, RISCV_IOMMU_CMD_ATS_FUNC_INVAL);
> +       cmd->dword1 = 0;
> +}
> +
> +static inline void riscv_iommu_cmd_ats_prgr(struct riscv_iommu_command *cmd)
> +{
> +       cmd->dword0 = FIELD_PREP(RISCV_IOMMU_CMD_OPCODE, RISCV_IOMMU_CMD_ATS_OPCODE) |
> +           FIELD_PREP(RISCV_IOMMU_CMD_FUNC, RISCV_IOMMU_CMD_ATS_FUNC_PRGR);
> +       cmd->dword1 = 0;
> +}
> +
> +static void riscv_iommu_cmd_ats_set_rid(struct riscv_iommu_command *cmd, u32 rid)
> +{
> +       cmd->dword0 |= FIELD_PREP(RISCV_IOMMU_CMD_ATS_RID, rid);
> +}
> +
> +static void riscv_iommu_cmd_ats_set_pid(struct riscv_iommu_command *cmd, u32 pid)
> +{
> +       cmd->dword0 |= FIELD_PREP(RISCV_IOMMU_CMD_ATS_PID, pid) | RISCV_IOMMU_CMD_ATS_PV;
> +}
> +
> +static void riscv_iommu_cmd_ats_set_dseg(struct riscv_iommu_command *cmd, u8 seg)
> +{
> +       cmd->dword0 |= FIELD_PREP(RISCV_IOMMU_CMD_ATS_DSEG, seg) | RISCV_IOMMU_CMD_ATS_DSV;
> +}
> +
> +static void riscv_iommu_cmd_ats_set_payload(struct riscv_iommu_command *cmd, u64 payload)
> +{
> +       cmd->dword1 = payload;
> +}
> +
> +/* Prepare the ATS invalidation payload */
> +static unsigned long riscv_iommu_ats_inval_payload(unsigned long start,
> +                                                  unsigned long end, bool global_inv)
> +{
> +       size_t len = end - start + 1;
> +       unsigned long payload = 0;
> +
> +       /*
> +        * PCI Express specification
> +        * Section 10.2.3.2 Translation Range Size (S) Field
> +        */
> +       if (len < PAGE_SIZE)
> +               len = PAGE_SIZE;
> +       else
> +               len = __roundup_pow_of_two(len);
> +
> +       payload = (start & ~(len - 1)) | (((len - 1) >> 12) << 11);
> +
> +       if (global_inv)
> +               payload |= RISCV_IOMMU_CMD_ATS_INVAL_G;
> +
> +       return payload;
> +}
> +
> +/* Prepare the ATS invalidation payload for all translations to be invalidated. */
> +static unsigned long riscv_iommu_ats_inval_all_payload(bool global_inv)
> +{
> +       unsigned long payload = GENMASK_ULL(62, 11);
> +
> +       if (global_inv)
> +               payload |= RISCV_IOMMU_CMD_ATS_INVAL_G;
> +
> +       return payload;
> +}
> +
> +/* Prepare the ATS "Page Request Group Response" payload */
> +static unsigned long riscv_iommu_ats_prgr_payload(u16 dest_id, u8 resp_code, u16 grp_idx)
> +{
> +       return FIELD_PREP(RISCV_IOMMU_CMD_ATS_PRGR_DST_ID, dest_id) |
> +           FIELD_PREP(RISCV_IOMMU_CMD_ATS_PRGR_RESP_CODE, resp_code) |
> +           FIELD_PREP(RISCV_IOMMU_CMD_ATS_PRGR_PRG_INDEX, grp_idx);
> +}
> +
>  /* TODO: Convert into lock-less MPSC implementation. */
>  static bool riscv_iommu_post_sync(struct riscv_iommu_device *iommu,
>                                   struct riscv_iommu_command *cmd, bool sync)
> @@ -460,6 +543,16 @@ static bool riscv_iommu_iodir_inv_devid(struct riscv_iommu_device *iommu, unsign
>         return riscv_iommu_post(iommu, &cmd);
>  }
>
> +static bool riscv_iommu_iodir_inv_pasid(struct riscv_iommu_device *iommu,
> +                                       unsigned devid, unsigned pasid)
> +{
> +       struct riscv_iommu_command cmd;
> +       riscv_iommu_cmd_iodir_inval_pdt(&cmd);
> +       riscv_iommu_cmd_iodir_set_did(&cmd, devid);
> +       riscv_iommu_cmd_iodir_set_pid(&cmd, pasid);
> +       return riscv_iommu_post(iommu, &cmd);
> +}
> +
>  static bool riscv_iommu_iofence_sync(struct riscv_iommu_device *iommu)
>  {
>         struct riscv_iommu_command cmd;
> @@ -467,6 +560,62 @@ static bool riscv_iommu_iofence_sync(struct riscv_iommu_device *iommu)
>         return riscv_iommu_post_sync(iommu, &cmd, true);
>  }
>
> +static void riscv_iommu_mm_invalidate(struct mmu_notifier *mn,
> +                                     struct mm_struct *mm, unsigned long start,
> +                                     unsigned long end)
> +{
> +       struct riscv_iommu_command cmd;
> +       struct riscv_iommu_endpoint *endpoint;
> +       struct riscv_iommu_domain *domain =
> +           container_of(mn, struct riscv_iommu_domain, mn);
> +       unsigned long iova;
> +       /*
> +        * The mm_types defines vm_end as the first byte after the end address,
> +        * different from IOMMU subsystem using the last address of an address
> +        * range. So do a simple translation here by updating what end means.
> +        */
> +       unsigned long payload = riscv_iommu_ats_inval_payload(start, end - 1, true);
> +
> +       riscv_iommu_cmd_inval_vma(&cmd);
> +       riscv_iommu_cmd_inval_set_gscid(&cmd, 0);
> +       riscv_iommu_cmd_inval_set_pscid(&cmd, domain->pscid);
> +       if (end > start) {
> +               /* Cover only the range that is needed */
> +               for (iova = start; iova < end; iova += PAGE_SIZE) {
> +                       riscv_iommu_cmd_inval_set_addr(&cmd, iova);
> +                       riscv_iommu_post(domain->iommu, &cmd);
> +               }
> +       } else {
> +               riscv_iommu_post(domain->iommu, &cmd);
> +       }
> +
> +       riscv_iommu_iofence_sync(domain->iommu);
> +
> +       /* ATS invalidation for every device and for specific translation range. */
> +       list_for_each_entry(endpoint, &domain->endpoints, domain) {
> +               if (!endpoint->pasid_enabled)
> +                       continue;
> +
> +               riscv_iommu_cmd_ats_inval(&cmd);
> +               riscv_iommu_cmd_ats_set_dseg(&cmd, endpoint->domid);
> +               riscv_iommu_cmd_ats_set_rid(&cmd, endpoint->devid);
> +               riscv_iommu_cmd_ats_set_pid(&cmd, domain->pasid);
> +               riscv_iommu_cmd_ats_set_payload(&cmd, payload);
> +               riscv_iommu_post(domain->iommu, &cmd);
> +       }
> +       riscv_iommu_iofence_sync(domain->iommu);
> +}
> +
> +static void riscv_iommu_mm_release(struct mmu_notifier *mn, struct mm_struct *mm)
> +{
> +       /* TODO: removed from notifier, cleanup PSCID mapping, flush IOTLB */
> +}
> +
> +static const struct mmu_notifier_ops riscv_iommu_mmuops = {
> +       .release = riscv_iommu_mm_release,
> +       .invalidate_range = riscv_iommu_mm_invalidate,
> +};
> +
>  /* Command queue primary interrupt handler */
>  static irqreturn_t riscv_iommu_cmdq_irq_check(int irq, void *data)
>  {
> @@ -608,6 +757,128 @@ static void riscv_iommu_add_device(struct riscv_iommu_device *iommu, struct devi
>         mutex_unlock(&iommu->eps_mutex);
>  }
>
> +/*
> + * Get device reference based on device identifier (requester id).
> + * Decrement reference count with put_device() call.
> + */
> +static struct device *riscv_iommu_get_device(struct riscv_iommu_device *iommu,
> +                                            unsigned devid)
> +{
> +       struct rb_node *node;
> +       struct riscv_iommu_endpoint *ep;
> +       struct device *dev = NULL;
> +
> +       mutex_lock(&iommu->eps_mutex);
> +
> +       node = iommu->eps.rb_node;
> +       while (node && !dev) {
> +               ep = rb_entry(node, struct riscv_iommu_endpoint, node);
> +               if (ep->devid < devid)
> +                       node = node->rb_right;
> +               else if (ep->devid > devid)
> +                       node = node->rb_left;
> +               else
> +                       dev = get_device(ep->dev);
> +       }
> +
> +       mutex_unlock(&iommu->eps_mutex);
> +
> +       return dev;
> +}
> +
> +static int riscv_iommu_ats_prgr(struct device *dev, struct iommu_page_response *msg)
> +{
> +       struct riscv_iommu_endpoint *ep = dev_iommu_priv_get(dev);
> +       struct riscv_iommu_command cmd;
> +       u8 resp_code;
> +       unsigned long payload;
> +
> +       switch (msg->code) {
> +       case IOMMU_PAGE_RESP_SUCCESS:
> +               resp_code = 0b0000;
> +               break;
> +       case IOMMU_PAGE_RESP_INVALID:
> +               resp_code = 0b0001;
> +               break;
> +       case IOMMU_PAGE_RESP_FAILURE:
> +               resp_code = 0b1111;
> +               break;
> +       }
> +       payload = riscv_iommu_ats_prgr_payload(ep->devid, resp_code, msg->grpid);
> +
> +       /* ATS Page Request Group Response */
> +       riscv_iommu_cmd_ats_prgr(&cmd);
> +       riscv_iommu_cmd_ats_set_dseg(&cmd, ep->domid);
> +       riscv_iommu_cmd_ats_set_rid(&cmd, ep->devid);
> +       if (msg->flags & IOMMU_PAGE_RESP_PASID_VALID)
> +               riscv_iommu_cmd_ats_set_pid(&cmd, msg->pasid);
> +       riscv_iommu_cmd_ats_set_payload(&cmd, payload);
> +       riscv_iommu_post(ep->iommu, &cmd);
> +
> +       return 0;
> +}
> +
> +static void riscv_iommu_page_request(struct riscv_iommu_device *iommu,
> +                                    struct riscv_iommu_pq_record *req)
> +{
> +       struct iommu_fault_event event = { 0 };
> +       struct iommu_fault_page_request *prm = &event.fault.prm;
> +       int ret;
> +       struct device *dev;
> +       unsigned devid = FIELD_GET(RISCV_IOMMU_PREQ_HDR_DID, req->hdr);
> +
> +       /* Ignore PGR Stop marker. */
> +       if ((req->payload & RISCV_IOMMU_PREQ_PAYLOAD_M) == RISCV_IOMMU_PREQ_PAYLOAD_L)
> +               return;
> +
> +       dev = riscv_iommu_get_device(iommu, devid);
> +       if (!dev) {
> +               /* TODO: Handle invalid page request */
> +               return;
> +       }
> +
> +       event.fault.type = IOMMU_FAULT_PAGE_REQ;
> +
> +       if (req->payload & RISCV_IOMMU_PREQ_PAYLOAD_L)
> +               prm->flags |= IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE;
> +       if (req->payload & RISCV_IOMMU_PREQ_PAYLOAD_W)
> +               prm->perm |= IOMMU_FAULT_PERM_WRITE;
> +       if (req->payload & RISCV_IOMMU_PREQ_PAYLOAD_R)
> +               prm->perm |= IOMMU_FAULT_PERM_READ;
> +
> +       prm->grpid = FIELD_GET(RISCV_IOMMU_PREQ_PRG_INDEX, req->payload);
> +       prm->addr = FIELD_GET(RISCV_IOMMU_PREQ_UADDR, req->payload) << PAGE_SHIFT;
> +
> +       if (req->hdr & RISCV_IOMMU_PREQ_HDR_PV) {
> +               prm->flags |= IOMMU_FAULT_PAGE_REQUEST_PASID_VALID;
> +               /* TODO: where to find this bit */
> +               prm->flags |= IOMMU_FAULT_PAGE_RESPONSE_NEEDS_PASID;
> +               prm->pasid = FIELD_GET(RISCV_IOMMU_PREQ_HDR_PID, req->hdr);
> +       }
> +
> +       ret = iommu_report_device_fault(dev, &event);
> +       if (ret) {
> +               struct iommu_page_response resp = {
> +                       .grpid = prm->grpid,
> +                       .code = IOMMU_PAGE_RESP_FAILURE,
> +               };
> +               if (prm->flags & IOMMU_FAULT_PAGE_RESPONSE_NEEDS_PASID) {
> +                       resp.flags |= IOMMU_PAGE_RESP_PASID_VALID;
> +                       resp.pasid = prm->pasid;
> +               }
> +               riscv_iommu_ats_prgr(dev, &resp);
> +       }
> +
> +       put_device(dev);
> +}
> +
> +static int riscv_iommu_page_response(struct device *dev,
> +                                    struct iommu_fault_event *evt,
> +                                    struct iommu_page_response *msg)
> +{
> +       return riscv_iommu_ats_prgr(dev, msg);
> +}
> +
>  /* Page request interface queue primary interrupt handler */
>  static irqreturn_t riscv_iommu_priq_irq_check(int irq, void *data)
>  {
> @@ -626,7 +897,7 @@ static irqreturn_t riscv_iommu_priq_process(int irq, void *data)
>         struct riscv_iommu_queue *q = (struct riscv_iommu_queue *)data;
>         struct riscv_iommu_device *iommu;
>         struct riscv_iommu_pq_record *requests;
> -       unsigned cnt, idx, ctrl;
> +       unsigned cnt, len, idx, ctrl;
>
>         iommu = container_of(q, struct riscv_iommu_device, priq);
>         requests = (struct riscv_iommu_pq_record *)q->base;
> @@ -649,7 +920,8 @@ static irqreturn_t riscv_iommu_priq_process(int irq, void *data)
>                 cnt = riscv_iommu_queue_consume(iommu, q, &idx);
>                 if (!cnt)
>                         break;
> -               dev_warn(iommu->dev, "unexpected %u page requests\n", cnt);
> +               for (len = 0; len < cnt; idx++, len++)
> +                       riscv_iommu_page_request(iommu, &requests[idx]);
>                 riscv_iommu_queue_release(iommu, q, cnt);
>         } while (1);
>
> @@ -660,6 +932,169 @@ static irqreturn_t riscv_iommu_priq_process(int irq, void *data)
>   * Endpoint management
>   */
>
> +/* Endpoint features/capabilities */
> +static void riscv_iommu_disable_ep(struct riscv_iommu_endpoint *ep)
> +{
> +       struct pci_dev *pdev;
> +
> +       if (!dev_is_pci(ep->dev))
> +               return;
> +
> +       pdev = to_pci_dev(ep->dev);
> +
> +       if (ep->pasid_enabled) {
> +               pci_disable_ats(pdev);
> +               pci_disable_pri(pdev);
> +               pci_disable_pasid(pdev);
> +               ep->pasid_enabled = false;
> +       }
> +}
> +
> +static void riscv_iommu_enable_ep(struct riscv_iommu_endpoint *ep)
> +{
> +       int rc, feat, num;
> +       struct pci_dev *pdev;
> +       struct device *dev = ep->dev;
> +
> +       if (!dev_is_pci(dev))
> +               return;
> +
> +       if (!ep->iommu->iommu.max_pasids)
> +               return;
> +
> +       pdev = to_pci_dev(dev);
> +
> +       if (!pci_ats_supported(pdev))
> +               return;
> +
> +       if (!pci_pri_supported(pdev))
> +               return;
> +
> +       feat = pci_pasid_features(pdev);
> +       if (feat < 0)
> +               return;
> +
> +       num = pci_max_pasids(pdev);
> +       if (!num) {
> +               dev_warn(dev, "Can't enable PASID (num: %d)\n", num);
> +               return;
> +       }
> +
> +       if (num > ep->iommu->iommu.max_pasids)
> +               num = ep->iommu->iommu.max_pasids;
> +
> +       rc = pci_enable_pasid(pdev, feat);
> +       if (rc) {
> +               dev_warn(dev, "Can't enable PASID (rc: %d)\n", rc);
> +               return;
> +       }
> +
> +       rc = pci_reset_pri(pdev);
> +       if (rc) {
> +               dev_warn(dev, "Can't reset PRI (rc: %d)\n", rc);
> +               pci_disable_pasid(pdev);
> +               return;
> +       }
> +
> +       /* TODO: Get supported PRI queue length, hard-code to 32 entries */
> +       rc = pci_enable_pri(pdev, 32);
> +       if (rc) {
> +               dev_warn(dev, "Can't enable PRI (rc: %d)\n", rc);
> +               pci_disable_pasid(pdev);
> +               return;
> +       }
> +
> +       rc = pci_enable_ats(pdev, PAGE_SHIFT);
> +       if (rc) {
> +               dev_warn(dev, "Can't enable ATS (rc: %d)\n", rc);
> +               pci_disable_pri(pdev);
> +               pci_disable_pasid(pdev);
> +               return;
> +       }
> +
> +       ep->pc = (struct riscv_iommu_pc *)get_zeroed_page(GFP_KERNEL);
> +       if (!ep->pc) {
> +               pci_disable_ats(pdev);
> +               pci_disable_pri(pdev);
> +               pci_disable_pasid(pdev);
> +               return;
> +       }
> +
> +       ep->pasid_enabled = true;
> +       ep->pasid_feat = feat;
> +       ep->pasid_bits = ilog2(num);
> +
> +       dev_dbg(ep->dev, "PASID/ATS support enabled, %d bits\n", ep->pasid_bits);
> +}
> +
> +static int riscv_iommu_enable_sva(struct device *dev)
> +{
> +       int ret;
> +       struct riscv_iommu_endpoint *ep = dev_iommu_priv_get(dev);
> +
> +       if (!ep || !ep->iommu || !ep->iommu->pq_work)
> +               return -EINVAL;
> +
> +       if (!ep->pasid_enabled)
> +               return -ENODEV;
> +
> +       ret = iopf_queue_add_device(ep->iommu->pq_work, dev);
> +       if (ret)
> +               return ret;
> +
> +       return iommu_register_device_fault_handler(dev, iommu_queue_iopf, dev);
> +}
> +
> +static int riscv_iommu_disable_sva(struct device *dev)
> +{
> +       int ret;
> +       struct riscv_iommu_endpoint *ep = dev_iommu_priv_get(dev);
> +
> +       ret = iommu_unregister_device_fault_handler(dev);
> +       if (!ret)
> +               ret = iopf_queue_remove_device(ep->iommu->pq_work, dev);
> +
> +       return ret;
> +}
> +
> +static int riscv_iommu_enable_iopf(struct device *dev)
> +{
> +       struct riscv_iommu_endpoint *ep = dev_iommu_priv_get(dev);
> +
> +       if (ep && ep->pasid_enabled)
> +               return 0;
> +
> +       return -EINVAL;
> +}
> +
> +static int riscv_iommu_dev_enable_feat(struct device *dev, enum iommu_dev_features feat)
> +{
> +       switch (feat) {
> +       case IOMMU_DEV_FEAT_IOPF:
> +               return riscv_iommu_enable_iopf(dev);
> +
> +       case IOMMU_DEV_FEAT_SVA:
> +               return riscv_iommu_enable_sva(dev);
> +
> +       default:
> +               return -ENODEV;
> +       }
> +}
> +
> +static int riscv_iommu_dev_disable_feat(struct device *dev, enum iommu_dev_features feat)
> +{
> +       switch (feat) {
> +       case IOMMU_DEV_FEAT_IOPF:
> +               return 0;
> +
> +       case IOMMU_DEV_FEAT_SVA:
> +               return riscv_iommu_disable_sva(dev);
> +
> +       default:
> +               return -ENODEV;
> +       }
> +}
> +
>  static int riscv_iommu_of_xlate(struct device *dev, struct of_phandle_args *args)
>  {
>         return iommu_fwspec_add_ids(dev, args->args, 1);
> @@ -812,6 +1247,7 @@ static struct iommu_device *riscv_iommu_probe_device(struct device *dev)
>
>         dev_iommu_priv_set(dev, ep);
>         riscv_iommu_add_device(iommu, dev);
> +       riscv_iommu_enable_ep(ep);
>
>         return &iommu->iommu;
>  }
> @@ -843,6 +1279,8 @@ static void riscv_iommu_release_device(struct device *dev)
>                 riscv_iommu_iodir_inv_devid(iommu, ep->devid);
>         }
>
> +       riscv_iommu_disable_ep(ep);
> +
>         /* Remove endpoint from IOMMU tracking structures */
>         mutex_lock(&iommu->eps_mutex);
>         rb_erase(&ep->node, &iommu->eps);
> @@ -878,7 +1316,8 @@ static struct iommu_domain *riscv_iommu_domain_alloc(unsigned type)
>             type != IOMMU_DOMAIN_DMA_FQ &&
>             type != IOMMU_DOMAIN_UNMANAGED &&
>             type != IOMMU_DOMAIN_IDENTITY &&
> -           type != IOMMU_DOMAIN_BLOCKED)
> +           type != IOMMU_DOMAIN_BLOCKED &&
> +           type != IOMMU_DOMAIN_SVA)
>                 return NULL;
>
>         domain = kzalloc(sizeof(*domain), GFP_KERNEL);
> @@ -906,6 +1345,9 @@ static void riscv_iommu_domain_free(struct iommu_domain *iommu_domain)
>                 pr_warn("IOMMU domain is not empty!\n");
>         }
>
> +       if (domain->mn.ops && iommu_domain->mm)
> +               mmu_notifier_unregister(&domain->mn, iommu_domain->mm);
> +
>         if (domain->pgtbl.cookie)
>                 free_io_pgtable_ops(&domain->pgtbl.ops);
>
> @@ -1023,14 +1465,29 @@ static int riscv_iommu_attach_dev(struct iommu_domain *iommu_domain, struct devi
>          */
>         val = FIELD_PREP(RISCV_IOMMU_DC_TA_PSCID, domain->pscid);
>
> -       dc->ta = cpu_to_le64(val);
> -       dc->fsc = cpu_to_le64(riscv_iommu_domain_atp(domain));
> +       if (ep->pasid_enabled) {
> +               ep->pc[0].ta = cpu_to_le64(val | RISCV_IOMMU_PC_TA_V);
> +               ep->pc[0].fsc = cpu_to_le64(riscv_iommu_domain_atp(domain));
> +               dc->ta = 0;
> +               dc->fsc = cpu_to_le64(virt_to_pfn(ep->pc) |
> +                   FIELD_PREP(RISCV_IOMMU_DC_FSC_MODE, RISCV_IOMMU_DC_FSC_PDTP_MODE_PD8));

Could I know why we determinate to use PD8 directly? Rather than PD17 or PD20.

> +       } else {
> +               dc->ta = cpu_to_le64(val);
> +               dc->fsc = cpu_to_le64(riscv_iommu_domain_atp(domain));
> +       }
>
>         wmb();
>
>         /* Mark device context as valid, synchronise device context cache. */
>         val = RISCV_IOMMU_DC_TC_V;
>
> +       if (ep->pasid_enabled) {
> +               val |= RISCV_IOMMU_DC_TC_EN_ATS |
> +                      RISCV_IOMMU_DC_TC_EN_PRI |
> +                      RISCV_IOMMU_DC_TC_DPE |
> +                      RISCV_IOMMU_DC_TC_PDTV;
> +       }
> +
>         if (ep->iommu->cap & RISCV_IOMMU_CAP_AMO) {
>                 val |= RISCV_IOMMU_DC_TC_GADE |
>                        RISCV_IOMMU_DC_TC_SADE;
> @@ -1051,13 +1508,107 @@ static int riscv_iommu_attach_dev(struct iommu_domain *iommu_domain, struct devi
>         return 0;
>  }
>
> +static int riscv_iommu_set_dev_pasid(struct iommu_domain *iommu_domain,
> +                                    struct device *dev, ioasid_t pasid)
> +{
> +       struct riscv_iommu_domain *domain = iommu_domain_to_riscv(iommu_domain);
> +       struct riscv_iommu_endpoint *ep = dev_iommu_priv_get(dev);
> +       u64 ta, fsc;
> +
> +       if (!iommu_domain || !iommu_domain->mm)
> +               return -EINVAL;
> +
> +       /* Driver uses TC.DPE mode, PASID #0 is incorrect. */
> +       if (pasid == 0)
> +               return -EINVAL;
> +
> +       /* Incorrect domain identifier */
> +       if ((int)domain->pscid < 0)
> +               return -ENOMEM;
> +
> +       /* Process Context table should be set for pasid enabled endpoints. */
> +       if (!ep || !ep->pasid_enabled || !ep->dc || !ep->pc)
> +               return -ENODEV;
> +
> +       domain->pasid = pasid;
> +       domain->iommu = ep->iommu;
> +       domain->mn.ops = &riscv_iommu_mmuops;
> +
> +       /* register mm notifier */
> +       if (mmu_notifier_register(&domain->mn, iommu_domain->mm))
> +               return -ENODEV;
> +
> +       /* TODO: get SXL value for the process, use 32 bit or SATP mode */
> +       fsc = virt_to_pfn(iommu_domain->mm->pgd) | satp_mode;
> +       ta = RISCV_IOMMU_PC_TA_V | FIELD_PREP(RISCV_IOMMU_PC_TA_PSCID, domain->pscid);
> +
> +       fsc = le64_to_cpu(xchg_relaxed(&(ep->pc[pasid].fsc), cpu_to_le64(fsc)));
> +       ta = le64_to_cpu(xchg_relaxed(&(ep->pc[pasid].ta), cpu_to_le64(ta)));
> +
> +       wmb();
> +
> +       if (ta & RISCV_IOMMU_PC_TA_V) {
> +               riscv_iommu_iodir_inv_pasid(ep->iommu, ep->devid, pasid);
> +               riscv_iommu_iofence_sync(ep->iommu);
> +       }
> +
> +       dev_info(dev, "domain type %d attached w/ PSCID %u PASID %u\n",
> +           domain->domain.type, domain->pscid, domain->pasid);
> +
> +       return 0;
> +}
> +
> +static void riscv_iommu_remove_dev_pasid(struct device *dev, ioasid_t pasid)
> +{
> +       struct riscv_iommu_endpoint *ep = dev_iommu_priv_get(dev);
> +       struct riscv_iommu_command cmd;
> +       unsigned long payload = riscv_iommu_ats_inval_all_payload(false);
> +       u64 ta;
> +
> +       /* invalidate TA.V */
> +       ta = le64_to_cpu(xchg_relaxed(&(ep->pc[pasid].ta), 0));
> +
> +       wmb();
> +
> +       dev_info(dev, "domain removed w/ PSCID %u PASID %u\n",
> +           (unsigned)FIELD_GET(RISCV_IOMMU_PC_TA_PSCID, ta), pasid);
> +
> +       /* 1. invalidate PDT entry */
> +       riscv_iommu_iodir_inv_pasid(ep->iommu, ep->devid, pasid);
> +
> +       /* 2. invalidate all matching IOATC entries (if PASID was valid) */
> +       if (ta & RISCV_IOMMU_PC_TA_V) {
> +               riscv_iommu_cmd_inval_vma(&cmd);
> +               riscv_iommu_cmd_inval_set_gscid(&cmd, 0);
> +               riscv_iommu_cmd_inval_set_pscid(&cmd,
> +                   FIELD_GET(RISCV_IOMMU_PC_TA_PSCID, ta));
> +               riscv_iommu_post(ep->iommu, &cmd);
> +       }
> +
> +       /* 3. Wait IOATC flush to happen */
> +       riscv_iommu_iofence_sync(ep->iommu);
> +
> +       /* 4. ATS invalidation */
> +       riscv_iommu_cmd_ats_inval(&cmd);
> +       riscv_iommu_cmd_ats_set_dseg(&cmd, ep->domid);
> +       riscv_iommu_cmd_ats_set_rid(&cmd, ep->devid);
> +       riscv_iommu_cmd_ats_set_pid(&cmd, pasid);
> +       riscv_iommu_cmd_ats_set_payload(&cmd, payload);
> +       riscv_iommu_post(ep->iommu, &cmd);
> +
> +       /* 5. Wait DevATC flush to happen */
> +       riscv_iommu_iofence_sync(ep->iommu);
> +}
> +
>  static void riscv_iommu_flush_iotlb_range(struct iommu_domain *iommu_domain,
>                                           unsigned long *start, unsigned long *end,
>                                           size_t *pgsize)
>  {
>         struct riscv_iommu_domain *domain = iommu_domain_to_riscv(iommu_domain);
>         struct riscv_iommu_command cmd;
> +       struct riscv_iommu_endpoint *endpoint;
>         unsigned long iova;
> +       unsigned long payload;
>
>         if (domain->mode == RISCV_IOMMU_DC_FSC_MODE_BARE)
>                 return;
> @@ -1065,6 +1616,12 @@ static void riscv_iommu_flush_iotlb_range(struct iommu_domain *iommu_domain,
>         /* Domain not attached to an IOMMU! */
>         BUG_ON(!domain->iommu);
>
> +       if (start && end) {
> +               payload = riscv_iommu_ats_inval_payload(*start, *end, true);
> +       } else {
> +               payload = riscv_iommu_ats_inval_all_payload(true);
> +       }
> +
>         riscv_iommu_cmd_inval_vma(&cmd);
>         riscv_iommu_cmd_inval_set_pscid(&cmd, domain->pscid);
>
> @@ -1078,6 +1635,20 @@ static void riscv_iommu_flush_iotlb_range(struct iommu_domain *iommu_domain,
>                 riscv_iommu_post(domain->iommu, &cmd);
>         }
>         riscv_iommu_iofence_sync(domain->iommu);
> +
> +       /* ATS invalidation for every device and for every translation */
> +       list_for_each_entry(endpoint, &domain->endpoints, domain) {
> +               if (!endpoint->pasid_enabled)
> +                       continue;
> +
> +               riscv_iommu_cmd_ats_inval(&cmd);
> +               riscv_iommu_cmd_ats_set_dseg(&cmd, endpoint->domid);
> +               riscv_iommu_cmd_ats_set_rid(&cmd, endpoint->devid);
> +               riscv_iommu_cmd_ats_set_pid(&cmd, domain->pasid);
> +               riscv_iommu_cmd_ats_set_payload(&cmd, payload);
> +               riscv_iommu_post(domain->iommu, &cmd);
> +       }
> +       riscv_iommu_iofence_sync(domain->iommu);
>  }
>
>  static void riscv_iommu_flush_iotlb_all(struct iommu_domain *iommu_domain)
> @@ -1310,6 +1881,7 @@ static int riscv_iommu_enable(struct riscv_iommu_device *iommu, unsigned request
>  static const struct iommu_domain_ops riscv_iommu_domain_ops = {
>         .free = riscv_iommu_domain_free,
>         .attach_dev = riscv_iommu_attach_dev,
> +       .set_dev_pasid = riscv_iommu_set_dev_pasid,
>         .map_pages = riscv_iommu_map_pages,
>         .unmap_pages = riscv_iommu_unmap_pages,
>         .iova_to_phys = riscv_iommu_iova_to_phys,
> @@ -1326,9 +1898,13 @@ static const struct iommu_ops riscv_iommu_ops = {
>         .probe_device = riscv_iommu_probe_device,
>         .probe_finalize = riscv_iommu_probe_finalize,
>         .release_device = riscv_iommu_release_device,
> +       .remove_dev_pasid = riscv_iommu_remove_dev_pasid,
>         .device_group = riscv_iommu_device_group,
>         .get_resv_regions = riscv_iommu_get_resv_regions,
>         .of_xlate = riscv_iommu_of_xlate,
> +       .dev_enable_feat = riscv_iommu_dev_enable_feat,
> +       .dev_disable_feat = riscv_iommu_dev_disable_feat,
> +       .page_response = riscv_iommu_page_response,
>         .default_domain_ops = &riscv_iommu_domain_ops,
>  };
>
> @@ -1340,6 +1916,7 @@ void riscv_iommu_remove(struct riscv_iommu_device *iommu)
>         riscv_iommu_queue_free(iommu, &iommu->cmdq);
>         riscv_iommu_queue_free(iommu, &iommu->fltq);
>         riscv_iommu_queue_free(iommu, &iommu->priq);
> +       iopf_queue_free(iommu->pq_work);
>  }
>
>  int riscv_iommu_init(struct riscv_iommu_device *iommu)
> @@ -1362,6 +1939,12 @@ int riscv_iommu_init(struct riscv_iommu_device *iommu)
>         }
>  #endif
>
> +       if (iommu->cap & RISCV_IOMMU_CAP_PD20)
> +               iommu->iommu.max_pasids = 1u << 20;
> +       else if (iommu->cap & RISCV_IOMMU_CAP_PD17)
> +               iommu->iommu.max_pasids = 1u << 17;
> +       else if (iommu->cap & RISCV_IOMMU_CAP_PD8)
> +               iommu->iommu.max_pasids = 1u << 8;
>         /*
>          * Assign queue lengths from module parameters if not already
>          * set on the device tree.
> @@ -1387,6 +1970,13 @@ int riscv_iommu_init(struct riscv_iommu_device *iommu)
>                 goto fail;
>         if (!(iommu->cap & RISCV_IOMMU_CAP_ATS))
>                 goto no_ats;
> +       /* PRI functionally depends on ATS’s capabilities. */
> +       iommu->pq_work = iopf_queue_alloc(dev_name(dev));
> +       if (!iommu->pq_work) {
> +               dev_err(dev, "failed to allocate iopf queue\n");
> +               ret = -ENOMEM;
> +               goto fail;
> +       }
>
>         ret = riscv_iommu_queue_init(iommu, RISCV_IOMMU_PAGE_REQUEST_QUEUE);
>         if (ret)
> @@ -1424,5 +2014,6 @@ int riscv_iommu_init(struct riscv_iommu_device *iommu)
>         riscv_iommu_queue_free(iommu, &iommu->priq);
>         riscv_iommu_queue_free(iommu, &iommu->fltq);
>         riscv_iommu_queue_free(iommu, &iommu->cmdq);
> +       iopf_queue_free(iommu->pq_work);
>         return ret;
>  }
> diff --git a/drivers/iommu/riscv/iommu.h b/drivers/iommu/riscv/iommu.h
> index fe32a4eff14e..83e8d00fd0f8 100644
> --- a/drivers/iommu/riscv/iommu.h
> +++ b/drivers/iommu/riscv/iommu.h
> @@ -17,9 +17,11 @@
>  #include <linux/iova.h>
>  #include <linux/io.h>
>  #include <linux/idr.h>
> +#include <linux/mmu_notifier.h>
>  #include <linux/list.h>
>  #include <linux/iommu.h>
>  #include <linux/io-pgtable.h>
> +#include <linux/mmu_notifier.h>

You include the mmu_notifier.h twice in this header

>
>  #include "iommu-bits.h"
>
> @@ -76,6 +78,9 @@ struct riscv_iommu_device {
>         unsigned ddt_mode;
>         bool ddtp_in_iomem;
>
> +       /* I/O page fault queue */
> +       struct iopf_queue *pq_work;
> +
>         /* hardware queues */
>         struct riscv_iommu_queue cmdq;
>         struct riscv_iommu_queue fltq;
> @@ -91,11 +96,14 @@ struct riscv_iommu_domain {
>         struct io_pgtable pgtbl;
>
>         struct list_head endpoints;
> +       struct list_head notifiers;
>         struct mutex lock;
> +       struct mmu_notifier mn;
>         struct riscv_iommu_device *iommu;
>
>         unsigned mode;          /* RIO_ATP_MODE_* enum */
>         unsigned pscid;         /* RISC-V IOMMU PSCID */
> +       ioasid_t pasid;         /* IOMMU_DOMAIN_SVA: Cached PASID */
>
>         pgd_t *pgd_root;        /* page table root pointer */
>  };
> @@ -107,10 +115,16 @@ struct riscv_iommu_endpoint {
>         unsigned domid;                         /* PCI domain number, segment */
>         struct rb_node node;                    /* device tracking node (lookup by devid) */
>         struct riscv_iommu_dc *dc;              /* device context pointer */
> +       struct riscv_iommu_pc *pc;              /* process context root, valid if pasid_enabled is true */
>         struct riscv_iommu_device *iommu;       /* parent iommu device */
>
>         struct mutex lock;
>         struct list_head domain;                /* endpoint attached managed domain */
> +
> +       /* end point info bits */
> +       unsigned pasid_bits;
> +       unsigned pasid_feat;
> +       bool pasid_enabled;
>  };
>
>  /* Helper functions and macros */
> --
> 2.34.1
>
>
> _______________________________________________
> linux-riscv mailing list
> linux-riscv at lists.infradead.org
> http://lists.infradead.org/mailman/listinfo/linux-riscv



More information about the linux-riscv mailing list