Skip to content

Commit 7eae5c8

Browse files
authored
Merge pull request #2048 from RossBrunton/ross/refc
Use reference counting on factories
2 parents 9c652ff + b78cfa7 commit 7eae5c8

File tree

3 files changed

+121
-9
lines changed

3 files changed

+121
-9
lines changed

scripts/templates/ldrddi.cpp.mako

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,11 +273,17 @@ namespace ur_loader
273273
274274
%endif
275275
%endif
276-
## Before we can re-enable the releases we will need ref-counted object_t.
277-
## See unified-runtime github issue #1784
278-
##%if item['release']:
279-
##// release loader handle
280-
##${item['factory']}.release( ${item['name']} );
276+
## Possibly handle release/retain ref counting - there are no ur_exp-image factories
277+
%if 'factory' in item and '_exp_image_' not in item['factory']:
278+
%if item['release']:
279+
// release loader handle
280+
context->factories.${item['factory']}.release( ${item['name']} );
281+
%endif
282+
%if item['retain']:
283+
// increment refcount of handle
284+
context->factories.${item['factory']}.retain( ${item['name']} );
285+
%endif
286+
%endif
281287
%if not item['release'] and not item['retain'] and not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
282288
try
283289
{

source/common/ur_singleton.hpp

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,26 @@
1111
#ifndef UR_SINGLETON_H
1212
#define UR_SINGLETON_H 1
1313

14+
#include <cassert>
1415
#include <memory>
1516
#include <mutex>
1617
#include <unordered_map>
1718

1819
//////////////////////////////////////////////////////////////////////////
1920
/// a abstract factory for creation of singleton objects
2021
template <typename singleton_tn, typename key_tn> class singleton_factory_t {
22+
struct entry_t {
23+
std::unique_ptr<singleton_tn> ptr;
24+
size_t ref_count;
25+
};
26+
2127
protected:
2228
using singleton_t = singleton_tn;
2329
using key_t = typename std::conditional<std::is_pointer<key_tn>::value,
2430
size_t, key_tn>::type;
2531

2632
using ptr_t = std::unique_ptr<singleton_t>;
27-
using map_t = std::unordered_map<key_t, ptr_t>;
33+
using map_t = std::unordered_map<key_t, entry_t>;
2834

2935
std::mutex mut; ///< lock for thread-safety
3036
map_t map; ///< single instance of singleton for each unique key
@@ -60,16 +66,31 @@ template <typename singleton_tn, typename key_tn> class singleton_factory_t {
6066
if (map.end() == iter) {
6167
auto ptr =
6268
std::make_unique<singleton_t>(std::forward<Ts>(params)...);
63-
iter = map.emplace(key, std::move(ptr)).first;
69+
iter = map.emplace(key, entry_t{std::move(ptr), 0}).first;
70+
} else {
71+
iter->second.ref_count++;
6472
}
65-
return iter->second.get();
73+
return iter->second.ptr.get();
74+
}
75+
76+
void retain(key_tn key) {
77+
std::lock_guard<std::mutex> lk(mut);
78+
auto iter = map.find(getKey(key));
79+
assert(iter != map.end());
80+
iter->second.ref_count++;
6681
}
6782

6883
//////////////////////////////////////////////////////////////////////////
6984
/// once the key is no longer valid, release the singleton
7085
void release(key_tn key) {
7186
std::lock_guard<std::mutex> lk(mut);
72-
map.erase(getKey(key));
87+
auto iter = map.find(getKey(key));
88+
assert(iter != map.end());
89+
if (iter->second.ref_count == 0) {
90+
map.erase(iter);
91+
} else {
92+
iter->second.ref_count--;
93+
}
7394
}
7495

7596
void clear() {

source/loader/ur_ldrddi.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease(
8585
// forward to device-platform
8686
result = pfnAdapterRelease(hAdapter);
8787

88+
// release loader handle
89+
context->factories.ur_adapter_factory.release(hAdapter);
90+
8891
return result;
8992
}
9093

@@ -110,6 +113,9 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain(
110113
// forward to device-platform
111114
result = pfnAdapterRetain(hAdapter);
112115

116+
// increment refcount of handle
117+
context->factories.ur_adapter_factory.retain(hAdapter);
118+
113119
return result;
114120
}
115121

@@ -647,6 +653,9 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain(
647653
// forward to device-platform
648654
result = pfnRetain(hDevice);
649655

656+
// increment refcount of handle
657+
context->factories.ur_device_factory.retain(hDevice);
658+
650659
return result;
651660
}
652661

@@ -673,6 +682,9 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease(
673682
// forward to device-platform
674683
result = pfnRelease(hDevice);
675684

685+
// release loader handle
686+
context->factories.ur_device_factory.release(hDevice);
687+
676688
return result;
677689
}
678690

@@ -943,6 +955,9 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain(
943955
// forward to device-platform
944956
result = pfnRetain(hContext);
945957

958+
// increment refcount of handle
959+
context->factories.ur_context_factory.retain(hContext);
960+
946961
return result;
947962
}
948963

@@ -969,6 +984,9 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease(
969984
// forward to device-platform
970985
result = pfnRelease(hContext);
971986

987+
// release loader handle
988+
context->factories.ur_context_factory.release(hContext);
989+
972990
return result;
973991
}
974992

@@ -1271,6 +1289,9 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain(
12711289
// forward to device-platform
12721290
result = pfnRetain(hMem);
12731291

1292+
// increment refcount of handle
1293+
context->factories.ur_mem_factory.retain(hMem);
1294+
12741295
return result;
12751296
}
12761297

@@ -1297,6 +1318,9 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease(
12971318
// forward to device-platform
12981319
result = pfnRelease(hMem);
12991320

1321+
// release loader handle
1322+
context->factories.ur_mem_factory.release(hMem);
1323+
13001324
return result;
13011325
}
13021326

@@ -1648,6 +1672,9 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain(
16481672
// forward to device-platform
16491673
result = pfnRetain(hSampler);
16501674

1675+
// increment refcount of handle
1676+
context->factories.ur_sampler_factory.retain(hSampler);
1677+
16511678
return result;
16521679
}
16531680

@@ -1674,6 +1701,9 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease(
16741701
// forward to device-platform
16751702
result = pfnRelease(hSampler);
16761703

1704+
// release loader handle
1705+
context->factories.ur_sampler_factory.release(hSampler);
1706+
16771707
return result;
16781708
}
16791709

@@ -2107,6 +2137,9 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain(
21072137
// forward to device-platform
21082138
result = pfnPoolRetain(pPool);
21092139

2140+
// increment refcount of handle
2141+
context->factories.ur_usm_pool_factory.retain(pPool);
2142+
21102143
return result;
21112144
}
21122145

@@ -2132,6 +2165,9 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease(
21322165
// forward to device-platform
21332166
result = pfnPoolRelease(pPool);
21342167

2168+
// release loader handle
2169+
context->factories.ur_usm_pool_factory.release(pPool);
2170+
21352171
return result;
21362172
}
21372173

@@ -2517,6 +2553,9 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain(
25172553
// forward to device-platform
25182554
result = pfnRetain(hPhysicalMem);
25192555

2556+
// increment refcount of handle
2557+
context->factories.ur_physical_mem_factory.retain(hPhysicalMem);
2558+
25202559
return result;
25212560
}
25222561

@@ -2545,6 +2584,9 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease(
25452584
// forward to device-platform
25462585
result = pfnRelease(hPhysicalMem);
25472586

2587+
// release loader handle
2588+
context->factories.ur_physical_mem_factory.release(hPhysicalMem);
2589+
25482590
return result;
25492591
}
25502592

@@ -2876,6 +2918,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain(
28762918
// forward to device-platform
28772919
result = pfnRetain(hProgram);
28782920

2921+
// increment refcount of handle
2922+
context->factories.ur_program_factory.retain(hProgram);
2923+
28792924
return result;
28802925
}
28812926

@@ -2902,6 +2947,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease(
29022947
// forward to device-platform
29032948
result = pfnRelease(hProgram);
29042949

2950+
// release loader handle
2951+
context->factories.ur_program_factory.release(hProgram);
2952+
29052953
return result;
29062954
}
29072955

@@ -3499,6 +3547,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(
34993547
// forward to device-platform
35003548
result = pfnRetain(hKernel);
35013549

3550+
// increment refcount of handle
3551+
context->factories.ur_kernel_factory.retain(hKernel);
3552+
35023553
return result;
35033554
}
35043555

@@ -3525,6 +3576,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease(
35253576
// forward to device-platform
35263577
result = pfnRelease(hKernel);
35273578

3579+
// release loader handle
3580+
context->factories.ur_kernel_factory.release(hKernel);
3581+
35283582
return result;
35293583
}
35303584

@@ -3975,6 +4029,9 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain(
39754029
// forward to device-platform
39764030
result = pfnRetain(hQueue);
39774031

4032+
// increment refcount of handle
4033+
context->factories.ur_queue_factory.retain(hQueue);
4034+
39784035
return result;
39794036
}
39804037

@@ -4001,6 +4058,9 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease(
40014058
// forward to device-platform
40024059
result = pfnRelease(hQueue);
40034060

4061+
// release loader handle
4062+
context->factories.ur_queue_factory.release(hQueue);
4063+
40044064
return result;
40054065
}
40064066

@@ -4305,6 +4365,9 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain(
43054365
// forward to device-platform
43064366
result = pfnRetain(hEvent);
43074367

4368+
// increment refcount of handle
4369+
context->factories.ur_event_factory.retain(hEvent);
4370+
43084371
return result;
43094372
}
43104373

@@ -4330,6 +4393,9 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease(
43304393
// forward to device-platform
43314394
result = pfnRelease(hEvent);
43324395

4396+
// release loader handle
4397+
context->factories.ur_event_factory.release(hEvent);
4398+
43334399
return result;
43344400
}
43354401

@@ -6862,6 +6928,9 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseExternalMemoryExp(
68626928
// forward to device-platform
68636929
result = pfnReleaseExternalMemoryExp(hContext, hDevice, hExternalMem);
68646930

6931+
// release loader handle
6932+
context->factories.ur_exp_external_mem_factory.release(hExternalMem);
6933+
68656934
return result;
68666935
}
68676936

@@ -6952,6 +7021,10 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseExternalSemaphoreExp(
69527021
result =
69537022
pfnReleaseExternalSemaphoreExp(hContext, hDevice, hExternalSemaphore);
69547023

7024+
// release loader handle
7025+
context->factories.ur_exp_external_semaphore_factory.release(
7026+
hExternalSemaphore);
7027+
69557028
return result;
69567029
}
69577030

@@ -7179,6 +7252,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp(
71797252
// forward to device-platform
71807253
result = pfnRetainExp(hCommandBuffer);
71817254

7255+
// increment refcount of handle
7256+
context->factories.ur_exp_command_buffer_factory.retain(hCommandBuffer);
7257+
71827258
return result;
71837259
}
71847260

@@ -7209,6 +7285,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseExp(
72097285
// forward to device-platform
72107286
result = pfnReleaseExp(hCommandBuffer);
72117287

7288+
// release loader handle
7289+
context->factories.ur_exp_command_buffer_factory.release(hCommandBuffer);
7290+
72127291
return result;
72137292
}
72147293

@@ -8525,6 +8604,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp(
85258604
// forward to device-platform
85268605
result = pfnRetainCommandExp(hCommand);
85278606

8607+
// increment refcount of handle
8608+
context->factories.ur_exp_command_buffer_command_factory.retain(hCommand);
8609+
85288610
return result;
85298611
}
85308612

@@ -8556,6 +8638,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseCommandExp(
85568638
// forward to device-platform
85578639
result = pfnReleaseCommandExp(hCommand);
85588640

8641+
// release loader handle
8642+
context->factories.ur_exp_command_buffer_command_factory.release(hCommand);
8643+
85598644
return result;
85608645
}
85618646

0 commit comments

Comments
 (0)