Skip to content

Commit eb6ea20

Browse files
authored
Merge pull request #2657 from igchor/v2_usm_support_env
[L0 v2] support SYCL_PI_LEVEL_ZERO_DISABLE_USM_ALLOCATOR
2 parents 72d8641 + 781604a commit eb6ea20

File tree

2 files changed

+71
-38
lines changed

2 files changed

+71
-38
lines changed

source/adapters/level_zero/v2/usm.cpp

+48-38
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
#include "umf_pools/disjoint_pool_config_parser.hpp"
1616
#include "usm.hpp"
1717

18-
#include <umf/pools/pool_disjoint.h>
19-
#include <umf/pools/pool_proxy.h>
2018
#include <umf/providers/provider_level_zero.h>
2119

2220
namespace umf {
@@ -34,7 +32,17 @@ ur_result_t getProviderNativeError(const char *providerName,
3432
}
3533
} // namespace umf
3634

37-
static usm::DisjointPoolAllConfigs initializeDisjointPoolConfig() {
35+
static std::optional<usm::DisjointPoolAllConfigs>
36+
initializeDisjointPoolConfig() {
37+
const char *UrRetDisable = std::getenv("UR_L0_DISABLE_USM_ALLOCATOR");
38+
const char *PiRetDisable =
39+
std::getenv("SYCL_PI_LEVEL_ZERO_DISABLE_USM_ALLOCATOR");
40+
const char *Disable =
41+
UrRetDisable ? UrRetDisable : (PiRetDisable ? PiRetDisable : nullptr);
42+
if (Disable != nullptr && Disable != std::string("")) {
43+
return std::nullopt;
44+
}
45+
3846
const char *PoolUrTraceVal = std::getenv("UR_L0_USM_ALLOCATOR_TRACE");
3947

4048
int PoolTrace = 0;
@@ -47,7 +55,14 @@ static usm::DisjointPoolAllConfigs initializeDisjointPoolConfig() {
4755
return usm::DisjointPoolAllConfigs(PoolTrace);
4856
}
4957

50-
return usm::parseDisjointPoolConfig(PoolUrConfigVal, PoolTrace);
58+
// TODO: rework parseDisjointPoolConfig to return optional,
59+
// once EnableBuffers is no longer used (by legacy L0)
60+
auto configs = usm::parseDisjointPoolConfig(PoolUrConfigVal, PoolTrace);
61+
if (configs.EnableBuffers) {
62+
return configs;
63+
}
64+
65+
return std::nullopt;
5166
}
5267

5368
inline umf_usm_memory_type_t urToUmfMemoryType(ur_usm_type_t type) {
@@ -81,32 +96,35 @@ descToDisjoinPoolMemType(const usm::pool_descriptor &desc) {
8196
}
8297
}
8398

84-
static umf::pool_unique_handle_t
85-
makePool(usm::umf_disjoint_pool_config_t *poolParams,
86-
usm::pool_descriptor poolDescriptor) {
87-
umf_level_zero_memory_provider_params_handle_t params = NULL;
88-
umf_result_t umf_ret = umfLevelZeroMemoryProviderParamsCreate(&params);
99+
static umf::provider_unique_handle_t
100+
makeProvider(usm::pool_descriptor poolDescriptor) {
101+
umf_level_zero_memory_provider_params_handle_t hParams;
102+
umf_result_t umf_ret = umfLevelZeroMemoryProviderParamsCreate(&hParams);
89103
if (umf_ret != UMF_RESULT_SUCCESS) {
90104
throw umf::umf2urResult(umf_ret);
91105
}
92106

107+
std::unique_ptr<umf_level_zero_memory_provider_params_t,
108+
decltype(&umfLevelZeroMemoryProviderParamsDestroy)>
109+
params(hParams, &umfLevelZeroMemoryProviderParamsDestroy);
110+
93111
umf_ret = umfLevelZeroMemoryProviderParamsSetContext(
94-
params, poolDescriptor.hContext->getZeHandle());
112+
hParams, poolDescriptor.hContext->getZeHandle());
95113
if (umf_ret != UMF_RESULT_SUCCESS) {
96114
throw umf::umf2urResult(umf_ret);
97115
};
98116

99117
ze_device_handle_t level_zero_device_handle =
100118
poolDescriptor.hDevice ? poolDescriptor.hDevice->ZeDevice : nullptr;
101119

102-
umf_ret = umfLevelZeroMemoryProviderParamsSetDevice(params,
120+
umf_ret = umfLevelZeroMemoryProviderParamsSetDevice(hParams,
103121
level_zero_device_handle);
104122
if (umf_ret != UMF_RESULT_SUCCESS) {
105123
throw umf::umf2urResult(umf_ret);
106124
}
107125

108126
umf_ret = umfLevelZeroMemoryProviderParamsSetMemoryType(
109-
params, urToUmfMemoryType(poolDescriptor.type));
127+
hParams, urToUmfMemoryType(poolDescriptor.type));
110128
if (umf_ret != UMF_RESULT_SUCCESS) {
111129
throw umf::umf2urResult(umf_ret);
112130
}
@@ -123,46 +141,37 @@ makePool(usm::umf_disjoint_pool_config_t *poolParams,
123141
}
124142

125143
umf_ret = umfLevelZeroMemoryProviderParamsSetResidentDevices(
126-
params, residentZeHandles.data(), residentZeHandles.size());
144+
hParams, residentZeHandles.data(), residentZeHandles.size());
127145
if (umf_ret != UMF_RESULT_SUCCESS) {
128146
throw umf::umf2urResult(umf_ret);
129147
}
130148
}
131149

132150
auto [ret, provider] =
133-
umf::providerMakeUniqueFromOps(umfLevelZeroMemoryProviderOps(), params);
151+
umf::providerMakeUniqueFromOps(umfLevelZeroMemoryProviderOps(), hParams);
134152
if (ret != UMF_RESULT_SUCCESS) {
135153
throw umf::umf2urResult(ret);
136154
}
137155

138-
if (!poolParams) {
139-
auto [ret, poolHandle] = umf::poolMakeUniqueFromOps(
140-
umfProxyPoolOps(), std::move(provider), nullptr);
141-
if (ret != UMF_RESULT_SUCCESS)
142-
throw umf::umf2urResult(ret);
143-
return std::move(poolHandle);
144-
} else {
145-
auto umfParams = getUmfParamsHandle(*poolParams);
146-
147-
auto [ret, poolHandle] =
148-
umf::poolMakeUniqueFromOps(umfDisjointPoolOps(), std::move(provider),
149-
static_cast<void *>(umfParams.get()));
150-
if (ret != UMF_RESULT_SUCCESS)
151-
throw umf::umf2urResult(ret);
152-
return std::move(poolHandle);
153-
}
156+
return std::move(provider);
154157
}
155158

156159
ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t hContext,
157160
ur_usm_pool_desc_t *pPoolDesc)
158161
: hContext(hContext) {
159162
// TODO: handle UR_USM_POOL_FLAG_ZERO_INITIALIZE_BLOCK from pPoolDesc
160163
auto disjointPoolConfigs = initializeDisjointPoolConfig();
161-
if (auto limits = find_stype_node<ur_usm_pool_limits_desc_t>(pPoolDesc)) {
162-
for (auto &config : disjointPoolConfigs.Configs) {
163-
config.MaxPoolableSize = limits->maxPoolableSize;
164-
config.SlabMinSize = limits->minDriverAllocSize;
164+
165+
if (disjointPoolConfigs.has_value()) {
166+
if (auto limits = find_stype_node<ur_usm_pool_limits_desc_t>(pPoolDesc)) {
167+
for (auto &config : disjointPoolConfigs.value().Configs) {
168+
config.MaxPoolableSize = limits->maxPoolableSize;
169+
config.SlabMinSize = limits->minDriverAllocSize;
170+
}
165171
}
172+
} else {
173+
// If pooling is disabled, do nothing.
174+
logger::info("USM pooling is disabled. Skiping pool limits adjustment.");
166175
}
167176

168177
auto [result, descriptors] = usm::pool_descriptor::create(this, hContext);
@@ -171,12 +180,13 @@ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t hContext,
171180
}
172181

173182
for (auto &desc : descriptors) {
174-
if (disjointPoolConfigs.EnableBuffers) {
183+
if (disjointPoolConfigs.has_value()) {
175184
auto &poolConfig =
176-
disjointPoolConfigs.Configs[descToDisjoinPoolMemType(desc)];
177-
poolManager.addPool(desc, makePool(&poolConfig, desc));
185+
disjointPoolConfigs.value().Configs[descToDisjoinPoolMemType(desc)];
186+
poolManager.addPool(
187+
desc, usm::makeDisjointPool(makeProvider(desc), poolConfig));
178188
} else {
179-
poolManager.addPool(desc, makePool(nullptr, desc));
189+
poolManager.addPool(desc, usm::makeProxyPool(makeProvider(desc)));
180190
}
181191
}
182192
}

source/common/ur_pool_manager.hpp

+23
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <umf/memory_pool.h>
2323
#include <umf/memory_provider.h>
2424
#include <umf/pools/pool_disjoint.h>
25+
#include <umf/pools/pool_proxy.h>
2526

2627
#include <functional>
2728
#include <unordered_map>
@@ -290,6 +291,28 @@ template <typename D> struct pool_manager {
290291
}
291292
};
292293

294+
inline umf::pool_unique_handle_t
295+
makeDisjointPool(umf::provider_unique_handle_t &&provider,
296+
usm::umf_disjoint_pool_config_t &poolParams) {
297+
auto umfParams = getUmfParamsHandle(poolParams);
298+
auto [ret, poolHandle] =
299+
umf::poolMakeUniqueFromOps(umfDisjointPoolOps(), std::move(provider),
300+
static_cast<void *>(umfParams.get()));
301+
if (ret != UMF_RESULT_SUCCESS)
302+
throw umf::umf2urResult(ret);
303+
return std::move(poolHandle);
304+
}
305+
306+
inline umf::pool_unique_handle_t
307+
makeProxyPool(umf::provider_unique_handle_t &&provider) {
308+
auto [ret, poolHandle] = umf::poolMakeUniqueFromOps(
309+
umfProxyPoolOps(), std::move(provider), nullptr);
310+
if (ret != UMF_RESULT_SUCCESS)
311+
throw umf::umf2urResult(ret);
312+
313+
return std::move(poolHandle);
314+
}
315+
293316
} // namespace usm
294317

295318
namespace std {

0 commit comments

Comments
 (0)