[PATCH 06/11] firmware: arm_scmi: add support for protocol modularization

Cristian Marussi cristian.marussi at arm.com
Wed Oct 14 11:05:40 EDT 2020


Modify protocol initialization callback adding a new parameter representing
a reference to the available xfer core operations and introduce a macro to
simply register with the core new protocols as loadable drivers.
Keep standard protocols as builtin.

Signed-off-by: Cristian Marussi <cristian.marussi at arm.com>
---
 drivers/firmware/arm_scmi/base.c    | 56 ++++++++++--------
 drivers/firmware/arm_scmi/bus.c     | 14 ++++-
 drivers/firmware/arm_scmi/clock.c   | 56 +++++++++---------
 drivers/firmware/arm_scmi/common.h  | 42 +++++++++-----
 drivers/firmware/arm_scmi/driver.c  | 50 ++++++++++------
 drivers/firmware/arm_scmi/perf.c    | 88 +++++++++++++++--------------
 drivers/firmware/arm_scmi/power.c   | 46 ++++++++-------
 drivers/firmware/arm_scmi/reset.c   | 46 ++++++++-------
 drivers/firmware/arm_scmi/sensors.c | 52 +++++++++--------
 drivers/firmware/arm_scmi/system.c  | 16 ++++--
 include/linux/scmi_protocol.h       | 18 +++++-
 11 files changed, 288 insertions(+), 196 deletions(-)

diff --git a/drivers/firmware/arm_scmi/base.c b/drivers/firmware/arm_scmi/base.c
index f40821eeb103..8d7214fd2187 100644
--- a/drivers/firmware/arm_scmi/base.c
+++ b/drivers/firmware/arm_scmi/base.c
@@ -15,6 +15,8 @@
 #define SCMI_BASE_NUM_SOURCES		1
 #define SCMI_BASE_MAX_CMD_ERR_COUNT	1024
 
+static const struct scmi_xfer_ops *ops;
+
 enum scmi_base_protocol_cmd {
 	BASE_DISCOVER_VENDOR = 0x3,
 	BASE_DISCOVER_SUB_VENDOR = 0x4,
@@ -61,19 +63,19 @@ static int scmi_base_attributes_get(const struct scmi_handle *handle)
 	struct scmi_msg_resp_base_attributes *attr_info;
 	struct scmi_revision_info *rev = handle->version;
 
-	ret = scmi_xfer_get_init(handle, PROTOCOL_ATTRIBUTES,
+	ret = ops->xfer_get_init(handle, PROTOCOL_ATTRIBUTES,
 				 SCMI_PROTOCOL_BASE, 0, sizeof(*attr_info), &t);
 	if (ret)
 		return ret;
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 	if (!ret) {
 		attr_info = t->rx.buf;
 		rev->num_protocols = attr_info->num_protocols;
 		rev->num_agents = attr_info->num_agents;
 	}
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 
 	return ret;
 }
@@ -105,15 +107,15 @@ scmi_base_vendor_id_get(const struct scmi_handle *handle, bool sub_vendor)
 		size = ARRAY_SIZE(rev->vendor_id);
 	}
 
-	ret = scmi_xfer_get_init(handle, cmd, SCMI_PROTOCOL_BASE, 0, size, &t);
+	ret = ops->xfer_get_init(handle, cmd, SCMI_PROTOCOL_BASE, 0, size, &t);
 	if (ret)
 		return ret;
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 	if (!ret)
 		memcpy(vendor_id, t->rx.buf, size);
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 
 	return ret;
 }
@@ -135,18 +137,18 @@ scmi_base_implementation_version_get(const struct scmi_handle *handle)
 	struct scmi_xfer *t;
 	struct scmi_revision_info *rev = handle->version;
 
-	ret = scmi_xfer_get_init(handle, BASE_DISCOVER_IMPLEMENT_VERSION,
+	ret = ops->xfer_get_init(handle, BASE_DISCOVER_IMPLEMENT_VERSION,
 				 SCMI_PROTOCOL_BASE, 0, sizeof(*impl_ver), &t);
 	if (ret)
 		return ret;
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 	if (!ret) {
 		impl_ver = t->rx.buf;
 		rev->impl_ver = le32_to_cpu(*impl_ver);
 	}
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 
 	return ret;
 }
@@ -170,7 +172,7 @@ static int scmi_base_implementation_list_get(const struct scmi_handle *handle,
 	u32 tot_num_ret = 0, loop_num_ret;
 	struct device *dev = handle->dev;
 
-	ret = scmi_xfer_get_init(handle, BASE_DISCOVER_LIST_PROTOCOLS,
+	ret = ops->xfer_get_init(handle, BASE_DISCOVER_LIST_PROTOCOLS,
 				 SCMI_PROTOCOL_BASE, sizeof(*num_skip), 0, &t);
 	if (ret)
 		return ret;
@@ -183,7 +185,7 @@ static int scmi_base_implementation_list_get(const struct scmi_handle *handle,
 		/* Set the number of protocols to be skipped/already read */
 		*num_skip = cpu_to_le32(tot_num_ret);
 
-		ret = scmi_do_xfer(handle, t);
+		ret = ops->do_xfer(handle, t);
 		if (ret)
 			break;
 
@@ -198,10 +200,10 @@ static int scmi_base_implementation_list_get(const struct scmi_handle *handle,
 
 		tot_num_ret += loop_num_ret;
 
-		scmi_reset_rx_to_maxsz(handle, t);
+		ops->reset_rx_to_maxsz(handle, t);
 	} while (loop_num_ret);
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 
 	return ret;
 }
@@ -224,7 +226,7 @@ static int scmi_base_discover_agent_get(const struct scmi_handle *handle,
 	int ret;
 	struct scmi_xfer *t;
 
-	ret = scmi_xfer_get_init(handle, BASE_DISCOVER_AGENT,
+	ret = ops->xfer_get_init(handle, BASE_DISCOVER_AGENT,
 				 SCMI_PROTOCOL_BASE, sizeof(__le32),
 				 SCMI_MAX_STR_SIZE, &t);
 	if (ret)
@@ -232,11 +234,11 @@ static int scmi_base_discover_agent_get(const struct scmi_handle *handle,
 
 	put_unaligned_le32(id, t->tx.buf);
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 	if (!ret)
 		strlcpy(name, t->rx.buf, SCMI_MAX_STR_SIZE);
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 
 	return ret;
 }
@@ -248,7 +250,7 @@ static int scmi_base_error_notify(const struct scmi_handle *handle, bool enable)
 	struct scmi_xfer *t;
 	struct scmi_msg_base_error_notify *cfg;
 
-	ret = scmi_xfer_get_init(handle, BASE_NOTIFY_ERRORS,
+	ret = ops->xfer_get_init(handle, BASE_NOTIFY_ERRORS,
 				 SCMI_PROTOCOL_BASE, sizeof(*cfg), 0, &t);
 	if (ret)
 		return ret;
@@ -256,9 +258,9 @@ static int scmi_base_error_notify(const struct scmi_handle *handle, bool enable)
 	cfg = t->tx.buf;
 	cfg->event_control = cpu_to_le32(evt_cntl);
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -326,7 +328,8 @@ static const struct scmi_protocol_events base_protocol_events = {
 	.num_sources = SCMI_BASE_NUM_SOURCES,
 };
 
-static int scmi_base_protocol_init(const struct scmi_handle *h)
+static int scmi_base_protocol_init(const struct scmi_handle *h,
+				   const struct scmi_xfer_ops *xops)
 {
 	int id, ret;
 	u8 *prot_imp;
@@ -336,7 +339,8 @@ static int scmi_base_protocol_init(const struct scmi_handle *h)
 	struct device *dev = handle->dev;
 	struct scmi_revision_info *rev = handle->version;
 
-	ret = scmi_version_get(handle, SCMI_PROTOCOL_BASE, &version);
+	ops = xops;
+	ret = ops->version_get(handle, SCMI_PROTOCOL_BASE, &version);
 	if (ret)
 		return ret;
 
@@ -375,4 +379,12 @@ static struct scmi_protocol scmi_base = {
 	.events = &base_protocol_events,
 };
 
-DEFINE_SCMI_PROTOCOL_REGISTER_UNREGISTER(base, scmi_base)
+int __init scmi_base_register(void)
+{
+	return scmi_protocol_register(&scmi_base, NULL);
+}
+
+void __exit scmi_base_unregister(void)
+{
+	return scmi_protocol_unregister(&scmi_base);
+}
diff --git a/drivers/firmware/arm_scmi/bus.c b/drivers/firmware/arm_scmi/bus.c
index 3a2be1193c85..2ce98fae56e3 100644
--- a/drivers/firmware/arm_scmi/bus.c
+++ b/drivers/firmware/arm_scmi/bus.c
@@ -56,7 +56,7 @@ const struct scmi_protocol *scmi_get_protocol(int protocol_id)
 	const struct scmi_protocol *proto;
 
 	proto = idr_find(&scmi_available_protocols, protocol_id);
-	if (!proto) {
+	if (!proto || !try_module_get(proto->owner)) {
 		pr_warn("SCMI Protocol 0x%x not found!\n", protocol_id);
 		return NULL;
 	}
@@ -66,6 +66,15 @@ const struct scmi_protocol *scmi_get_protocol(int protocol_id)
 	return proto;
 }
 
+void scmi_put_protocol(int protocol_id)
+{
+	const struct scmi_protocol *proto;
+
+	proto = idr_find(&scmi_available_protocols, protocol_id);
+	if (proto)
+		module_put(proto->owner);
+}
+
 static int scmi_dev_probe(struct device *dev)
 {
 	struct scmi_driver *scmi_drv = to_scmi_driver(dev->driver);
@@ -186,7 +195,7 @@ void scmi_set_handle(struct scmi_device *scmi_dev)
 	scmi_dev->handle = scmi_handle_get(&scmi_dev->dev);
 }
 
-int scmi_protocol_register(struct scmi_protocol *proto)
+int scmi_protocol_register(struct scmi_protocol *proto, struct module *owner)
 {
 	int ret;
 
@@ -200,6 +209,7 @@ int scmi_protocol_register(struct scmi_protocol *proto)
 		return -EINVAL;
 	}
 
+	proto->owner = owner;
 	spin_lock(&protocol_lock);
 	ret = idr_alloc(&scmi_available_protocols, proto,
 			proto->id, proto->id + 1, GFP_ATOMIC);
diff --git a/drivers/firmware/arm_scmi/clock.c b/drivers/firmware/arm_scmi/clock.c
index 539c94860b8f..a2f552c87b3e 100644
--- a/drivers/firmware/arm_scmi/clock.c
+++ b/drivers/firmware/arm_scmi/clock.c
@@ -9,6 +9,8 @@
 
 #include "common.h"
 
+static const struct scmi_xfer_ops *ops;
+
 enum scmi_clock_protocol_cmd {
 	CLOCK_ATTRIBUTES = 0x3,
 	CLOCK_DESCRIBE_RATES = 0x4,
@@ -81,20 +83,20 @@ static int scmi_clock_protocol_attributes_get(const struct scmi_handle *handle,
 	struct scmi_xfer *t;
 	struct scmi_msg_resp_clock_protocol_attributes *attr;
 
-	ret = scmi_xfer_get_init(handle, PROTOCOL_ATTRIBUTES,
+	ret = ops->xfer_get_init(handle, PROTOCOL_ATTRIBUTES,
 				 SCMI_PROTOCOL_CLOCK, 0, sizeof(*attr), &t);
 	if (ret)
 		return ret;
 
 	attr = t->rx.buf;
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 	if (!ret) {
 		ci->num_clocks = le16_to_cpu(attr->num_clocks);
 		ci->max_async_req = attr->max_async_req;
 	}
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -105,7 +107,7 @@ static int scmi_clock_attributes_get(const struct scmi_handle *handle,
 	struct scmi_xfer *t;
 	struct scmi_msg_resp_clock_attributes *attr;
 
-	ret = scmi_xfer_get_init(handle, CLOCK_ATTRIBUTES, SCMI_PROTOCOL_CLOCK,
+	ret = ops->xfer_get_init(handle, CLOCK_ATTRIBUTES, SCMI_PROTOCOL_CLOCK,
 				 sizeof(clk_id), sizeof(*attr), &t);
 	if (ret)
 		return ret;
@@ -113,13 +115,13 @@ static int scmi_clock_attributes_get(const struct scmi_handle *handle,
 	put_unaligned_le32(clk_id, t->tx.buf);
 	attr = t->rx.buf;
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 	if (!ret)
 		strlcpy(clk->name, attr->name, SCMI_MAX_STR_SIZE);
 	else
 		clk->name[0] = '\0';
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -148,7 +150,7 @@ scmi_clock_describe_rates_get(const struct scmi_handle *handle, u32 clk_id,
 	struct scmi_msg_clock_describe_rates *clk_desc;
 	struct scmi_msg_resp_clock_describe_rates *rlist;
 
-	ret = scmi_xfer_get_init(handle, CLOCK_DESCRIBE_RATES,
+	ret = ops->xfer_get_init(handle, CLOCK_DESCRIBE_RATES,
 				 SCMI_PROTOCOL_CLOCK, sizeof(*clk_desc), 0, &t);
 	if (ret)
 		return ret;
@@ -161,7 +163,7 @@ scmi_clock_describe_rates_get(const struct scmi_handle *handle, u32 clk_id,
 		/* Set the number of rates to be skipped/already read */
 		clk_desc->rate_index = cpu_to_le32(tot_rate_cnt);
 
-		ret = scmi_do_xfer(handle, t);
+		ret = ops->do_xfer(handle, t);
 		if (ret)
 			goto err;
 
@@ -193,7 +195,7 @@ scmi_clock_describe_rates_get(const struct scmi_handle *handle, u32 clk_id,
 
 		tot_rate_cnt += num_returned;
 
-		scmi_reset_rx_to_maxsz(handle, t);
+		ops->reset_rx_to_maxsz(handle, t);
 		/*
 		 * check for both returned and remaining to avoid infinite
 		 * loop due to buggy firmware
@@ -208,7 +210,7 @@ scmi_clock_describe_rates_get(const struct scmi_handle *handle, u32 clk_id,
 	clk->rate_discrete = rate_discrete;
 
 err:
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -218,18 +220,18 @@ scmi_clock_rate_get(const struct scmi_handle *handle, u32 clk_id, u64 *value)
 	int ret;
 	struct scmi_xfer *t;
 
-	ret = scmi_xfer_get_init(handle, CLOCK_RATE_GET, SCMI_PROTOCOL_CLOCK,
+	ret = ops->xfer_get_init(handle, CLOCK_RATE_GET, SCMI_PROTOCOL_CLOCK,
 				 sizeof(__le32), sizeof(u64), &t);
 	if (ret)
 		return ret;
 
 	put_unaligned_le32(clk_id, t->tx.buf);
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 	if (!ret)
 		*value = get_unaligned_le64(t->rx.buf);
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -241,9 +243,9 @@ static int scmi_clock_rate_set(const struct scmi_handle *handle, u32 clk_id,
 	struct scmi_xfer *t;
 	struct scmi_clock_set_rate *cfg;
 	struct clock_info *ci =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_CLOCK);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_CLOCK);
 
-	ret = scmi_xfer_get_init(handle, CLOCK_RATE_SET, SCMI_PROTOCOL_CLOCK,
+	ret = ops->xfer_get_init(handle, CLOCK_RATE_SET, SCMI_PROTOCOL_CLOCK,
 				 sizeof(*cfg), 0, &t);
 	if (ret)
 		return ret;
@@ -259,14 +261,14 @@ static int scmi_clock_rate_set(const struct scmi_handle *handle, u32 clk_id,
 	cfg->value_high = cpu_to_le32(rate >> 32);
 
 	if (flags & CLOCK_SET_ASYNC)
-		ret = scmi_do_xfer_with_response(handle, t);
+		ret = ops->do_xfer_with_response(handle, t);
 	else
-		ret = scmi_do_xfer(handle, t);
+		ret = ops->do_xfer(handle, t);
 
 	if (ci->max_async_req)
 		atomic_dec(&ci->cur_async_req);
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -277,7 +279,7 @@ scmi_clock_config_set(const struct scmi_handle *handle, u32 clk_id, u32 config)
 	struct scmi_xfer *t;
 	struct scmi_clock_set_config *cfg;
 
-	ret = scmi_xfer_get_init(handle, CLOCK_CONFIG_SET, SCMI_PROTOCOL_CLOCK,
+	ret = ops->xfer_get_init(handle, CLOCK_CONFIG_SET, SCMI_PROTOCOL_CLOCK,
 				 sizeof(*cfg), 0, &t);
 	if (ret)
 		return ret;
@@ -286,9 +288,9 @@ scmi_clock_config_set(const struct scmi_handle *handle, u32 clk_id, u32 config)
 	cfg->id = cpu_to_le32(clk_id);
 	cfg->attributes = cpu_to_le32(config);
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -305,7 +307,7 @@ static int scmi_clock_disable(const struct scmi_handle *handle, u32 clk_id)
 static int scmi_clock_count_get(const struct scmi_handle *handle)
 {
 	struct clock_info *ci =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_CLOCK);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_CLOCK);
 
 	return ci->num_clocks;
 }
@@ -314,7 +316,7 @@ static const struct scmi_clock_info *
 scmi_clock_info_get(const struct scmi_handle *handle, u32 clk_id)
 {
 	struct clock_info *ci =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_CLOCK);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_CLOCK);
 	struct scmi_clock_info *clk = ci->clk + clk_id;
 
 	if (!clk->name[0])
@@ -332,13 +334,15 @@ static const struct scmi_clk_ops clk_ops = {
 	.disable = scmi_clock_disable,
 };
 
-static int scmi_clock_protocol_init(const struct scmi_handle *handle)
+static int scmi_clock_protocol_init(const struct scmi_handle *handle,
+				    const struct scmi_xfer_ops *xops)
 {
 	u32 version;
 	int clkid, ret;
 	struct clock_info *cinfo;
 
-	scmi_version_get(handle, SCMI_PROTOCOL_CLOCK, &version);
+	ops = xops;
+	ops->version_get(handle, SCMI_PROTOCOL_CLOCK, &version);
 
 	dev_dbg(handle->dev, "Clock Version %d.%d\n",
 		PROTOCOL_REV_MAJOR(version), PROTOCOL_REV_MINOR(version));
@@ -363,7 +367,7 @@ static int scmi_clock_protocol_init(const struct scmi_handle *handle)
 	}
 
 	cinfo->version = version;
-	return scmi_set_proto_priv(handle, SCMI_PROTOCOL_CLOCK, cinfo);
+	return ops->set_proto_priv(handle, SCMI_PROTOCOL_CLOCK, cinfo);
 }
 
 static struct scmi_protocol scmi_clock = {
diff --git a/drivers/firmware/arm_scmi/common.h b/drivers/firmware/arm_scmi/common.h
index 5a91e3324697..ec81edc12ca5 100644
--- a/drivers/firmware/arm_scmi/common.h
+++ b/drivers/firmware/arm_scmi/common.h
@@ -14,6 +14,7 @@
 #include <linux/device.h>
 #include <linux/errno.h>
 #include <linux/kernel.h>
+#include <linux/module.h>
 #include <linux/scmi_protocol.h>
 #include <linux/types.h>
 
@@ -145,26 +146,39 @@ struct scmi_xfer {
 	struct completion *async_done;
 };
 
-void scmi_xfer_put(const struct scmi_handle *h, struct scmi_xfer *xfer);
-int scmi_do_xfer(const struct scmi_handle *h, struct scmi_xfer *xfer);
-int scmi_do_xfer_with_response(const struct scmi_handle *h,
-			       struct scmi_xfer *xfer);
-int scmi_xfer_get_init(const struct scmi_handle *h, u8 msg_id, u8 prot_id,
-		       size_t tx_size, size_t rx_size, struct scmi_xfer **p);
-void scmi_reset_rx_to_maxsz(const struct scmi_handle *handle,
-			    struct scmi_xfer *xfer);
+struct scmi_xfer_ops {
+	int (*version_get)(const struct scmi_handle *handle, u8 protocol,
+			   u32 *version);
+	int (*set_proto_priv)(const struct scmi_handle *handle, u8 protocol_id,
+			      void *priv);
+	void *(*get_proto_priv)(const struct scmi_handle *handle,
+				u8 protocol_id);
+	int (*xfer_get_init)(const struct scmi_handle *handle, u8 msg_id,
+			     u8 prot_id, size_t tx_size, size_t rx_size,
+			     struct scmi_xfer **p);
+	void (*reset_rx_to_maxsz)(const struct scmi_handle *handle,
+				  struct scmi_xfer *xfer);
+	int (*do_xfer)(const struct scmi_handle *handle,
+		       struct scmi_xfer *xfer);
+	int (*do_xfer_with_response)(const struct scmi_handle *handle,
+				     struct scmi_xfer *xfer);
+	void (*xfer_put)(const struct scmi_handle *handle,
+			 struct scmi_xfer *xfer);
+};
+
 int scmi_handle_put(const struct scmi_handle *handle);
 struct scmi_handle *scmi_handle_get(struct device *dev);
 void scmi_set_handle(struct scmi_device *scmi_dev);
-int scmi_version_get(const struct scmi_handle *h, u8 protocol, u32 *version);
 void scmi_setup_protocol_implemented(const struct scmi_handle *handle,
 				     u8 *prot_imp);
 
-typedef int (*scmi_prot_init_fn_t)(const struct scmi_handle *);
+typedef int (*scmi_prot_init_fn_t)(const struct scmi_handle *,
+				   const struct scmi_xfer_ops *);
 
 /**
  * struct scmi_protocol  - Protocol descriptor
  * @id: Protocol ID.
+ * @owner: Module reference if any.
  * @init: Mandatory protocol initialization function.
  * @deinit: Optional protocol de-initialization function.
  * @ops: Optional reference to the operations provided by the protocol and
@@ -173,6 +187,7 @@ typedef int (*scmi_prot_init_fn_t)(const struct scmi_handle *);
  */
 struct scmi_protocol {
 	const u8				id;
+	struct module				*owner;
 	const scmi_prot_init_fn_t		init;
 	const scmi_prot_init_fn_t		deinit;
 	const void				*ops;
@@ -196,7 +211,7 @@ DECLARE_SCMI_REGISTER_UNREGISTER(system);
 #define DEFINE_SCMI_PROTOCOL_REGISTER_UNREGISTER(name, proto) \
 int __init scmi_##name##_register(void) \
 { \
-	return scmi_protocol_register(&(proto)); \
+	return scmi_protocol_register(&(proto), THIS_MODULE); \
 } \
 \
 void __exit scmi_##name##_unregister(void) \
@@ -205,14 +220,11 @@ void __exit scmi_##name##_unregister(void) \
 }
 
 const struct scmi_protocol *scmi_get_protocol(int protocol_id);
+void scmi_put_protocol(int protocol_id);
 
 int scmi_acquire_protocol(const struct scmi_handle *handle, u8 protocol_id);
 void scmi_release_protocol(const struct scmi_handle *handle, u8 protocol_id);
 
-void *scmi_get_proto_priv(const struct scmi_handle *h, u8 prot);
-int scmi_set_proto_priv(const struct scmi_handle *handle, const u8 proto,
-			void *priv);
-
 /* SCMI Transport */
 /**
  * struct scmi_chan_info - Structure representing a SCMI channel information
diff --git a/drivers/firmware/arm_scmi/driver.c b/drivers/firmware/arm_scmi/driver.c
index 25a4152537e6..55df134c2338 100644
--- a/drivers/firmware/arm_scmi/driver.c
+++ b/drivers/firmware/arm_scmi/driver.c
@@ -370,7 +370,8 @@ void scmi_rx_callback(struct scmi_chan_info *cinfo, u32 msg_hdr)
  * @handle: Pointer to SCMI entity handle
  * @xfer: message that was reserved by scmi_xfer_get
  */
-void scmi_xfer_put(const struct scmi_handle *handle, struct scmi_xfer *xfer)
+static void scmi_xfer_put(const struct scmi_handle *handle,
+			  struct scmi_xfer *xfer)
 {
 	struct scmi_info *info = handle_to_scmi_info(handle);
 
@@ -398,7 +399,8 @@ static bool scmi_xfer_done_no_timeout(struct scmi_chan_info *cinfo,
  *	return corresponding error, else if all goes well,
  *	return 0.
  */
-int scmi_do_xfer(const struct scmi_handle *handle, struct scmi_xfer *xfer)
+static int scmi_do_xfer(const struct scmi_handle *handle,
+			struct scmi_xfer *xfer)
 {
 	int ret;
 	int timeout;
@@ -451,8 +453,8 @@ int scmi_do_xfer(const struct scmi_handle *handle, struct scmi_xfer *xfer)
 	return ret;
 }
 
-void scmi_reset_rx_to_maxsz(const struct scmi_handle *handle,
-			    struct scmi_xfer *xfer)
+static void scmi_reset_rx_to_maxsz(const struct scmi_handle *handle,
+				   struct scmi_xfer *xfer)
 {
 	struct scmi_info *info = handle_to_scmi_info(handle);
 
@@ -471,8 +473,8 @@ void scmi_reset_rx_to_maxsz(const struct scmi_handle *handle,
  * Return: -ETIMEDOUT in case of no delayed response, if transmit error,
  *	return corresponding error, else if all goes well, return 0.
  */
-int scmi_do_xfer_with_response(const struct scmi_handle *handle,
-			       struct scmi_xfer *xfer)
+static int scmi_do_xfer_with_response(const struct scmi_handle *handle,
+				      struct scmi_xfer *xfer)
 {
 	int ret, timeout = msecs_to_jiffies(SCMI_MAX_RESPONSE_TIMEOUT);
 	DECLARE_COMPLETION_ONSTACK(async_response);
@@ -503,8 +505,9 @@ int scmi_do_xfer_with_response(const struct scmi_handle *handle,
  * Return: 0 if all went fine with @p pointing to message, else
  *	corresponding error.
  */
-int scmi_xfer_get_init(const struct scmi_handle *handle, u8 msg_id, u8 prot_id,
-		       size_t tx_size, size_t rx_size, struct scmi_xfer **p)
+static int scmi_xfer_get_init(const struct scmi_handle *handle, u8 msg_id,
+			      u8 prot_id, size_t tx_size, size_t rx_size,
+			      struct scmi_xfer **p)
 {
 	int ret;
 	struct scmi_xfer *xfer;
@@ -546,8 +549,8 @@ int scmi_xfer_get_init(const struct scmi_handle *handle, u8 msg_id, u8 prot_id,
  *
  * Return: 0 if all went fine, else return appropriate error.
  */
-int scmi_version_get(const struct scmi_handle *handle, u8 protocol,
-		     u32 *version)
+static int scmi_version_get(const struct scmi_handle *handle, u8 protocol,
+			    u32 *version)
 {
 	int ret;
 	__le32 *rev_info;
@@ -568,8 +571,8 @@ int scmi_version_get(const struct scmi_handle *handle, u8 protocol,
 	return ret;
 }
 
-int scmi_set_proto_priv(const struct scmi_handle *handle,
-			u8 protocol_id, void *priv)
+static int scmi_set_proto_priv(const struct scmi_handle *handle,
+			       u8 protocol_id, void *priv)
 {
 	struct scmi_info *info = handle_to_scmi_info(handle);
 
@@ -585,7 +588,8 @@ int scmi_set_proto_priv(const struct scmi_handle *handle,
 	return 0;
 }
 
-void *scmi_get_proto_priv(const struct scmi_handle *handle, u8 protocol_id)
+static void *scmi_get_proto_priv(const struct scmi_handle *handle,
+				 u8 protocol_id)
 {
 	struct scmi_info *info = handle_to_scmi_info(handle);
 
@@ -594,6 +598,17 @@ void *scmi_get_proto_priv(const struct scmi_handle *handle, u8 protocol_id)
 	return info->protocols_private_data[protocol_id];
 }
 
+static const struct scmi_xfer_ops xfer_ops = {
+	.version_get = scmi_version_get,
+	.set_proto_priv = scmi_set_proto_priv,
+	.get_proto_priv = scmi_get_proto_priv,
+	.xfer_get_init = scmi_xfer_get_init,
+	.reset_rx_to_maxsz = scmi_reset_rx_to_maxsz,
+	.do_xfer = scmi_do_xfer,
+	.do_xfer_with_response = scmi_do_xfer_with_response,
+	.xfer_put = scmi_xfer_put,
+};
+
 /**
  * scmi_get_protocol_instance  - Protocol initialization helper.
  * @handle: A reference to the SCMI platform instance.
@@ -624,7 +639,7 @@ scmi_get_protocol_instance(const struct scmi_handle *handle, u8 protocol_id)
 		/* Fail if protocol not registered on bus */
 		proto = scmi_get_protocol(protocol_id);
 		if (!proto) {
-			ret = -EINVAL;
+			ret = -EPROBE_DEFER;
 			goto out;
 		}
 
@@ -641,7 +656,7 @@ scmi_get_protocol_instance(const struct scmi_handle *handle, u8 protocol_id)
 		pi->proto = proto;
 		refcount_set(&pi->users, 1);
 		/* proto->init is assured NON NULL by scmi_protocol_register */
-		ret = pi->proto->init(handle);
+		ret = pi->proto->init(handle, &xfer_ops);
 		if (ret)
 			goto clean;
 
@@ -664,6 +679,7 @@ scmi_get_protocol_instance(const struct scmi_handle *handle, u8 protocol_id)
 	return pi;
 
 clean:
+	scmi_put_protocol(protocol_id);
 	devres_release_group(handle->dev, gid);
 out:
 	mutex_unlock(&info->protocols_mtx);
@@ -714,13 +730,15 @@ void scmi_release_protocol(const struct scmi_handle *handle, u8 protocol_id)
 			scmi_deregister_protocol_events(handle, protocol_id);
 
 		if (pi->proto->deinit)
-			pi->proto->deinit(handle);
+			pi->proto->deinit(handle, &xfer_ops);
 
 		info->protocols_private_data[protocol_id] = NULL;
 		info->protocols[protocol_id] = NULL;
 		/* Ensure deinitialized protocol is visible */
 		smp_wmb();
 
+		scmi_put_protocol(protocol_id);
+
 		devres_release_group(handle->dev, gid);
 		dev_dbg(handle->dev, "De-Initialized protocol: 0x%X\n",
 			protocol_id);
diff --git a/drivers/firmware/arm_scmi/perf.c b/drivers/firmware/arm_scmi/perf.c
index b3038362f362..60a28ca39455 100644
--- a/drivers/firmware/arm_scmi/perf.c
+++ b/drivers/firmware/arm_scmi/perf.c
@@ -19,6 +19,8 @@
 #include "common.h"
 #include "notify.h"
 
+static const struct scmi_xfer_ops *ops;
+
 enum scmi_performance_protocol_cmd {
 	PERF_DOMAIN_ATTRIBUTES = 0x3,
 	PERF_DESCRIBE_LEVELS = 0x4,
@@ -182,14 +184,14 @@ static int scmi_perf_attributes_get(const struct scmi_handle *handle,
 	struct scmi_xfer *t;
 	struct scmi_msg_resp_perf_attributes *attr;
 
-	ret = scmi_xfer_get_init(handle, PROTOCOL_ATTRIBUTES,
+	ret = ops->xfer_get_init(handle, PROTOCOL_ATTRIBUTES,
 				 SCMI_PROTOCOL_PERF, 0, sizeof(*attr), &t);
 	if (ret)
 		return ret;
 
 	attr = t->rx.buf;
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 	if (!ret) {
 		u16 flags = le16_to_cpu(attr->flags);
 
@@ -200,7 +202,7 @@ static int scmi_perf_attributes_get(const struct scmi_handle *handle,
 		pi->stats_size = le32_to_cpu(attr->stats_size);
 	}
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -212,7 +214,7 @@ scmi_perf_domain_attributes_get(const struct scmi_handle *handle, u32 domain,
 	struct scmi_xfer *t;
 	struct scmi_msg_resp_perf_domain_attributes *attr;
 
-	ret = scmi_xfer_get_init(handle, PERF_DOMAIN_ATTRIBUTES,
+	ret = ops->xfer_get_init(handle, PERF_DOMAIN_ATTRIBUTES,
 				 SCMI_PROTOCOL_PERF, sizeof(domain),
 				 sizeof(*attr), &t);
 	if (ret)
@@ -221,7 +223,7 @@ scmi_perf_domain_attributes_get(const struct scmi_handle *handle, u32 domain,
 	put_unaligned_le32(domain, t->tx.buf);
 	attr = t->rx.buf;
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 	if (!ret) {
 		u32 flags = le32_to_cpu(attr->flags);
 
@@ -245,7 +247,7 @@ scmi_perf_domain_attributes_get(const struct scmi_handle *handle, u32 domain,
 		strlcpy(dom_info->name, attr->name, SCMI_MAX_STR_SIZE);
 	}
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -268,7 +270,7 @@ scmi_perf_describe_levels_get(const struct scmi_handle *handle, u32 domain,
 	struct scmi_msg_perf_describe_levels *dom_info;
 	struct scmi_msg_resp_perf_describe_levels *level_info;
 
-	ret = scmi_xfer_get_init(handle, PERF_DESCRIBE_LEVELS,
+	ret = ops->xfer_get_init(handle, PERF_DESCRIBE_LEVELS,
 				 SCMI_PROTOCOL_PERF, sizeof(*dom_info), 0, &t);
 	if (ret)
 		return ret;
@@ -281,7 +283,7 @@ scmi_perf_describe_levels_get(const struct scmi_handle *handle, u32 domain,
 		/* Set the number of OPPs to be skipped/already read */
 		dom_info->level_index = cpu_to_le32(tot_opp_cnt);
 
-		ret = scmi_do_xfer(handle, t);
+		ret = ops->do_xfer(handle, t);
 		if (ret)
 			break;
 
@@ -305,7 +307,7 @@ scmi_perf_describe_levels_get(const struct scmi_handle *handle, u32 domain,
 
 		tot_opp_cnt += num_returned;
 
-		scmi_reset_rx_to_maxsz(handle, t);
+		ops->reset_rx_to_maxsz(handle, t);
 		/*
 		 * check for both returned and remaining to avoid infinite
 		 * loop due to buggy firmware
@@ -313,7 +315,7 @@ scmi_perf_describe_levels_get(const struct scmi_handle *handle, u32 domain,
 	} while (num_returned && num_remaining);
 
 	perf_dom->opp_count = tot_opp_cnt;
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 
 	sort(perf_dom->opp, tot_opp_cnt, sizeof(*opp), opp_cmp_func, NULL);
 	return ret;
@@ -360,7 +362,7 @@ static int scmi_perf_mb_limits_set(const struct scmi_handle *handle, u32 domain,
 	struct scmi_xfer *t;
 	struct scmi_perf_set_limits *limits;
 
-	ret = scmi_xfer_get_init(handle, PERF_LIMITS_SET, SCMI_PROTOCOL_PERF,
+	ret = ops->xfer_get_init(handle, PERF_LIMITS_SET, SCMI_PROTOCOL_PERF,
 				 sizeof(*limits), 0, &t);
 	if (ret)
 		return ret;
@@ -370,16 +372,16 @@ static int scmi_perf_mb_limits_set(const struct scmi_handle *handle, u32 domain,
 	limits->max_level = cpu_to_le32(max_perf);
 	limits->min_level = cpu_to_le32(min_perf);
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
 static int scmi_perf_limits_set(const struct scmi_handle *handle, u32 domain,
 				u32 max_perf, u32 min_perf)
 {
-	struct scmi_perf_info *pi = scmi_get_proto_priv(handle,
+	struct scmi_perf_info *pi = ops->get_proto_priv(handle,
 							SCMI_PROTOCOL_PERF);
 	struct perf_dom_info *dom = pi->dom_info + domain;
 
@@ -400,14 +402,14 @@ static int scmi_perf_mb_limits_get(const struct scmi_handle *handle, u32 domain,
 	struct scmi_xfer *t;
 	struct scmi_perf_get_limits *limits;
 
-	ret = scmi_xfer_get_init(handle, PERF_LIMITS_GET, SCMI_PROTOCOL_PERF,
+	ret = ops->xfer_get_init(handle, PERF_LIMITS_GET, SCMI_PROTOCOL_PERF,
 				 sizeof(__le32), 0, &t);
 	if (ret)
 		return ret;
 
 	put_unaligned_le32(domain, t->tx.buf);
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 	if (!ret) {
 		limits = t->rx.buf;
 
@@ -415,7 +417,7 @@ static int scmi_perf_mb_limits_get(const struct scmi_handle *handle, u32 domain,
 		*min_perf = le32_to_cpu(limits->min_level);
 	}
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -423,7 +425,7 @@ static int scmi_perf_limits_get(const struct scmi_handle *handle, u32 domain,
 				u32 *max_perf, u32 *min_perf)
 {
 	struct scmi_perf_info *pi =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_PERF);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_PERF);
 	struct perf_dom_info *dom = pi->dom_info + domain;
 
 	if (dom->fc_info && dom->fc_info->limit_get_addr) {
@@ -442,7 +444,7 @@ static int scmi_perf_mb_level_set(const struct scmi_handle *handle, u32 domain,
 	struct scmi_xfer *t;
 	struct scmi_perf_set_level *lvl;
 
-	ret = scmi_xfer_get_init(handle, PERF_LEVEL_SET, SCMI_PROTOCOL_PERF,
+	ret = ops->xfer_get_init(handle, PERF_LEVEL_SET, SCMI_PROTOCOL_PERF,
 				 sizeof(*lvl), 0, &t);
 	if (ret)
 		return ret;
@@ -452,9 +454,9 @@ static int scmi_perf_mb_level_set(const struct scmi_handle *handle, u32 domain,
 	lvl->domain = cpu_to_le32(domain);
 	lvl->level = cpu_to_le32(level);
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -462,7 +464,7 @@ static int scmi_perf_level_set(const struct scmi_handle *handle, u32 domain,
 			       u32 level, bool poll)
 {
 	struct scmi_perf_info *pi =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_PERF);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_PERF);
 	struct perf_dom_info *dom = pi->dom_info + domain;
 
 	if (dom->fc_info && dom->fc_info->level_set_addr) {
@@ -480,7 +482,7 @@ static int scmi_perf_mb_level_get(const struct scmi_handle *handle, u32 domain,
 	int ret;
 	struct scmi_xfer *t;
 
-	ret = scmi_xfer_get_init(handle, PERF_LEVEL_GET, SCMI_PROTOCOL_PERF,
+	ret = ops->xfer_get_init(handle, PERF_LEVEL_GET, SCMI_PROTOCOL_PERF,
 				 sizeof(u32), sizeof(u32), &t);
 	if (ret)
 		return ret;
@@ -488,11 +490,11 @@ static int scmi_perf_mb_level_get(const struct scmi_handle *handle, u32 domain,
 	t->hdr.poll_completion = poll;
 	put_unaligned_le32(domain, t->tx.buf);
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 	if (!ret)
 		*level = get_unaligned_le32(t->rx.buf);
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -500,7 +502,7 @@ static int scmi_perf_level_get(const struct scmi_handle *handle, u32 domain,
 			       u32 *level, bool poll)
 {
 	struct scmi_perf_info *pi =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_PERF);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_PERF);
 	struct perf_dom_info *dom = pi->dom_info + domain;
 
 	if (dom->fc_info && dom->fc_info->level_get_addr) {
@@ -519,7 +521,7 @@ static int scmi_perf_level_limits_notify(const struct scmi_handle *handle,
 	struct scmi_xfer *t;
 	struct scmi_perf_notify_level_or_limits *notify;
 
-	ret = scmi_xfer_get_init(handle, message_id, SCMI_PROTOCOL_PERF,
+	ret = ops->xfer_get_init(handle, message_id, SCMI_PROTOCOL_PERF,
 				 sizeof(*notify), 0, &t);
 	if (ret)
 		return ret;
@@ -528,9 +530,9 @@ static int scmi_perf_level_limits_notify(const struct scmi_handle *handle,
 	notify->domain = cpu_to_le32(domain);
 	notify->notify_enable = enable ? cpu_to_le32(BIT(0)) : 0;
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -561,7 +563,7 @@ scmi_perf_domain_desc_fc(const struct scmi_handle *handle, u32 domain,
 	if (!p_addr)
 		return;
 
-	ret = scmi_xfer_get_init(handle, PERF_DESCRIBE_FASTCHANNEL,
+	ret = ops->xfer_get_init(handle, PERF_DESCRIBE_FASTCHANNEL,
 				 SCMI_PROTOCOL_PERF,
 				 sizeof(*info), sizeof(*resp), &t);
 	if (ret)
@@ -571,7 +573,7 @@ scmi_perf_domain_desc_fc(const struct scmi_handle *handle, u32 domain,
 	info->domain = cpu_to_le32(domain);
 	info->message_id = cpu_to_le32(message_id);
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 	if (ret)
 		goto err_xfer;
 
@@ -609,7 +611,7 @@ scmi_perf_domain_desc_fc(const struct scmi_handle *handle, u32 domain,
 		*p_db = db;
 	}
 err_xfer:
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 }
 
 static void scmi_perf_domain_init_fc(const struct scmi_handle *handle,
@@ -652,7 +654,7 @@ static int scmi_dvfs_device_opps_add(const struct scmi_handle *handle,
 	struct scmi_opp *opp;
 	struct perf_dom_info *dom;
 	struct scmi_perf_info *pi =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_PERF);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_PERF);
 
 	domain = scmi_dev_domain_id(dev);
 	if (domain < 0)
@@ -682,7 +684,7 @@ static int scmi_dvfs_transition_latency_get(const struct scmi_handle *handle,
 {
 	struct perf_dom_info *dom;
 	struct scmi_perf_info *pi =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_PERF);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_PERF);
 	int domain = scmi_dev_domain_id(dev);
 
 	if (domain < 0)
@@ -697,7 +699,7 @@ static int scmi_dvfs_freq_set(const struct scmi_handle *handle, u32 domain,
 			      unsigned long freq, bool poll)
 {
 	struct scmi_perf_info *pi =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_PERF);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_PERF);
 	struct perf_dom_info *dom = pi->dom_info + domain;
 
 	return scmi_perf_level_set(handle, domain, freq / dom->mult_factor,
@@ -710,7 +712,7 @@ static int scmi_dvfs_freq_get(const struct scmi_handle *handle, u32 domain,
 	int ret;
 	u32 level;
 	struct scmi_perf_info *pi =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_PERF);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_PERF);
 	struct perf_dom_info *dom = pi->dom_info + domain;
 
 	ret = scmi_perf_level_get(handle, domain, &level, poll);
@@ -724,7 +726,7 @@ static int scmi_dvfs_est_power_get(const struct scmi_handle *handle, u32 domain,
 				   unsigned long *freq, unsigned long *power)
 {
 	struct scmi_perf_info *pi =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_PERF);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_PERF);
 	struct perf_dom_info *dom;
 	unsigned long opp_freq;
 	int idx, ret = -EINVAL;
@@ -753,7 +755,7 @@ static bool scmi_fast_switch_possible(const struct scmi_handle *handle,
 {
 	struct perf_dom_info *dom;
 	struct scmi_perf_info *pi =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_PERF);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_PERF);
 
 	dom = pi->dom_info + scmi_dev_domain_id(dev);
 
@@ -842,7 +844,7 @@ static void *scmi_perf_fill_custom_report(const struct scmi_handle *handle,
 static int scmi_perf_get_num_sources(const struct scmi_handle *handle)
 {
 	struct scmi_perf_info *pinfo =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_PERF);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_PERF);
 
 	if (!pinfo)
 		return -EINVAL;
@@ -876,13 +878,15 @@ static const struct scmi_protocol_events perf_protocol_events = {
 	.num_events = ARRAY_SIZE(perf_events),
 };
 
-static int scmi_perf_protocol_init(const struct scmi_handle *handle)
+static int scmi_perf_protocol_init(const struct scmi_handle *handle,
+				   const struct scmi_xfer_ops *xops)
 {
 	int domain;
 	u32 version;
 	struct scmi_perf_info *pinfo;
 
-	scmi_version_get(handle, SCMI_PROTOCOL_PERF, &version);
+	ops = xops;
+	ops->version_get(handle, SCMI_PROTOCOL_PERF, &version);
 
 	dev_dbg(handle->dev, "Performance Version %d.%d\n",
 		PROTOCOL_REV_MAJOR(version), PROTOCOL_REV_MINOR(version));
@@ -909,7 +913,7 @@ static int scmi_perf_protocol_init(const struct scmi_handle *handle)
 	}
 
 	pinfo->version = version;
-	return scmi_set_proto_priv(handle, SCMI_PROTOCOL_PERF, pinfo);
+	return ops->set_proto_priv(handle, SCMI_PROTOCOL_PERF, pinfo);
 }
 
 static struct scmi_protocol scmi_perf = {
diff --git a/drivers/firmware/arm_scmi/power.c b/drivers/firmware/arm_scmi/power.c
index cb9b1f6a56dd..766e1782c9ff 100644
--- a/drivers/firmware/arm_scmi/power.c
+++ b/drivers/firmware/arm_scmi/power.c
@@ -12,6 +12,8 @@
 #include "common.h"
 #include "notify.h"
 
+static const struct scmi_xfer_ops *ops;
+
 enum scmi_power_protocol_cmd {
 	POWER_DOMAIN_ATTRIBUTES = 0x3,
 	POWER_STATE_SET = 0x4,
@@ -75,14 +77,14 @@ static int scmi_power_attributes_get(const struct scmi_handle *handle,
 	struct scmi_xfer *t;
 	struct scmi_msg_resp_power_attributes *attr;
 
-	ret = scmi_xfer_get_init(handle, PROTOCOL_ATTRIBUTES,
+	ret = ops->xfer_get_init(handle, PROTOCOL_ATTRIBUTES,
 				 SCMI_PROTOCOL_POWER, 0, sizeof(*attr), &t);
 	if (ret)
 		return ret;
 
 	attr = t->rx.buf;
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 	if (!ret) {
 		pi->num_domains = le16_to_cpu(attr->num_domains);
 		pi->stats_addr = le32_to_cpu(attr->stats_addr_low) |
@@ -90,7 +92,7 @@ static int scmi_power_attributes_get(const struct scmi_handle *handle,
 		pi->stats_size = le32_to_cpu(attr->stats_size);
 	}
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -102,7 +104,7 @@ scmi_power_domain_attributes_get(const struct scmi_handle *handle, u32 domain,
 	struct scmi_xfer *t;
 	struct scmi_msg_resp_power_domain_attributes *attr;
 
-	ret = scmi_xfer_get_init(handle, POWER_DOMAIN_ATTRIBUTES,
+	ret = ops->xfer_get_init(handle, POWER_DOMAIN_ATTRIBUTES,
 				 SCMI_PROTOCOL_POWER, sizeof(domain),
 				 sizeof(*attr), &t);
 	if (ret)
@@ -111,7 +113,7 @@ scmi_power_domain_attributes_get(const struct scmi_handle *handle, u32 domain,
 	put_unaligned_le32(domain, t->tx.buf);
 	attr = t->rx.buf;
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 	if (!ret) {
 		u32 flags = le32_to_cpu(attr->flags);
 
@@ -121,7 +123,7 @@ scmi_power_domain_attributes_get(const struct scmi_handle *handle, u32 domain,
 		strlcpy(dom_info->name, attr->name, SCMI_MAX_STR_SIZE);
 	}
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -132,7 +134,7 @@ scmi_power_state_set(const struct scmi_handle *handle, u32 domain, u32 state)
 	struct scmi_xfer *t;
 	struct scmi_power_set_state *st;
 
-	ret = scmi_xfer_get_init(handle, POWER_STATE_SET, SCMI_PROTOCOL_POWER,
+	ret = ops->xfer_get_init(handle, POWER_STATE_SET, SCMI_PROTOCOL_POWER,
 				 sizeof(*st), 0, &t);
 	if (ret)
 		return ret;
@@ -142,9 +144,9 @@ scmi_power_state_set(const struct scmi_handle *handle, u32 domain, u32 state)
 	st->domain = cpu_to_le32(domain);
 	st->state = cpu_to_le32(state);
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -154,25 +156,25 @@ scmi_power_state_get(const struct scmi_handle *handle, u32 domain, u32 *state)
 	int ret;
 	struct scmi_xfer *t;
 
-	ret = scmi_xfer_get_init(handle, POWER_STATE_GET, SCMI_PROTOCOL_POWER,
+	ret = ops->xfer_get_init(handle, POWER_STATE_GET, SCMI_PROTOCOL_POWER,
 				 sizeof(u32), sizeof(u32), &t);
 	if (ret)
 		return ret;
 
 	put_unaligned_le32(domain, t->tx.buf);
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 	if (!ret)
 		*state = get_unaligned_le32(t->rx.buf);
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
 static int scmi_power_num_domains_get(const struct scmi_handle *handle)
 {
 	struct scmi_power_info *pi =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_POWER);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_POWER);
 
 	return pi->num_domains;
 }
@@ -180,7 +182,7 @@ static int scmi_power_num_domains_get(const struct scmi_handle *handle)
 static char *scmi_power_name_get(const struct scmi_handle *handle, u32 domain)
 {
 	struct scmi_power_info *pi =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_POWER);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_POWER);
 	struct power_dom_info *dom = pi->dom_info + domain;
 
 	return dom->name;
@@ -200,7 +202,7 @@ static int scmi_power_request_notify(const struct scmi_handle *handle,
 	struct scmi_xfer *t;
 	struct scmi_power_state_notify *notify;
 
-	ret = scmi_xfer_get_init(handle, POWER_STATE_NOTIFY,
+	ret = ops->xfer_get_init(handle, POWER_STATE_NOTIFY,
 				 SCMI_PROTOCOL_POWER, sizeof(*notify), 0, &t);
 	if (ret)
 		return ret;
@@ -209,9 +211,9 @@ static int scmi_power_request_notify(const struct scmi_handle *handle,
 	notify->domain = cpu_to_le32(domain);
 	notify->notify_enable = enable ? cpu_to_le32(BIT(0)) : 0;
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -251,7 +253,7 @@ static void *scmi_power_fill_custom_report(const struct scmi_handle *handle,
 static int scmi_power_get_num_sources(const struct scmi_handle *handle)
 {
 	struct scmi_power_info *pinfo =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_POWER);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_POWER);
 
 	if (!pinfo)
 		return -EINVAL;
@@ -281,13 +283,15 @@ static const struct scmi_protocol_events power_protocol_events = {
 	.num_events = ARRAY_SIZE(power_events),
 };
 
-static int scmi_power_protocol_init(const struct scmi_handle *handle)
+static int scmi_power_protocol_init(const struct scmi_handle *handle,
+				    const struct scmi_xfer_ops *xops)
 {
 	int domain;
 	u32 version;
 	struct scmi_power_info *pinfo;
 
-	scmi_version_get(handle, SCMI_PROTOCOL_POWER, &version);
+	ops = xops;
+	ops->version_get(handle, SCMI_PROTOCOL_POWER, &version);
 
 	dev_dbg(handle->dev, "Power Version %d.%d\n",
 		PROTOCOL_REV_MAJOR(version), PROTOCOL_REV_MINOR(version));
@@ -310,7 +314,7 @@ static int scmi_power_protocol_init(const struct scmi_handle *handle)
 	}
 
 	pinfo->version = version;
-	return scmi_set_proto_priv(handle, SCMI_PROTOCOL_POWER, pinfo);
+	return ops->set_proto_priv(handle, SCMI_PROTOCOL_POWER, pinfo);
 }
 
 static struct scmi_protocol scmi_power = {
diff --git a/drivers/firmware/arm_scmi/reset.c b/drivers/firmware/arm_scmi/reset.c
index 83bfd0514d4d..9accad66e07e 100644
--- a/drivers/firmware/arm_scmi/reset.c
+++ b/drivers/firmware/arm_scmi/reset.c
@@ -12,6 +12,8 @@
 #include "common.h"
 #include "notify.h"
 
+static const struct scmi_xfer_ops *ops;
+
 enum scmi_reset_protocol_cmd {
 	RESET_DOMAIN_ATTRIBUTES = 0x3,
 	RESET = 0x4,
@@ -71,18 +73,18 @@ static int scmi_reset_attributes_get(const struct scmi_handle *handle,
 	struct scmi_xfer *t;
 	u32 attr;
 
-	ret = scmi_xfer_get_init(handle, PROTOCOL_ATTRIBUTES,
+	ret = ops->xfer_get_init(handle, PROTOCOL_ATTRIBUTES,
 				 SCMI_PROTOCOL_RESET, 0, sizeof(attr), &t);
 	if (ret)
 		return ret;
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 	if (!ret) {
 		attr = get_unaligned_le32(t->rx.buf);
 		pi->num_domains = attr & NUM_RESET_DOMAIN_MASK;
 	}
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -94,7 +96,7 @@ scmi_reset_domain_attributes_get(const struct scmi_handle *handle, u32 domain,
 	struct scmi_xfer *t;
 	struct scmi_msg_resp_reset_domain_attributes *attr;
 
-	ret = scmi_xfer_get_init(handle, RESET_DOMAIN_ATTRIBUTES,
+	ret = ops->xfer_get_init(handle, RESET_DOMAIN_ATTRIBUTES,
 				 SCMI_PROTOCOL_RESET, sizeof(domain),
 				 sizeof(*attr), &t);
 	if (ret)
@@ -103,7 +105,7 @@ scmi_reset_domain_attributes_get(const struct scmi_handle *handle, u32 domain,
 	put_unaligned_le32(domain, t->tx.buf);
 	attr = t->rx.buf;
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 	if (!ret) {
 		u32 attributes = le32_to_cpu(attr->attributes);
 
@@ -115,14 +117,14 @@ scmi_reset_domain_attributes_get(const struct scmi_handle *handle, u32 domain,
 		strlcpy(dom_info->name, attr->name, SCMI_MAX_STR_SIZE);
 	}
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
 static int scmi_reset_num_domains_get(const struct scmi_handle *handle)
 {
 	struct scmi_reset_info *pi =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_RESET);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_RESET);
 
 	return pi->num_domains;
 }
@@ -130,7 +132,7 @@ static int scmi_reset_num_domains_get(const struct scmi_handle *handle)
 static char *scmi_reset_name_get(const struct scmi_handle *handle, u32 domain)
 {
 	struct scmi_reset_info *pi =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_RESET);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_RESET);
 	struct reset_dom_info *dom = pi->dom_info + domain;
 
 	return dom->name;
@@ -139,7 +141,7 @@ static char *scmi_reset_name_get(const struct scmi_handle *handle, u32 domain)
 static int scmi_reset_latency_get(const struct scmi_handle *handle, u32 domain)
 {
 	struct scmi_reset_info *pi =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_RESET);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_RESET);
 	struct reset_dom_info *dom = pi->dom_info + domain;
 
 	return dom->latency_us;
@@ -152,13 +154,13 @@ static int scmi_domain_reset(const struct scmi_handle *handle, u32 domain,
 	struct scmi_xfer *t;
 	struct scmi_msg_reset_domain_reset *dom;
 	struct scmi_reset_info *pi =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_RESET);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_RESET);
 	struct reset_dom_info *rdom = pi->dom_info + domain;
 
 	if (rdom->async_reset)
 		flags |= ASYNCHRONOUS_RESET;
 
-	ret = scmi_xfer_get_init(handle, RESET, SCMI_PROTOCOL_RESET,
+	ret = ops->xfer_get_init(handle, RESET, SCMI_PROTOCOL_RESET,
 				 sizeof(*dom), 0, &t);
 	if (ret)
 		return ret;
@@ -169,11 +171,11 @@ static int scmi_domain_reset(const struct scmi_handle *handle, u32 domain,
 	dom->reset_state = cpu_to_le32(state);
 
 	if (rdom->async_reset)
-		ret = scmi_do_xfer_with_response(handle, t);
+		ret = ops->do_xfer_with_response(handle, t);
 	else
-		ret = scmi_do_xfer(handle, t);
+		ret = ops->do_xfer(handle, t);
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -213,7 +215,7 @@ static int scmi_reset_notify(const struct scmi_handle *handle, u32 domain_id,
 	struct scmi_xfer *t;
 	struct scmi_msg_reset_notify *cfg;
 
-	ret = scmi_xfer_get_init(handle, RESET_NOTIFY,
+	ret = ops->xfer_get_init(handle, RESET_NOTIFY,
 				 SCMI_PROTOCOL_RESET, sizeof(*cfg), 0, &t);
 	if (ret)
 		return ret;
@@ -222,9 +224,9 @@ static int scmi_reset_notify(const struct scmi_handle *handle, u32 domain_id,
 	cfg->id = cpu_to_le32(domain_id);
 	cfg->event_control = cpu_to_le32(evt_cntl);
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -264,7 +266,7 @@ static void *scmi_reset_fill_custom_report(const struct scmi_handle *handle,
 static int scmi_reset_get_num_sources(const struct scmi_handle *handle)
 {
 	struct scmi_reset_info *pinfo =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_RESET);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_RESET);
 
 	if (!pinfo)
 		return -EINVAL;
@@ -293,13 +295,15 @@ static const struct scmi_protocol_events reset_protocol_events = {
 	.num_events = ARRAY_SIZE(reset_events),
 };
 
-static int scmi_reset_protocol_init(const struct scmi_handle *handle)
+static int scmi_reset_protocol_init(const struct scmi_handle *handle,
+				    const struct scmi_xfer_ops *xops)
 {
 	int domain;
 	u32 version;
 	struct scmi_reset_info *pinfo;
 
-	scmi_version_get(handle, SCMI_PROTOCOL_RESET, &version);
+	ops = xops;
+	ops->version_get(handle, SCMI_PROTOCOL_RESET, &version);
 
 	dev_dbg(handle->dev, "Reset Version %d.%d\n",
 		PROTOCOL_REV_MAJOR(version), PROTOCOL_REV_MINOR(version));
@@ -322,7 +326,7 @@ static int scmi_reset_protocol_init(const struct scmi_handle *handle)
 	}
 
 	pinfo->version = version;
-	return scmi_set_proto_priv(handle, SCMI_PROTOCOL_RESET, pinfo);
+	return ops->set_proto_priv(handle, SCMI_PROTOCOL_RESET, pinfo);
 }
 
 static struct scmi_protocol scmi_reset = {
diff --git a/drivers/firmware/arm_scmi/sensors.c b/drivers/firmware/arm_scmi/sensors.c
index 79bdd53ab7ba..3a58dbca2b70 100644
--- a/drivers/firmware/arm_scmi/sensors.c
+++ b/drivers/firmware/arm_scmi/sensors.c
@@ -12,6 +12,8 @@
 #include "common.h"
 #include "notify.h"
 
+static const struct scmi_xfer_ops *ops;
+
 enum scmi_sensor_protocol_cmd {
 	SENSOR_DESCRIPTION_GET = 0x3,
 	SENSOR_TRIP_POINT_NOTIFY = 0x4,
@@ -94,14 +96,14 @@ static int scmi_sensor_attributes_get(const struct scmi_handle *handle,
 	struct scmi_xfer *t;
 	struct scmi_msg_resp_sensor_attributes *attr;
 
-	ret = scmi_xfer_get_init(handle, PROTOCOL_ATTRIBUTES,
+	ret = ops->xfer_get_init(handle, PROTOCOL_ATTRIBUTES,
 				 SCMI_PROTOCOL_SENSOR, 0, sizeof(*attr), &t);
 	if (ret)
 		return ret;
 
 	attr = t->rx.buf;
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 	if (!ret) {
 		si->num_sensors = le16_to_cpu(attr->num_sensors);
 		si->max_requests = attr->max_requests;
@@ -110,7 +112,7 @@ static int scmi_sensor_attributes_get(const struct scmi_handle *handle,
 		si->reg_size = le32_to_cpu(attr->reg_size);
 	}
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -123,7 +125,7 @@ static int scmi_sensor_description_get(const struct scmi_handle *handle,
 	struct scmi_xfer *t;
 	struct scmi_msg_resp_sensor_description *buf;
 
-	ret = scmi_xfer_get_init(handle, SENSOR_DESCRIPTION_GET,
+	ret = ops->xfer_get_init(handle, SENSOR_DESCRIPTION_GET,
 				 SCMI_PROTOCOL_SENSOR, sizeof(__le32), 0, &t);
 	if (ret)
 		return ret;
@@ -134,7 +136,7 @@ static int scmi_sensor_description_get(const struct scmi_handle *handle,
 		/* Set the number of sensors to be skipped/already read */
 		put_unaligned_le32(desc_index, t->tx.buf);
 
-		ret = scmi_do_xfer(handle, t);
+		ret = ops->do_xfer(handle, t);
 		if (ret)
 			break;
 
@@ -167,14 +169,14 @@ static int scmi_sensor_description_get(const struct scmi_handle *handle,
 
 		desc_index += num_returned;
 
-		scmi_reset_rx_to_maxsz(handle, t);
+		ops->reset_rx_to_maxsz(handle, t);
 		/*
 		 * check for both returned and remaining to avoid infinite
 		 * loop due to buggy firmware
 		 */
 	} while (num_returned && num_remaining);
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -186,7 +188,7 @@ static int scmi_sensor_trip_point_notify(const struct scmi_handle *handle,
 	struct scmi_xfer *t;
 	struct scmi_msg_sensor_trip_point_notify *cfg;
 
-	ret = scmi_xfer_get_init(handle, SENSOR_TRIP_POINT_NOTIFY,
+	ret = ops->xfer_get_init(handle, SENSOR_TRIP_POINT_NOTIFY,
 				 SCMI_PROTOCOL_SENSOR, sizeof(*cfg), 0, &t);
 	if (ret)
 		return ret;
@@ -195,9 +197,9 @@ static int scmi_sensor_trip_point_notify(const struct scmi_handle *handle,
 	cfg->id = cpu_to_le32(sensor_id);
 	cfg->event_control = cpu_to_le32(evt_cntl);
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -210,7 +212,7 @@ scmi_sensor_trip_point_config(const struct scmi_handle *handle, u32 sensor_id,
 	struct scmi_xfer *t;
 	struct scmi_msg_set_sensor_trip_point *trip;
 
-	ret = scmi_xfer_get_init(handle, SENSOR_TRIP_POINT_CONFIG,
+	ret = ops->xfer_get_init(handle, SENSOR_TRIP_POINT_CONFIG,
 				 SCMI_PROTOCOL_SENSOR, sizeof(*trip), 0, &t);
 	if (ret)
 		return ret;
@@ -221,9 +223,9 @@ scmi_sensor_trip_point_config(const struct scmi_handle *handle, u32 sensor_id,
 	trip->value_low = cpu_to_le32(trip_value & 0xffffffff);
 	trip->value_high = cpu_to_le32(trip_value >> 32);
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -234,10 +236,10 @@ static int scmi_sensor_reading_get(const struct scmi_handle *handle,
 	struct scmi_xfer *t;
 	struct scmi_msg_sensor_reading_get *sensor;
 	struct sensors_info *si =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_SENSOR);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_SENSOR);
 	struct scmi_sensor_info *s = si->sensors + sensor_id;
 
-	ret = scmi_xfer_get_init(handle, SENSOR_READING_GET,
+	ret = ops->xfer_get_init(handle, SENSOR_READING_GET,
 				 SCMI_PROTOCOL_SENSOR, sizeof(*sensor),
 				 sizeof(u64), &t);
 	if (ret)
@@ -248,18 +250,18 @@ static int scmi_sensor_reading_get(const struct scmi_handle *handle,
 
 	if (s->async) {
 		sensor->flags = cpu_to_le32(SENSOR_READ_ASYNC);
-		ret = scmi_do_xfer_with_response(handle, t);
+		ret = ops->do_xfer_with_response(handle, t);
 		if (!ret)
 			*value = get_unaligned_le64((void *)
 						    ((__le32 *)t->rx.buf + 1));
 	} else {
 		sensor->flags = cpu_to_le32(0);
-		ret = scmi_do_xfer(handle, t);
+		ret = ops->do_xfer(handle, t);
 		if (!ret)
 			*value = get_unaligned_le64(t->rx.buf);
 	}
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -267,7 +269,7 @@ static const struct scmi_sensor_info *
 scmi_sensor_info_get(const struct scmi_handle *handle, u32 sensor_id)
 {
 	struct sensors_info *si =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_SENSOR);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_SENSOR);
 
 	return si->sensors + sensor_id;
 }
@@ -275,7 +277,7 @@ scmi_sensor_info_get(const struct scmi_handle *handle, u32 sensor_id)
 static int scmi_sensor_count_get(const struct scmi_handle *handle)
 {
 	struct sensors_info *si =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_SENSOR);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_SENSOR);
 
 	return si->num_sensors;
 }
@@ -324,7 +326,7 @@ static void *scmi_sensor_fill_custom_report(const struct scmi_handle *handle,
 static int scmi_sensor_get_num_sources(const struct scmi_handle *handle)
 {
 	struct sensors_info *sinfo =
-		scmi_get_proto_priv(handle, SCMI_PROTOCOL_SENSOR);
+		ops->get_proto_priv(handle, SCMI_PROTOCOL_SENSOR);
 
 	if (!sinfo)
 		return -EINVAL;
@@ -353,12 +355,14 @@ static const struct scmi_protocol_events sensor_protocol_events = {
 	.num_events = ARRAY_SIZE(sensor_events),
 };
 
-static int scmi_sensors_protocol_init(const struct scmi_handle *handle)
+static int scmi_sensors_protocol_init(const struct scmi_handle *handle,
+				      const struct scmi_xfer_ops *xops)
 {
 	u32 version;
 	struct sensors_info *sinfo;
 
-	scmi_version_get(handle, SCMI_PROTOCOL_SENSOR, &version);
+	ops = xops;
+	ops->version_get(handle, SCMI_PROTOCOL_SENSOR, &version);
 
 	dev_dbg(handle->dev, "Sensor Version %d.%d\n",
 		PROTOCOL_REV_MAJOR(version), PROTOCOL_REV_MINOR(version));
@@ -377,7 +381,7 @@ static int scmi_sensors_protocol_init(const struct scmi_handle *handle)
 	scmi_sensor_description_get(handle, sinfo);
 
 	sinfo->version = version;
-	return scmi_set_proto_priv(handle, SCMI_PROTOCOL_SENSOR, sinfo);
+	return ops->set_proto_priv(handle, SCMI_PROTOCOL_SENSOR, sinfo);
 }
 
 static struct scmi_protocol scmi_sensors = {
diff --git a/drivers/firmware/arm_scmi/system.c b/drivers/firmware/arm_scmi/system.c
index ae884fc669f5..4db3cc9cea3b 100644
--- a/drivers/firmware/arm_scmi/system.c
+++ b/drivers/firmware/arm_scmi/system.c
@@ -14,6 +14,8 @@
 
 #define SCMI_SYSTEM_NUM_SOURCES		1
 
+static const struct scmi_xfer_ops *ops;
+
 enum scmi_system_protocol_cmd {
 	SYSTEM_POWER_STATE_NOTIFY = 0x5,
 };
@@ -39,7 +41,7 @@ static int scmi_system_request_notify(const struct scmi_handle *handle,
 	struct scmi_xfer *t;
 	struct scmi_system_power_state_notify *notify;
 
-	ret = scmi_xfer_get_init(handle, SYSTEM_POWER_STATE_NOTIFY,
+	ret = ops->xfer_get_init(handle, SYSTEM_POWER_STATE_NOTIFY,
 				 SCMI_PROTOCOL_SYSTEM, sizeof(*notify), 0, &t);
 	if (ret)
 		return ret;
@@ -47,9 +49,9 @@ static int scmi_system_request_notify(const struct scmi_handle *handle,
 	notify = t->tx.buf;
 	notify->notify_enable = enable ? cpu_to_le32(BIT(0)) : 0;
 
-	ret = scmi_do_xfer(handle, t);
+	ret = ops->do_xfer(handle, t);
 
-	scmi_xfer_put(handle, t);
+	ops->xfer_put(handle, t);
 	return ret;
 }
 
@@ -109,12 +111,14 @@ static const struct scmi_protocol_events system_protocol_events = {
 	.num_sources = SCMI_SYSTEM_NUM_SOURCES,
 };
 
-static int scmi_system_protocol_init(const struct scmi_handle *handle)
+static int scmi_system_protocol_init(const struct scmi_handle *handle,
+				     const struct scmi_xfer_ops *xops)
 {
 	u32 version;
 	struct scmi_system_info *pinfo;
 
-	scmi_version_get(handle, SCMI_PROTOCOL_SYSTEM, &version);
+	ops = xops;
+	ops->version_get(handle, SCMI_PROTOCOL_SYSTEM, &version);
 
 	dev_dbg(handle->dev, "System Power Version %d.%d\n",
 		PROTOCOL_REV_MAJOR(version), PROTOCOL_REV_MINOR(version));
@@ -124,7 +128,7 @@ static int scmi_system_protocol_init(const struct scmi_handle *handle)
 		return -ENOMEM;
 
 	pinfo->version = version;
-	return scmi_set_proto_priv(handle, SCMI_PROTOCOL_SYSTEM, pinfo);
+	return ops->set_proto_priv(handle, SCMI_PROTOCOL_SYSTEM, pinfo);
 }
 
 static struct scmi_protocol scmi_system = {
diff --git a/include/linux/scmi_protocol.h b/include/linux/scmi_protocol.h
index 650d0877a5c8..da675d6f90c0 100644
--- a/include/linux/scmi_protocol.h
+++ b/include/linux/scmi_protocol.h
@@ -351,8 +351,24 @@ static inline void scmi_driver_unregister(struct scmi_driver *driver) {}
 #define module_scmi_driver(__scmi_driver)	\
 	module_driver(__scmi_driver, scmi_register, scmi_unregister)
 
+#define scmi_load(proto) \
+	scmi_protocol_register(proto, THIS_MODULE)
+#define scmi_unload(proto) \
+	scmi_protocol_unregister(proto)
+
+/**
+ * module_scmi_protocol() - Helper macro for registering a scmi protocol
+ * @__scmi_protocol: scmi_protocol structure
+ *
+ * Helper macro for scmi drivers to set up proper module init / exit
+ * functions.  Replaces module_init() and module_exit() and keeps people from
+ * printing pointless things to the kernel log when their driver is loaded.
+ */
+#define module_scmi_protocol(__scmi_protocol)	\
+	module_driver(__scmi_protocol, scmi_load, scmi_unload)
+
 struct scmi_protocol;
-int scmi_protocol_register(struct scmi_protocol *proto);
+int scmi_protocol_register(struct scmi_protocol *proto, struct module *owner);
 void scmi_protocol_unregister(const struct scmi_protocol *proto);
 
 /* SCMI Notification API - Custom Event Reports */
-- 
2.17.1




More information about the linux-arm-kernel mailing list