15
15
#include < sycl/detail/pi.h> // for PI_ERROR_INVALID_DEVICE
16
16
#include < sycl/exception.hpp> // for sycl_category, exception
17
17
#include < sycl/ext/oneapi/bfloat16.hpp> // for bfloat16
18
- #include < sycl/memory_enums.hpp> // for memory_scope
19
- #include < sycl/range.hpp> // for range
20
- #include < sycl/sycl_span.hpp> // for span
18
+ #include < sycl/ext/oneapi/properties/properties.hpp>
19
+ #include < sycl/memory_enums.hpp> // for memory_scope
20
+ #include < sycl/range.hpp> // for range
21
+ #include < sycl/sycl_span.hpp> // for span
21
22
22
23
#ifdef __SYCL_DEVICE_ONLY__
23
24
#include < sycl/detail/group_sort_impl.hpp>
@@ -36,6 +37,54 @@ namespace sycl {
36
37
inline namespace _V1 {
37
38
namespace ext ::oneapi::experimental {
38
39
40
+ enum class group_algorithm_data_placement { blocked, striped };
41
+
42
+ struct input_data_placement_key
43
+ : detail::compile_time_property_key<detail::PropKind::InputDataPlacement> {
44
+ template <group_algorithm_data_placement Placement>
45
+ using value_t =
46
+ property_value<input_data_placement_key,
47
+ std::integral_constant<int , static_cast <int >(Placement)>>;
48
+ };
49
+
50
+ struct output_data_placement_key
51
+ : detail::compile_time_property_key<detail::PropKind::OutputDataPlacement> {
52
+ template <group_algorithm_data_placement Placement>
53
+ using value_t =
54
+ property_value<output_data_placement_key,
55
+ std::integral_constant<int , static_cast <int >(Placement)>>;
56
+ };
57
+
58
+ template <group_algorithm_data_placement Placement>
59
+ inline constexpr input_data_placement_key::value_t <Placement>
60
+ input_data_placement;
61
+
62
+ template <group_algorithm_data_placement Placement>
63
+ inline constexpr output_data_placement_key::value_t <Placement>
64
+ output_data_placement;
65
+
66
+ namespace detail {
67
+
68
+ template <typename Properties>
69
+ constexpr bool isInputBlocked (Properties properties) {
70
+ if constexpr (properties.template has_property <input_data_placement_key>())
71
+ return properties.template get_property <input_data_placement_key>() ==
72
+ input_data_placement<group_algorithm_data_placement::blocked>;
73
+ else
74
+ return true ;
75
+ }
76
+
77
+ template <typename Properties>
78
+ constexpr bool isOutputBlocked (Properties properties) {
79
+ if constexpr (properties.template has_property <output_data_placement_key>())
80
+ return properties.template get_property <output_data_placement_key>() ==
81
+ output_data_placement<group_algorithm_data_placement::blocked>;
82
+ else
83
+ return true ;
84
+ }
85
+
86
+ } // namespace detail
87
+
39
88
// ---- group helpers
40
89
template <typename Group, size_t Extent> class group_with_scratchpad {
41
90
Group g;
@@ -63,26 +112,9 @@ template <typename Compare = std::less<>> class default_sorter {
63
112
void operator ()([[maybe_unused]] Group g, [[maybe_unused]] Ptr first,
64
113
[[maybe_unused]] Ptr last) {
65
114
#ifdef __SYCL_DEVICE_ONLY__
66
- // Adjust the scratch pointer based on alignment of the type T.
67
- // Per extension specification if scratch size is less than the value
68
- // returned by memory_required then behavior is undefined, so we don't check
69
- // that the scratch size statisfies the requirement.
70
115
using T = typename sycl::detail::GetValueType<Ptr >::type;
71
- T *scratch_begin = nullptr ;
72
116
size_t n = last - first;
73
- // We must have a barrier here before array placement new because it is
74
- // possible that scratch memory is already in use, so we need to synchronize
75
- // work items.
76
- sycl::group_barrier (g);
77
- if (g.leader ()) {
78
- void *scratch_ptr = scratch.data ();
79
- size_t space = scratch.size ();
80
- scratch_ptr = std::align (alignof (T), n * sizeof (T), scratch_ptr, space);
81
- scratch_begin = ::new (scratch_ptr) T[n];
82
- }
83
- // Broadcast leader's pointer (the beginning of the scratch) to all work
84
- // items in the group.
85
- scratch_begin = sycl::group_broadcast (g, scratch_begin);
117
+ T *scratch_begin = sycl::detail::align_scratch<T>(scratch, g, n);
86
118
sycl::detail::merge_sort (g, first, n, comp, scratch_begin);
87
119
#else
88
120
throw sycl::exception (
@@ -94,29 +126,10 @@ template <typename Compare = std::less<>> class default_sorter {
94
126
template <typename Group, typename T>
95
127
T operator ()([[maybe_unused]] Group g, T val) {
96
128
#ifdef __SYCL_DEVICE_ONLY__
97
- // Adjust the scratch pointer based on alignment of the type T.
98
- // Per extension specification if scratch size is less than the value
99
- // returned by memory_required then behavior is undefined, so we don't check
100
- // that the scratch size statisfies the requirement.
101
- T *scratch_begin = nullptr ;
102
129
std::size_t local_id = g.get_local_linear_id ();
103
130
auto range_size = g.get_local_range ().size ();
104
- // We must have a barrier here before array placement new because it is
105
- // possible that scratch memory is already in use, so we need to synchronize
106
- // work items.
107
- sycl::group_barrier (g);
108
- if (g.leader ()) {
109
- void *scratch_ptr = scratch.data ();
110
- size_t space = scratch.size ();
111
- scratch_ptr =
112
- std::align (alignof (T), /* output storage and temporary storage */ 2 *
113
- range_size * sizeof (T),
114
- scratch_ptr, space);
115
- scratch_begin = ::new (scratch_ptr) T[2 * range_size];
116
- }
117
- // Broadcast leader's pointer (the beginning of the scratch) to all work
118
- // items in the group.
119
- scratch_begin = sycl::group_broadcast (g, scratch_begin);
131
+ T *scratch_begin = sycl::detail::align_scratch<T>(
132
+ scratch, g, /* output storage and temporary storage */ 2 * range_size);
120
133
scratch_begin[local_id] = val;
121
134
sycl::detail::merge_sort (g, scratch_begin, range_size, comp,
122
135
scratch_begin + range_size);
@@ -252,26 +265,9 @@ template <typename CompareT = std::less<>> class joint_sorter {
252
265
void operator ()([[maybe_unused]] Group g, [[maybe_unused]] Ptr first,
253
266
[[maybe_unused]] Ptr last) {
254
267
#ifdef __SYCL_DEVICE_ONLY__
255
- // Adjust the scratch pointer based on alignment of the type T.
256
- // Per extension specification if scratch size is less than the value
257
- // returned by memory_required then behavior is undefined, so we don't check
258
- // that the scratch size statisfies the requirement.
259
268
using T = typename sycl::detail::GetValueType<Ptr >::type;
260
- T *scratch_begin = nullptr ;
261
269
size_t n = last - first;
262
- // We must have a barrier here before array placement new because it is
263
- // possible that scratch memory is already in use, so we need to synchronize
264
- // work items.
265
- sycl::group_barrier (g);
266
- if (g.leader ()) {
267
- void *scratch_ptr = scratch.data ();
268
- size_t space = scratch.size ();
269
- scratch_ptr = std::align (alignof (T), n * sizeof (T), scratch_ptr, space);
270
- scratch_begin = ::new (scratch_ptr) T[n];
271
- }
272
- // Broadcast leader's pointer (the beginning of the scratch) to all work
273
- // items in the group.
274
- scratch_begin = sycl::group_broadcast (g, scratch_begin);
270
+ T *scratch_begin = sycl::detail::align_scratch<T>(scratch, g, n);
275
271
sycl::detail::merge_sort (g, first, n, comp, scratch_begin);
276
272
#else
277
273
throw sycl::exception (
@@ -300,29 +296,10 @@ class group_sorter {
300
296
301
297
template <typename Group> T operator ()([[maybe_unused]] Group g, T val) {
302
298
#ifdef __SYCL_DEVICE_ONLY__
303
- // Adjust the scratch pointer based on alignment of the type T.
304
- // Per extension specification if scratch size is less than the value
305
- // returned by memory_required then behavior is undefined, so we don't check
306
- // that the scratch size statisfies the requirement.
307
- T *scratch_begin = nullptr ;
308
299
std::size_t local_id = g.get_local_linear_id ();
309
300
auto range_size = g.get_local_range ().size ();
310
- // We must have a barrier here before array placement new because it is
311
- // possible that scratch memory is already in use, so we need to synchronize
312
- // work items.
313
- sycl::group_barrier (g);
314
- if (g.leader ()) {
315
- void *scratch_ptr = scratch.data ();
316
- size_t space = scratch.size ();
317
- scratch_ptr =
318
- std::align (alignof (T), /* output storage and temporary storage */ 2 *
319
- range_size * sizeof (T),
320
- scratch_ptr, space);
321
- scratch_begin = ::new (scratch_ptr) T[2 * range_size];
322
- }
323
- // Broadcast leader's pointer (the beginning of the scratch) to all work
324
- // items in the group.
325
- scratch_begin = sycl::group_broadcast (g, scratch_begin);
301
+ T *scratch_begin = sycl::detail::align_scratch<T>(
302
+ scratch, g, /* output storage and temporary storage */ 2 * range_size);
326
303
scratch_begin[local_id] = val;
327
304
sycl::detail::merge_sort (g, scratch_begin, range_size, comp,
328
305
scratch_begin + range_size);
@@ -335,6 +312,34 @@ class group_sorter {
335
312
return val;
336
313
}
337
314
315
+ template <typename Group, typename Properties>
316
+ void operator ()([[maybe_unused]] Group g,
317
+ [[maybe_unused]] sycl::span<T, ElementsPerWorkItem> values,
318
+ [[maybe_unused]] Properties properties) {
319
+ #ifdef __SYCL_DEVICE_ONLY__
320
+ std::size_t local_id = g.get_local_linear_id ();
321
+ auto wg_size = g.get_local_range ().size ();
322
+ auto number_of_elements = wg_size * ElementsPerWorkItem;
323
+ T *scratch_begin = sycl::detail::align_scratch<T>(
324
+ scratch, g,
325
+ /* output storage and temporary storage */ 2 * number_of_elements);
326
+ for (std::uint32_t i = 0 ; i < ElementsPerWorkItem; ++i)
327
+ scratch_begin[local_id * ElementsPerWorkItem + i] = values[i];
328
+ sycl::detail::merge_sort (g, scratch_begin, number_of_elements, comp,
329
+ scratch_begin + number_of_elements);
330
+
331
+ std::size_t shift{};
332
+ for (std::uint32_t i = 0 ; i < ElementsPerWorkItem; ++i) {
333
+ if constexpr (detail::isOutputBlocked (properties)) {
334
+ shift = local_id * ElementsPerWorkItem + i;
335
+ } else {
336
+ shift = i * wg_size + local_id;
337
+ }
338
+ values[i] = scratch_begin[shift];
339
+ }
340
+ #endif
341
+ }
342
+
338
343
static std::size_t memory_required (sycl::memory_scope scope,
339
344
size_t range_size) {
340
345
return 2 * joint_sorter<>::template memory_required<T>(
@@ -480,6 +485,19 @@ class group_sorter {
480
485
#endif
481
486
}
482
487
488
+ template <typename Group, typename Properties>
489
+ void operator ()([[maybe_unused]] Group g,
490
+ [[maybe_unused]] sycl::span<ValT, ElementsPerWorkItem> values,
491
+ [[maybe_unused]] Properties properties) {
492
+ #ifdef __SYCL_DEVICE_ONLY__
493
+ sycl::detail::privateStaticSort<
494
+ /* is_key_value=*/ false , detail::isOutputBlocked (properties),
495
+ OrderT == sorting_order::ascending, ElementsPerWorkItem, bits>(
496
+ g, values.data (), /* empty*/ values.data (), scratch.data (), first_bit,
497
+ last_bit);
498
+ #endif
499
+ }
500
+
483
501
static constexpr size_t
484
502
memory_required ([[maybe_unused]] sycl::memory_scope scope,
485
503
size_t range_size) {
0 commit comments