Skip to content

Commit e162168

Browse files
committed
made changes for multi-dimensional groups
Signed-off-by: Fedorov, Andrey <[email protected]>
1 parent b6ae2e5 commit e162168

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

sycl/include/CL/sycl/detail/group_sort_impl.hpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ struct GetValueType<sycl::multi_ptr<ElementType, Space>> {
6767
using type = ElementType;
6868
};
6969

70-
// since we couldn't assign data to raw memory, it's better to use placement for
71-
// first assignment
70+
// since we couldn't assign data to raw memory, it's better to use placement
71+
// for first assignment
7272
template <typename Acc, typename T>
7373
void set_value(Acc ptr, const std::size_t idx, const T &val, bool is_first) {
7474
if (is_first) {
@@ -97,8 +97,8 @@ void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
9797
const std::size_t local_size_1 = local_end_1 - local_start_1;
9898
const std::size_t local_size_2 = local_end_2 - local_start_2;
9999

100-
// TODO: process cases where all elements of 1st sequence > 2nd, 2nd > 1st to
101-
// improve performance
100+
// TODO: process cases where all elements of 1st sequence > 2nd, 2nd > 1st
101+
// to improve performance
102102

103103
// Process 1st sequence
104104
if (local_start_1 < local_end_1) {
@@ -204,7 +204,7 @@ void merge_sort(Group group, Iter first, const std::size_t n, Compare comp,
204204
std::byte *scratch) {
205205
using T = typename GetValueType<Iter>::type;
206206
auto id = sycl::detail::Builder::getNDItem<Group::dimensions>();
207-
const std::size_t idx = id.get_local_id();
207+
const std::size_t idx = id.get_local_linear_id();
208208
const std::size_t local = group.get_local_range().size();
209209
const std::size_t chunk = (n - 1) / local + 1;
210210

sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ template <typename Compare = std::less<>> class default_sorter {
6363
auto range_size = g.get_local_range().size();
6464
if (scratch_size >= memory_required<T>(Group::fence_scope, range_size)) {
6565
auto id = sycl::detail::Builder::getNDItem<Group::dimensions>();
66-
uint32_t local_id = id.get_local_id();
66+
std::size_t local_id = id.get_local_linear_id();
6767
T *temp = reinterpret_cast<T *>(scratch);
6868
::new (temp + local_id) T(val);
6969
sycl::detail::merge_sort(g, temp, range_size, comp,

0 commit comments

Comments
 (0)