Skip to content

Commit 665a5c8

Browse files
committed
[SYCL] Key/Value sorting with fixed-size private array input
1 parent 43286ab commit 665a5c8

File tree

5 files changed

+599
-22
lines changed

5 files changed

+599
-22
lines changed

sycl/include/sycl/detail/group_sort_impl.hpp

Lines changed: 88 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <sycl/group_barrier.hpp>
1818
#include <sycl/sycl_span.hpp>
1919

20+
#include <iterator>
2021
#include <memory>
2122

2223
namespace sycl {
@@ -52,6 +53,86 @@ static __SYCL_ALWAYS_INLINE T *align_scratch(sycl::span<std::byte> scratch,
5253
}
5354
#endif
5455

56+
template <typename T1, typename T2> class key_value_iterator {
57+
public:
58+
key_value_iterator(T1 *Keys, T2 *Values) : KeyValue{Keys, Values} {}
59+
60+
using difference_type = std::ptrdiff_t;
61+
using value_type = std::tuple<T1, T2>;
62+
using reference = std::tuple<T1 &, T2 &>;
63+
using pointer = std::tuple<T1 *, T2 *>;
64+
using iterator_category = std::random_access_iterator_tag;
65+
66+
reference operator*() const {
67+
return std::tie(*(std::get<0>(KeyValue)), *(std::get<1>(KeyValue)));
68+
}
69+
70+
reference operator[](difference_type i) const { return *(*this + i); }
71+
72+
difference_type operator-(const key_value_iterator &it) const {
73+
return std::get<0>(KeyValue) - std::get<0>(it.KeyValue);
74+
}
75+
76+
key_value_iterator &operator+=(difference_type i) {
77+
KeyValue =
78+
std::make_tuple(std::get<0>(KeyValue) + i, std::get<1>(KeyValue) + i);
79+
return *this;
80+
}
81+
key_value_iterator &operator-=(difference_type i) { return *this += -i; }
82+
key_value_iterator &operator++() { return *this += 1; }
83+
key_value_iterator &operator--() { return *this -= 1; }
84+
std::tuple<T1 *, T2 *> base() const { return KeyValue; }
85+
key_value_iterator operator++(int) {
86+
key_value_iterator it(*this);
87+
++(*this);
88+
return it;
89+
}
90+
key_value_iterator operator--(int) {
91+
key_value_iterator it(*this);
92+
--(*this);
93+
return it;
94+
}
95+
96+
key_value_iterator operator-(difference_type i) const {
97+
key_value_iterator it(*this);
98+
return it -= i;
99+
}
100+
key_value_iterator operator+(difference_type i) const {
101+
key_value_iterator it(*this);
102+
return it += i;
103+
}
104+
friend key_value_iterator operator+(difference_type i,
105+
const key_value_iterator &it) {
106+
return it + i;
107+
}
108+
109+
bool operator==(const key_value_iterator &it) const {
110+
return *this - it == 0;
111+
}
112+
113+
bool operator!=(const key_value_iterator &it) const { return !(*this == it); }
114+
bool operator<(const key_value_iterator &it) const { return *this - it < 0; }
115+
bool operator>(const key_value_iterator &it) const { return it < *this; }
116+
bool operator<=(const key_value_iterator &it) const { return !(*this > it); }
117+
bool operator>=(const key_value_iterator &it) const { return !(*this < it); }
118+
119+
private:
120+
std::tuple<T1 *, T2 *> KeyValue;
121+
};
122+
123+
template <typename T> void swap(T &first, T &second) {
124+
std::swap(first, second);
125+
}
126+
127+
// Swap tuples of references.
128+
template <template <typename...> class Tuple, typename... T>
129+
void swap(Tuple<T &...> &&first, Tuple<T &...> &&second) {
130+
auto lhs = first;
131+
auto rhs = second;
132+
// Do std::swap for each element of the tuple.
133+
std::swap(lhs, rhs);
134+
}
135+
55136
// ---- merge sort implementation
56137

57138
// following two functions could be useless if std::[lower|upper]_bound worked
@@ -81,15 +162,6 @@ size_t upper_bound(Acc acc, const size_t first, const size_t last,
81162
[comp](auto x, auto y) { return !comp(y, x); });
82163
}
83164

84-
// swap for all data types including tuple-like types
85-
template <typename T> void swap_tuples(T &a, T &b) { std::swap(a, b); }
86-
87-
template <template <typename...> class TupleLike, typename T1, typename T2>
88-
void swap_tuples(TupleLike<T1, T2> &&a, TupleLike<T1, T2> &&b) {
89-
std::swap(std::get<0>(a), std::get<0>(b));
90-
std::swap(std::get<1>(a), std::get<1>(b));
91-
}
92-
93165
template <typename Iter> struct GetValueType {
94166
using type = typename std::iterator_traits<Iter>::value_type;
95167
};
@@ -205,18 +277,20 @@ void bubble_sort(Iter first, const size_t begin, const size_t end,
205277
if (begin < end) {
206278
for (size_t i = begin; i < end; ++i) {
207279
// Handle intermediate items
208-
for (size_t idx = i + 1; idx < end; ++idx) {
209-
if (comp(first[idx], first[i])) {
210-
detail::swap_tuples(first[i], first[idx]);
280+
for (size_t idx = begin; idx < begin + (end - 1 - i); ++idx) {
281+
auto refs_1 = first[idx + 1];
282+
auto refs_2 = first[idx];
283+
if (comp(refs_1, refs_2)) {
284+
detail::swap(first[idx], first[idx + 1]);
211285
}
212286
}
213287
}
214288
}
215289
}
216290

217-
template <typename Group, typename Iter, typename T, typename Compare>
291+
template <typename Group, typename Iter, typename ScratchIter, typename Compare>
218292
void merge_sort(Group group, Iter first, const size_t n, Compare comp,
219-
T *scratch) {
293+
ScratchIter scratch) {
220294
const size_t idx = group.get_local_linear_id();
221295
const size_t local = group.get_local_range().size();
222296
const size_t chunk = (n - 1) / local + 1;

sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ template <typename Compare = std::less<>> class default_sorter {
113113
[[maybe_unused]] Ptr last) {
114114
#ifdef __SYCL_DEVICE_ONLY__
115115
using T = typename sycl::detail::GetValueType<Ptr>::type;
116-
size_t n = last - first;
116+
size_t n = std::distance(first, last);
117117
T *scratch_begin = sycl::detail::align_scratch<T>(scratch, g, n);
118118
sycl::detail::merge_sort(g, first, n, comp, scratch_begin);
119119
#else
@@ -206,8 +206,8 @@ class radix_sorter {
206206
sycl::detail::privateDynamicSort</*is_key_value=*/false,
207207
OrderT == sorting_order::ascending,
208208
/*empty*/ 1, BitsPerPass>(
209-
g, first, /*empty*/ first, last - first, scratch.data(), first_bit,
210-
last_bit);
209+
g, first, /*empty*/ first, std::distance(first, last), scratch.data(),
210+
first_bit, last_bit);
211211
#else
212212
throw sycl::exception(
213213
std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
@@ -266,7 +266,7 @@ template <typename CompareT = std::less<>> class joint_sorter {
266266
[[maybe_unused]] Ptr last) {
267267
#ifdef __SYCL_DEVICE_ONLY__
268268
using T = typename sycl::detail::GetValueType<Ptr>::type;
269-
size_t n = last - first;
269+
size_t n = std::distance(first, last);
270270
T *scratch_begin = sycl::detail::align_scratch<T>(scratch, g, n);
271271
sycl::detail::merge_sort(g, first, n, comp, scratch_begin);
272272
#else
@@ -374,6 +374,68 @@ class group_key_value_sorter {
374374
g, KeyValue(key, value));
375375
}
376376

377+
template <typename Group, typename Properties>
378+
void operator()(Group g, sycl::span<KeyTy, ElementsPerWorkItem> keys,
379+
sycl::span<ValueTy, ElementsPerWorkItem> values,
380+
Properties property) {
381+
#ifdef __SYCL_DEVICE_ONLY__
382+
auto range_size = g.get_local_linear_range();
383+
std::size_t local_id = g.get_local_linear_id();
384+
auto number_of_elements = range_size * ElementsPerWorkItem;
385+
KeyTy *keys_scratch_begin =
386+
sycl::detail::align_scratch<KeyTy>(scratch, g, number_of_elements);
387+
size_t KeysScratchSpace =
388+
number_of_elements * sizeof(KeyTy) + alignof(KeyTy);
389+
size_t ValuesScratchSpace =
390+
number_of_elements * sizeof(ValueTy) + alignof(ValueTy);
391+
ValueTy *values_scratch_begin = sycl::detail::align_scratch<ValueTy>(
392+
scratch.subspan(/* offset */ KeysScratchSpace,
393+
/* extent */ ValuesScratchSpace),
394+
g, number_of_elements);
395+
396+
KeyTy *keys_scratch_temp_begin = sycl::detail::align_scratch<KeyTy>(
397+
scratch.subspan(KeysScratchSpace + ValuesScratchSpace,
398+
KeysScratchSpace),
399+
g, number_of_elements);
400+
ValueTy *values_scratch_temp_begin = sycl::detail::align_scratch<ValueTy>(
401+
scratch.subspan(2 * KeysScratchSpace + ValuesScratchSpace,
402+
ValuesScratchSpace),
403+
g, number_of_elements);
404+
405+
std::size_t shift{};
406+
for (std::uint32_t i = 0; i < ElementsPerWorkItem; ++i) {
407+
if constexpr (detail::isInputBlocked(property)) {
408+
shift = local_id * ElementsPerWorkItem + i;
409+
} else {
410+
shift = i * range_size + local_id;
411+
}
412+
keys_scratch_begin[shift] = keys[i];
413+
values_scratch_begin[shift] = values[i];
414+
}
415+
416+
auto scratch_begin = sycl::detail::key_value_iterator(keys_scratch_begin,
417+
values_scratch_begin);
418+
auto scratch_temp_begin = sycl::detail::key_value_iterator(
419+
keys_scratch_temp_begin, values_scratch_temp_begin);
420+
sycl::detail::merge_sort(
421+
g, scratch_begin, number_of_elements,
422+
[this](auto x, auto y) { return comp(std::get<0>(x), std::get<0>(y)); },
423+
scratch_temp_begin);
424+
425+
// from temp
426+
for (std::uint32_t i = 0; i < ElementsPerWorkItem; ++i) {
427+
if constexpr (detail::isOutputBlocked(property)) {
428+
shift = local_id * ElementsPerWorkItem + i;
429+
} else {
430+
shift = i * range_size + local_id;
431+
}
432+
433+
keys[i] = std::get<0>(scratch_begin[shift]);
434+
values[i] = std::get<1>(scratch_begin[shift]);
435+
}
436+
#endif
437+
}
438+
377439
static std::size_t memory_required(sycl::memory_scope scope,
378440
std::size_t range_size) {
379441
return group_sorter<std::tuple<KeyTy, ValueTy>, CompareT,
@@ -422,8 +484,8 @@ class joint_sorter {
422484
sycl::detail::privateDynamicSort</*is_key_value=*/false,
423485
OrderT == sorting_order::ascending,
424486
/*empty*/ 1, BitsPerPass>(
425-
g, first, /*empty*/ first, last - first, scratch.data(), first_bit,
426-
last_bit);
487+
g, first, /*empty*/ first, std::distance(first, last), scratch.data(),
488+
first_bit, last_bit);
427489
#else
428490
throw sycl::exception(
429491
std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
@@ -550,6 +612,18 @@ class group_key_value_sorter {
550612
return {key, val};
551613
}
552614

615+
template <typename Group, typename Properties>
616+
void operator()(Group g, sycl::span<KeyTy, ElementsPerWorkItem> keys,
617+
sycl::span<ValueTy, ElementsPerWorkItem> vals,
618+
Properties properties) {
619+
#ifdef __SYCL_DEVICE_ONLY__
620+
sycl::detail::privateStaticSort<
621+
/*is_key_value=*/true, detail::isOutputBlocked(properties),
622+
Order == sorting_order::ascending, ElementsPerWorkItem, bits>(
623+
g, keys.data(), vals.data(), scratch.data(), first_bit, last_bit);
624+
#endif
625+
}
626+
553627
static constexpr std::size_t memory_required(sycl::memory_scope,
554628
std::size_t range_size) {
555629
return (std::max)(range_size * ElementsPerWorkItem *

sycl/include/sycl/ext/oneapi/experimental/group_sort.hpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,21 @@ struct is_key_value_sorter<
8181
std::tuple<Key, Value>> &&
8282
sycl::is_group_v<Group>>> : std::true_type {};
8383

84+
template <typename Sorter, typename Group, typename Key, typename Value,
85+
typename Properties, size_t ElementsPerWorkItem, typename = void>
86+
struct is_array_key_value_sorter : std::false_type {};
87+
88+
template <typename Sorter, typename Group, typename Key, typename Value,
89+
typename Properties, size_t ElementsPerWorkItem>
90+
struct is_array_key_value_sorter<
91+
Sorter, Group, Key, Value, Properties, ElementsPerWorkItem,
92+
std::enable_if_t<
93+
std::is_same_v<std::invoke_result_t<
94+
Sorter, Group, sycl::span<Key, ElementsPerWorkItem>,
95+
sycl::span<Value, ElementsPerWorkItem>, Properties>,
96+
void> &&
97+
sycl::is_group_v<Group>>> : std::true_type {};
98+
8499
} // namespace detail
85100

86101
// ---- sort_over_group
@@ -239,6 +254,63 @@ sort_key_value_over_group(
239254
exec.get_memory()));
240255
}
241256

257+
template <std::size_t ElementsPerWorkItem, typename Group, typename T,
258+
typename U, typename ArraySorter,
259+
typename Properties = ext::oneapi::experimental::empty_properties_t>
260+
std::enable_if_t<
261+
sycl::ext::oneapi::experimental::is_property_list_v<
262+
std::decay_t<Properties>> &&
263+
detail::is_array_key_value_sorter<ArraySorter, Group, T, U, Properties,
264+
ElementsPerWorkItem>::value,
265+
void>
266+
sort_key_value_over_group(Group group, sycl::span<T, ElementsPerWorkItem> keys,
267+
sycl::span<U, ElementsPerWorkItem> values,
268+
ArraySorter array_sorter,
269+
Properties properties = {}) {
270+
array_sorter(group, keys, values, properties);
271+
}
272+
273+
template <typename Group, typename T, typename U, std::size_t Extent,
274+
std::size_t ElementsPerWorkItem, typename Compare,
275+
typename Properties = ext::oneapi::experimental::empty_properties_t>
276+
std::enable_if_t<
277+
sycl::ext::oneapi::experimental::is_property_list_v<
278+
std::decay_t<Properties>> &&
279+
!sycl::ext::oneapi::experimental::is_property_list_v<
280+
std::decay_t<Compare>> &&
281+
!detail::is_array_key_value_sorter<Compare, Group, T, U, Properties,
282+
ElementsPerWorkItem>::value,
283+
void>
284+
sort_key_value_over_group(
285+
experimental::group_with_scratchpad<Group, Extent> exec,
286+
sycl::span<T, ElementsPerWorkItem> keys,
287+
sycl::span<U, ElementsPerWorkItem> values, Compare comp,
288+
Properties properties = {}) {
289+
return experimental::sort_key_value_over_group(
290+
exec.get_group(), keys, values,
291+
typename experimental::default_sorters::group_key_value_sorter<
292+
T, U, Compare, ElementsPerWorkItem>(exec.get_memory(), comp),
293+
properties);
294+
}
295+
296+
template <typename Group, typename T, typename U, std::size_t Extent,
297+
std::size_t ElementsPerWorkItem,
298+
typename Properties = ext::oneapi::experimental::empty_properties_t>
299+
std::enable_if_t<sycl::is_group_v<std::decay_t<Group>> &&
300+
sycl::ext::oneapi::experimental::is_property_list_v<
301+
std::decay_t<Properties>>,
302+
void>
303+
sort_key_value_over_group(
304+
experimental::group_with_scratchpad<Group, Extent> exec,
305+
sycl::span<T, ElementsPerWorkItem> keys,
306+
sycl::span<U, ElementsPerWorkItem> values, Properties properties = {}) {
307+
return experimental::sort_key_value_over_group(
308+
exec.get_group(), keys, values,
309+
typename experimental::default_sorters::group_key_value_sorter<
310+
T, U, std::less<>, ElementsPerWorkItem>(exec.get_memory()),
311+
properties);
312+
}
313+
242314
} // namespace ext::oneapi::experimental
243315
} // namespace _V1
244316
} // namespace sycl

sycl/test-e2e/GroupAlgorithm/SYCL2020/group_sort/array_input_sort.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,9 @@ template <class T> void RunOverType(sycl::queue &Q, size_t DataSize) {
209209
for (T &Elem : ArrayDataRandom)
210210
Elem = T(distribution(generator));
211211

212-
auto blocked = oneapi_exp::properties{oneapi_exp::input_data_placement<
212+
auto blocked = oneapi_exp::properties{oneapi_exp::output_data_placement<
213213
oneapi_exp::group_algorithm_data_placement::blocked>};
214-
auto striped = oneapi_exp::properties{oneapi_exp::input_data_placement<
214+
auto striped = oneapi_exp::properties{oneapi_exp::output_data_placement<
215215
oneapi_exp::group_algorithm_data_placement::striped>};
216216
RunOnData<UseGroupT::WorkGroup, 1, PerWI>(Q, ArrayDataRandom, std::less<T>{},
217217
blocked);

0 commit comments

Comments
 (0)