Skip to content

Commit 3910d0c

Browse files
authored
[SYCL] Add support for key/value sorting APIs (#13942)
Add group_key_value_sorter sorters and sort_key_value_over_group APIs based on https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/proposed/sycl_ext_oneapi_group_sort.asciidoc extension. This PR was split out from larger PR: #13713 Co-authored-by: "Andrei Fedorov [[email protected]](mailto:[email protected])" Co-authored-by: "Romanov Vlad [[email protected]](mailto:[email protected])"
1 parent 87f47b4 commit 3910d0c

File tree

6 files changed

+509
-50
lines changed

6 files changed

+509
-50
lines changed

sycl/include/sycl/detail/group_sort_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ template <size_t items_per_work_item, uint32_t radix_bits, bool is_comp_asc,
578578
typename ValsT, typename GroupT>
579579
void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter,
580580
const uint32_t last_iter, KeysT *keys,
581-
ValsT vals, const ScratchMemory &memory) {
581+
ValsT *vals, const ScratchMemory &memory) {
582582
const uint32_t radix_states = getStatesInBits(radix_bits);
583583
const size_t wgsize = group.get_local_linear_range();
584584
const size_t idx = group.get_local_linear_id();

sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,7 @@ template <typename CompareT = std::less<>> class joint_sorter {
281281
}
282282

283283
template <typename T>
284-
static constexpr size_t memory_required(sycl::memory_scope,
285-
size_t range_size) {
284+
static size_t memory_required(sycl::memory_scope, size_t range_size) {
286285
return range_size * sizeof(T) + alignof(T);
287286
}
288287
};
@@ -336,13 +335,47 @@ class group_sorter {
336335
return val;
337336
}
338337

339-
static constexpr std::size_t memory_required(sycl::memory_scope scope,
340-
size_t range_size) {
338+
static std::size_t memory_required(sycl::memory_scope scope,
339+
size_t range_size) {
341340
return 2 * joint_sorter<>::template memory_required<T>(
342341
scope, range_size * ElementsPerWorkItem);
343342
}
344343
};
345344

345+
template <typename KeyTy, typename ValueTy, typename CompareT = std::less<>,
346+
std::size_t ElementsPerWorkItem = 1>
347+
class group_key_value_sorter {
348+
CompareT comp;
349+
sycl::span<std::byte> scratch;
350+
351+
public:
352+
template <std::size_t Extent>
353+
group_key_value_sorter(sycl::span<std::byte, Extent> scratch_,
354+
CompareT comp_ = {})
355+
: comp(comp_), scratch(scratch_) {}
356+
357+
template <typename Group>
358+
std::tuple<KeyTy, ValueTy> operator()(Group g, KeyTy key, ValueTy value) {
359+
static_assert(ElementsPerWorkItem == 1,
360+
"ElementsPerWorkItem must be equal 1");
361+
362+
using KeyValue = std::tuple<KeyTy, ValueTy>;
363+
auto comp_key_value = [this_comp = this->comp](const KeyValue &lhs,
364+
const KeyValue &rhs) {
365+
return this_comp(std::get<0>(lhs), std::get<0>(rhs));
366+
};
367+
return group_sorter<KeyValue, decltype(comp_key_value),
368+
ElementsPerWorkItem>(scratch, comp_key_value)(
369+
g, KeyValue(key, value));
370+
}
371+
372+
static std::size_t memory_required(sycl::memory_scope scope,
373+
std::size_t range_size) {
374+
return group_sorter<std::tuple<KeyTy, ValueTy>, CompareT,
375+
ElementsPerWorkItem>::memory_required(scope,
376+
range_size);
377+
}
378+
};
346379
} // namespace default_sorters
347380

348381
// Radix sorters provided by the second version of the extension specification.
@@ -455,6 +488,57 @@ class group_sorter {
455488
}
456489
};
457490

491+
template <typename KeyTy, typename ValueTy,
492+
sorting_order Order = sorting_order::ascending,
493+
size_t ElementsPerWorkItem = 1, unsigned int BitsPerPass = 4>
494+
class group_key_value_sorter {
495+
sycl::span<std::byte> scratch;
496+
uint32_t first_bit;
497+
uint32_t last_bit;
498+
499+
static constexpr uint32_t bits = BitsPerPass;
500+
using bitset_t = std::bitset<sizeof(KeyTy) * CHAR_BIT>;
501+
502+
public:
503+
template <std::size_t Extent>
504+
group_key_value_sorter(sycl::span<std::byte, Extent> scratch_,
505+
const bitset_t mask = bitset_t{}.set())
506+
: scratch(scratch_) {
507+
static_assert((std::is_arithmetic<KeyTy>::value ||
508+
std::is_same<KeyTy, sycl::half>::value),
509+
"radix sort is not usable");
510+
for (first_bit = 0; first_bit < mask.size() && !mask[first_bit];
511+
++first_bit)
512+
;
513+
for (last_bit = first_bit; last_bit < mask.size() && mask[last_bit];
514+
++last_bit)
515+
;
516+
}
517+
518+
template <typename Group>
519+
std::tuple<KeyTy, ValueTy> operator()([[maybe_unused]] Group g, KeyTy key,
520+
ValueTy val) {
521+
static_assert(ElementsPerWorkItem == 1, "ElementsPerWorkItem must be 1");
522+
KeyTy key_result[]{key};
523+
ValueTy val_result[]{val};
524+
#ifdef __SYCL_DEVICE_ONLY__
525+
sycl::detail::privateStaticSort<
526+
/*is_key_value=*/true,
527+
/*is_blocked=*/true, Order == sorting_order::ascending, 1, bits>(
528+
g, key_result, val_result, scratch.data(), first_bit, last_bit);
529+
#endif
530+
key = key_result[0];
531+
val = val_result[0];
532+
return {key, val};
533+
}
534+
535+
static constexpr std::size_t memory_required(sycl::memory_scope,
536+
std::size_t range_size) {
537+
return (std::max)(range_size * ElementsPerWorkItem *
538+
(sizeof(KeyTy) + sizeof(ValueTy)),
539+
range_size * (1 << bits) * sizeof(uint32_t));
540+
}
541+
};
458542
} // namespace radix_sorters
459543

460544
} // namespace ext::oneapi::experimental

sycl/include/sycl/ext/oneapi/experimental/group_sort.hpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,19 @@ struct is_sorter_impl<Sorter, Group, Ptr,
6868
template <typename Sorter, typename Group, typename ValOrPtr>
6969
struct is_sorter : decltype(is_sorter_impl<Sorter, Group, ValOrPtr>::test(0)) {
7070
};
71+
72+
template <typename Sorter, typename Group, typename Key, typename Value,
73+
typename = void>
74+
struct is_key_value_sorter : std::false_type {};
75+
76+
template <typename Sorter, typename Group, typename Key, typename Value>
77+
struct is_key_value_sorter<
78+
Sorter, Group, Key, Value,
79+
std::enable_if_t<
80+
std::is_same_v<std::invoke_result_t<Sorter, Group, Key, Value>,
81+
std::tuple<Key, Value>> &&
82+
sycl::is_group_v<Group>>> : std::true_type {};
83+
7184
} // namespace detail
7285

7386
// ---- sort_over_group
@@ -131,6 +144,48 @@ joint_sort(experimental::group_with_scratchpad<Group, Extent> exec, Iter first,
131144
default_sorters::joint_sorter<>(exec.get_memory()));
132145
}
133146

147+
template <typename Group, typename KeyTy, typename ValueTy, typename Sorter>
148+
std::enable_if_t<
149+
detail::is_key_value_sorter<Sorter, Group, KeyTy, ValueTy>::value,
150+
std::tuple<KeyTy, ValueTy>>
151+
sort_key_value_over_group([[maybe_unused]] Group g, [[maybe_unused]] KeyTy key,
152+
[[maybe_unused]] ValueTy value,
153+
[[maybe_unused]] Sorter sorter) {
154+
#ifdef __SYCL_DEVICE_ONLY__
155+
return sorter(g, key, value);
156+
#else
157+
throw sycl::exception(
158+
std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
159+
"Group algorithms are not supported on host device.");
160+
#endif
161+
}
162+
163+
template <typename Group, typename KeyTy, typename ValueTy, typename Compare,
164+
std::size_t Extent>
165+
std::enable_if_t<
166+
!detail::is_key_value_sorter<Compare, Group, KeyTy, ValueTy>::value,
167+
std::tuple<KeyTy, ValueTy>>
168+
sort_key_value_over_group(
169+
experimental::group_with_scratchpad<Group, Extent> exec, KeyTy key,
170+
ValueTy value, Compare comp) {
171+
return sort_key_value_over_group(
172+
exec.get_group(), key, value,
173+
default_sorters::group_key_value_sorter<KeyTy, ValueTy, Compare>(
174+
exec.get_memory(), comp));
175+
}
176+
177+
template <typename KeyTy, typename ValueTy, typename Group, std::size_t Extent>
178+
std::enable_if_t<sycl::is_group_v<std::decay_t<Group>>,
179+
std::tuple<KeyTy, ValueTy>>
180+
sort_key_value_over_group(
181+
experimental::group_with_scratchpad<Group, Extent> exec, KeyTy key,
182+
ValueTy value) {
183+
return sort_key_value_over_group(
184+
exec.get_group(), key, value,
185+
default_sorters::group_key_value_sorter<KeyTy, ValueTy>(
186+
exec.get_memory()));
187+
}
188+
134189
} // namespace ext::oneapi::experimental
135190
} // namespace _V1
136191
} // namespace sycl
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
2+
#include <sycl/detail/core.hpp>
3+
#include <sycl/ext/oneapi/experimental/group_sort.hpp>
4+
5+
#pragma once
6+
7+
namespace oneapi_exp = sycl::ext::oneapi::experimental;
8+
9+
enum class UseGroupT { SubGroup = true, WorkGroup = false };
10+
11+
// these classes are needed to pass non-type template parameters to KernelName
12+
template <int> class IntWrapper;
13+
template <UseGroupT> class UseGroupWrapper;
14+
15+
class CustomType {
16+
public:
17+
CustomType(size_t Val) : MVal(Val) {}
18+
CustomType() : MVal(0) {}
19+
20+
bool operator<(const CustomType &RHS) const { return MVal < RHS.MVal; }
21+
bool operator>(const CustomType &RHS) const { return MVal > RHS.MVal; }
22+
bool operator==(const CustomType &RHS) const { return MVal == RHS.MVal; }
23+
24+
private:
25+
size_t MVal = 0;
26+
};
27+
28+
template <class T> struct ConvertToSimpleType {
29+
using Type = T;
30+
};
31+
32+
// Dummy overloads for CustomType which is not supported by radix sorter
33+
template <> struct ConvertToSimpleType<CustomType> {
34+
using Type = int;
35+
};
36+
37+
template <class SorterT> struct ConvertToSortingOrder;
38+
39+
template <class T> struct ConvertToSortingOrder<std::greater<T>> {
40+
static const auto Type = oneapi_exp::sorting_order::descending;
41+
};
42+
43+
template <class T> struct ConvertToSortingOrder<std::less<T>> {
44+
static const auto Type = oneapi_exp::sorting_order::ascending;
45+
};
46+
47+
constexpr size_t ReqSubGroupSize = 8;
48+
49+
template <typename...> class KernelNameOverGroup;
50+
template <typename...> class KernelNameJoint;

sycl/test-e2e/GroupAlgorithm/SYCL2020/sort.cpp renamed to sycl/test-e2e/GroupAlgorithm/SYCL2020/group_sort/group_and_joint_sort.cpp

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
#include <sycl/detail/core.hpp>
3131

32+
#include "common.hpp"
3233
#include <sycl/builtins.hpp>
3334
#include <sycl/ext/oneapi/experimental/group_sort.hpp>
3435
#include <sycl/group_algorithm.hpp>
@@ -39,30 +40,6 @@
3940
#include <random>
4041
#include <vector>
4142

42-
namespace oneapi_exp = sycl::ext::oneapi::experimental;
43-
44-
template <typename...> class KernelNameOverGroup;
45-
template <typename...> class KernelNameJoint;
46-
47-
enum class UseGroupT { SubGroup = true, WorkGroup = false };
48-
49-
// these classes are needed to pass non-type template parameters to KernelName
50-
template <int> class IntWrapper;
51-
template <UseGroupT> class UseGroupWrapper;
52-
53-
class CustomType {
54-
public:
55-
CustomType(size_t Val) : MVal(Val) {}
56-
CustomType() : MVal(0) {}
57-
58-
bool operator<(const CustomType &RHS) const { return MVal < RHS.MVal; }
59-
bool operator>(const CustomType &RHS) const { return MVal > RHS.MVal; }
60-
bool operator==(const CustomType &RHS) const { return MVal == RHS.MVal; }
61-
62-
private:
63-
size_t MVal = 0;
64-
};
65-
6643
#if VERSION == 1
6744
template <class CompT, class T> struct RadixSorterType;
6845

@@ -86,29 +63,8 @@ template <> struct RadixSorterType<std::greater<CustomType>, CustomType> {
8663
using Type =
8764
oneapi_exp::radix_sorter<int, oneapi_exp::sorting_order::descending>;
8865
};
89-
#else
90-
template <class T> struct ConvertToSimpleType {
91-
using Type = T;
92-
};
93-
94-
// Dummy overloads for CustomType which is not supported by radix sorter
95-
template <> struct ConvertToSimpleType<CustomType> {
96-
using Type = int;
97-
};
98-
99-
template <class SorterT> struct ConvertToSortingOrder;
100-
101-
template <class T> struct ConvertToSortingOrder<std::greater<T>> {
102-
static const auto Type = oneapi_exp::sorting_order::descending;
103-
};
104-
105-
template <class T> struct ConvertToSortingOrder<std::less<T>> {
106-
static const auto Type = oneapi_exp::sorting_order::ascending;
107-
};
10866
#endif
10967

110-
constexpr size_t ReqSubGroupSize = 8;
111-
11268
template <UseGroupT UseGroup, int Dims, class T, class Compare>
11369
void RunJointSort(sycl::queue &Q, const std::vector<T> &DataToSort,
11470
const Compare &Comp) {

0 commit comments

Comments
 (0)