Skip to content

Commit d596593

Browse files
authored
[SYCL][USM] Fix USM malloc_shared and free to handle zero byte (#1273)
Enum variables were too commonly used by users. This kind of conflicts cannot be avoided 100%, but we can minimize the chance by using the prefix SYCL_ Signed-off-by: Byoungro So <[email protected]>
1 parent 20aa83e commit d596593

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

sycl/source/detail/usm/usm_impl.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ namespace usm {
2727
void *alignedAllocHost(size_t Alignment, size_t Size, const context &Ctxt,
2828
alloc Kind) {
2929
void *RetVal = nullptr;
30+
if (Size == 0)
31+
return nullptr;
3032
if (Ctxt.is_host()) {
3133
if (!Alignment) {
3234
// worst case default
@@ -72,6 +74,8 @@ void *alignedAllocHost(size_t Alignment, size_t Size, const context &Ctxt,
7274
void *alignedAlloc(size_t Alignment, size_t Size, const context &Ctxt,
7375
const device &Dev, alloc Kind) {
7476
void *RetVal = nullptr;
77+
if (Size == 0)
78+
return nullptr;
7579
if (Ctxt.is_host()) {
7680
if (Kind == alloc::unknown) {
7781
RetVal = nullptr;
@@ -126,6 +130,8 @@ void *alignedAlloc(size_t Alignment, size_t Size, const context &Ctxt,
126130
}
127131

128132
void free(void *Ptr, const context &Ctxt) {
133+
if (Ptr == nullptr)
134+
return;
129135
if (Ctxt.is_host()) {
130136
// need to use alignedFree here for Windows
131137
detail::OSUtil::alignedFree(Ptr);
+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// RUN: %clangxx -fsycl %s -o %t.out
2+
// RUN: env SYCL_DEVICE_TYPE=HOST %t.out
3+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
4+
5+
// This test checks if users will successfully allocate 160, 0, and -16 bytes of
6+
// shared memory, and also test user can call free() without worrying about
7+
// nullptr or invalid memory descriptor returned from malloc.
8+
9+
#include <CL/sycl.hpp>
10+
#include <iostream>
11+
#include <stdlib.h>
12+
using namespace cl::sycl;
13+
14+
int main(int argc, char *argv[]) {
15+
auto exception_handler = [](cl::sycl::exception_list exceptions) {
16+
for (std::exception_ptr const &e : exceptions) {
17+
try {
18+
std::rethrow_exception(e);
19+
} catch (cl::sycl::exception const &e) {
20+
std::cout << "Caught asynchronous SYCL "
21+
"exception:\n"
22+
<< e.what() << std::endl;
23+
}
24+
}
25+
};
26+
27+
queue myQueue(default_selector{}, exception_handler);
28+
std::cout << "Device: " << myQueue.get_device().get_info<info::device::name>()
29+
<< std::endl;
30+
31+
double *ia = (double *)malloc_shared(160, myQueue);
32+
double *ja = (double *)malloc_shared(0, myQueue);
33+
double *result = (double *)malloc_shared(-16, myQueue);
34+
35+
assert(ia != nullptr);
36+
assert(ja == nullptr);
37+
assert(result == nullptr);
38+
39+
std::cout << "ia : " << ia << " ja: " << ja << " result : " << result
40+
<< std::endl;
41+
42+
// followings should not throw CL_INVALID_VALUE
43+
cl::sycl::free(ia, myQueue);
44+
cl::sycl::free(nullptr, myQueue);
45+
cl::sycl::free(ja, myQueue);
46+
cl::sycl::free(result, myQueue);
47+
48+
return 0;
49+
}

0 commit comments

Comments
 (0)