[PATCH 4/5] staging: vc04_services: use kref + RCU to reference count services

Marcelo Diop-Gonzalez marcgonzalez at google.com
Wed Feb 12 18:43:32 UTC 2020


Currently reference counts are implemented by locking service_spinlock
and then incrementing the service's ->ref_count field, calling
kfree() when the last reference has been dropped. But at the same
time, there's code in multiple places that dereferences pointers
to services without having a reference, so there could be a race there.

It should be possible to avoid taking any lock in unlock_service()
or service_release() because we are setting a single array element
to NULL, and on service creation, a mutex is locked before looking
for a NULL spot to put the new service in.

Using a struct kref and RCU-delaying the freeing of services fixes
this race condition while still making it possible to skip
grabbing a reference in many places. Also it avoids the need to
acquire a single spinlock when e.g. taking a reference on
state->services[i] when somebody else is in the middle of taking
a reference on state->services[j].

Signed-off-by: Marcelo Diop-Gonzalez <marcgonzalez at google.com>
---
 .../interface/vchiq_arm/vchiq_arm.c           |  25 +-
 .../interface/vchiq_arm/vchiq_core.c          | 222 +++++++++---------
 .../interface/vchiq_arm/vchiq_core.h          |  12 +-
 3 files changed, 140 insertions(+), 119 deletions(-)

diff --git a/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_arm.c b/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_arm.c
index c456ced431af..3ed0e4ea7f5c 100644
--- a/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_arm.c
+++ b/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_arm.c
@@ -22,6 +22,7 @@
 #include <linux/platform_device.h>
 #include <linux/compat.h>
 #include <linux/dma-mapping.h>
+#include <linux/rcupdate.h>
 #include <soc/bcm2835/raspberrypi-firmware.h>
 
 #include "vchiq_core.h"
@@ -2096,10 +2097,12 @@ int vchiq_dump_platform_instances(void *dump_context)
 	/* There is no list of instances, so instead scan all services,
 		marking those that have been dumped. */
 
+	rcu_read_lock();
 	for (i = 0; i < state->unused_service; i++) {
-		struct vchiq_service *service = state->services[i];
+		struct vchiq_service *service;
 		struct vchiq_instance *instance;
 
+		service = rcu_dereference(state->services[i]);
 		if (!service || service->base.callback != service_callback)
 			continue;
 
@@ -2107,18 +2110,26 @@ int vchiq_dump_platform_instances(void *dump_context)
 		if (instance)
 			instance->mark = 0;
 	}
+	rcu_read_unlock();
 
 	for (i = 0; i < state->unused_service; i++) {
-		struct vchiq_service *service = state->services[i];
+		struct vchiq_service *service;
 		struct vchiq_instance *instance;
 		int err;
 
-		if (!service || service->base.callback != service_callback)
+		rcu_read_lock();
+		service = rcu_dereference(state->services[i]);
+		if (!service || service->base.callback != service_callback) {
+			rcu_read_unlock();
 			continue;
+		}
 
 		instance = service->instance;
-		if (!instance || instance->mark)
+		if (!instance || instance->mark) {
+			rcu_read_unlock();
 			continue;
+		}
+		rcu_read_unlock();
 
 		len = snprintf(buf, sizeof(buf),
 			       "Instance %pK: pid %d,%s completions %d/%d",
@@ -2128,7 +2139,6 @@ int vchiq_dump_platform_instances(void *dump_context)
 			       instance->completion_insert -
 			       instance->completion_remove,
 			       MAX_COMPLETIONS);
-
 		err = vchiq_dump(dump_context, buf, len + 1);
 		if (err)
 			return err;
@@ -2585,8 +2595,10 @@ vchiq_dump_service_use_state(struct vchiq_state *state)
 	if (active_services > MAX_SERVICES)
 		only_nonzero = 1;
 
+	rcu_read_lock();
 	for (i = 0; i < active_services; i++) {
-		struct vchiq_service *service_ptr = state->services[i];
+		struct vchiq_service *service_ptr =
+			rcu_dereference(state->services[i]);
 
 		if (!service_ptr)
 			continue;
@@ -2604,6 +2616,7 @@ vchiq_dump_service_use_state(struct vchiq_state *state)
 		if (found >= MAX_SERVICES)
 			break;
 	}
+	rcu_read_unlock();
 
 	read_unlock_bh(&arm_state->susp_res_lock);
 
diff --git a/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.c b/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.c
index b2d9013b7f79..65270a5b29db 100644
--- a/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.c
+++ b/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.c
@@ -1,6 +1,9 @@
 // SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
 /* Copyright (c) 2010-2012 Broadcom. All rights reserved. */
 
+#include <linux/kref.h>
+#include <linux/rcupdate.h>
+
 #include "vchiq_core.h"
 
 #define VCHIQ_SLOT_HANDLER_STACK 8192
@@ -54,7 +57,6 @@ int vchiq_core_log_level = VCHIQ_LOG_DEFAULT;
 int vchiq_core_msg_log_level = VCHIQ_LOG_DEFAULT;
 int vchiq_sync_log_level = VCHIQ_LOG_DEFAULT;
 
-static DEFINE_SPINLOCK(service_spinlock);
 DEFINE_SPINLOCK(bulk_waiter_spinlock);
 static DEFINE_SPINLOCK(quota_spinlock);
 
@@ -136,44 +138,41 @@ find_service_by_handle(unsigned int handle)
 {
 	struct vchiq_service *service;
 
-	spin_lock(&service_spinlock);
+	rcu_read_lock();
 	service = handle_to_service(handle);
 	if (service && service->srvstate != VCHIQ_SRVSTATE_FREE &&
-	    service->handle == handle) {
-		WARN_ON(service->ref_count == 0);
-		service->ref_count++;
-	} else
-		service = NULL;
-	spin_unlock(&service_spinlock);
-
-	if (!service)
-		vchiq_log_info(vchiq_core_log_level,
-			"Invalid service handle 0x%x", handle);
-
-	return service;
+	    service->handle == handle &&
+	    kref_get_unless_zero(&service->ref_count)) {
+		service = rcu_pointer_handoff(service);
+		rcu_read_unlock();
+		return service;
+	}
+	rcu_read_unlock();
+	vchiq_log_info(vchiq_core_log_level,
+		       "Invalid service handle 0x%x", handle);
+	return NULL;
 }
 
 struct vchiq_service *
 find_service_by_port(struct vchiq_state *state, int localport)
 {
-	struct vchiq_service *service = NULL;
 
 	if ((unsigned int)localport <= VCHIQ_PORT_MAX) {
-		spin_lock(&service_spinlock);
-		service = state->services[localport];
-		if (service && service->srvstate != VCHIQ_SRVSTATE_FREE) {
-			WARN_ON(service->ref_count == 0);
-			service->ref_count++;
-		} else
-			service = NULL;
-		spin_unlock(&service_spinlock);
-	}
-
-	if (!service)
-		vchiq_log_info(vchiq_core_log_level,
-			"Invalid port %d", localport);
+		struct vchiq_service *service;
 
-	return service;
+		rcu_read_lock();
+		service = rcu_dereference(state->services[localport]);
+		if (service && service->srvstate != VCHIQ_SRVSTATE_FREE &&
+		    kref_get_unless_zero(&service->ref_count)) {
+			service = rcu_pointer_handoff(service);
+			rcu_read_unlock();
+			return service;
+		}
+		rcu_read_unlock();
+	}
+	vchiq_log_info(vchiq_core_log_level,
+		       "Invalid port %d", localport);
+	return NULL;
 }
 
 struct vchiq_service *
@@ -182,22 +181,20 @@ find_service_for_instance(struct vchiq_instance *instance,
 {
 	struct vchiq_service *service;
 
-	spin_lock(&service_spinlock);
+	rcu_read_lock();
 	service = handle_to_service(handle);
 	if (service && service->srvstate != VCHIQ_SRVSTATE_FREE &&
 	    service->handle == handle &&
-	    service->instance == instance) {
-		WARN_ON(service->ref_count == 0);
-		service->ref_count++;
-	} else
-		service = NULL;
-	spin_unlock(&service_spinlock);
-
-	if (!service)
-		vchiq_log_info(vchiq_core_log_level,
-			"Invalid service handle 0x%x", handle);
-
-	return service;
+	    service->instance == instance &&
+	    kref_get_unless_zero(&service->ref_count)) {
+		service = rcu_pointer_handoff(service);
+		rcu_read_unlock();
+		return service;
+	}
+	rcu_read_unlock();
+	vchiq_log_info(vchiq_core_log_level,
+		       "Invalid service handle 0x%x", handle);
+	return NULL;
 }
 
 struct vchiq_service *
@@ -206,23 +203,21 @@ find_closed_service_for_instance(struct vchiq_instance *instance,
 {
 	struct vchiq_service *service;
 
-	spin_lock(&service_spinlock);
+	rcu_read_lock();
 	service = handle_to_service(handle);
 	if (service &&
 	    (service->srvstate == VCHIQ_SRVSTATE_FREE ||
 	     service->srvstate == VCHIQ_SRVSTATE_CLOSED) &&
 	    service->handle == handle &&
-	    service->instance == instance) {
-		WARN_ON(service->ref_count == 0);
-		service->ref_count++;
-	} else
-		service = NULL;
-	spin_unlock(&service_spinlock);
-
-	if (!service)
-		vchiq_log_info(vchiq_core_log_level,
-			"Invalid service handle 0x%x", handle);
-
+	    service->instance == instance &&
+	    kref_get_unless_zero(&service->ref_count)) {
+		service = rcu_pointer_handoff(service);
+		rcu_read_unlock();
+		return service;
+	}
+	rcu_read_unlock();
+	vchiq_log_info(vchiq_core_log_level,
+		       "Invalid service handle 0x%x", handle);
 	return service;
 }
 
@@ -233,19 +228,19 @@ next_service_by_instance(struct vchiq_state *state, struct vchiq_instance *insta
 	struct vchiq_service *service = NULL;
 	int idx = *pidx;
 
-	spin_lock(&service_spinlock);
+	rcu_read_lock();
 	while (idx < state->unused_service) {
-		struct vchiq_service *srv = state->services[idx++];
+		struct vchiq_service *srv;
 
+		srv = rcu_dereference(state->services[idx++]);
 		if (srv && srv->srvstate != VCHIQ_SRVSTATE_FREE &&
-		    srv->instance == instance) {
-			service = srv;
-			WARN_ON(service->ref_count == 0);
-			service->ref_count++;
+		    srv->instance == instance &&
+		    kref_get_unless_zero(&srv->ref_count)) {
+			service = rcu_pointer_handoff(srv);
 			break;
 		}
 	}
-	spin_unlock(&service_spinlock);
+	rcu_read_unlock();
 
 	*pidx = idx;
 
@@ -255,43 +250,34 @@ next_service_by_instance(struct vchiq_state *state, struct vchiq_instance *insta
 void
 lock_service(struct vchiq_service *service)
 {
-	spin_lock(&service_spinlock);
-	WARN_ON(!service);
-	if (service) {
-		WARN_ON(service->ref_count == 0);
-		service->ref_count++;
+	if (!service) {
+		WARN(1, "%s service is NULL\n", __func__);
+		return;
 	}
-	spin_unlock(&service_spinlock);
+	kref_get(&service->ref_count);
+}
+
+static void service_release(struct kref *kref)
+{
+	struct vchiq_service *service =
+		container_of(kref, struct vchiq_service, ref_count);
+	struct vchiq_state *state = service->state;
+
+	WARN_ON(service->srvstate != VCHIQ_SRVSTATE_FREE);
+	rcu_assign_pointer(state->services[service->localport], NULL);
+	if (service->userdata_term)
+		service->userdata_term(service->base.userdata);
+	kfree_rcu(service, rcu);
 }
 
 void
 unlock_service(struct vchiq_service *service)
 {
-	spin_lock(&service_spinlock);
 	if (!service) {
 		WARN(1, "%s: service is NULL\n", __func__);
-		goto unlock;
-	}
-	if (!service->ref_count) {
-		WARN(1, "%s: ref_count is zero\n", __func__);
-		goto unlock;
-	}
-	service->ref_count--;
-	if (!service->ref_count) {
-		struct vchiq_state *state = service->state;
-
-		WARN_ON(service->srvstate != VCHIQ_SRVSTATE_FREE);
-		state->services[service->localport] = NULL;
-	} else {
-		service = NULL;
+		return;
 	}
-unlock:
-	spin_unlock(&service_spinlock);
-
-	if (service && service->userdata_term)
-		service->userdata_term(service->base.userdata);
-
-	kfree(service);
+	kref_put(&service->ref_count, service_release);
 }
 
 int
@@ -310,9 +296,14 @@ vchiq_get_client_id(unsigned int handle)
 void *
 vchiq_get_service_userdata(unsigned int handle)
 {
-	struct vchiq_service *service = handle_to_service(handle);
+	void *userdata;
+	struct vchiq_service *service;
 
-	return service ? service->base.userdata : NULL;
+	rcu_read_lock();
+	service = handle_to_service(handle);
+	userdata = service ? service->base.userdata : NULL;
+	rcu_read_unlock();
+	return userdata;
 }
 
 static void
@@ -460,19 +451,23 @@ get_listening_service(struct vchiq_state *state, int fourcc)
 
 	WARN_ON(fourcc == VCHIQ_FOURCC_INVALID);
 
+	rcu_read_lock();
 	for (i = 0; i < state->unused_service; i++) {
-		struct vchiq_service *service = state->services[i];
+		struct vchiq_service *service;
 
+		service = rcu_dereference(state->services[i]);
 		if (service &&
 		    service->public_fourcc == fourcc &&
 		    (service->srvstate == VCHIQ_SRVSTATE_LISTENING ||
 		     (service->srvstate == VCHIQ_SRVSTATE_OPEN &&
-		      service->remoteport == VCHIQ_PORT_FREE))) {
-			lock_service(service);
+		      service->remoteport == VCHIQ_PORT_FREE)) &&
+		    kref_get_unless_zero(&service->ref_count)) {
+			service = rcu_pointer_handoff(service);
+			rcu_read_unlock();
 			return service;
 		}
 	}
-
+	rcu_read_unlock();
 	return NULL;
 }
 
@@ -482,15 +477,20 @@ get_connected_service(struct vchiq_state *state, unsigned int port)
 {
 	int i;
 
+	rcu_read_lock();
 	for (i = 0; i < state->unused_service; i++) {
-		struct vchiq_service *service = state->services[i];
+		struct vchiq_service *service =
+			rcu_dereference(state->services[i]);
 
 		if (service && service->srvstate == VCHIQ_SRVSTATE_OPEN &&
-		    service->remoteport == port) {
-			lock_service(service);
+		    service->remoteport == port &&
+		    kref_get_unless_zero(&service->ref_count)) {
+			service = rcu_pointer_handoff(service);
+			rcu_read_unlock();
 			return service;
 		}
 	}
+	rcu_read_unlock();
 	return NULL;
 }
 
@@ -2260,7 +2260,7 @@ vchiq_add_service_internal(struct vchiq_state *state,
 			   vchiq_userdata_term userdata_term)
 {
 	struct vchiq_service *service;
-	struct vchiq_service **pservice = NULL;
+	struct vchiq_service __rcu **pservice = NULL;
 	struct vchiq_service_quota *service_quota;
 	int i;
 
@@ -2272,7 +2272,7 @@ vchiq_add_service_internal(struct vchiq_state *state,
 	service->base.callback = params->callback;
 	service->base.userdata = params->userdata;
 	service->handle        = VCHIQ_SERVICE_HANDLE_INVALID;
-	service->ref_count     = 1;
+	kref_init(&service->ref_count);
 	service->srvstate      = VCHIQ_SRVSTATE_FREE;
 	service->userdata_term = userdata_term;
 	service->localport     = VCHIQ_PORT_FREE;
@@ -2298,7 +2298,7 @@ vchiq_add_service_internal(struct vchiq_state *state,
 	mutex_init(&service->bulk_mutex);
 	memset(&service->stats, 0, sizeof(service->stats));
 
-	/* Although it is perfectly possible to use service_spinlock
+	/* Although it is perfectly possible to use a spinlock
 	** to protect the creation of services, it is overkill as it
 	** disables interrupts while the array is searched.
 	** The only danger is of another thread trying to create a
@@ -2316,17 +2316,17 @@ vchiq_add_service_internal(struct vchiq_state *state,
 
 	if (srvstate == VCHIQ_SRVSTATE_OPENING) {
 		for (i = 0; i < state->unused_service; i++) {
-			struct vchiq_service *srv = state->services[i];
-
-			if (!srv) {
+			if (!rcu_access_pointer(state->services[i])) {
 				pservice = &state->services[i];
 				break;
 			}
 		}
 	} else {
+		rcu_read_lock();
 		for (i = (state->unused_service - 1); i >= 0; i--) {
-			struct vchiq_service *srv = state->services[i];
+			struct vchiq_service *srv;
 
+			srv = rcu_dereference(state->services[i]);
 			if (!srv)
 				pservice = &state->services[i];
 			else if ((srv->public_fourcc == params->fourcc)
@@ -2339,6 +2339,7 @@ vchiq_add_service_internal(struct vchiq_state *state,
 				break;
 			}
 		}
+		rcu_read_unlock();
 	}
 
 	if (pservice) {
@@ -2350,7 +2351,7 @@ vchiq_add_service_internal(struct vchiq_state *state,
 			(state->id * VCHIQ_MAX_SERVICES) |
 			service->localport;
 		handle_seq += VCHIQ_MAX_STATES * VCHIQ_MAX_SERVICES;
-		*pservice = service;
+		rcu_assign_pointer(*pservice, service);
 		if (pservice == &state->services[state->unused_service])
 			state->unused_service++;
 	}
@@ -2416,10 +2417,10 @@ vchiq_open_service_internal(struct vchiq_service *service, int client_id)
 			   (service->srvstate != VCHIQ_SRVSTATE_OPENSYNC)) {
 			if (service->srvstate != VCHIQ_SRVSTATE_CLOSEWAIT)
 				vchiq_log_error(vchiq_core_log_level,
-						"%d: osi - srvstate = %s (ref %d)",
+						"%d: osi - srvstate = %s (ref %u)",
 						service->state->id,
 						srvstate_names[service->srvstate],
-						service->ref_count);
+						kref_read(&service->ref_count));
 			status = VCHIQ_ERROR;
 			VCHIQ_SERVICE_STATS_INC(service, error_count);
 			vchiq_release_service_internal(service);
@@ -3425,10 +3426,13 @@ int vchiq_dump_service_state(void *dump_context, struct vchiq_service *service)
 	char buf[80];
 	int len;
 	int err;
+	unsigned int ref_count;
 
+	/*Don't include the lock just taken*/
+	ref_count = kref_read(&service->ref_count) - 1;
 	len = scnprintf(buf, sizeof(buf), "Service %u: %s (ref %u)",
 			service->localport, srvstate_names[service->srvstate],
-			service->ref_count - 1); /*Don't include the lock just taken*/
+			ref_count);
 
 	if (service->srvstate != VCHIQ_SRVSTATE_FREE) {
 		char remoteport[30];
diff --git a/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.h b/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.h
index 604d0c330819..30e4965c7666 100644
--- a/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.h
+++ b/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.h
@@ -7,6 +7,8 @@
 #include <linux/mutex.h>
 #include <linux/completion.h>
 #include <linux/kthread.h>
+#include <linux/kref.h>
+#include <linux/rcupdate.h>
 #include <linux/wait.h>
 
 #include "vchiq_cfg.h"
@@ -251,7 +253,8 @@ struct vchiq_slot_info {
 struct vchiq_service {
 	struct vchiq_service_base base;
 	unsigned int handle;
-	unsigned int ref_count;
+	struct kref ref_count;
+	struct rcu_head rcu;
 	int srvstate;
 	vchiq_userdata_term userdata_term;
 	unsigned int localport;
@@ -464,7 +467,7 @@ struct vchiq_state {
 		int error_count;
 	} stats;
 
-	struct vchiq_service *services[VCHIQ_MAX_SERVICES];
+	struct vchiq_service __rcu *services[VCHIQ_MAX_SERVICES];
 	struct vchiq_service_quota service_quotas[VCHIQ_MAX_SERVICES];
 	struct vchiq_slot_info slot_info[VCHIQ_MAX_SLOTS];
 
@@ -545,12 +548,13 @@ request_poll(struct vchiq_state *state, struct vchiq_service *service,
 static inline struct vchiq_service *
 handle_to_service(unsigned int handle)
 {
+	int idx = handle & (VCHIQ_MAX_SERVICES - 1);
 	struct vchiq_state *state = vchiq_states[(handle / VCHIQ_MAX_SERVICES) &
 		(VCHIQ_MAX_STATES - 1)];
+
 	if (!state)
 		return NULL;
-
-	return state->services[handle & (VCHIQ_MAX_SERVICES - 1)];
+	return rcu_dereference(state->services[idx]);
 }
 
 extern struct vchiq_service *
-- 
2.25.0.225.g125e21ebc7-goog



More information about the devel mailing list