Skip to content

[SYCL][USM] Add templated forms of USM mallocs #1086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 5, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 92 additions & 1 deletion sycl/include/CL/sycl/usm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ void *aligned_alloc_shared(size_t alignment, size_t size, const device &dev,
const context &ctxt);
void *aligned_alloc_shared(size_t alignment, size_t size, const queue &q);

///
// single form

///
void *malloc(size_t size, const device &dev, const context &ctxt,
usm::alloc kind);
void *malloc(size_t size, const queue &q, usm::alloc kind);
Expand All @@ -54,5 +55,95 @@ void *aligned_alloc(size_t alignment, size_t size, const device &dev,
void *aligned_alloc(size_t alignment, size_t size, const queue &q,
usm::alloc kind);

///
// Template forms
///
template <typename T>
T *malloc_device(size_t Count, const device &Dev, const context &Ctxt) {
return static_cast<T *>(malloc_device(Count*sizeof(T), Dev, Ctxt));
}

template <typename T> T *malloc_device(size_t Count, const queue &Q) {
return malloc_device<T>(Count, Q.get_device(), Q.get_context());
}

template <typename T>
T *aligned_alloc_device(size_t Alignment, size_t Count, const device &Dev,
const context &Ctxt) {
return static_cast<T *>(
aligned_alloc_device(Alignment, Count * sizeof(T), Dev, Ctxt));
}

template <typename T>
T *aligned_alloc_device(size_t Alignment, size_t Count, const queue &Q) {
return aligned_alloc_device<T>(Alignment, Count, Q.get_device(),
Q.get_context());
}

template <typename T> T *malloc_host(size_t Count, const context &Ctxt) {
return static_cast<T *>(malloc_host(Count * sizeof(T), Ctxt));
}

template <typename T> T *malloc_host(size_t Count, const queue &Q) {
return malloc_host<T>(Count, Q.get_context());
}

template <typename T>
T *malloc_shared(size_t Count, const device &Dev, const context &Ctxt) {
return static_cast<T *>(malloc_shared(Count * sizeof(T), Dev, Ctxt));
}

template <typename T> T *malloc_shared(size_t Count, const queue &Q) {
return malloc_shared<T>(Count, Q.get_device(), Q.get_context());
}

template <typename T>
T *aligned_alloc_host(size_t Alignment, size_t Count, const context &Ctxt) {
return static_cast<T *>(
aligned_alloc_host(Alignment, Count * sizeof(T), Ctxt));
}

template <typename T>
T *aligned_alloc_host(size_t Alignment, size_t Count, const queue &Q) {
return aligned_alloc_host<T>(Alignment, Count, Q.get_context());
}

template <typename T>
T *aligned_alloc_shared(size_t Alignment, size_t Count, const device &Dev,
const context &Ctxt) {
return static_cast<T *>(
aligned_alloc_shared(Alignment, Count * sizeof(T), Dev, Ctxt));
}

template <typename T>
T *aligned_alloc_shared(size_t Alignment, size_t Count, const queue &Q) {
return aligned_alloc_shared<T>(Alignment, Count, Q.get_device(),
Q.get_context());
}

template <typename T>
T *malloc(size_t Count, const device &Dev, const context &Ctxt,
usm::alloc Kind) {
return static_cast<T *>(malloc(Count * sizeof(T), Dev, Ctxt, Kind));
}

template <typename T> T *malloc(size_t Count, const queue &Q, usm::alloc Kind) {
return malloc<T>(Count, Q.get_device(), Q.get_context(), Kind);
}

template <typename T>
T *aligned_alloc(size_t Alignment, size_t Count, const device &Dev,
const context &Ctxt, usm::alloc Kind) {
return static_cast<T *>(
aligned_alloc(Alignment, Count * sizeof(T), Dev, Ctxt, Kind));
}

template <typename T>
T *aligned_alloc(size_t Alignment, size_t Count, const queue &Q,
usm::alloc Kind) {
return aligned_alloc<T>(Alignment, Count, Q.get_device(), Q.get_context(),
Kind);
}

} // namespace sycl
} // namespace cl