Skip to content

Commit a689b8d

Browse files
authored
[SYCL] Protect access to the native handle of a sycl::event (#15179)
Fix for #14623 Currently event_impl exposes reference to the underlying UR handle. As a result this handle can be updated/read at the random moments of time by different threads causing data race. This PR removes methods which expose the reference and replace them with thread-safe getter/setter.
1 parent e374c69 commit a689b8d

19 files changed

+294
-182
lines changed

sycl/source/detail/event_impl.cpp

+48-33
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,19 @@ void event_impl::initContextIfNeeded() {
4444

4545
event_impl::~event_impl() {
4646
try {
47-
if (MEvent)
48-
getPlugin()->call(urEventRelease, MEvent);
47+
auto Handle = this->getHandle();
48+
if (Handle)
49+
getPlugin()->call(urEventRelease, Handle);
4950
} catch (std::exception &e) {
5051
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~event_impl", e);
5152
}
5253
}
5354

5455
void event_impl::waitInternal(bool *Success) {
55-
if (!MIsHostEvent && MEvent) {
56+
auto Handle = this->getHandle();
57+
if (!MIsHostEvent && Handle) {
5658
// Wait for the native event
57-
ur_result_t Err = getPlugin()->call_nocheck(urEventWait, 1, &MEvent);
59+
ur_result_t Err = getPlugin()->call_nocheck(urEventWait, 1, &Handle);
5860
// TODO drop the UR_RESULT_ERROR_UKNOWN from here (this was waiting for
5961
// https://github.com/oneapi-src/unified-runtime/issues/1459 which is now
6062
// closed).
@@ -89,7 +91,7 @@ void event_impl::waitInternal(bool *Success) {
8991
}
9092

9193
void event_impl::setComplete() {
92-
if (MIsHostEvent || !MEvent) {
94+
if (MIsHostEvent || !this->getHandle()) {
9395
{
9496
std::unique_lock<std::mutex> lock(MMutex);
9597
#ifndef NDEBUG
@@ -116,8 +118,11 @@ static uint64_t inline getTimestamp() {
116118
.count();
117119
}
118120

119-
const ur_event_handle_t &event_impl::getHandleRef() const { return MEvent; }
120-
ur_event_handle_t &event_impl::getHandleRef() { return MEvent; }
121+
ur_event_handle_t event_impl::getHandle() const { return MEvent.load(); }
122+
123+
void event_impl::setHandle(const ur_event_handle_t &UREvent) {
124+
MEvent.store(UREvent);
125+
}
121126

122127
const ContextImplPtr &event_impl::getContextImpl() {
123128
initContextIfNeeded();
@@ -141,7 +146,7 @@ event_impl::event_impl(ur_event_handle_t Event, const context &SyclContext)
141146
MIsFlushed(true), MState(HES_Complete) {
142147

143148
ur_context_handle_t TempContext;
144-
getPlugin()->call(urEventGetInfo, MEvent, UR_EVENT_INFO_CONTEXT,
149+
getPlugin()->call(urEventGetInfo, this->getHandle(), UR_EVENT_INFO_CONTEXT,
145150
sizeof(ur_context_handle_t), &TempContext, nullptr);
146151

147152
if (MContext->getHandleRef() != TempContext) {
@@ -183,7 +188,7 @@ void *event_impl::instrumentationProlog(std::string &Name, int32_t StreamID,
183188
// Create a string with the event address so it
184189
// can be associated with other debug data
185190
xpti::utils::StringHelper SH;
186-
Name = SH.nameWithAddress<ur_event_handle_t>("event.wait", MEvent);
191+
Name = SH.nameWithAddress<ur_event_handle_t>("event.wait", this->getHandle());
187192

188193
// We can emit the wait associated with the graph if the
189194
// event does not have a command object or associated with
@@ -249,9 +254,10 @@ void event_impl::wait(std::shared_ptr<sycl::detail::event_impl> Self,
249254
TelemetryEvent = instrumentationProlog(Name, StreamID, IId);
250255
#endif
251256

252-
if (MEvent)
253-
// presence of MEvent means the command has been enqueued, so no need to
254-
// go via the slow path event waiting in the scheduler
257+
auto EventHandle = getHandle();
258+
if (EventHandle)
259+
// presence of the native handle means the command has been enqueued, so no
260+
// need to go via the slow path event waiting in the scheduler
255261
waitInternal(Success);
256262
else if (MCommand)
257263
detail::Scheduler::getInstance().waitForEvent(Self, Success);
@@ -294,7 +300,7 @@ event_impl::get_profiling_info<info::event_profiling::command_submit>() {
294300
// For profiling tag events we rely on the submission time reported as
295301
// the start time has undefined behavior.
296302
return get_event_profiling_info<info::event_profiling::command_submit>(
297-
this->getHandleRef(), this->getPlugin());
303+
this->getHandle(), this->getPlugin());
298304
}
299305

300306
// The delay between the submission and the actual start of a CommandBuffer
@@ -311,10 +317,11 @@ event_impl::get_profiling_info<info::event_profiling::command_submit>() {
311317
// made by forcing the re-sync of submit time to start time is less than
312318
// 0.5ms. These timing values were obtained empirically using an integrated
313319
// Intel GPU).
314-
if (MEventFromSubmittedExecCommandBuffer && !MIsHostEvent && MEvent) {
320+
auto Handle = this->getHandle();
321+
if (MEventFromSubmittedExecCommandBuffer && !MIsHostEvent && Handle) {
315322
uint64_t StartTime =
316323
get_event_profiling_info<info::event_profiling::command_start>(
317-
this->getHandleRef(), this->getPlugin());
324+
Handle, this->getPlugin());
318325
if (StartTime < MSubmitTime)
319326
MSubmitTime = StartTime;
320327
}
@@ -326,16 +333,17 @@ uint64_t
326333
event_impl::get_profiling_info<info::event_profiling::command_start>() {
327334
checkProfilingPreconditions();
328335
if (!MIsHostEvent) {
329-
if (MEvent) {
336+
auto Handle = getHandle();
337+
if (Handle) {
330338
auto StartTime =
331339
get_event_profiling_info<info::event_profiling::command_start>(
332-
this->getHandleRef(), this->getPlugin());
340+
Handle, this->getPlugin());
333341
if (!MFallbackProfiling) {
334342
return StartTime;
335343
} else {
336344
auto DeviceBaseTime =
337345
get_event_profiling_info<info::event_profiling::command_submit>(
338-
this->getHandleRef(), this->getPlugin());
346+
Handle, this->getPlugin());
339347
return MHostBaseTime - DeviceBaseTime + StartTime;
340348
}
341349
}
@@ -353,16 +361,17 @@ template <>
353361
uint64_t event_impl::get_profiling_info<info::event_profiling::command_end>() {
354362
checkProfilingPreconditions();
355363
if (!MIsHostEvent) {
356-
if (MEvent) {
364+
auto Handle = this->getHandle();
365+
if (Handle) {
357366
auto EndTime =
358367
get_event_profiling_info<info::event_profiling::command_end>(
359-
this->getHandleRef(), this->getPlugin());
368+
Handle, this->getPlugin());
360369
if (!MFallbackProfiling) {
361370
return EndTime;
362371
} else {
363372
auto DeviceBaseTime =
364373
get_event_profiling_info<info::event_profiling::command_submit>(
365-
this->getHandleRef(), this->getPlugin());
374+
Handle, this->getPlugin());
366375
return MHostBaseTime - DeviceBaseTime + EndTime;
367376
}
368377
}
@@ -377,8 +386,9 @@ uint64_t event_impl::get_profiling_info<info::event_profiling::command_end>() {
377386
}
378387

379388
template <> uint32_t event_impl::get_info<info::event::reference_count>() {
380-
if (!MIsHostEvent && MEvent) {
381-
return get_event_info<info::event::reference_count>(this->getHandleRef(),
389+
auto Handle = this->getHandle();
390+
if (!MIsHostEvent && Handle) {
391+
return get_event_info<info::event::reference_count>(Handle,
382392
this->getPlugin());
383393
}
384394
return 0;
@@ -392,9 +402,10 @@ event_impl::get_info<info::event::command_execution_status>() {
392402

393403
if (!MIsHostEvent) {
394404
// Command is enqueued and UrEvent is ready
395-
if (MEvent)
405+
auto Handle = this->getHandle();
406+
if (Handle)
396407
return get_event_info<info::event::command_execution_status>(
397-
this->getHandleRef(), this->getPlugin());
408+
Handle, this->getPlugin());
398409
// Command is blocked and not enqueued, UrEvent is not assigned yet
399410
else if (MCommand)
400411
return sycl::info::event_command_status::submitted;
@@ -471,17 +482,20 @@ ur_native_handle_t event_impl::getNative() {
471482
initContextIfNeeded();
472483

473484
auto Plugin = getPlugin();
474-
if (MIsDefaultConstructed && !MEvent) {
485+
auto Handle = getHandle();
486+
if (MIsDefaultConstructed && !Handle) {
475487
auto TempContext = MContext.get()->getHandleRef();
476488
ur_event_native_properties_t NativeProperties{};
489+
ur_event_handle_t UREvent = nullptr;
477490
Plugin->call(urEventCreateWithNativeHandle, 0, TempContext,
478-
&NativeProperties, &MEvent);
491+
&NativeProperties, &UREvent);
492+
this->setHandle(UREvent);
479493
}
480494
if (MContext->getBackend() == backend::opencl)
481-
Plugin->call(urEventRetain, getHandleRef());
482-
ur_native_handle_t Handle;
483-
Plugin->call(urEventGetNativeHandle, getHandleRef(), &Handle);
484-
return Handle;
495+
Plugin->call(urEventRetain, Handle);
496+
ur_native_handle_t OutHandle;
497+
Plugin->call(urEventGetNativeHandle, Handle, &OutHandle);
498+
return OutHandle;
485499
}
486500

487501
std::vector<EventImplPtr> event_impl::getWaitList() {
@@ -505,7 +519,8 @@ std::vector<EventImplPtr> event_impl::getWaitList() {
505519
void event_impl::flushIfNeeded(const QueueImplPtr &UserQueue) {
506520
// Some events might not have a native handle underneath even at this point,
507521
// e.g. those produced by memset with 0 size (no UR call is made).
508-
if (MIsFlushed || !MEvent)
522+
auto Handle = this->getHandle();
523+
if (MIsFlushed || !Handle)
509524
return;
510525

511526
QueueImplPtr Queue = MQueue.lock();
@@ -520,7 +535,7 @@ void event_impl::flushIfNeeded(const QueueImplPtr &UserQueue) {
520535

521536
// Check if the task for this event has already been submitted.
522537
ur_event_status_t Status = UR_EVENT_STATUS_QUEUED;
523-
getPlugin()->call(urEventGetInfo, MEvent,
538+
getPlugin()->call(urEventGetInfo, Handle,
524539
UR_EVENT_INFO_COMMAND_EXECUTION_STATUS,
525540
sizeof(ur_event_status_t), &Status, nullptr);
526541
if (Status == UR_EVENT_STATUS_QUEUED) {

sycl/source/detail/event_impl.hpp

+7-12
Original file line numberDiff line numberDiff line change
@@ -126,16 +126,11 @@ class event_impl {
126126
/// Marks this event as completed.
127127
void setComplete();
128128

129-
/// Returns raw interoperability event handle. Returned reference will be
130-
/// invalid if event_impl was destroyed.
131-
///
132-
/// \return a reference to an instance of plug-in event handle.
133-
ur_event_handle_t &getHandleRef();
134-
/// Returns raw interoperability event handle. Returned reference will be
135-
/// invalid if event_impl was destroyed.
136-
///
137-
/// \return a const reference to an instance of plug-in event handle.
138-
const ur_event_handle_t &getHandleRef() const;
129+
/// Returns raw interoperability event handle.
130+
ur_event_handle_t getHandle() const;
131+
132+
/// Set event handle for this event object.
133+
void setHandle(const ur_event_handle_t &UREvent);
139134

140135
/// Returns context that is associated with this event.
141136
///
@@ -240,7 +235,7 @@ class event_impl {
240235
/// have native handle.
241236
///
242237
/// @return true if no associated command and no event handle.
243-
bool isNOP() { return !MCommand && !getHandleRef(); }
238+
bool isNOP() { return !MCommand && !getHandle(); }
244239

245240
/// Calling this function queries the current device timestamp and sets it as
246241
/// submission time for the command associated with this event.
@@ -344,7 +339,7 @@ class event_impl {
344339
int32_t StreamID, uint64_t IId) const;
345340
void checkProfilingPreconditions() const;
346341

347-
ur_event_handle_t MEvent = nullptr;
342+
std::atomic<ur_event_handle_t> MEvent = nullptr;
348343
// Stores submission time of command associated with event
349344
uint64_t MSubmitTime = 0;
350345
uint64_t MHostBaseTime = 0;

sycl/source/detail/graph_impl.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,7 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
910910
}
911911

912912
NewEvent = CreateNewEvent();
913-
ur_event_handle_t *OutEvent = &NewEvent->getHandleRef();
913+
ur_event_handle_t UREvent = nullptr;
914914
// Merge requirements from the nodes into requirements (if any) from the
915915
// handler.
916916
CGData.MRequirements.insert(CGData.MRequirements.end(),
@@ -927,7 +927,8 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
927927
}
928928
ur_result_t Res = Queue->getPlugin()->call_nocheck(
929929
urCommandBufferEnqueueExp, CommandBuffer, Queue->getHandleRef(), 0,
930-
nullptr, OutEvent);
930+
nullptr, &UREvent);
931+
NewEvent->setHandle(UREvent);
931932
if (Res == UR_RESULT_ERROR_INVALID_QUEUE_PROPERTIES) {
932933
throw sycl::exception(
933934
make_error_code(errc::invalid),

sycl/source/detail/memory_manager.cpp

+4-5
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,9 @@ static void waitForEvents(const std::vector<EventImplPtr> &Events) {
124124
if (!Events.empty()) {
125125
const PluginPtr &Plugin = Events[0]->getPlugin();
126126
std::vector<ur_event_handle_t> UrEvents(Events.size());
127-
std::transform(Events.begin(), Events.end(), UrEvents.begin(),
128-
[](const EventImplPtr &EventImpl) {
129-
return EventImpl->getHandleRef();
130-
});
127+
std::transform(
128+
Events.begin(), Events.end(), UrEvents.begin(),
129+
[](const EventImplPtr &EventImpl) { return EventImpl->getHandle(); });
131130
if (!UrEvents.empty() && UrEvents[0]) {
132131
Plugin->call(urEventWait, UrEvents.size(), &UrEvents[0]);
133132
}
@@ -313,7 +312,7 @@ void *MemoryManager::allocateInteropMemObject(
313312
// If memory object is created with interop c'tor return cl_mem as is.
314313
assert(TargetContext == InteropContext && "Expected matching contexts");
315314

316-
OutEventToWait = InteropEvent->getHandleRef();
315+
OutEventToWait = InteropEvent->getHandle();
317316
// Retain the event since it will be released during alloca command
318317
// destruction
319318
if (nullptr != OutEventToWait) {

sycl/source/detail/queue_impl.cpp

+9-6
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ getUrEvents(const std::vector<sycl::event> &DepEvents) {
4949
std::vector<ur_event_handle_t> RetUrEvents;
5050
for (const sycl::event &Event : DepEvents) {
5151
const EventImplPtr &EventImpl = detail::getSyclObjImpl(Event);
52-
if (EventImpl->getHandleRef() != nullptr)
53-
RetUrEvents.push_back(EventImpl->getHandleRef());
52+
auto Handle = EventImpl->getHandle();
53+
if (Handle != nullptr)
54+
RetUrEvents.push_back(Handle);
5455
}
5556
return RetUrEvents;
5657
}
@@ -307,7 +308,7 @@ void queue_impl::addEvent(const event &Event) {
307308
}
308309
// As long as the queue supports urQueueFinish we only need to store events
309310
// for unenqueued commands and host tasks.
310-
else if (MEmulateOOO || EImpl->getHandleRef() == nullptr) {
311+
else if (MEmulateOOO || EImpl->getHandle() == nullptr) {
311312
std::weak_ptr<event_impl> EventWeakPtr{EImpl};
312313
std::lock_guard<std::mutex> Lock{MMutex};
313314
MEventsWeak.push_back(std::move(EventWeakPtr));
@@ -447,8 +448,10 @@ event queue_impl::submitMemOpHelper(const std::shared_ptr<queue_impl> &Self,
447448
auto EventImpl = detail::getSyclObjImpl(ResEvent);
448449
{
449450
NestedCallsTracker tracker;
450-
MemOpFunc(MemOpArgs..., getUrEvents(ExpandedDepEvents),
451-
&EventImpl->getHandleRef(), EventImpl);
451+
ur_event_handle_t UREvent = nullptr;
452+
MemOpFunc(MemOpArgs..., getUrEvents(ExpandedDepEvents), &UREvent,
453+
EventImpl);
454+
EventImpl->setHandle(UREvent);
452455
}
453456

454457
if (isInOrder()) {
@@ -603,7 +606,7 @@ void queue_impl::wait(const detail::code_location &CodeLoc) {
603606
EventImplWeakPtrIt->lock()) {
604607
// A nullptr UR event indicates that urQueueFinish will not cover it,
605608
// either because it's a host task event or an unenqueued one.
606-
if (!SupportsPiFinish || nullptr == EventImplSharedPtr->getHandleRef()) {
609+
if (!SupportsPiFinish || nullptr == EventImplSharedPtr->getHandle()) {
607610
EventImplSharedPtr->wait(EventImplSharedPtr);
608611
}
609612
}

sycl/source/detail/queue_impl.hpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -737,9 +737,10 @@ class queue_impl {
737737
template <typename HandlerType = handler>
738738
EventImplPtr insertHelperBarrier(const HandlerType &Handler) {
739739
auto ResEvent = std::make_shared<detail::event_impl>(Handler.MQueue);
740+
ur_event_handle_t UREvent = nullptr;
740741
getPlugin()->call(urEnqueueEventsWaitWithBarrier,
741-
Handler.MQueue->getHandleRef(), 0, nullptr,
742-
&ResEvent->getHandleRef());
742+
Handler.MQueue->getHandleRef(), 0, nullptr, &UREvent);
743+
ResEvent->setHandle(UREvent);
743744
return ResEvent;
744745
}
745746

sycl/source/detail/reduction.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,10 @@ addCounterInit(handler &CGH, std::shared_ptr<sycl::detail::queue_impl> &Queue,
172172
auto EventImpl = std::make_shared<detail::event_impl>(Queue);
173173
EventImpl->setContextImpl(detail::getSyclObjImpl(Queue->get_context()));
174174
EventImpl->setStateIncomplete();
175-
MemoryManager::fill_usm(Counter.get(), Queue, sizeof(int), {0}, {},
176-
&EventImpl->getHandleRef(), EventImpl);
175+
ur_event_handle_t UREvent = nullptr;
176+
MemoryManager::fill_usm(Counter.get(), Queue, sizeof(int), {0}, {}, &UREvent,
177+
EventImpl);
178+
EventImpl->setHandle(UREvent);
177179
CGH.depends_on(createSyclObjFromImpl<event>(EventImpl));
178180
}
179181

0 commit comments

Comments
 (0)