Skip to content

Commit 62ec7e4

Browse files
authored
[SYCL][USM] Add more extensive checks to USM pointer queries (#1123)
Signed-off-by: James Brodman <[email protected]>
1 parent dd329d6 commit 62ec7e4

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

sycl/source/detail/usm/usm_impl.cpp

+19-2
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,9 @@ void *aligned_alloc(size_t Alignment, size_t Size, const queue &Q, alloc Kind) {
241241
/// @param ptr is the USM pointer to query
242242
/// @param ctxt is the sycl context the ptr was allocated in
243243
alloc get_pointer_type(const void *Ptr, const context &Ctxt) {
244+
if (!Ptr)
245+
return alloc::unknown;
246+
244247
// Everything on a host device is just system malloc so call it host
245248
if (Ctxt.is_host())
246249
return alloc::host;
@@ -251,8 +254,18 @@ alloc get_pointer_type(const void *Ptr, const context &Ctxt) {
251254

252255
// query type using PI function
253256
const detail::plugin &Plugin = CtxImpl->getPlugin();
254-
Plugin.call<detail::PiApiKind::piextUSMGetMemAllocInfo>(
255-
PICtx, Ptr, PI_MEM_ALLOC_TYPE, sizeof(pi_usm_type), &AllocTy, nullptr);
257+
RT::PiResult Err =
258+
Plugin.call_nocheck<detail::PiApiKind::piextUSMGetMemAllocInfo>(
259+
PICtx, Ptr, PI_MEM_ALLOC_TYPE, sizeof(pi_usm_type), &AllocTy,
260+
nullptr);
261+
262+
// PI_INVALID_VALUE means USM doesn't know about this ptr
263+
if (Err == PI_INVALID_VALUE)
264+
return alloc::unknown;
265+
// otherwise PI_SUCCESS is expected
266+
if (Err != PI_SUCCESS) {
267+
throw runtime_error("Error querying USM pointer: ", Err);
268+
}
256269

257270
alloc ResultAlloc;
258271
switch (AllocTy) {
@@ -278,6 +291,10 @@ alloc get_pointer_type(const void *Ptr, const context &Ctxt) {
278291
/// @param ptr is the USM pointer to query
279292
/// @param ctxt is the sycl context the ptr was allocated in
280293
device get_pointer_device(const void *Ptr, const context &Ctxt) {
294+
// Check if ptr is a valid USM pointer
295+
if (get_pointer_type(Ptr, Ctxt) == alloc::unknown)
296+
throw runtime_error("Ptr not a valid USM allocation!");
297+
281298
// Just return the host device in the host context
282299
if (Ctxt.is_host())
283300
return Ctxt.get_devices()[0];

sycl/test/usm/pointer_query.cpp

+27
Original file line numberDiff line numberDiff line change
@@ -92,5 +92,32 @@ int main() {
9292
}
9393
free(array, ctxt);
9494

95+
// Test invalid ptrs
96+
Kind = get_pointer_type(nullptr, ctxt);
97+
if (Kind != usm::alloc::unknown) {
98+
return 11;
99+
}
100+
101+
// next checks only valid for non-host contexts
102+
array = (int*)malloc(N*sizeof(int));
103+
Kind = get_pointer_type(array, ctxt);
104+
if (!ctxt.is_host()) {
105+
if (Kind != usm::alloc::unknown) {
106+
return 12;
107+
}
108+
try {
109+
D = get_pointer_device(array, ctxt);
110+
} catch (runtime_error) {
111+
return 0;
112+
}
113+
return 13;
114+
} else {
115+
// host ctxts always report host
116+
if (Kind != usm::alloc::host) {
117+
return 14;
118+
}
119+
}
120+
free(array);
121+
95122
return 0;
96123
}

0 commit comments

Comments
 (0)