Skip to content

Commit 28c732d

Browse files
committed
[SYCL] Add sorting APIs for fixed-sizte private array input
PR doesn't include key/value sorting with fixed-size array input which will be added in a separate PR.
1 parent e7defab commit 28c732d

File tree

6 files changed

+491
-81
lines changed

6 files changed

+491
-81
lines changed

sycl/include/sycl/detail/group_sort_impl.hpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,43 @@
1515
#include <sycl/builtins.hpp>
1616
#include <sycl/group_algorithm.hpp>
1717
#include <sycl/group_barrier.hpp>
18+
#include <sycl/sycl_span.hpp>
19+
20+
#include <memory>
1821

1922
namespace sycl {
2023
inline namespace _V1 {
2124
namespace detail {
2225

26+
// Helpers for sorting algorithms
27+
#ifdef __SYCL_DEVICE_ONLY__
28+
template <typename T, typename Group>
29+
static __SYCL_ALWAYS_INLINE T *align_scratch(sycl::span<std::byte> scratch,
30+
Group g,
31+
size_t number_of_elements) {
32+
// Adjust the scratch pointer based on alignment of the type T.
33+
// Per extension specification if scratch size is less than the value
34+
// returned by memory_required then behavior is undefined, so we don't check
35+
// that the scratch size statisfies the requirement.
36+
T *scratch_begin = nullptr;
37+
// We must have a barrier here before array placement new because it is
38+
// possible that scratch memory is already in use, so we need to synchronize
39+
// work items.
40+
sycl::group_barrier(g);
41+
if (g.leader()) {
42+
void *scratch_ptr = scratch.data();
43+
size_t space = scratch.size();
44+
scratch_ptr = std::align(alignof(T), number_of_elements * sizeof(T),
45+
scratch_ptr, space);
46+
scratch_begin = ::new (scratch_ptr) T[number_of_elements];
47+
}
48+
// Broadcast leader's pointer (the beginning of the scratch) to all work
49+
// items in the group.
50+
scratch_begin = sycl::group_broadcast(g, scratch_begin);
51+
return scratch_begin;
52+
}
53+
#endif
54+
2355
// ---- merge sort implementation
2456

2557
// following two functions could be useless if std::[lower|upper]_bound worked

sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp

Lines changed: 113 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
#include <sycl/detail/pi.h> // for PI_ERROR_INVALID_DEVICE
1616
#include <sycl/exception.hpp> // for sycl_category, exception
1717
#include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16
18-
#include <sycl/memory_enums.hpp> // for memory_scope
19-
#include <sycl/range.hpp> // for range
20-
#include <sycl/sycl_span.hpp> // for span
18+
#include <sycl/ext/oneapi/properties/properties.hpp>
19+
#include <sycl/memory_enums.hpp> // for memory_scope
20+
#include <sycl/range.hpp> // for range
21+
#include <sycl/sycl_span.hpp> // for span
2122

2223
#ifdef __SYCL_DEVICE_ONLY__
2324
#include <sycl/detail/group_sort_impl.hpp>
@@ -36,6 +37,68 @@ namespace sycl {
3637
inline namespace _V1 {
3738
namespace ext::oneapi::experimental {
3839

40+
enum class group_algorithm_data_placement { blocked, striped };
41+
42+
struct input_data_placement_key
43+
: detail::compile_time_property_key<detail::PropKind::InputDataPlacement> {
44+
template <group_algorithm_data_placement Placement>
45+
using value_t =
46+
property_value<input_data_placement_key,
47+
std::integral_constant<int, static_cast<int>(Placement)>>;
48+
};
49+
50+
struct output_data_placement_key
51+
: detail::compile_time_property_key<detail::PropKind::OutputDataPlacement> {
52+
template <group_algorithm_data_placement Placement>
53+
using value_t =
54+
property_value<output_data_placement_key,
55+
std::integral_constant<int, static_cast<int>(Placement)>>;
56+
};
57+
58+
template <group_algorithm_data_placement Placement>
59+
inline constexpr input_data_placement_key::value_t<Placement>
60+
input_data_placement;
61+
62+
template <group_algorithm_data_placement Placement>
63+
inline constexpr output_data_placement_key::value_t<Placement>
64+
output_data_placement;
65+
66+
inline constexpr input_data_placement_key::value_t<
67+
group_algorithm_data_placement::blocked>
68+
input_data_placement_blocked;
69+
inline constexpr input_data_placement_key::value_t<
70+
group_algorithm_data_placement::striped>
71+
input_data_placement_striped;
72+
73+
inline constexpr output_data_placement_key::value_t<
74+
group_algorithm_data_placement::blocked>
75+
output_data_placement_blocked;
76+
inline constexpr output_data_placement_key::value_t<
77+
group_algorithm_data_placement::striped>
78+
output_data_placement_striped;
79+
80+
namespace detail {
81+
82+
template <typename Properties>
83+
constexpr bool isInputBlocked(Properties properties) {
84+
if constexpr (properties.template has_property<input_data_placement_key>())
85+
return properties.template get_property<input_data_placement_key>() ==
86+
input_data_placement<group_algorithm_data_placement::blocked>;
87+
else
88+
return true;
89+
}
90+
91+
template <typename Properties>
92+
constexpr bool isOutputBlocked(Properties properties) {
93+
if constexpr (properties.template has_property<output_data_placement_key>())
94+
return properties.template get_property<output_data_placement_key>() ==
95+
output_data_placement<group_algorithm_data_placement::blocked>;
96+
else
97+
return true;
98+
}
99+
100+
} // namespace detail
101+
39102
// ---- group helpers
40103
template <typename Group, size_t Extent> class group_with_scratchpad {
41104
Group g;
@@ -63,26 +126,9 @@ template <typename Compare = std::less<>> class default_sorter {
63126
void operator()([[maybe_unused]] Group g, [[maybe_unused]] Ptr first,
64127
[[maybe_unused]] Ptr last) {
65128
#ifdef __SYCL_DEVICE_ONLY__
66-
// Adjust the scratch pointer based on alignment of the type T.
67-
// Per extension specification if scratch size is less than the value
68-
// returned by memory_required then behavior is undefined, so we don't check
69-
// that the scratch size statisfies the requirement.
70129
using T = typename sycl::detail::GetValueType<Ptr>::type;
71-
T *scratch_begin = nullptr;
72130
size_t n = last - first;
73-
// We must have a barrier here before array placement new because it is
74-
// possible that scratch memory is already in use, so we need to synchronize
75-
// work items.
76-
sycl::group_barrier(g);
77-
if (g.leader()) {
78-
void *scratch_ptr = scratch.data();
79-
size_t space = scratch.size();
80-
scratch_ptr = std::align(alignof(T), n * sizeof(T), scratch_ptr, space);
81-
scratch_begin = ::new (scratch_ptr) T[n];
82-
}
83-
// Broadcast leader's pointer (the beginning of the scratch) to all work
84-
// items in the group.
85-
scratch_begin = sycl::group_broadcast(g, scratch_begin);
131+
T *scratch_begin = sycl::detail::align_scratch<T>(scratch, g, n);
86132
sycl::detail::merge_sort(g, first, n, comp, scratch_begin);
87133
#else
88134
throw sycl::exception(
@@ -94,29 +140,10 @@ template <typename Compare = std::less<>> class default_sorter {
94140
template <typename Group, typename T>
95141
T operator()([[maybe_unused]] Group g, T val) {
96142
#ifdef __SYCL_DEVICE_ONLY__
97-
// Adjust the scratch pointer based on alignment of the type T.
98-
// Per extension specification if scratch size is less than the value
99-
// returned by memory_required then behavior is undefined, so we don't check
100-
// that the scratch size statisfies the requirement.
101-
T *scratch_begin = nullptr;
102143
std::size_t local_id = g.get_local_linear_id();
103144
auto range_size = g.get_local_range().size();
104-
// We must have a barrier here before array placement new because it is
105-
// possible that scratch memory is already in use, so we need to synchronize
106-
// work items.
107-
sycl::group_barrier(g);
108-
if (g.leader()) {
109-
void *scratch_ptr = scratch.data();
110-
size_t space = scratch.size();
111-
scratch_ptr =
112-
std::align(alignof(T), /* output storage and temporary storage */ 2 *
113-
range_size * sizeof(T),
114-
scratch_ptr, space);
115-
scratch_begin = ::new (scratch_ptr) T[2 * range_size];
116-
}
117-
// Broadcast leader's pointer (the beginning of the scratch) to all work
118-
// items in the group.
119-
scratch_begin = sycl::group_broadcast(g, scratch_begin);
145+
T *scratch_begin = sycl::detail::align_scratch<T>(
146+
scratch, g, /* output storage and temporary storage */ 2 * range_size);
120147
scratch_begin[local_id] = val;
121148
sycl::detail::merge_sort(g, scratch_begin, range_size, comp,
122149
scratch_begin + range_size);
@@ -252,26 +279,9 @@ template <typename CompareT = std::less<>> class joint_sorter {
252279
void operator()([[maybe_unused]] Group g, [[maybe_unused]] Ptr first,
253280
[[maybe_unused]] Ptr last) {
254281
#ifdef __SYCL_DEVICE_ONLY__
255-
// Adjust the scratch pointer based on alignment of the type T.
256-
// Per extension specification if scratch size is less than the value
257-
// returned by memory_required then behavior is undefined, so we don't check
258-
// that the scratch size statisfies the requirement.
259282
using T = typename sycl::detail::GetValueType<Ptr>::type;
260-
T *scratch_begin = nullptr;
261283
size_t n = last - first;
262-
// We must have a barrier here before array placement new because it is
263-
// possible that scratch memory is already in use, so we need to synchronize
264-
// work items.
265-
sycl::group_barrier(g);
266-
if (g.leader()) {
267-
void *scratch_ptr = scratch.data();
268-
size_t space = scratch.size();
269-
scratch_ptr = std::align(alignof(T), n * sizeof(T), scratch_ptr, space);
270-
scratch_begin = ::new (scratch_ptr) T[n];
271-
}
272-
// Broadcast leader's pointer (the beginning of the scratch) to all work
273-
// items in the group.
274-
scratch_begin = sycl::group_broadcast(g, scratch_begin);
284+
T *scratch_begin = sycl::detail::align_scratch<T>(scratch, g, n);
275285
sycl::detail::merge_sort(g, first, n, comp, scratch_begin);
276286
#else
277287
throw sycl::exception(
@@ -300,29 +310,10 @@ class group_sorter {
300310

301311
template <typename Group> T operator()([[maybe_unused]] Group g, T val) {
302312
#ifdef __SYCL_DEVICE_ONLY__
303-
// Adjust the scratch pointer based on alignment of the type T.
304-
// Per extension specification if scratch size is less than the value
305-
// returned by memory_required then behavior is undefined, so we don't check
306-
// that the scratch size statisfies the requirement.
307-
T *scratch_begin = nullptr;
308313
std::size_t local_id = g.get_local_linear_id();
309314
auto range_size = g.get_local_range().size();
310-
// We must have a barrier here before array placement new because it is
311-
// possible that scratch memory is already in use, so we need to synchronize
312-
// work items.
313-
sycl::group_barrier(g);
314-
if (g.leader()) {
315-
void *scratch_ptr = scratch.data();
316-
size_t space = scratch.size();
317-
scratch_ptr =
318-
std::align(alignof(T), /* output storage and temporary storage */ 2 *
319-
range_size * sizeof(T),
320-
scratch_ptr, space);
321-
scratch_begin = ::new (scratch_ptr) T[2 * range_size];
322-
}
323-
// Broadcast leader's pointer (the beginning of the scratch) to all work
324-
// items in the group.
325-
scratch_begin = sycl::group_broadcast(g, scratch_begin);
315+
T *scratch_begin = sycl::detail::align_scratch<T>(
316+
scratch, g, /* output storage and temporary storage */ 2 * range_size);
326317
scratch_begin[local_id] = val;
327318
sycl::detail::merge_sort(g, scratch_begin, range_size, comp,
328319
scratch_begin + range_size);
@@ -335,6 +326,34 @@ class group_sorter {
335326
return val;
336327
}
337328

329+
template <typename Group, typename Properties>
330+
void operator()([[maybe_unused]] Group g,
331+
[[maybe_unused]] sycl::span<T, ElementsPerWorkItem> values,
332+
[[maybe_unused]] Properties properties) {
333+
#ifdef __SYCL_DEVICE_ONLY__
334+
std::size_t local_id = g.get_local_linear_id();
335+
auto wg_size = g.get_local_range().size();
336+
auto number_of_elements = wg_size * ElementsPerWorkItem;
337+
T *scratch_begin = sycl::detail::align_scratch<T>(
338+
scratch, g,
339+
/* output storage and temporary storage */ 2 * number_of_elements);
340+
for (std::uint32_t i = 0; i < ElementsPerWorkItem; ++i)
341+
scratch_begin[local_id * ElementsPerWorkItem + i] = values[i];
342+
sycl::detail::merge_sort(g, scratch_begin, number_of_elements, comp,
343+
scratch_begin + number_of_elements);
344+
345+
std::size_t shift{};
346+
for (std::uint32_t i = 0; i < ElementsPerWorkItem; ++i) {
347+
if constexpr (detail::isOutputBlocked(properties)) {
348+
shift = local_id * ElementsPerWorkItem + i;
349+
} else {
350+
shift = i * wg_size + local_id;
351+
}
352+
values[i] = scratch_begin[shift];
353+
}
354+
#endif
355+
}
356+
338357
static std::size_t memory_required(sycl::memory_scope scope,
339358
size_t range_size) {
340359
return 2 * joint_sorter<>::template memory_required<T>(
@@ -480,6 +499,19 @@ class group_sorter {
480499
#endif
481500
}
482501

502+
template <typename Group, typename Properties>
503+
void operator()([[maybe_unused]] Group g,
504+
[[maybe_unused]] sycl::span<ValT, ElementsPerWorkItem> values,
505+
[[maybe_unused]] Properties properties) {
506+
#ifdef __SYCL_DEVICE_ONLY__
507+
sycl::detail::privateStaticSort<
508+
/*is_key_value=*/false, detail::isOutputBlocked(properties),
509+
OrderT == sorting_order::ascending, ElementsPerWorkItem, bits>(
510+
g, values.data(), /*empty*/ values.data(), scratch.data(), first_bit,
511+
last_bit);
512+
#endif
513+
}
514+
483515
static constexpr size_t
484516
memory_required([[maybe_unused]] sycl::memory_scope scope,
485517
size_t range_size) {

sycl/include/sycl/ext/oneapi/experimental/group_sort.hpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,59 @@ sort_over_group(experimental::group_with_scratchpad<Group, Extent> exec,
114114
default_sorters::group_sorter<T>(exec.get_memory()));
115115
}
116116

117+
template <typename Group, typename T, std::size_t ElementsPerWorkItem,
118+
typename Sorter,
119+
typename Properties = ext::oneapi::experimental::empty_properties_t>
120+
std::enable_if_t<sycl::ext::oneapi::experimental::is_property_list_v<
121+
std::decay_t<Properties>>,
122+
void>
123+
sort_over_group([[maybe_unused]] Group g,
124+
[[maybe_unused]] sycl::span<T, ElementsPerWorkItem> values,
125+
[[maybe_unused]] Sorter sorter,
126+
[[maybe_unused]] Properties properties = {}) {
127+
#ifdef __SYCL_DEVICE_ONLY__
128+
return sorter(g, values, properties);
129+
#else
130+
throw sycl::exception(
131+
std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
132+
"Group algorithms are not supported on host device.");
133+
#endif
134+
}
135+
136+
template <typename Group, typename T, std::size_t Extent,
137+
std::size_t ElementsPerWorkItem,
138+
typename Properties = ext::oneapi::experimental::empty_properties_t>
139+
std::enable_if_t<sycl::ext::oneapi::experimental::is_property_list_v<
140+
std::decay_t<Properties>>,
141+
void>
142+
sort_over_group(experimental::group_with_scratchpad<Group, Extent> exec,
143+
sycl::span<T, ElementsPerWorkItem> values,
144+
Properties properties = {}) {
145+
return sort_over_group(
146+
exec.get_group(), values,
147+
default_sorters::group_sorter<T, std::less<T>, ElementsPerWorkItem>(
148+
exec.get_memory()),
149+
properties);
150+
}
151+
152+
template <typename Group, typename T, std::size_t Extent,
153+
std::size_t ElementsPerWorkItem, typename Compare,
154+
typename Properties = ext::oneapi::experimental::empty_properties_t>
155+
std::enable_if_t<!sycl::ext::oneapi::experimental::is_property_list_v<
156+
std::decay_t<Compare>> &&
157+
sycl::ext::oneapi::experimental::is_property_list_v<
158+
std::decay_t<Properties>>,
159+
void>
160+
sort_over_group(experimental::group_with_scratchpad<Group, Extent> exec,
161+
sycl::span<T, ElementsPerWorkItem> values, Compare comp,
162+
Properties properties = {}) {
163+
return sort_over_group(
164+
exec.get_group(), values,
165+
default_sorters::group_sorter<T, Compare, ElementsPerWorkItem>(
166+
exec.get_memory(), comp),
167+
properties);
168+
}
169+
117170
// ---- joint_sort
118171
template <typename Group, typename Iter, typename Sorter>
119172
std::enable_if_t<detail::is_sorter<Sorter, Group, Iter>::value, void>

sycl/include/sycl/ext/oneapi/properties/property.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ enum PropKind : uint32_t {
207207
SingleTaskKernel = 66,
208208
IndirectlyCallable = 67,
209209
CallsIndirectly = 68,
210+
InputDataPlacement = 69,
211+
OutputDataPlacement = 70,
210212
// PropKindSize must always be the last value.
211213
PropKindSize = 69,
212214
};

0 commit comments

Comments
 (0)