[PATCH 09/11] RISC-V: drivers/iommu/riscv: Add SVA with PASID/ATS/PRI support.

Tomasz Jeznach tjeznach at rivosinc.com
Wed Jul 19 12:33:53 PDT 2023


Introduces SVA (Shared Virtual Address) for RISC-V IOMMU, with
ATS/PRI services for capable devices.

Co-developed-by: Sebastien Boeuf <seb at rivosinc.com>
Signed-off-by: Sebastien Boeuf <seb at rivosinc.com>
Signed-off-by: Tomasz Jeznach <tjeznach at rivosinc.com>
---
 drivers/iommu/riscv/iommu.c | 601 +++++++++++++++++++++++++++++++++++-
 drivers/iommu/riscv/iommu.h |  14 +
 2 files changed, 610 insertions(+), 5 deletions(-)

diff --git a/drivers/iommu/riscv/iommu.c b/drivers/iommu/riscv/iommu.c
index 2ef6952a2109..6042c35be3ca 100644
--- a/drivers/iommu/riscv/iommu.c
+++ b/drivers/iommu/riscv/iommu.c
@@ -384,6 +384,89 @@ static inline void riscv_iommu_cmd_iodir_set_did(struct riscv_iommu_command *cmd
 	    FIELD_PREP(RISCV_IOMMU_CMD_IODIR_DID, devid) | RISCV_IOMMU_CMD_IODIR_DV;
 }
 
+static inline void riscv_iommu_cmd_iodir_set_pid(struct riscv_iommu_command *cmd,
+						 unsigned pasid)
+{
+	cmd->dword0 |= FIELD_PREP(RISCV_IOMMU_CMD_IODIR_PID, pasid);
+}
+
+static void riscv_iommu_cmd_ats_inval(struct riscv_iommu_command *cmd)
+{
+	cmd->dword0 = FIELD_PREP(RISCV_IOMMU_CMD_OPCODE, RISCV_IOMMU_CMD_ATS_OPCODE) |
+	    FIELD_PREP(RISCV_IOMMU_CMD_FUNC, RISCV_IOMMU_CMD_ATS_FUNC_INVAL);
+	cmd->dword1 = 0;
+}
+
+static inline void riscv_iommu_cmd_ats_prgr(struct riscv_iommu_command *cmd)
+{
+	cmd->dword0 = FIELD_PREP(RISCV_IOMMU_CMD_OPCODE, RISCV_IOMMU_CMD_ATS_OPCODE) |
+	    FIELD_PREP(RISCV_IOMMU_CMD_FUNC, RISCV_IOMMU_CMD_ATS_FUNC_PRGR);
+	cmd->dword1 = 0;
+}
+
+static void riscv_iommu_cmd_ats_set_rid(struct riscv_iommu_command *cmd, u32 rid)
+{
+	cmd->dword0 |= FIELD_PREP(RISCV_IOMMU_CMD_ATS_RID, rid);
+}
+
+static void riscv_iommu_cmd_ats_set_pid(struct riscv_iommu_command *cmd, u32 pid)
+{
+	cmd->dword0 |= FIELD_PREP(RISCV_IOMMU_CMD_ATS_PID, pid) | RISCV_IOMMU_CMD_ATS_PV;
+}
+
+static void riscv_iommu_cmd_ats_set_dseg(struct riscv_iommu_command *cmd, u8 seg)
+{
+	cmd->dword0 |= FIELD_PREP(RISCV_IOMMU_CMD_ATS_DSEG, seg) | RISCV_IOMMU_CMD_ATS_DSV;
+}
+
+static void riscv_iommu_cmd_ats_set_payload(struct riscv_iommu_command *cmd, u64 payload)
+{
+	cmd->dword1 = payload;
+}
+
+/* Prepare the ATS invalidation payload */
+static unsigned long riscv_iommu_ats_inval_payload(unsigned long start,
+						   unsigned long end, bool global_inv)
+{
+	size_t len = end - start + 1;
+	unsigned long payload = 0;
+
+	/*
+	 * PCI Express specification
+	 * Section 10.2.3.2 Translation Range Size (S) Field
+	 */
+	if (len < PAGE_SIZE)
+		len = PAGE_SIZE;
+	else
+		len = __roundup_pow_of_two(len);
+
+	payload = (start & ~(len - 1)) | (((len - 1) >> 12) << 11);
+
+	if (global_inv)
+		payload |= RISCV_IOMMU_CMD_ATS_INVAL_G;
+
+	return payload;
+}
+
+/* Prepare the ATS invalidation payload for all translations to be invalidated. */
+static unsigned long riscv_iommu_ats_inval_all_payload(bool global_inv)
+{
+	unsigned long payload = GENMASK_ULL(62, 11);
+
+	if (global_inv)
+		payload |= RISCV_IOMMU_CMD_ATS_INVAL_G;
+
+	return payload;
+}
+
+/* Prepare the ATS "Page Request Group Response" payload */
+static unsigned long riscv_iommu_ats_prgr_payload(u16 dest_id, u8 resp_code, u16 grp_idx)
+{
+	return FIELD_PREP(RISCV_IOMMU_CMD_ATS_PRGR_DST_ID, dest_id) |
+	    FIELD_PREP(RISCV_IOMMU_CMD_ATS_PRGR_RESP_CODE, resp_code) |
+	    FIELD_PREP(RISCV_IOMMU_CMD_ATS_PRGR_PRG_INDEX, grp_idx);
+}
+
 /* TODO: Convert into lock-less MPSC implementation. */
 static bool riscv_iommu_post_sync(struct riscv_iommu_device *iommu,
 				  struct riscv_iommu_command *cmd, bool sync)
@@ -460,6 +543,16 @@ static bool riscv_iommu_iodir_inv_devid(struct riscv_iommu_device *iommu, unsign
 	return riscv_iommu_post(iommu, &cmd);
 }
 
+static bool riscv_iommu_iodir_inv_pasid(struct riscv_iommu_device *iommu,
+					unsigned devid, unsigned pasid)
+{
+	struct riscv_iommu_command cmd;
+	riscv_iommu_cmd_iodir_inval_pdt(&cmd);
+	riscv_iommu_cmd_iodir_set_did(&cmd, devid);
+	riscv_iommu_cmd_iodir_set_pid(&cmd, pasid);
+	return riscv_iommu_post(iommu, &cmd);
+}
+
 static bool riscv_iommu_iofence_sync(struct riscv_iommu_device *iommu)
 {
 	struct riscv_iommu_command cmd;
@@ -467,6 +560,62 @@ static bool riscv_iommu_iofence_sync(struct riscv_iommu_device *iommu)
 	return riscv_iommu_post_sync(iommu, &cmd, true);
 }
 
+static void riscv_iommu_mm_invalidate(struct mmu_notifier *mn,
+				      struct mm_struct *mm, unsigned long start,
+				      unsigned long end)
+{
+	struct riscv_iommu_command cmd;
+	struct riscv_iommu_endpoint *endpoint;
+	struct riscv_iommu_domain *domain =
+	    container_of(mn, struct riscv_iommu_domain, mn);
+	unsigned long iova;
+	/*
+	 * The mm_types defines vm_end as the first byte after the end address,
+	 * different from IOMMU subsystem using the last address of an address
+	 * range. So do a simple translation here by updating what end means.
+	 */
+	unsigned long payload = riscv_iommu_ats_inval_payload(start, end - 1, true);
+
+	riscv_iommu_cmd_inval_vma(&cmd);
+	riscv_iommu_cmd_inval_set_gscid(&cmd, 0);
+	riscv_iommu_cmd_inval_set_pscid(&cmd, domain->pscid);
+	if (end > start) {
+		/* Cover only the range that is needed */
+		for (iova = start; iova < end; iova += PAGE_SIZE) {
+			riscv_iommu_cmd_inval_set_addr(&cmd, iova);
+			riscv_iommu_post(domain->iommu, &cmd);
+		}
+	} else {
+		riscv_iommu_post(domain->iommu, &cmd);
+	}
+
+	riscv_iommu_iofence_sync(domain->iommu);
+
+	/* ATS invalidation for every device and for specific translation range. */
+	list_for_each_entry(endpoint, &domain->endpoints, domain) {
+		if (!endpoint->pasid_enabled)
+			continue;
+
+		riscv_iommu_cmd_ats_inval(&cmd);
+		riscv_iommu_cmd_ats_set_dseg(&cmd, endpoint->domid);
+		riscv_iommu_cmd_ats_set_rid(&cmd, endpoint->devid);
+		riscv_iommu_cmd_ats_set_pid(&cmd, domain->pasid);
+		riscv_iommu_cmd_ats_set_payload(&cmd, payload);
+		riscv_iommu_post(domain->iommu, &cmd);
+	}
+	riscv_iommu_iofence_sync(domain->iommu);
+}
+
+static void riscv_iommu_mm_release(struct mmu_notifier *mn, struct mm_struct *mm)
+{
+	/* TODO: removed from notifier, cleanup PSCID mapping, flush IOTLB */
+}
+
+static const struct mmu_notifier_ops riscv_iommu_mmuops = {
+	.release = riscv_iommu_mm_release,
+	.invalidate_range = riscv_iommu_mm_invalidate,
+};
+
 /* Command queue primary interrupt handler */
 static irqreturn_t riscv_iommu_cmdq_irq_check(int irq, void *data)
 {
@@ -608,6 +757,128 @@ static void riscv_iommu_add_device(struct riscv_iommu_device *iommu, struct devi
 	mutex_unlock(&iommu->eps_mutex);
 }
 
+/*
+ * Get device reference based on device identifier (requester id).
+ * Decrement reference count with put_device() call.
+ */
+static struct device *riscv_iommu_get_device(struct riscv_iommu_device *iommu,
+					     unsigned devid)
+{
+	struct rb_node *node;
+	struct riscv_iommu_endpoint *ep;
+	struct device *dev = NULL;
+
+	mutex_lock(&iommu->eps_mutex);
+
+	node = iommu->eps.rb_node;
+	while (node && !dev) {
+		ep = rb_entry(node, struct riscv_iommu_endpoint, node);
+		if (ep->devid < devid)
+			node = node->rb_right;
+		else if (ep->devid > devid)
+			node = node->rb_left;
+		else
+			dev = get_device(ep->dev);
+	}
+
+	mutex_unlock(&iommu->eps_mutex);
+
+	return dev;
+}
+
+static int riscv_iommu_ats_prgr(struct device *dev, struct iommu_page_response *msg)
+{
+	struct riscv_iommu_endpoint *ep = dev_iommu_priv_get(dev);
+	struct riscv_iommu_command cmd;
+	u8 resp_code;
+	unsigned long payload;
+
+	switch (msg->code) {
+	case IOMMU_PAGE_RESP_SUCCESS:
+		resp_code = 0b0000;
+		break;
+	case IOMMU_PAGE_RESP_INVALID:
+		resp_code = 0b0001;
+		break;
+	case IOMMU_PAGE_RESP_FAILURE:
+		resp_code = 0b1111;
+		break;
+	}
+	payload = riscv_iommu_ats_prgr_payload(ep->devid, resp_code, msg->grpid);
+
+	/* ATS Page Request Group Response */
+	riscv_iommu_cmd_ats_prgr(&cmd);
+	riscv_iommu_cmd_ats_set_dseg(&cmd, ep->domid);
+	riscv_iommu_cmd_ats_set_rid(&cmd, ep->devid);
+	if (msg->flags & IOMMU_PAGE_RESP_PASID_VALID)
+		riscv_iommu_cmd_ats_set_pid(&cmd, msg->pasid);
+	riscv_iommu_cmd_ats_set_payload(&cmd, payload);
+	riscv_iommu_post(ep->iommu, &cmd);
+
+	return 0;
+}
+
+static void riscv_iommu_page_request(struct riscv_iommu_device *iommu,
+				     struct riscv_iommu_pq_record *req)
+{
+	struct iommu_fault_event event = { 0 };
+	struct iommu_fault_page_request *prm = &event.fault.prm;
+	int ret;
+	struct device *dev;
+	unsigned devid = FIELD_GET(RISCV_IOMMU_PREQ_HDR_DID, req->hdr);
+
+	/* Ignore PGR Stop marker. */
+	if ((req->payload & RISCV_IOMMU_PREQ_PAYLOAD_M) == RISCV_IOMMU_PREQ_PAYLOAD_L)
+		return;
+
+	dev = riscv_iommu_get_device(iommu, devid);
+	if (!dev) {
+		/* TODO: Handle invalid page request */
+		return;
+	}
+
+	event.fault.type = IOMMU_FAULT_PAGE_REQ;
+
+	if (req->payload & RISCV_IOMMU_PREQ_PAYLOAD_L)
+		prm->flags |= IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE;
+	if (req->payload & RISCV_IOMMU_PREQ_PAYLOAD_W)
+		prm->perm |= IOMMU_FAULT_PERM_WRITE;
+	if (req->payload & RISCV_IOMMU_PREQ_PAYLOAD_R)
+		prm->perm |= IOMMU_FAULT_PERM_READ;
+
+	prm->grpid = FIELD_GET(RISCV_IOMMU_PREQ_PRG_INDEX, req->payload);
+	prm->addr = FIELD_GET(RISCV_IOMMU_PREQ_UADDR, req->payload) << PAGE_SHIFT;
+
+	if (req->hdr & RISCV_IOMMU_PREQ_HDR_PV) {
+		prm->flags |= IOMMU_FAULT_PAGE_REQUEST_PASID_VALID;
+		/* TODO: where to find this bit */
+		prm->flags |= IOMMU_FAULT_PAGE_RESPONSE_NEEDS_PASID;
+		prm->pasid = FIELD_GET(RISCV_IOMMU_PREQ_HDR_PID, req->hdr);
+	}
+
+	ret = iommu_report_device_fault(dev, &event);
+	if (ret) {
+		struct iommu_page_response resp = {
+			.grpid = prm->grpid,
+			.code = IOMMU_PAGE_RESP_FAILURE,
+		};
+		if (prm->flags & IOMMU_FAULT_PAGE_RESPONSE_NEEDS_PASID) {
+			resp.flags |= IOMMU_PAGE_RESP_PASID_VALID;
+			resp.pasid = prm->pasid;
+		}
+		riscv_iommu_ats_prgr(dev, &resp);
+	}
+
+	put_device(dev);
+}
+
+static int riscv_iommu_page_response(struct device *dev,
+				     struct iommu_fault_event *evt,
+				     struct iommu_page_response *msg)
+{
+	return riscv_iommu_ats_prgr(dev, msg);
+}
+
 /* Page request interface queue primary interrupt handler */
 static irqreturn_t riscv_iommu_priq_irq_check(int irq, void *data)
 {
@@ -626,7 +897,7 @@ static irqreturn_t riscv_iommu_priq_process(int irq, void *data)
 	struct riscv_iommu_queue *q = (struct riscv_iommu_queue *)data;
 	struct riscv_iommu_device *iommu;
 	struct riscv_iommu_pq_record *requests;
-	unsigned cnt, idx, ctrl;
+	unsigned cnt, len, idx, ctrl;
 
 	iommu = container_of(q, struct riscv_iommu_device, priq);
 	requests = (struct riscv_iommu_pq_record *)q->base;
@@ -649,7 +920,8 @@ static irqreturn_t riscv_iommu_priq_process(int irq, void *data)
 		cnt = riscv_iommu_queue_consume(iommu, q, &idx);
 		if (!cnt)
 			break;
-		dev_warn(iommu->dev, "unexpected %u page requests\n", cnt);
+		for (len = 0; len < cnt; idx++, len++)
+			riscv_iommu_page_request(iommu, &requests[idx]);
 		riscv_iommu_queue_release(iommu, q, cnt);
 	} while (1);
 
@@ -660,6 +932,169 @@ static irqreturn_t riscv_iommu_priq_process(int irq, void *data)
  * Endpoint management
  */
 
+/* Endpoint features/capabilities */
+static void riscv_iommu_disable_ep(struct riscv_iommu_endpoint *ep)
+{
+	struct pci_dev *pdev;
+
+	if (!dev_is_pci(ep->dev))
+		return;
+
+	pdev = to_pci_dev(ep->dev);
+
+	if (ep->pasid_enabled) {
+		pci_disable_ats(pdev);
+		pci_disable_pri(pdev);
+		pci_disable_pasid(pdev);
+		ep->pasid_enabled = false;
+	}
+}
+
+static void riscv_iommu_enable_ep(struct riscv_iommu_endpoint *ep)
+{
+	int rc, feat, num;
+	struct pci_dev *pdev;
+	struct device *dev = ep->dev;
+
+	if (!dev_is_pci(dev))
+		return;
+
+	if (!ep->iommu->iommu.max_pasids)
+		return;
+
+	pdev = to_pci_dev(dev);
+
+	if (!pci_ats_supported(pdev))
+		return;
+
+	if (!pci_pri_supported(pdev))
+		return;
+
+	feat = pci_pasid_features(pdev);
+	if (feat < 0)
+		return;
+
+	num = pci_max_pasids(pdev);
+	if (!num) {
+		dev_warn(dev, "Can't enable PASID (num: %d)\n", num);
+		return;
+	}
+
+	if (num > ep->iommu->iommu.max_pasids)
+		num = ep->iommu->iommu.max_pasids;
+
+	rc = pci_enable_pasid(pdev, feat);
+	if (rc) {
+		dev_warn(dev, "Can't enable PASID (rc: %d)\n", rc);
+		return;
+	}
+
+	rc = pci_reset_pri(pdev);
+	if (rc) {
+		dev_warn(dev, "Can't reset PRI (rc: %d)\n", rc);
+		pci_disable_pasid(pdev);
+		return;
+	}
+
+	/* TODO: Get supported PRI queue length, hard-code to 32 entries */
+	rc = pci_enable_pri(pdev, 32);
+	if (rc) {
+		dev_warn(dev, "Can't enable PRI (rc: %d)\n", rc);
+		pci_disable_pasid(pdev);
+		return;
+	}
+
+	rc = pci_enable_ats(pdev, PAGE_SHIFT);
+	if (rc) {
+		dev_warn(dev, "Can't enable ATS (rc: %d)\n", rc);
+		pci_disable_pri(pdev);
+		pci_disable_pasid(pdev);
+		return;
+	}
+
+	ep->pc = (struct riscv_iommu_pc *)get_zeroed_page(GFP_KERNEL);
+	if (!ep->pc) {
+		pci_disable_ats(pdev);
+		pci_disable_pri(pdev);
+		pci_disable_pasid(pdev);
+		return;
+	}
+
+	ep->pasid_enabled = true;
+	ep->pasid_feat = feat;
+	ep->pasid_bits = ilog2(num);
+
+	dev_dbg(ep->dev, "PASID/ATS support enabled, %d bits\n", ep->pasid_bits);
+}
+
+static int riscv_iommu_enable_sva(struct device *dev)
+{
+	int ret;
+	struct riscv_iommu_endpoint *ep = dev_iommu_priv_get(dev);
+
+	if (!ep || !ep->iommu || !ep->iommu->pq_work)
+		return -EINVAL;
+
+	if (!ep->pasid_enabled)
+		return -ENODEV;
+
+	ret = iopf_queue_add_device(ep->iommu->pq_work, dev);
+	if (ret)
+		return ret;
+
+	return iommu_register_device_fault_handler(dev, iommu_queue_iopf, dev);
+}
+
+static int riscv_iommu_disable_sva(struct device *dev)
+{
+	int ret;
+	struct riscv_iommu_endpoint *ep = dev_iommu_priv_get(dev);
+
+	ret = iommu_unregister_device_fault_handler(dev);
+	if (!ret)
+		ret = iopf_queue_remove_device(ep->iommu->pq_work, dev);
+
+	return ret;
+}
+
+static int riscv_iommu_enable_iopf(struct device *dev)
+{
+	struct riscv_iommu_endpoint *ep = dev_iommu_priv_get(dev);
+
+	if (ep && ep->pasid_enabled)
+		return 0;
+
+	return -EINVAL;
+}
+
+static int riscv_iommu_dev_enable_feat(struct device *dev, enum iommu_dev_features feat)
+{
+	switch (feat) {
+	case IOMMU_DEV_FEAT_IOPF:
+		return riscv_iommu_enable_iopf(dev);
+
+	case IOMMU_DEV_FEAT_SVA:
+		return riscv_iommu_enable_sva(dev);
+
+	default:
+		return -ENODEV;
+	}
+}
+
+static int riscv_iommu_dev_disable_feat(struct device *dev, enum iommu_dev_features feat)
+{
+	switch (feat) {
+	case IOMMU_DEV_FEAT_IOPF:
+		return 0;
+
+	case IOMMU_DEV_FEAT_SVA:
+		return riscv_iommu_disable_sva(dev);
+
+	default:
+		return -ENODEV;
+	}
+}
+
 static int riscv_iommu_of_xlate(struct device *dev, struct of_phandle_args *args)
 {
 	return iommu_fwspec_add_ids(dev, args->args, 1);
@@ -812,6 +1247,7 @@ static struct iommu_device *riscv_iommu_probe_device(struct device *dev)
 
 	dev_iommu_priv_set(dev, ep);
 	riscv_iommu_add_device(iommu, dev);
+	riscv_iommu_enable_ep(ep);
 
 	return &iommu->iommu;
 }
@@ -843,6 +1279,8 @@ static void riscv_iommu_release_device(struct device *dev)
 		riscv_iommu_iodir_inv_devid(iommu, ep->devid);
 	}
 
+	riscv_iommu_disable_ep(ep);
+
 	/* Remove endpoint from IOMMU tracking structures */
 	mutex_lock(&iommu->eps_mutex);
 	rb_erase(&ep->node, &iommu->eps);
@@ -878,7 +1316,8 @@ static struct iommu_domain *riscv_iommu_domain_alloc(unsigned type)
 	    type != IOMMU_DOMAIN_DMA_FQ &&
 	    type != IOMMU_DOMAIN_UNMANAGED &&
 	    type != IOMMU_DOMAIN_IDENTITY &&
-	    type != IOMMU_DOMAIN_BLOCKED)
+	    type != IOMMU_DOMAIN_BLOCKED &&
+	    type != IOMMU_DOMAIN_SVA)
 		return NULL;
 
 	domain = kzalloc(sizeof(*domain), GFP_KERNEL);
@@ -906,6 +1345,9 @@ static void riscv_iommu_domain_free(struct iommu_domain *iommu_domain)
 		pr_warn("IOMMU domain is not empty!\n");
 	}
 
+	if (domain->mn.ops && iommu_domain->mm)
+		mmu_notifier_unregister(&domain->mn, iommu_domain->mm);
+
 	if (domain->pgtbl.cookie)
 		free_io_pgtable_ops(&domain->pgtbl.ops);
 
@@ -1023,14 +1465,29 @@ static int riscv_iommu_attach_dev(struct iommu_domain *iommu_domain, struct devi
 	 */
 	val = FIELD_PREP(RISCV_IOMMU_DC_TA_PSCID, domain->pscid);
 
-	dc->ta = cpu_to_le64(val);
-	dc->fsc = cpu_to_le64(riscv_iommu_domain_atp(domain));
+	if (ep->pasid_enabled) {
+		ep->pc[0].ta = cpu_to_le64(val | RISCV_IOMMU_PC_TA_V);
+		ep->pc[0].fsc = cpu_to_le64(riscv_iommu_domain_atp(domain));
+		dc->ta = 0;
+		dc->fsc = cpu_to_le64(virt_to_pfn(ep->pc) |
+		    FIELD_PREP(RISCV_IOMMU_DC_FSC_MODE, RISCV_IOMMU_DC_FSC_PDTP_MODE_PD8));
+	} else {
+		dc->ta = cpu_to_le64(val);
+		dc->fsc = cpu_to_le64(riscv_iommu_domain_atp(domain));
+	}
 
 	wmb();
 
 	/* Mark device context as valid, synchronise device context cache. */
 	val = RISCV_IOMMU_DC_TC_V;
 
+	if (ep->pasid_enabled) {
+		val |= RISCV_IOMMU_DC_TC_EN_ATS |
+		       RISCV_IOMMU_DC_TC_EN_PRI |
+		       RISCV_IOMMU_DC_TC_DPE |
+		       RISCV_IOMMU_DC_TC_PDTV;
+	}
+
 	if (ep->iommu->cap & RISCV_IOMMU_CAP_AMO) {
 		val |= RISCV_IOMMU_DC_TC_GADE |
 		       RISCV_IOMMU_DC_TC_SADE;
@@ -1051,13 +1508,107 @@ static int riscv_iommu_attach_dev(struct iommu_domain *iommu_domain, struct devi
 	return 0;
 }
 
+static int riscv_iommu_set_dev_pasid(struct iommu_domain *iommu_domain,
+				     struct device *dev, ioasid_t pasid)
+{
+	struct riscv_iommu_domain *domain = iommu_domain_to_riscv(iommu_domain);
+	struct riscv_iommu_endpoint *ep = dev_iommu_priv_get(dev);
+	u64 ta, fsc;
+
+	if (!iommu_domain || !iommu_domain->mm)
+		return -EINVAL;
+
+	/* Driver uses TC.DPE mode, PASID #0 is incorrect. */
+	if (pasid == 0)
+		return -EINVAL;
+
+	/* Incorrect domain identifier */
+	if ((int)domain->pscid < 0)
+		return -ENOMEM;
+
+	/* Process Context table should be set for pasid enabled endpoints. */
+	if (!ep || !ep->pasid_enabled || !ep->dc || !ep->pc)
+		return -ENODEV;
+
+	domain->pasid = pasid;
+	domain->iommu = ep->iommu;
+	domain->mn.ops = &riscv_iommu_mmuops;
+
+	/* register mm notifier */
+	if (mmu_notifier_register(&domain->mn, iommu_domain->mm))
+		return -ENODEV;
+
+	/* TODO: get SXL value for the process, use 32 bit or SATP mode */
+	fsc = virt_to_pfn(iommu_domain->mm->pgd) | satp_mode;
+	ta = RISCV_IOMMU_PC_TA_V | FIELD_PREP(RISCV_IOMMU_PC_TA_PSCID, domain->pscid);
+
+	fsc = le64_to_cpu(xchg_relaxed(&(ep->pc[pasid].fsc), cpu_to_le64(fsc)));
+	ta = le64_to_cpu(xchg_relaxed(&(ep->pc[pasid].ta), cpu_to_le64(ta)));
+
+	wmb();
+
+	if (ta & RISCV_IOMMU_PC_TA_V) {
+		riscv_iommu_iodir_inv_pasid(ep->iommu, ep->devid, pasid);
+		riscv_iommu_iofence_sync(ep->iommu);
+	}
+
+	dev_info(dev, "domain type %d attached w/ PSCID %u PASID %u\n",
+	    domain->domain.type, domain->pscid, domain->pasid);
+
+	return 0;
+}
+
+static void riscv_iommu_remove_dev_pasid(struct device *dev, ioasid_t pasid)
+{
+	struct riscv_iommu_endpoint *ep = dev_iommu_priv_get(dev);
+	struct riscv_iommu_command cmd;
+	unsigned long payload = riscv_iommu_ats_inval_all_payload(false);
+	u64 ta;
+
+	/* invalidate TA.V */
+	ta = le64_to_cpu(xchg_relaxed(&(ep->pc[pasid].ta), 0));
+
+	wmb();
+
+	dev_info(dev, "domain removed w/ PSCID %u PASID %u\n",
+	    (unsigned)FIELD_GET(RISCV_IOMMU_PC_TA_PSCID, ta), pasid);
+
+	/* 1. invalidate PDT entry */
+	riscv_iommu_iodir_inv_pasid(ep->iommu, ep->devid, pasid);
+
+	/* 2. invalidate all matching IOATC entries (if PASID was valid) */
+	if (ta & RISCV_IOMMU_PC_TA_V) {
+		riscv_iommu_cmd_inval_vma(&cmd);
+		riscv_iommu_cmd_inval_set_gscid(&cmd, 0);
+		riscv_iommu_cmd_inval_set_pscid(&cmd,
+		    FIELD_GET(RISCV_IOMMU_PC_TA_PSCID, ta));
+		riscv_iommu_post(ep->iommu, &cmd);
+	}
+
+	/* 3. Wait IOATC flush to happen */
+	riscv_iommu_iofence_sync(ep->iommu);
+
+	/* 4. ATS invalidation */
+	riscv_iommu_cmd_ats_inval(&cmd);
+	riscv_iommu_cmd_ats_set_dseg(&cmd, ep->domid);
+	riscv_iommu_cmd_ats_set_rid(&cmd, ep->devid);
+	riscv_iommu_cmd_ats_set_pid(&cmd, pasid);
+	riscv_iommu_cmd_ats_set_payload(&cmd, payload);
+	riscv_iommu_post(ep->iommu, &cmd);
+
+	/* 5. Wait DevATC flush to happen */
+	riscv_iommu_iofence_sync(ep->iommu);
+}
+
 static void riscv_iommu_flush_iotlb_range(struct iommu_domain *iommu_domain,
 					  unsigned long *start, unsigned long *end,
 					  size_t *pgsize)
 {
 	struct riscv_iommu_domain *domain = iommu_domain_to_riscv(iommu_domain);
 	struct riscv_iommu_command cmd;
+	struct riscv_iommu_endpoint *endpoint;
 	unsigned long iova;
+	unsigned long payload;
 
 	if (domain->mode == RISCV_IOMMU_DC_FSC_MODE_BARE)
 		return;
@@ -1065,6 +1616,12 @@ static void riscv_iommu_flush_iotlb_range(struct iommu_domain *iommu_domain,
 	/* Domain not attached to an IOMMU! */
 	BUG_ON(!domain->iommu);
 
+	if (start && end) {
+		payload = riscv_iommu_ats_inval_payload(*start, *end, true);
+	} else {
+		payload = riscv_iommu_ats_inval_all_payload(true);
+	}
+
 	riscv_iommu_cmd_inval_vma(&cmd);
 	riscv_iommu_cmd_inval_set_pscid(&cmd, domain->pscid);
 
@@ -1078,6 +1635,20 @@ static void riscv_iommu_flush_iotlb_range(struct iommu_domain *iommu_domain,
 		riscv_iommu_post(domain->iommu, &cmd);
 	}
 	riscv_iommu_iofence_sync(domain->iommu);
+
+	/* ATS invalidation for every device and for every translation */
+	list_for_each_entry(endpoint, &domain->endpoints, domain) {
+		if (!endpoint->pasid_enabled)
+			continue;
+
+		riscv_iommu_cmd_ats_inval(&cmd);
+		riscv_iommu_cmd_ats_set_dseg(&cmd, endpoint->domid);
+		riscv_iommu_cmd_ats_set_rid(&cmd, endpoint->devid);
+		riscv_iommu_cmd_ats_set_pid(&cmd, domain->pasid);
+		riscv_iommu_cmd_ats_set_payload(&cmd, payload);
+		riscv_iommu_post(domain->iommu, &cmd);
+	}
+	riscv_iommu_iofence_sync(domain->iommu);
 }
 
 static void riscv_iommu_flush_iotlb_all(struct iommu_domain *iommu_domain)
@@ -1310,6 +1881,7 @@ static int riscv_iommu_enable(struct riscv_iommu_device *iommu, unsigned request
 static const struct iommu_domain_ops riscv_iommu_domain_ops = {
 	.free = riscv_iommu_domain_free,
 	.attach_dev = riscv_iommu_attach_dev,
+	.set_dev_pasid = riscv_iommu_set_dev_pasid,
 	.map_pages = riscv_iommu_map_pages,
 	.unmap_pages = riscv_iommu_unmap_pages,
 	.iova_to_phys = riscv_iommu_iova_to_phys,
@@ -1326,9 +1898,13 @@ static const struct iommu_ops riscv_iommu_ops = {
 	.probe_device = riscv_iommu_probe_device,
 	.probe_finalize = riscv_iommu_probe_finalize,
 	.release_device = riscv_iommu_release_device,
+	.remove_dev_pasid = riscv_iommu_remove_dev_pasid,
 	.device_group = riscv_iommu_device_group,
 	.get_resv_regions = riscv_iommu_get_resv_regions,
 	.of_xlate = riscv_iommu_of_xlate,
+	.dev_enable_feat = riscv_iommu_dev_enable_feat,
+	.dev_disable_feat = riscv_iommu_dev_disable_feat,
+	.page_response = riscv_iommu_page_response,
 	.default_domain_ops = &riscv_iommu_domain_ops,
 };
 
@@ -1340,6 +1916,7 @@ void riscv_iommu_remove(struct riscv_iommu_device *iommu)
 	riscv_iommu_queue_free(iommu, &iommu->cmdq);
 	riscv_iommu_queue_free(iommu, &iommu->fltq);
 	riscv_iommu_queue_free(iommu, &iommu->priq);
+	iopf_queue_free(iommu->pq_work);
 }
 
 int riscv_iommu_init(struct riscv_iommu_device *iommu)
@@ -1362,6 +1939,12 @@ int riscv_iommu_init(struct riscv_iommu_device *iommu)
 	}
 #endif
 
+	if (iommu->cap & RISCV_IOMMU_CAP_PD20)
+		iommu->iommu.max_pasids = 1u << 20;
+	else if (iommu->cap & RISCV_IOMMU_CAP_PD17)
+		iommu->iommu.max_pasids = 1u << 17;
+	else if (iommu->cap & RISCV_IOMMU_CAP_PD8)
+		iommu->iommu.max_pasids = 1u << 8;
 	/*
 	 * Assign queue lengths from module parameters if not already
 	 * set on the device tree.
@@ -1387,6 +1970,13 @@ int riscv_iommu_init(struct riscv_iommu_device *iommu)
 		goto fail;
 	if (!(iommu->cap & RISCV_IOMMU_CAP_ATS))
 		goto no_ats;
+	/* PRI functionally depends on ATS’s capabilities. */
+	iommu->pq_work = iopf_queue_alloc(dev_name(dev));
+	if (!iommu->pq_work) {
+		dev_err(dev, "failed to allocate iopf queue\n");
+		ret = -ENOMEM;
+		goto fail;
+	}
 
 	ret = riscv_iommu_queue_init(iommu, RISCV_IOMMU_PAGE_REQUEST_QUEUE);
 	if (ret)
@@ -1424,5 +2014,6 @@ int riscv_iommu_init(struct riscv_iommu_device *iommu)
 	riscv_iommu_queue_free(iommu, &iommu->priq);
 	riscv_iommu_queue_free(iommu, &iommu->fltq);
 	riscv_iommu_queue_free(iommu, &iommu->cmdq);
+	iopf_queue_free(iommu->pq_work);
 	return ret;
 }
diff --git a/drivers/iommu/riscv/iommu.h b/drivers/iommu/riscv/iommu.h
index fe32a4eff14e..83e8d00fd0f8 100644
--- a/drivers/iommu/riscv/iommu.h
+++ b/drivers/iommu/riscv/iommu.h
@@ -17,9 +17,11 @@
 #include <linux/iova.h>
 #include <linux/io.h>
 #include <linux/idr.h>
+#include <linux/mmu_notifier.h>
 #include <linux/list.h>
 #include <linux/iommu.h>
 #include <linux/io-pgtable.h>
+#include <linux/mmu_notifier.h>
 
 #include "iommu-bits.h"
 
@@ -76,6 +78,9 @@ struct riscv_iommu_device {
 	unsigned ddt_mode;
 	bool ddtp_in_iomem;
 
+	/* I/O page fault queue */
+	struct iopf_queue *pq_work;
+
 	/* hardware queues */
 	struct riscv_iommu_queue cmdq;
 	struct riscv_iommu_queue fltq;
@@ -91,11 +96,14 @@ struct riscv_iommu_domain {
 	struct io_pgtable pgtbl;
 
 	struct list_head endpoints;
+	struct list_head notifiers;
 	struct mutex lock;
+	struct mmu_notifier mn;
 	struct riscv_iommu_device *iommu;
 
 	unsigned mode;		/* RIO_ATP_MODE_* enum */
 	unsigned pscid;		/* RISC-V IOMMU PSCID */
+	ioasid_t pasid;		/* IOMMU_DOMAIN_SVA: Cached PASID */
 
 	pgd_t *pgd_root;	/* page table root pointer */
 };
@@ -107,10 +115,16 @@ struct riscv_iommu_endpoint {
 	unsigned domid;    			/* PCI domain number, segment */
 	struct rb_node node;    		/* device tracking node (lookup by devid) */
 	struct riscv_iommu_dc *dc;		/* device context pointer */
+	struct riscv_iommu_pc *pc;		/* process context root, valid if pasid_enabled is true */
 	struct riscv_iommu_device *iommu;	/* parent iommu device */
 
 	struct mutex lock;
 	struct list_head domain;		/* endpoint attached managed domain */
+
+	/* end point info bits */
+	unsigned pasid_bits;
+	unsigned pasid_feat;
+	bool pasid_enabled;
 };
 
 /* Helper functions and macros */
-- 
2.34.1




More information about the linux-riscv mailing list