15
15
#include < sycl/detail/pi.h> // for PI_ERROR_INVALID_DEVICE
16
16
#include < sycl/exception.hpp> // for sycl_category, exception
17
17
#include < sycl/ext/oneapi/bfloat16.hpp> // for bfloat16
18
- #include < sycl/memory_enums.hpp> // for memory_scope
19
- #include < sycl/range.hpp> // for range
20
- #include < sycl/sycl_span.hpp> // for span
18
+ #include < sycl/ext/oneapi/properties/properties.hpp>
19
+ #include < sycl/memory_enums.hpp> // for memory_scope
20
+ #include < sycl/range.hpp> // for range
21
+ #include < sycl/sycl_span.hpp> // for span
21
22
22
23
#ifdef __SYCL_DEVICE_ONLY__
23
24
#include < sycl/detail/group_sort_impl.hpp>
@@ -36,6 +37,68 @@ namespace sycl {
36
37
inline namespace _V1 {
37
38
namespace ext ::oneapi::experimental {
38
39
40
+ enum class group_algorithm_data_placement { blocked, striped };
41
+
42
+ struct input_data_placement_key
43
+ : detail::compile_time_property_key<detail::PropKind::InputDataPlacement> {
44
+ template <group_algorithm_data_placement Placement>
45
+ using value_t =
46
+ property_value<input_data_placement_key,
47
+ std::integral_constant<int , static_cast <int >(Placement)>>;
48
+ };
49
+
50
+ struct output_data_placement_key
51
+ : detail::compile_time_property_key<detail::PropKind::OutputDataPlacement> {
52
+ template <group_algorithm_data_placement Placement>
53
+ using value_t =
54
+ property_value<output_data_placement_key,
55
+ std::integral_constant<int , static_cast <int >(Placement)>>;
56
+ };
57
+
58
+ template <group_algorithm_data_placement Placement>
59
+ inline constexpr input_data_placement_key::value_t <Placement>
60
+ input_data_placement;
61
+
62
+ template <group_algorithm_data_placement Placement>
63
+ inline constexpr output_data_placement_key::value_t <Placement>
64
+ output_data_placement;
65
+
66
+ inline constexpr input_data_placement_key::value_t <
67
+ group_algorithm_data_placement::blocked>
68
+ input_data_placement_blocked;
69
+ inline constexpr input_data_placement_key::value_t <
70
+ group_algorithm_data_placement::striped>
71
+ input_data_placement_striped;
72
+
73
+ inline constexpr output_data_placement_key::value_t <
74
+ group_algorithm_data_placement::blocked>
75
+ output_data_placement_blocked;
76
+ inline constexpr output_data_placement_key::value_t <
77
+ group_algorithm_data_placement::striped>
78
+ output_data_placement_striped;
79
+
80
+ namespace detail {
81
+
82
+ template <typename Properties>
83
+ constexpr bool isInputBlocked (Properties properties) {
84
+ if constexpr (properties.template has_property <input_data_placement_key>())
85
+ return properties.template get_property <input_data_placement_key>() ==
86
+ input_data_placement<group_algorithm_data_placement::blocked>;
87
+ else
88
+ return true ;
89
+ }
90
+
91
+ template <typename Properties>
92
+ constexpr bool isOutputBlocked (Properties properties) {
93
+ if constexpr (properties.template has_property <output_data_placement_key>())
94
+ return properties.template get_property <output_data_placement_key>() ==
95
+ output_data_placement<group_algorithm_data_placement::blocked>;
96
+ else
97
+ return true ;
98
+ }
99
+
100
+ } // namespace detail
101
+
39
102
// ---- group helpers
40
103
template <typename Group, size_t Extent> class group_with_scratchpad {
41
104
Group g;
@@ -63,26 +126,9 @@ template <typename Compare = std::less<>> class default_sorter {
63
126
void operator ()([[maybe_unused]] Group g, [[maybe_unused]] Ptr first,
64
127
[[maybe_unused]] Ptr last) {
65
128
#ifdef __SYCL_DEVICE_ONLY__
66
- // Adjust the scratch pointer based on alignment of the type T.
67
- // Per extension specification if scratch size is less than the value
68
- // returned by memory_required then behavior is undefined, so we don't check
69
- // that the scratch size statisfies the requirement.
70
129
using T = typename sycl::detail::GetValueType<Ptr >::type;
71
- T *scratch_begin = nullptr ;
72
130
size_t n = last - first;
73
- // We must have a barrier here before array placement new because it is
74
- // possible that scratch memory is already in use, so we need to synchronize
75
- // work items.
76
- sycl::group_barrier (g);
77
- if (g.leader ()) {
78
- void *scratch_ptr = scratch.data ();
79
- size_t space = scratch.size ();
80
- scratch_ptr = std::align (alignof (T), n * sizeof (T), scratch_ptr, space);
81
- scratch_begin = ::new (scratch_ptr) T[n];
82
- }
83
- // Broadcast leader's pointer (the beginning of the scratch) to all work
84
- // items in the group.
85
- scratch_begin = sycl::group_broadcast (g, scratch_begin);
131
+ T *scratch_begin = sycl::detail::align_scratch<T>(scratch, g, n);
86
132
sycl::detail::merge_sort (g, first, n, comp, scratch_begin);
87
133
#else
88
134
throw sycl::exception (
@@ -94,29 +140,10 @@ template <typename Compare = std::less<>> class default_sorter {
94
140
template <typename Group, typename T>
95
141
T operator ()([[maybe_unused]] Group g, T val) {
96
142
#ifdef __SYCL_DEVICE_ONLY__
97
- // Adjust the scratch pointer based on alignment of the type T.
98
- // Per extension specification if scratch size is less than the value
99
- // returned by memory_required then behavior is undefined, so we don't check
100
- // that the scratch size statisfies the requirement.
101
- T *scratch_begin = nullptr ;
102
143
std::size_t local_id = g.get_local_linear_id ();
103
144
auto range_size = g.get_local_range ().size ();
104
- // We must have a barrier here before array placement new because it is
105
- // possible that scratch memory is already in use, so we need to synchronize
106
- // work items.
107
- sycl::group_barrier (g);
108
- if (g.leader ()) {
109
- void *scratch_ptr = scratch.data ();
110
- size_t space = scratch.size ();
111
- scratch_ptr =
112
- std::align (alignof (T), /* output storage and temporary storage */ 2 *
113
- range_size * sizeof (T),
114
- scratch_ptr, space);
115
- scratch_begin = ::new (scratch_ptr) T[2 * range_size];
116
- }
117
- // Broadcast leader's pointer (the beginning of the scratch) to all work
118
- // items in the group.
119
- scratch_begin = sycl::group_broadcast (g, scratch_begin);
145
+ T *scratch_begin = sycl::detail::align_scratch<T>(
146
+ scratch, g, /* output storage and temporary storage */ 2 * range_size);
120
147
scratch_begin[local_id] = val;
121
148
sycl::detail::merge_sort (g, scratch_begin, range_size, comp,
122
149
scratch_begin + range_size);
@@ -252,26 +279,9 @@ template <typename CompareT = std::less<>> class joint_sorter {
252
279
void operator ()([[maybe_unused]] Group g, [[maybe_unused]] Ptr first,
253
280
[[maybe_unused]] Ptr last) {
254
281
#ifdef __SYCL_DEVICE_ONLY__
255
- // Adjust the scratch pointer based on alignment of the type T.
256
- // Per extension specification if scratch size is less than the value
257
- // returned by memory_required then behavior is undefined, so we don't check
258
- // that the scratch size statisfies the requirement.
259
282
using T = typename sycl::detail::GetValueType<Ptr >::type;
260
- T *scratch_begin = nullptr ;
261
283
size_t n = last - first;
262
- // We must have a barrier here before array placement new because it is
263
- // possible that scratch memory is already in use, so we need to synchronize
264
- // work items.
265
- sycl::group_barrier (g);
266
- if (g.leader ()) {
267
- void *scratch_ptr = scratch.data ();
268
- size_t space = scratch.size ();
269
- scratch_ptr = std::align (alignof (T), n * sizeof (T), scratch_ptr, space);
270
- scratch_begin = ::new (scratch_ptr) T[n];
271
- }
272
- // Broadcast leader's pointer (the beginning of the scratch) to all work
273
- // items in the group.
274
- scratch_begin = sycl::group_broadcast (g, scratch_begin);
284
+ T *scratch_begin = sycl::detail::align_scratch<T>(scratch, g, n);
275
285
sycl::detail::merge_sort (g, first, n, comp, scratch_begin);
276
286
#else
277
287
throw sycl::exception (
@@ -300,29 +310,10 @@ class group_sorter {
300
310
301
311
template <typename Group> T operator ()([[maybe_unused]] Group g, T val) {
302
312
#ifdef __SYCL_DEVICE_ONLY__
303
- // Adjust the scratch pointer based on alignment of the type T.
304
- // Per extension specification if scratch size is less than the value
305
- // returned by memory_required then behavior is undefined, so we don't check
306
- // that the scratch size statisfies the requirement.
307
- T *scratch_begin = nullptr ;
308
313
std::size_t local_id = g.get_local_linear_id ();
309
314
auto range_size = g.get_local_range ().size ();
310
- // We must have a barrier here before array placement new because it is
311
- // possible that scratch memory is already in use, so we need to synchronize
312
- // work items.
313
- sycl::group_barrier (g);
314
- if (g.leader ()) {
315
- void *scratch_ptr = scratch.data ();
316
- size_t space = scratch.size ();
317
- scratch_ptr =
318
- std::align (alignof (T), /* output storage and temporary storage */ 2 *
319
- range_size * sizeof (T),
320
- scratch_ptr, space);
321
- scratch_begin = ::new (scratch_ptr) T[2 * range_size];
322
- }
323
- // Broadcast leader's pointer (the beginning of the scratch) to all work
324
- // items in the group.
325
- scratch_begin = sycl::group_broadcast (g, scratch_begin);
315
+ T *scratch_begin = sycl::detail::align_scratch<T>(
316
+ scratch, g, /* output storage and temporary storage */ 2 * range_size);
326
317
scratch_begin[local_id] = val;
327
318
sycl::detail::merge_sort (g, scratch_begin, range_size, comp,
328
319
scratch_begin + range_size);
@@ -335,6 +326,34 @@ class group_sorter {
335
326
return val;
336
327
}
337
328
329
+ template <typename Group, typename Properties>
330
+ void operator ()([[maybe_unused]] Group g,
331
+ [[maybe_unused]] sycl::span<T, ElementsPerWorkItem> values,
332
+ [[maybe_unused]] Properties properties) {
333
+ #ifdef __SYCL_DEVICE_ONLY__
334
+ std::size_t local_id = g.get_local_linear_id ();
335
+ auto wg_size = g.get_local_range ().size ();
336
+ auto number_of_elements = wg_size * ElementsPerWorkItem;
337
+ T *scratch_begin = sycl::detail::align_scratch<T>(
338
+ scratch, g,
339
+ /* output storage and temporary storage */ 2 * number_of_elements);
340
+ for (std::uint32_t i = 0 ; i < ElementsPerWorkItem; ++i)
341
+ scratch_begin[local_id * ElementsPerWorkItem + i] = values[i];
342
+ sycl::detail::merge_sort (g, scratch_begin, number_of_elements, comp,
343
+ scratch_begin + number_of_elements);
344
+
345
+ std::size_t shift{};
346
+ for (std::uint32_t i = 0 ; i < ElementsPerWorkItem; ++i) {
347
+ if constexpr (detail::isOutputBlocked (properties)) {
348
+ shift = local_id * ElementsPerWorkItem + i;
349
+ } else {
350
+ shift = i * wg_size + local_id;
351
+ }
352
+ values[i] = scratch_begin[shift];
353
+ }
354
+ #endif
355
+ }
356
+
338
357
static std::size_t memory_required (sycl::memory_scope scope,
339
358
size_t range_size) {
340
359
return 2 * joint_sorter<>::template memory_required<T>(
@@ -480,6 +499,19 @@ class group_sorter {
480
499
#endif
481
500
}
482
501
502
+ template <typename Group, typename Properties>
503
+ void operator ()([[maybe_unused]] Group g,
504
+ [[maybe_unused]] sycl::span<ValT, ElementsPerWorkItem> values,
505
+ [[maybe_unused]] Properties properties) {
506
+ #ifdef __SYCL_DEVICE_ONLY__
507
+ sycl::detail::privateStaticSort<
508
+ /* is_key_value=*/ false , detail::isOutputBlocked (properties),
509
+ OrderT == sorting_order::ascending, ElementsPerWorkItem, bits>(
510
+ g, values.data (), /* empty*/ values.data (), scratch.data (), first_bit,
511
+ last_bit);
512
+ #endif
513
+ }
514
+
483
515
static constexpr size_t
484
516
memory_required ([[maybe_unused]] sycl::memory_scope scope,
485
517
size_t range_size) {
0 commit comments