Skip to content

Commit 8e3b8ce

Browse files
authored
[SYCL] Add sorting APIs for fixed-size private array input (intel#14185)
These PR covers (1), (2), (3) from https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/proposed/sycl_ext_oneapi_group_sort.asciidoc#functions-with-fixed-size-arrays. PR doesn't include key/value sorting with fixed-size array input which will be added in a separate PR.
1 parent c3d8e27 commit 8e3b8ce

File tree

6 files changed

+479
-82
lines changed

6 files changed

+479
-82
lines changed

sycl/include/sycl/detail/group_sort_impl.hpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,43 @@
1515
#include <sycl/builtins.hpp>
1616
#include <sycl/group_algorithm.hpp>
1717
#include <sycl/group_barrier.hpp>
18+
#include <sycl/sycl_span.hpp>
19+
20+
#include <memory>
1821

1922
namespace sycl {
2023
inline namespace _V1 {
2124
namespace detail {
2225

26+
// Helpers for sorting algorithms
27+
#ifdef __SYCL_DEVICE_ONLY__
28+
template <typename T, typename Group>
29+
static __SYCL_ALWAYS_INLINE T *align_scratch(sycl::span<std::byte> scratch,
30+
Group g,
31+
size_t number_of_elements) {
32+
// Adjust the scratch pointer based on alignment of the type T.
33+
// Per extension specification if scratch size is less than the value
34+
// returned by memory_required then behavior is undefined, so we don't check
35+
// that the scratch size statisfies the requirement.
36+
T *scratch_begin = nullptr;
37+
// We must have a barrier here before array placement new because it is
38+
// possible that scratch memory is already in use, so we need to synchronize
39+
// work items.
40+
sycl::group_barrier(g);
41+
if (g.leader()) {
42+
void *scratch_ptr = scratch.data();
43+
size_t space = scratch.size();
44+
scratch_ptr = std::align(alignof(T), number_of_elements * sizeof(T),
45+
scratch_ptr, space);
46+
scratch_begin = ::new (scratch_ptr) T[number_of_elements];
47+
}
48+
// Broadcast leader's pointer (the beginning of the scratch) to all work
49+
// items in the group.
50+
scratch_begin = sycl::group_broadcast(g, scratch_begin);
51+
return scratch_begin;
52+
}
53+
#endif
54+
2355
// ---- merge sort implementation
2456

2557
// following two functions could be useless if std::[lower|upper]_bound worked

sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp

Lines changed: 99 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
#include <sycl/detail/pi.h> // for PI_ERROR_INVALID_DEVICE
1616
#include <sycl/exception.hpp> // for sycl_category, exception
1717
#include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16
18-
#include <sycl/memory_enums.hpp> // for memory_scope
19-
#include <sycl/range.hpp> // for range
20-
#include <sycl/sycl_span.hpp> // for span
18+
#include <sycl/ext/oneapi/properties/properties.hpp>
19+
#include <sycl/memory_enums.hpp> // for memory_scope
20+
#include <sycl/range.hpp> // for range
21+
#include <sycl/sycl_span.hpp> // for span
2122

2223
#ifdef __SYCL_DEVICE_ONLY__
2324
#include <sycl/detail/group_sort_impl.hpp>
@@ -36,6 +37,54 @@ namespace sycl {
3637
inline namespace _V1 {
3738
namespace ext::oneapi::experimental {
3839

40+
enum class group_algorithm_data_placement { blocked, striped };
41+
42+
struct input_data_placement_key
43+
: detail::compile_time_property_key<detail::PropKind::InputDataPlacement> {
44+
template <group_algorithm_data_placement Placement>
45+
using value_t =
46+
property_value<input_data_placement_key,
47+
std::integral_constant<int, static_cast<int>(Placement)>>;
48+
};
49+
50+
struct output_data_placement_key
51+
: detail::compile_time_property_key<detail::PropKind::OutputDataPlacement> {
52+
template <group_algorithm_data_placement Placement>
53+
using value_t =
54+
property_value<output_data_placement_key,
55+
std::integral_constant<int, static_cast<int>(Placement)>>;
56+
};
57+
58+
template <group_algorithm_data_placement Placement>
59+
inline constexpr input_data_placement_key::value_t<Placement>
60+
input_data_placement;
61+
62+
template <group_algorithm_data_placement Placement>
63+
inline constexpr output_data_placement_key::value_t<Placement>
64+
output_data_placement;
65+
66+
namespace detail {
67+
68+
template <typename Properties>
69+
constexpr bool isInputBlocked(Properties properties) {
70+
if constexpr (properties.template has_property<input_data_placement_key>())
71+
return properties.template get_property<input_data_placement_key>() ==
72+
input_data_placement<group_algorithm_data_placement::blocked>;
73+
else
74+
return true;
75+
}
76+
77+
template <typename Properties>
78+
constexpr bool isOutputBlocked(Properties properties) {
79+
if constexpr (properties.template has_property<output_data_placement_key>())
80+
return properties.template get_property<output_data_placement_key>() ==
81+
output_data_placement<group_algorithm_data_placement::blocked>;
82+
else
83+
return true;
84+
}
85+
86+
} // namespace detail
87+
3988
// ---- group helpers
4089
template <typename Group, size_t Extent> class group_with_scratchpad {
4190
Group g;
@@ -63,26 +112,9 @@ template <typename Compare = std::less<>> class default_sorter {
63112
void operator()([[maybe_unused]] Group g, [[maybe_unused]] Ptr first,
64113
[[maybe_unused]] Ptr last) {
65114
#ifdef __SYCL_DEVICE_ONLY__
66-
// Adjust the scratch pointer based on alignment of the type T.
67-
// Per extension specification if scratch size is less than the value
68-
// returned by memory_required then behavior is undefined, so we don't check
69-
// that the scratch size statisfies the requirement.
70115
using T = typename sycl::detail::GetValueType<Ptr>::type;
71-
T *scratch_begin = nullptr;
72116
size_t n = last - first;
73-
// We must have a barrier here before array placement new because it is
74-
// possible that scratch memory is already in use, so we need to synchronize
75-
// work items.
76-
sycl::group_barrier(g);
77-
if (g.leader()) {
78-
void *scratch_ptr = scratch.data();
79-
size_t space = scratch.size();
80-
scratch_ptr = std::align(alignof(T), n * sizeof(T), scratch_ptr, space);
81-
scratch_begin = ::new (scratch_ptr) T[n];
82-
}
83-
// Broadcast leader's pointer (the beginning of the scratch) to all work
84-
// items in the group.
85-
scratch_begin = sycl::group_broadcast(g, scratch_begin);
117+
T *scratch_begin = sycl::detail::align_scratch<T>(scratch, g, n);
86118
sycl::detail::merge_sort(g, first, n, comp, scratch_begin);
87119
#else
88120
throw sycl::exception(
@@ -94,29 +126,10 @@ template <typename Compare = std::less<>> class default_sorter {
94126
template <typename Group, typename T>
95127
T operator()([[maybe_unused]] Group g, T val) {
96128
#ifdef __SYCL_DEVICE_ONLY__
97-
// Adjust the scratch pointer based on alignment of the type T.
98-
// Per extension specification if scratch size is less than the value
99-
// returned by memory_required then behavior is undefined, so we don't check
100-
// that the scratch size statisfies the requirement.
101-
T *scratch_begin = nullptr;
102129
std::size_t local_id = g.get_local_linear_id();
103130
auto range_size = g.get_local_range().size();
104-
// We must have a barrier here before array placement new because it is
105-
// possible that scratch memory is already in use, so we need to synchronize
106-
// work items.
107-
sycl::group_barrier(g);
108-
if (g.leader()) {
109-
void *scratch_ptr = scratch.data();
110-
size_t space = scratch.size();
111-
scratch_ptr =
112-
std::align(alignof(T), /* output storage and temporary storage */ 2 *
113-
range_size * sizeof(T),
114-
scratch_ptr, space);
115-
scratch_begin = ::new (scratch_ptr) T[2 * range_size];
116-
}
117-
// Broadcast leader's pointer (the beginning of the scratch) to all work
118-
// items in the group.
119-
scratch_begin = sycl::group_broadcast(g, scratch_begin);
131+
T *scratch_begin = sycl::detail::align_scratch<T>(
132+
scratch, g, /* output storage and temporary storage */ 2 * range_size);
120133
scratch_begin[local_id] = val;
121134
sycl::detail::merge_sort(g, scratch_begin, range_size, comp,
122135
scratch_begin + range_size);
@@ -252,26 +265,9 @@ template <typename CompareT = std::less<>> class joint_sorter {
252265
void operator()([[maybe_unused]] Group g, [[maybe_unused]] Ptr first,
253266
[[maybe_unused]] Ptr last) {
254267
#ifdef __SYCL_DEVICE_ONLY__
255-
// Adjust the scratch pointer based on alignment of the type T.
256-
// Per extension specification if scratch size is less than the value
257-
// returned by memory_required then behavior is undefined, so we don't check
258-
// that the scratch size statisfies the requirement.
259268
using T = typename sycl::detail::GetValueType<Ptr>::type;
260-
T *scratch_begin = nullptr;
261269
size_t n = last - first;
262-
// We must have a barrier here before array placement new because it is
263-
// possible that scratch memory is already in use, so we need to synchronize
264-
// work items.
265-
sycl::group_barrier(g);
266-
if (g.leader()) {
267-
void *scratch_ptr = scratch.data();
268-
size_t space = scratch.size();
269-
scratch_ptr = std::align(alignof(T), n * sizeof(T), scratch_ptr, space);
270-
scratch_begin = ::new (scratch_ptr) T[n];
271-
}
272-
// Broadcast leader's pointer (the beginning of the scratch) to all work
273-
// items in the group.
274-
scratch_begin = sycl::group_broadcast(g, scratch_begin);
270+
T *scratch_begin = sycl::detail::align_scratch<T>(scratch, g, n);
275271
sycl::detail::merge_sort(g, first, n, comp, scratch_begin);
276272
#else
277273
throw sycl::exception(
@@ -300,29 +296,10 @@ class group_sorter {
300296

301297
template <typename Group> T operator()([[maybe_unused]] Group g, T val) {
302298
#ifdef __SYCL_DEVICE_ONLY__
303-
// Adjust the scratch pointer based on alignment of the type T.
304-
// Per extension specification if scratch size is less than the value
305-
// returned by memory_required then behavior is undefined, so we don't check
306-
// that the scratch size statisfies the requirement.
307-
T *scratch_begin = nullptr;
308299
std::size_t local_id = g.get_local_linear_id();
309300
auto range_size = g.get_local_range().size();
310-
// We must have a barrier here before array placement new because it is
311-
// possible that scratch memory is already in use, so we need to synchronize
312-
// work items.
313-
sycl::group_barrier(g);
314-
if (g.leader()) {
315-
void *scratch_ptr = scratch.data();
316-
size_t space = scratch.size();
317-
scratch_ptr =
318-
std::align(alignof(T), /* output storage and temporary storage */ 2 *
319-
range_size * sizeof(T),
320-
scratch_ptr, space);
321-
scratch_begin = ::new (scratch_ptr) T[2 * range_size];
322-
}
323-
// Broadcast leader's pointer (the beginning of the scratch) to all work
324-
// items in the group.
325-
scratch_begin = sycl::group_broadcast(g, scratch_begin);
301+
T *scratch_begin = sycl::detail::align_scratch<T>(
302+
scratch, g, /* output storage and temporary storage */ 2 * range_size);
326303
scratch_begin[local_id] = val;
327304
sycl::detail::merge_sort(g, scratch_begin, range_size, comp,
328305
scratch_begin + range_size);
@@ -335,6 +312,34 @@ class group_sorter {
335312
return val;
336313
}
337314

315+
template <typename Group, typename Properties>
316+
void operator()([[maybe_unused]] Group g,
317+
[[maybe_unused]] sycl::span<T, ElementsPerWorkItem> values,
318+
[[maybe_unused]] Properties properties) {
319+
#ifdef __SYCL_DEVICE_ONLY__
320+
std::size_t local_id = g.get_local_linear_id();
321+
auto wg_size = g.get_local_range().size();
322+
auto number_of_elements = wg_size * ElementsPerWorkItem;
323+
T *scratch_begin = sycl::detail::align_scratch<T>(
324+
scratch, g,
325+
/* output storage and temporary storage */ 2 * number_of_elements);
326+
for (std::uint32_t i = 0; i < ElementsPerWorkItem; ++i)
327+
scratch_begin[local_id * ElementsPerWorkItem + i] = values[i];
328+
sycl::detail::merge_sort(g, scratch_begin, number_of_elements, comp,
329+
scratch_begin + number_of_elements);
330+
331+
std::size_t shift{};
332+
for (std::uint32_t i = 0; i < ElementsPerWorkItem; ++i) {
333+
if constexpr (detail::isOutputBlocked(properties)) {
334+
shift = local_id * ElementsPerWorkItem + i;
335+
} else {
336+
shift = i * wg_size + local_id;
337+
}
338+
values[i] = scratch_begin[shift];
339+
}
340+
#endif
341+
}
342+
338343
static std::size_t memory_required(sycl::memory_scope scope,
339344
size_t range_size) {
340345
return 2 * joint_sorter<>::template memory_required<T>(
@@ -480,6 +485,19 @@ class group_sorter {
480485
#endif
481486
}
482487

488+
template <typename Group, typename Properties>
489+
void operator()([[maybe_unused]] Group g,
490+
[[maybe_unused]] sycl::span<ValT, ElementsPerWorkItem> values,
491+
[[maybe_unused]] Properties properties) {
492+
#ifdef __SYCL_DEVICE_ONLY__
493+
sycl::detail::privateStaticSort<
494+
/*is_key_value=*/false, detail::isOutputBlocked(properties),
495+
OrderT == sorting_order::ascending, ElementsPerWorkItem, bits>(
496+
g, values.data(), /*empty*/ values.data(), scratch.data(), first_bit,
497+
last_bit);
498+
#endif
499+
}
500+
483501
static constexpr size_t
484502
memory_required([[maybe_unused]] sycl::memory_scope scope,
485503
size_t range_size) {

sycl/include/sycl/ext/oneapi/experimental/group_sort.hpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,59 @@ sort_over_group(experimental::group_with_scratchpad<Group, Extent> exec,
114114
default_sorters::group_sorter<T>(exec.get_memory()));
115115
}
116116

117+
template <typename Group, typename T, std::size_t ElementsPerWorkItem,
118+
typename Sorter,
119+
typename Properties = ext::oneapi::experimental::empty_properties_t>
120+
std::enable_if_t<sycl::ext::oneapi::experimental::is_property_list_v<
121+
std::decay_t<Properties>>,
122+
void>
123+
sort_over_group([[maybe_unused]] Group g,
124+
[[maybe_unused]] sycl::span<T, ElementsPerWorkItem> values,
125+
[[maybe_unused]] Sorter sorter,
126+
[[maybe_unused]] Properties properties = {}) {
127+
#ifdef __SYCL_DEVICE_ONLY__
128+
return sorter(g, values, properties);
129+
#else
130+
throw sycl::exception(
131+
std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
132+
"Group algorithms are not supported on host device.");
133+
#endif
134+
}
135+
136+
template <typename Group, typename T, std::size_t Extent,
137+
std::size_t ElementsPerWorkItem,
138+
typename Properties = ext::oneapi::experimental::empty_properties_t>
139+
std::enable_if_t<sycl::ext::oneapi::experimental::is_property_list_v<
140+
std::decay_t<Properties>>,
141+
void>
142+
sort_over_group(experimental::group_with_scratchpad<Group, Extent> exec,
143+
sycl::span<T, ElementsPerWorkItem> values,
144+
Properties properties = {}) {
145+
return sort_over_group(
146+
exec.get_group(), values,
147+
default_sorters::group_sorter<T, std::less<T>, ElementsPerWorkItem>(
148+
exec.get_memory()),
149+
properties);
150+
}
151+
152+
template <typename Group, typename T, std::size_t Extent,
153+
std::size_t ElementsPerWorkItem, typename Compare,
154+
typename Properties = ext::oneapi::experimental::empty_properties_t>
155+
std::enable_if_t<!sycl::ext::oneapi::experimental::is_property_list_v<
156+
std::decay_t<Compare>> &&
157+
sycl::ext::oneapi::experimental::is_property_list_v<
158+
std::decay_t<Properties>>,
159+
void>
160+
sort_over_group(experimental::group_with_scratchpad<Group, Extent> exec,
161+
sycl::span<T, ElementsPerWorkItem> values, Compare comp,
162+
Properties properties = {}) {
163+
return sort_over_group(
164+
exec.get_group(), values,
165+
default_sorters::group_sorter<T, Compare, ElementsPerWorkItem>(
166+
exec.get_memory(), comp),
167+
properties);
168+
}
169+
117170
// ---- joint_sort
118171
template <typename Group, typename Iter, typename Sorter>
119172
std::enable_if_t<detail::is_sorter<Sorter, Group, Iter>::value, void>

sycl/include/sycl/ext/oneapi/properties/property.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,10 @@ enum PropKind : uint32_t {
207207
SingleTaskKernel = 66,
208208
IndirectlyCallable = 67,
209209
CallsIndirectly = 68,
210+
InputDataPlacement = 69,
211+
OutputDataPlacement = 70,
210212
// PropKindSize must always be the last value.
211-
PropKindSize = 69,
213+
PropKindSize = 71,
212214
};
213215

214216
struct property_key_base_tag {};

0 commit comments

Comments
 (0)