Skip to content

Commit 6aa7995

Browse files
committed
[SYCL] Add support for key/value sorting APIs
1 parent 849299f commit 6aa7995

File tree

4 files changed

+400
-5
lines changed

4 files changed

+400
-5
lines changed

sycl/include/sycl/detail/group_sort_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ template <size_t items_per_work_item, uint32_t radix_bits, bool is_comp_asc,
578578
typename ValsT, typename GroupT>
579579
void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter,
580580
const uint32_t last_iter, KeysT *keys,
581-
ValsT vals, const ScratchMemory &memory) {
581+
ValsT *vals, const ScratchMemory &memory) {
582582
const uint32_t radix_states = getStatesInBits(radix_bits);
583583
const size_t wgsize = group.get_local_linear_range();
584584
const size_t idx = group.get_local_linear_id();

sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,41 @@ class group_sorter {
343343
}
344344
};
345345

346+
template <typename T, typename U, typename CompareT = std::less<>,
347+
std::size_t ElementsPerWorkItem = 1>
348+
class group_key_value_sorter {
349+
CompareT comp;
350+
sycl::span<std::byte> scratch;
351+
352+
public:
353+
template <std::size_t Extent>
354+
group_key_value_sorter(sycl::span<std::byte, Extent> scratch_,
355+
CompareT comp_ = {})
356+
: comp(comp_), scratch(scratch_) {}
357+
358+
template <typename Group>
359+
std::tuple<T, U> operator()(Group g, T key, U value) {
360+
static_assert(ElementsPerWorkItem == 1,
361+
"ElementsPerWorkItem must be equal 1");
362+
363+
using KeyValue = std::tuple<T, U>;
364+
auto this_comp = this->comp;
365+
auto comp_key_value = [this_comp](const KeyValue &lhs,
366+
const KeyValue &rhs) {
367+
return this_comp(std::get<0>(lhs), std::get<0>(rhs));
368+
};
369+
return group_sorter<KeyValue, decltype(comp_key_value),
370+
ElementsPerWorkItem>(scratch, comp_key_value)(
371+
g, KeyValue(key, value));
372+
}
373+
374+
static constexpr std::size_t memory_required(sycl::memory_scope scope,
375+
std::size_t range_size) {
376+
return group_sorter<std::tuple<T, U>, CompareT,
377+
ElementsPerWorkItem>::memory_required(scope,
378+
range_size);
379+
}
380+
};
346381
} // namespace default_sorters
347382

348383
// Radix sorters provided by the second version of the extension specification.
@@ -455,6 +490,56 @@ class group_sorter {
455490
}
456491
};
457492

493+
template <typename T, typename U,
494+
sorting_order Order = sorting_order::ascending,
495+
size_t ElementsPerWorkItem = 1, unsigned int BitsPerPass = 4>
496+
class group_key_value_sorter {
497+
sycl::span<std::byte> scratch;
498+
uint32_t first_bit;
499+
uint32_t last_bit;
500+
501+
static constexpr uint32_t bits = BitsPerPass;
502+
using bitset_t = std::bitset<sizeof(T) * CHAR_BIT>;
503+
504+
public:
505+
template <std::size_t Extent>
506+
group_key_value_sorter(sycl::span<std::byte, Extent> scratch_,
507+
const bitset_t mask = bitset_t{}.set())
508+
: scratch(scratch_) {
509+
static_assert(
510+
(std::is_arithmetic<T>::value || std::is_same<T, sycl::half>::value),
511+
"radix sort is not usable");
512+
for (first_bit = 0; first_bit < mask.size() && !mask[first_bit];
513+
++first_bit)
514+
;
515+
for (last_bit = first_bit; last_bit < mask.size() && mask[last_bit];
516+
++last_bit)
517+
;
518+
}
519+
520+
template <typename Group>
521+
std::tuple<T, U> operator()([[maybe_unused]] Group g, T key, U val) {
522+
static_assert(ElementsPerWorkItem == 1, "ElementsPerWorkItem must be 1");
523+
T key_result[]{key};
524+
U val_result[]{val};
525+
#ifdef __SYCL_DEVICE_ONLY__
526+
sycl::detail::privateStaticSort<
527+
/*is_key_value=*/true,
528+
/*is_blocked=*/true, Order == sorting_order::ascending, 1, bits>(
529+
g, key_result, val_result, scratch.data(), first_bit, last_bit);
530+
#endif
531+
key = key_result[0];
532+
val = val_result[0];
533+
return {key, val};
534+
}
535+
536+
static constexpr std::size_t memory_required(sycl::memory_scope scope,
537+
std::size_t range_size) {
538+
return (std::max)(range_size * ElementsPerWorkItem *
539+
(sizeof(T) + sizeof(U)),
540+
range_size * (1 << bits) * sizeof(uint32_t));
541+
}
542+
};
458543
} // namespace radix_sorters
459544

460545
} // namespace ext::oneapi::experimental

sycl/include/sycl/ext/oneapi/experimental/group_sort.hpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,29 @@ struct is_sorter_impl<Sorter, Group, Ptr,
6868
template <typename Sorter, typename Group, typename ValOrPtr>
6969
struct is_sorter : decltype(is_sorter_impl<Sorter, Group, ValOrPtr>::test(0)) {
7070
};
71+
72+
template <typename Sorter, typename Group, typename Key, typename Value>
73+
struct is_key_value_sorter_impl {
74+
template <typename G>
75+
using is_expected_return_type =
76+
typename std::is_same<std::tuple<Key, Value>,
77+
decltype(std::declval<Sorter>()(
78+
std::declval<G>(), std::declval<Key>(),
79+
std::declval<Value>()))>;
80+
81+
template <typename G = Group>
82+
static decltype(std::integral_constant<bool,
83+
is_expected_return_type<G>::value &&
84+
sycl::is_group_v<G>>{})
85+
test(int);
86+
87+
template <typename = Group> static std::false_type test(...);
88+
};
89+
90+
template <typename Sorter, typename Group, typename Key, typename Value>
91+
struct is_key_value_sorter
92+
: decltype(is_key_value_sorter_impl<Sorter, Group, Key, Value>::test(0)){};
93+
7194
} // namespace detail
7295

7396
// ---- sort_over_group
@@ -131,6 +154,43 @@ joint_sort(experimental::group_with_scratchpad<Group, Extent> exec, Iter first,
131154
default_sorters::joint_sorter<>(exec.get_memory()));
132155
}
133156

157+
template <typename Group, typename T, typename U, typename Sorter>
158+
std::enable_if_t<detail::is_key_value_sorter<Sorter, Group, T, U>::value,
159+
std::tuple<T, U>>
160+
sort_key_value_over_group([[maybe_unused]] Group g, [[maybe_unused]] T key,
161+
[[maybe_unused]] U value,
162+
[[maybe_unused]] Sorter sorter) {
163+
#ifdef __SYCL_DEVICE_ONLY__
164+
return sorter(g, key, value);
165+
#else
166+
throw sycl::exception(
167+
std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
168+
"Group algorithms are not supported on host device.");
169+
#endif
170+
}
171+
172+
template <typename Group, typename T, typename U, typename Compare,
173+
std::size_t Extent>
174+
std::enable_if_t<!detail::is_key_value_sorter<Compare, Group, T, U>::value,
175+
std::tuple<T, U>>
176+
sort_key_value_over_group(
177+
experimental::group_with_scratchpad<Group, Extent> exec, T key, U value,
178+
Compare comp) {
179+
return sort_key_value_over_group(
180+
exec.get_group(), key, value,
181+
default_sorters::group_key_value_sorter<T, U, Compare>(exec.get_memory(),
182+
comp));
183+
}
184+
185+
template <typename T, typename U, typename Group, std::size_t Extent>
186+
std::enable_if_t<sycl::is_group_v<std::decay_t<Group>>, std::tuple<T, U>>
187+
sort_key_value_over_group(
188+
experimental::group_with_scratchpad<Group, Extent> exec, T key, U value) {
189+
return sort_key_value_over_group(
190+
exec.get_group(), key, value,
191+
default_sorters::group_key_value_sorter<T, U>(exec.get_memory()));
192+
}
193+
134194
} // namespace ext::oneapi::experimental
135195
} // namespace _V1
136196
} // namespace sycl

0 commit comments

Comments
 (0)