15
15
#include " umf_pools/disjoint_pool_config_parser.hpp"
16
16
#include " usm.hpp"
17
17
18
- #include < umf/pools/pool_disjoint.h>
19
- #include < umf/pools/pool_proxy.h>
20
18
#include < umf/providers/provider_level_zero.h>
21
19
22
20
namespace umf {
@@ -34,7 +32,17 @@ ur_result_t getProviderNativeError(const char *providerName,
34
32
}
35
33
} // namespace umf
36
34
37
- static usm::DisjointPoolAllConfigs initializeDisjointPoolConfig () {
35
+ static std::optional<usm::DisjointPoolAllConfigs>
36
+ initializeDisjointPoolConfig () {
37
+ const char *UrRetDisable = std::getenv (" UR_L0_DISABLE_USM_ALLOCATOR" );
38
+ const char *PiRetDisable =
39
+ std::getenv (" SYCL_PI_LEVEL_ZERO_DISABLE_USM_ALLOCATOR" );
40
+ const char *Disable =
41
+ UrRetDisable ? UrRetDisable : (PiRetDisable ? PiRetDisable : nullptr );
42
+ if (Disable != nullptr && Disable != std::string (" " )) {
43
+ return std::nullopt;
44
+ }
45
+
38
46
const char *PoolUrTraceVal = std::getenv (" UR_L0_USM_ALLOCATOR_TRACE" );
39
47
40
48
int PoolTrace = 0 ;
@@ -47,7 +55,14 @@ static usm::DisjointPoolAllConfigs initializeDisjointPoolConfig() {
47
55
return usm::DisjointPoolAllConfigs (PoolTrace);
48
56
}
49
57
50
- return usm::parseDisjointPoolConfig (PoolUrConfigVal, PoolTrace);
58
+ // TODO: rework parseDisjointPoolConfig to return optional,
59
+ // once EnableBuffers is no longer used (by legacy L0)
60
+ auto configs = usm::parseDisjointPoolConfig (PoolUrConfigVal, PoolTrace);
61
+ if (configs.EnableBuffers ) {
62
+ return configs;
63
+ }
64
+
65
+ return std::nullopt;
51
66
}
52
67
53
68
inline umf_usm_memory_type_t urToUmfMemoryType (ur_usm_type_t type) {
@@ -81,32 +96,35 @@ descToDisjoinPoolMemType(const usm::pool_descriptor &desc) {
81
96
}
82
97
}
83
98
84
- static umf::pool_unique_handle_t
85
- makePool (usm::umf_disjoint_pool_config_t *poolParams,
86
- usm::pool_descriptor poolDescriptor) {
87
- umf_level_zero_memory_provider_params_handle_t params = NULL ;
88
- umf_result_t umf_ret = umfLevelZeroMemoryProviderParamsCreate (¶ms);
99
+ static umf::provider_unique_handle_t
100
+ makeProvider (usm::pool_descriptor poolDescriptor) {
101
+ umf_level_zero_memory_provider_params_handle_t hParams;
102
+ umf_result_t umf_ret = umfLevelZeroMemoryProviderParamsCreate (&hParams);
89
103
if (umf_ret != UMF_RESULT_SUCCESS) {
90
104
throw umf::umf2urResult (umf_ret);
91
105
}
92
106
107
+ std::unique_ptr<umf_level_zero_memory_provider_params_t ,
108
+ decltype (&umfLevelZeroMemoryProviderParamsDestroy)>
109
+ params (hParams, &umfLevelZeroMemoryProviderParamsDestroy);
110
+
93
111
umf_ret = umfLevelZeroMemoryProviderParamsSetContext (
94
- params , poolDescriptor.hContext ->getZeHandle ());
112
+ hParams , poolDescriptor.hContext ->getZeHandle ());
95
113
if (umf_ret != UMF_RESULT_SUCCESS) {
96
114
throw umf::umf2urResult (umf_ret);
97
115
};
98
116
99
117
ze_device_handle_t level_zero_device_handle =
100
118
poolDescriptor.hDevice ? poolDescriptor.hDevice ->ZeDevice : nullptr ;
101
119
102
- umf_ret = umfLevelZeroMemoryProviderParamsSetDevice (params ,
120
+ umf_ret = umfLevelZeroMemoryProviderParamsSetDevice (hParams ,
103
121
level_zero_device_handle);
104
122
if (umf_ret != UMF_RESULT_SUCCESS) {
105
123
throw umf::umf2urResult (umf_ret);
106
124
}
107
125
108
126
umf_ret = umfLevelZeroMemoryProviderParamsSetMemoryType (
109
- params , urToUmfMemoryType (poolDescriptor.type ));
127
+ hParams , urToUmfMemoryType (poolDescriptor.type ));
110
128
if (umf_ret != UMF_RESULT_SUCCESS) {
111
129
throw umf::umf2urResult (umf_ret);
112
130
}
@@ -123,46 +141,37 @@ makePool(usm::umf_disjoint_pool_config_t *poolParams,
123
141
}
124
142
125
143
umf_ret = umfLevelZeroMemoryProviderParamsSetResidentDevices (
126
- params , residentZeHandles.data (), residentZeHandles.size ());
144
+ hParams , residentZeHandles.data (), residentZeHandles.size ());
127
145
if (umf_ret != UMF_RESULT_SUCCESS) {
128
146
throw umf::umf2urResult (umf_ret);
129
147
}
130
148
}
131
149
132
150
auto [ret, provider] =
133
- umf::providerMakeUniqueFromOps (umfLevelZeroMemoryProviderOps (), params );
151
+ umf::providerMakeUniqueFromOps (umfLevelZeroMemoryProviderOps (), hParams );
134
152
if (ret != UMF_RESULT_SUCCESS) {
135
153
throw umf::umf2urResult (ret);
136
154
}
137
155
138
- if (!poolParams) {
139
- auto [ret, poolHandle] = umf::poolMakeUniqueFromOps (
140
- umfProxyPoolOps (), std::move (provider), nullptr );
141
- if (ret != UMF_RESULT_SUCCESS)
142
- throw umf::umf2urResult (ret);
143
- return std::move (poolHandle);
144
- } else {
145
- auto umfParams = getUmfParamsHandle (*poolParams);
146
-
147
- auto [ret, poolHandle] =
148
- umf::poolMakeUniqueFromOps (umfDisjointPoolOps (), std::move (provider),
149
- static_cast <void *>(umfParams.get ()));
150
- if (ret != UMF_RESULT_SUCCESS)
151
- throw umf::umf2urResult (ret);
152
- return std::move (poolHandle);
153
- }
156
+ return std::move (provider);
154
157
}
155
158
156
159
ur_usm_pool_handle_t_::ur_usm_pool_handle_t_ (ur_context_handle_t hContext,
157
160
ur_usm_pool_desc_t *pPoolDesc)
158
161
: hContext(hContext) {
159
162
// TODO: handle UR_USM_POOL_FLAG_ZERO_INITIALIZE_BLOCK from pPoolDesc
160
163
auto disjointPoolConfigs = initializeDisjointPoolConfig ();
161
- if (auto limits = find_stype_node<ur_usm_pool_limits_desc_t >(pPoolDesc)) {
162
- for (auto &config : disjointPoolConfigs.Configs ) {
163
- config.MaxPoolableSize = limits->maxPoolableSize ;
164
- config.SlabMinSize = limits->minDriverAllocSize ;
164
+
165
+ if (disjointPoolConfigs.has_value ()) {
166
+ if (auto limits = find_stype_node<ur_usm_pool_limits_desc_t >(pPoolDesc)) {
167
+ for (auto &config : disjointPoolConfigs.value ().Configs ) {
168
+ config.MaxPoolableSize = limits->maxPoolableSize ;
169
+ config.SlabMinSize = limits->minDriverAllocSize ;
170
+ }
165
171
}
172
+ } else {
173
+ // If pooling is disabled, do nothing.
174
+ logger::info (" USM pooling is disabled. Skiping pool limits adjustment." );
166
175
}
167
176
168
177
auto [result, descriptors] = usm::pool_descriptor::create (this , hContext);
@@ -171,12 +180,13 @@ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t hContext,
171
180
}
172
181
173
182
for (auto &desc : descriptors) {
174
- if (disjointPoolConfigs.EnableBuffers ) {
183
+ if (disjointPoolConfigs.has_value () ) {
175
184
auto &poolConfig =
176
- disjointPoolConfigs.Configs [descToDisjoinPoolMemType (desc)];
177
- poolManager.addPool (desc, makePool (&poolConfig, desc));
185
+ disjointPoolConfigs.value ().Configs [descToDisjoinPoolMemType (desc)];
186
+ poolManager.addPool (
187
+ desc, usm::makeDisjointPool (makeProvider (desc), poolConfig));
178
188
} else {
179
- poolManager.addPool (desc, makePool ( nullptr , desc));
189
+ poolManager.addPool (desc, usm::makeProxyPool ( makeProvider ( desc) ));
180
190
}
181
191
}
182
192
}
0 commit comments