Skip to content

Commit d277bc4

Browse files
committed
[UR] Use reference counting on OpenCL adapters
1 parent 4ead2bf commit d277bc4

File tree

11 files changed

+95
-95
lines changed

11 files changed

+95
-95
lines changed

unified-runtime/source/adapters/opencl/adapter.cpp

Lines changed: 33 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
#include <dlfcn.h>
1919
#endif
2020

21+
// There can only be one OpenCL adapter alive at a time.
22+
// If it is alive (more get/retains than releases called), this is a pointer to
23+
// it.
24+
static ur_adapter_handle_t liveAdapter = nullptr;
25+
2126
ur_adapter_handle_t_::ur_adapter_handle_t_() {
2227
#ifdef _MSC_VER
2328

@@ -42,45 +47,38 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
4247
#undef CL_CORE_FUNCTION
4348

4449
#endif // _MSC_VER
50+
assert(!liveAdapter);
51+
liveAdapter = this;
4552
}
4653

47-
static ur_adapter_handle_t adapter = nullptr;
54+
ur_adapter_handle_t_::~ur_adapter_handle_t_() {
55+
assert(liveAdapter == this);
56+
liveAdapter = nullptr;
57+
}
4858

4959
ur_adapter_handle_t ur::cl::getAdapter() {
50-
if (!adapter) {
60+
if (!liveAdapter) {
5161
die("OpenCL adapter used before initalization or after destruction");
5262
}
53-
return adapter;
54-
}
55-
56-
static void globalAdapterShutdown() {
57-
if (cl_ext::ExtFuncPtrCache) {
58-
delete cl_ext::ExtFuncPtrCache;
59-
cl_ext::ExtFuncPtrCache = nullptr;
60-
}
61-
if (adapter) {
62-
delete adapter;
63-
adapter = nullptr;
64-
}
63+
return liveAdapter;
6564
}
6665

6766
UR_APIEXPORT ur_result_t UR_APICALL
6867
urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
6968
uint32_t *pNumAdapters) {
69+
static std::mutex AdapterConstructionMutex{};
70+
7071
if (NumEntries > 0 && phAdapters) {
71-
// Sometimes urAdaterGet may be called after the library already been torn
72-
// down, we also need to create a temporary handle for it.
73-
if (!adapter) {
74-
adapter = new ur_adapter_handle_t_();
75-
atexit(globalAdapterShutdown);
76-
}
72+
std::lock_guard<std::mutex> Lock{AdapterConstructionMutex};
7773

78-
std::lock_guard<std::mutex> Lock{adapter->Mutex};
79-
if (adapter->RefCount++ == 0) {
80-
cl_ext::ExtFuncPtrCache = new cl_ext::ExtFuncPtrCacheT();
74+
if (!liveAdapter) {
75+
*phAdapters = new ur_adapter_handle_t_();
76+
} else {
77+
*phAdapters = liveAdapter;
8178
}
8279

83-
*phAdapters = adapter;
80+
auto &adapter = *phAdapters;
81+
adapter->RefCount++;
8482
}
8583

8684
if (pNumAdapters) {
@@ -90,21 +88,16 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
9088
return UR_RESULT_SUCCESS;
9189
}
9290

93-
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
94-
++adapter->RefCount;
91+
UR_APIEXPORT ur_result_t UR_APICALL
92+
urAdapterRetain(ur_adapter_handle_t hAdapter) {
93+
++hAdapter->RefCount;
9594
return UR_RESULT_SUCCESS;
9695
}
9796

98-
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
99-
// Check first if the adapter is valid pointer
100-
if (adapter) {
101-
std::lock_guard<std::mutex> Lock{adapter->Mutex};
102-
if (--adapter->RefCount == 0) {
103-
if (cl_ext::ExtFuncPtrCache) {
104-
delete cl_ext::ExtFuncPtrCache;
105-
cl_ext::ExtFuncPtrCache = nullptr;
106-
}
107-
}
97+
UR_APIEXPORT ur_result_t UR_APICALL
98+
urAdapterRelease(ur_adapter_handle_t hAdapter) {
99+
if (--hAdapter->RefCount == 0) {
100+
delete hAdapter;
108101
}
109102
return UR_RESULT_SUCCESS;
110103
}
@@ -117,18 +110,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetLastError(
117110
return UR_RESULT_SUCCESS;
118111
}
119112

120-
UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
121-
ur_adapter_info_t propName,
122-
size_t propSize,
123-
void *pPropValue,
124-
size_t *pPropSizeRet) {
113+
UR_APIEXPORT ur_result_t UR_APICALL
114+
urAdapterGetInfo(ur_adapter_handle_t hAdapter, ur_adapter_info_t propName,
115+
size_t propSize, void *pPropValue, size_t *pPropSizeRet) {
125116
UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);
126117

127118
switch (propName) {
128119
case UR_ADAPTER_INFO_BACKEND:
129120
return ReturnValue(UR_ADAPTER_BACKEND_OPENCL);
130121
case UR_ADAPTER_INFO_REFERENCE_COUNT:
131-
return ReturnValue(adapter->RefCount.load());
122+
return ReturnValue(hAdapter->RefCount.load());
132123
case UR_ADAPTER_INFO_VERSION:
133124
return ReturnValue(uint32_t{1});
134125
default:

unified-runtime/source/adapters/opencl/adapter.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,22 @@
77
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
88
//
99
//===----------------------------------------------------------------------===//
10+
#pragma once
11+
1012
#include "logger/ur_logger.hpp"
1113
#include "platform.hpp"
1214

1315
#include "CL/cl.h"
16+
#include "common.hpp"
1417
#include "logger/ur_logger.hpp"
1518

1619
struct ur_adapter_handle_t_ {
1720
ur_adapter_handle_t_();
21+
~ur_adapter_handle_t_();
1822

1923
std::atomic<uint32_t> RefCount = 0;
20-
std::mutex Mutex;
2124
logger::Logger &log = logger::get_logger("opencl");
25+
cl_ext::ExtFuncPtrCacheT fnCache{};
2226

2327
std::vector<std::unique_ptr<ur_platform_handle_t_>> URPlatforms;
2428
uint32_t NumPlatforms = 0;
@@ -34,5 +38,8 @@ struct ur_adapter_handle_t_ {
3438
namespace ur {
3539
namespace cl {
3640
ur_adapter_handle_t getAdapter();
41+
inline cl_ext::ExtFuncPtrCacheT &getExtFnPtrCache() {
42+
return getAdapter()->fnCache;
43+
}
3744
} // namespace cl
3845
} // namespace ur

unified-runtime/source/adapters/opencl/command_buffer.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#include "command_buffer.hpp"
12+
#include "adapter.hpp"
1213
#include "common.hpp"
1314
#include "context.hpp"
1415
#include "event.hpp"
@@ -25,7 +26,7 @@ ur_exp_command_buffer_handle_t_::~ur_exp_command_buffer_handle_t_() {
2526
cl_ext::clReleaseCommandBufferKHR_fn clReleaseCommandBufferKHR = nullptr;
2627
cl_int Res =
2728
cl_ext::getExtFuncFromContext<decltype(clReleaseCommandBufferKHR)>(
28-
CLContext, cl_ext::ExtFuncPtrCache->clReleaseCommandBufferKHRCache,
29+
CLContext, ur::cl::getExtFnPtrCache().clReleaseCommandBufferKHRCache,
2930
cl_ext::ReleaseCommandBufferName, &clReleaseCommandBufferKHR);
3031
assert(Res == CL_SUCCESS);
3132
(void)Res;
@@ -53,7 +54,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp(
5354
cl_ext::clCreateCommandBufferKHR_fn clCreateCommandBufferKHR = nullptr;
5455
UR_RETURN_ON_FAILURE(
5556
cl_ext::getExtFuncFromContext<decltype(clCreateCommandBufferKHR)>(
56-
CLContext, cl_ext::ExtFuncPtrCache->clCreateCommandBufferKHRCache,
57+
CLContext, ur::cl::getExtFnPtrCache().clCreateCommandBufferKHRCache,
5758
cl_ext::CreateCommandBufferName, &clCreateCommandBufferKHR));
5859

5960
const bool IsUpdatable = pCommandBufferDesc->isUpdatable;
@@ -114,7 +115,7 @@ urCommandBufferFinalizeExp(ur_exp_command_buffer_handle_t hCommandBuffer) {
114115
cl_ext::clFinalizeCommandBufferKHR_fn clFinalizeCommandBufferKHR = nullptr;
115116
UR_RETURN_ON_FAILURE(
116117
cl_ext::getExtFuncFromContext<decltype(clFinalizeCommandBufferKHR)>(
117-
CLContext, cl_ext::ExtFuncPtrCache->clFinalizeCommandBufferKHRCache,
118+
CLContext, ur::cl::getExtFnPtrCache().clFinalizeCommandBufferKHRCache,
118119
cl_ext::FinalizeCommandBufferName, &clFinalizeCommandBufferKHR));
119120

120121
CL_RETURN_ON_FAILURE(
@@ -146,7 +147,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
146147
cl_ext::clCommandNDRangeKernelKHR_fn clCommandNDRangeKernelKHR = nullptr;
147148
UR_RETURN_ON_FAILURE(
148149
cl_ext::getExtFuncFromContext<decltype(clCommandNDRangeKernelKHR)>(
149-
CLContext, cl_ext::ExtFuncPtrCache->clCommandNDRangeKernelKHRCache,
150+
CLContext, ur::cl::getExtFnPtrCache().clCommandNDRangeKernelKHRCache,
150151
cl_ext::CommandNRRangeKernelName, &clCommandNDRangeKernelKHR));
151152

152153
cl_mutable_command_khr CommandHandle = nullptr;
@@ -236,7 +237,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp(
236237
cl_ext::clCommandCopyBufferKHR_fn clCommandCopyBufferKHR = nullptr;
237238
UR_RETURN_ON_FAILURE(
238239
cl_ext::getExtFuncFromContext<decltype(clCommandCopyBufferKHR)>(
239-
CLContext, cl_ext::ExtFuncPtrCache->clCommandCopyBufferKHRCache,
240+
CLContext, ur::cl::getExtFnPtrCache().clCommandCopyBufferKHRCache,
240241
cl_ext::CommandCopyBufferName, &clCommandCopyBufferKHR));
241242

242243
const bool IsInOrder = hCommandBuffer->IsInOrder;
@@ -278,7 +279,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(
278279
cl_ext::clCommandCopyBufferRectKHR_fn clCommandCopyBufferRectKHR = nullptr;
279280
UR_RETURN_ON_FAILURE(
280281
cl_ext::getExtFuncFromContext<decltype(clCommandCopyBufferRectKHR)>(
281-
CLContext, cl_ext::ExtFuncPtrCache->clCommandCopyBufferRectKHRCache,
282+
CLContext, ur::cl::getExtFnPtrCache().clCommandCopyBufferRectKHRCache,
282283
cl_ext::CommandCopyBufferRectName, &clCommandCopyBufferRectKHR));
283284

284285
const bool IsInOrder = hCommandBuffer->IsInOrder;
@@ -386,7 +387,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
386387
cl_ext::clCommandFillBufferKHR_fn clCommandFillBufferKHR = nullptr;
387388
UR_RETURN_ON_FAILURE(
388389
cl_ext::getExtFuncFromContext<decltype(clCommandFillBufferKHR)>(
389-
CLContext, cl_ext::ExtFuncPtrCache->clCommandFillBufferKHRCache,
390+
CLContext, ur::cl::getExtFnPtrCache().clCommandFillBufferKHRCache,
390391
cl_ext::CommandFillBufferName, &clCommandFillBufferKHR));
391392

392393
const bool IsInOrder = hCommandBuffer->IsInOrder;
@@ -457,7 +458,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
457458
cl_ext::clEnqueueCommandBufferKHR_fn clEnqueueCommandBufferKHR = nullptr;
458459
UR_RETURN_ON_FAILURE(
459460
cl_ext::getExtFuncFromContext<decltype(clEnqueueCommandBufferKHR)>(
460-
CLContext, cl_ext::ExtFuncPtrCache->clEnqueueCommandBufferKHRCache,
461+
CLContext, ur::cl::getExtFnPtrCache().clEnqueueCommandBufferKHRCache,
461462
cl_ext::EnqueueCommandBufferName, &clEnqueueCommandBufferKHR));
462463

463464
const uint32_t NumberOfQueues = 1;
@@ -615,7 +616,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
615616
cl_ext::clUpdateMutableCommandsKHR_fn clUpdateMutableCommandsKHR = nullptr;
616617
UR_RETURN_ON_FAILURE(
617618
cl_ext::getExtFuncFromContext<decltype(clUpdateMutableCommandsKHR)>(
618-
CLContext, cl_ext::ExtFuncPtrCache->clUpdateMutableCommandsKHRCache,
619+
CLContext, ur::cl::getExtFnPtrCache().clUpdateMutableCommandsKHRCache,
619620
cl_ext::UpdateMutableCommandsName, &clUpdateMutableCommandsKHR));
620621

621622
if (!hCommandBuffer->IsFinalized || !hCommandBuffer->IsUpdatable)

unified-runtime/source/adapters/opencl/common.hpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -342,11 +342,6 @@ struct ExtFuncPtrCacheT {
342342
#undef CL_EXTENSION_FUNC
343343
}
344344
};
345-
// A raw pointer is used here since the lifetime of this map has to be tied to
346-
// piTeardown to avoid issues with static destruction order (a user application
347-
// might have static objects that indirectly access this cache in their
348-
// destructor).
349-
inline ExtFuncPtrCacheT *ExtFuncPtrCache;
350345

351346
// USM helper function to get an extension function pointer
352347
template <typename T>

unified-runtime/source/adapters/opencl/context.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -125,20 +125,9 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,
125125

126126
UR_APIEXPORT ur_result_t UR_APICALL
127127
urContextRelease(ur_context_handle_t hContext) {
128-
// If we're reasonably sure this context is about to be detroyed we should
129-
// clear the ext function pointer cache. This isn't foolproof sadly but it
130-
// should drastically reduce the chances of the pathological case described
131-
// in the comments in common.hpp.
132128
static std::mutex contextReleaseMutex;
133-
auto clContext = hContext->CLContext;
134129

135130
std::lock_guard<std::mutex> lock(contextReleaseMutex);
136-
size_t refCount = hContext->getReferenceCount();
137-
// ExtFuncPtrCache is destroyed in an atexit() callback, so it doesn't
138-
// necessarily outlive the adapter (or all the contexts).
139-
if (refCount == 1 && cl_ext::ExtFuncPtrCache) {
140-
cl_ext::ExtFuncPtrCache->clearCache(clContext);
141-
}
142131

143132
if (hContext->decrementReferenceCount() == 0) {
144133
delete hContext;

unified-runtime/source/adapters/opencl/context.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010
#pragma once
1111

12+
#include "adapter.hpp"
1213
#include "common.hpp"
1314
#include "device.hpp"
1415

@@ -29,6 +30,9 @@ struct ur_context_handle_t_ {
2930
Devices.emplace_back(phDevices[i]);
3031
urDeviceRetain(phDevices[i]);
3132
}
33+
// The context retains a reference to the adapter so it can clear the
34+
// function ptr cache on distruction
35+
urAdapterRetain(ur::cl::getAdapter());
3236
RefCount = 1;
3337
}
3438

@@ -42,6 +46,13 @@ struct ur_context_handle_t_ {
4246
const ur_device_handle_t *phDevices,
4347
ur_context_handle_t &Context);
4448
~ur_context_handle_t_() {
49+
// If we're reasonably sure this context is about to be destroyed we should
50+
// clear the ext function pointer cache. This isn't foolproof sadly but it
51+
// should drastically reduce the chances of the pathological case described
52+
// in the comments in common.hpp.
53+
ur::cl::getExtFnPtrCache().clearCache(CLContext);
54+
urAdapterRelease(ur::cl::getAdapter());
55+
4556
for (uint32_t i = 0; i < DeviceCount; i++) {
4657
urDeviceRelease(Devices[i]);
4758
}

unified-runtime/source/adapters/opencl/enqueue.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//
99
//===----------------------------------------------------------------------===//
1010

11+
#include "adapter.hpp"
1112
#include "common.hpp"
1213
#include "context.hpp"
1314
#include "event.hpp"
@@ -410,7 +411,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite(
410411
MapUREventsToCL(numEventsInWaitList, phEventWaitList, CLWaitEvents);
411412
cl_ext::clEnqueueWriteGlobalVariable_fn F = nullptr;
412413
UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext<decltype(F)>(
413-
Ctx, cl_ext::ExtFuncPtrCache->clEnqueueWriteGlobalVariableCache,
414+
Ctx, ur::cl::getExtFnPtrCache().clEnqueueWriteGlobalVariableCache,
414415
cl_ext::EnqueueWriteGlobalVariableName, &F));
415416

416417
cl_int Res =
@@ -432,7 +433,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead(
432433
MapUREventsToCL(numEventsInWaitList, phEventWaitList, CLWaitEvents);
433434
cl_ext::clEnqueueReadGlobalVariable_fn F = nullptr;
434435
UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext<decltype(F)>(
435-
Ctx, cl_ext::ExtFuncPtrCache->clEnqueueReadGlobalVariableCache,
436+
Ctx, ur::cl::getExtFnPtrCache().clEnqueueReadGlobalVariableCache,
436437
cl_ext::EnqueueReadGlobalVariableName, &F));
437438

438439
cl_int Res =
@@ -456,7 +457,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueReadHostPipe(
456457
cl_ext::clEnqueueReadHostPipeINTEL_fn FuncPtr = nullptr;
457458
UR_RETURN_ON_FAILURE(
458459
cl_ext::getExtFuncFromContext<cl_ext::clEnqueueReadHostPipeINTEL_fn>(
459-
CLContext, cl_ext::ExtFuncPtrCache->clEnqueueReadHostPipeINTELCache,
460+
CLContext, ur::cl::getExtFnPtrCache().clEnqueueReadHostPipeINTELCache,
460461
cl_ext::EnqueueReadHostPipeName, &FuncPtr));
461462

462463
if (FuncPtr) {
@@ -484,7 +485,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueWriteHostPipe(
484485
cl_ext::clEnqueueWriteHostPipeINTEL_fn FuncPtr = nullptr;
485486
UR_RETURN_ON_FAILURE(
486487
cl_ext::getExtFuncFromContext<cl_ext::clEnqueueWriteHostPipeINTEL_fn>(
487-
CLContext, cl_ext::ExtFuncPtrCache->clEnqueueWriteHostPipeINTELCache,
488+
CLContext,
489+
ur::cl::getExtFnPtrCache().clEnqueueWriteHostPipeINTELCache,
488490
cl_ext::EnqueueWriteHostPipeName, &FuncPtr));
489491

490492
if (FuncPtr) {

0 commit comments

Comments
 (0)