Skip to content

Commit af2b274

Browse files
committed
[SYCL] Key/Value sorting with fixed-size private array input
1 parent 5447301 commit af2b274

File tree

6 files changed

+678
-68
lines changed

6 files changed

+678
-68
lines changed

sycl/include/sycl/detail/group_sort_impl.hpp

Lines changed: 149 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <sycl/group_barrier.hpp>
1818
#include <sycl/sycl_span.hpp>
1919

20+
#include <iterator>
2021
#include <memory>
2122

2223
namespace sycl {
@@ -50,8 +51,117 @@ static __SYCL_ALWAYS_INLINE T *align_scratch(sycl::span<std::byte> scratch,
5051
scratch_begin = sycl::group_broadcast(g, scratch_begin);
5152
return scratch_begin;
5253
}
54+
55+
template <typename KeyTy, typename ValueTy, typename Group>
56+
static __SYCL_ALWAYS_INLINE std::pair<KeyTy *, ValueTy *>
57+
align_key_value_scratch(sycl::span<std::byte> scratch, Group g,
58+
size_t number_of_elements) {
59+
size_t KeysSize = number_of_elements * sizeof(KeyTy);
60+
size_t ValuesSize = number_of_elements * sizeof(ValueTy);
61+
size_t KeysScratchSpace = KeysSize + alignof(KeyTy);
62+
size_t ValuesScratchSpace = ValuesSize + alignof(ValueTy);
63+
64+
KeyTy *keys_scratch_begin = nullptr;
65+
ValueTy *values_scratch_begin = nullptr;
66+
sycl::group_barrier(g);
67+
if (g.leader()) {
68+
void *scratch_ptr = scratch.data();
69+
scratch_ptr =
70+
std::align(alignof(KeyTy), KeysSize, scratch_ptr, KeysScratchSpace);
71+
keys_scratch_begin = ::new (scratch_ptr) KeyTy[number_of_elements];
72+
scratch_ptr = scratch.data() + KeysScratchSpace;
73+
scratch_ptr = std::align(alignof(ValueTy), ValuesSize, scratch_ptr,
74+
ValuesScratchSpace);
75+
values_scratch_begin = ::new (scratch_ptr) ValueTy[number_of_elements];
76+
}
77+
// Broadcast leader's pointer (the beginning of the scratch) to all work
78+
// items in the group.
79+
keys_scratch_begin = sycl::group_broadcast(g, keys_scratch_begin);
80+
values_scratch_begin = sycl::group_broadcast(g, values_scratch_begin);
81+
return std::make_pair(keys_scratch_begin, values_scratch_begin);
82+
}
5383
#endif
5484

85+
template <typename T1, typename T2> class key_value_iterator {
86+
public:
87+
key_value_iterator(T1 *Keys, T2 *Values) : KeyValue{Keys, Values} {}
88+
89+
using difference_type = std::ptrdiff_t;
90+
using value_type = std::tuple<T1, T2>;
91+
using reference = std::tuple<T1 &, T2 &>;
92+
using pointer = std::tuple<T1 *, T2 *>;
93+
using iterator_category = std::random_access_iterator_tag;
94+
95+
reference operator*() const {
96+
return std::tie(*(std::get<0>(KeyValue)), *(std::get<1>(KeyValue)));
97+
}
98+
99+
reference operator[](difference_type i) const { return *(*this + i); }
100+
101+
difference_type operator-(const key_value_iterator &it) const {
102+
return std::get<0>(KeyValue) - std::get<0>(it.KeyValue);
103+
}
104+
105+
key_value_iterator &operator+=(difference_type i) {
106+
KeyValue =
107+
std::make_tuple(std::get<0>(KeyValue) + i, std::get<1>(KeyValue) + i);
108+
return *this;
109+
}
110+
key_value_iterator &operator-=(difference_type i) { return *this += -i; }
111+
key_value_iterator &operator++() { return *this += 1; }
112+
key_value_iterator &operator--() { return *this -= 1; }
113+
std::tuple<T1 *, T2 *> base() const { return KeyValue; }
114+
key_value_iterator operator++(int) {
115+
key_value_iterator it(*this);
116+
++(*this);
117+
return it;
118+
}
119+
key_value_iterator operator--(int) {
120+
key_value_iterator it(*this);
121+
--(*this);
122+
return it;
123+
}
124+
125+
key_value_iterator operator-(difference_type i) const {
126+
key_value_iterator it(*this);
127+
return it -= i;
128+
}
129+
key_value_iterator operator+(difference_type i) const {
130+
key_value_iterator it(*this);
131+
return it += i;
132+
}
133+
friend key_value_iterator operator+(difference_type i,
134+
const key_value_iterator &it) {
135+
return it + i;
136+
}
137+
138+
bool operator==(const key_value_iterator &it) const {
139+
return *this - it == 0;
140+
}
141+
142+
bool operator!=(const key_value_iterator &it) const { return !(*this == it); }
143+
bool operator<(const key_value_iterator &it) const { return *this - it < 0; }
144+
bool operator>(const key_value_iterator &it) const { return it < *this; }
145+
bool operator<=(const key_value_iterator &it) const { return !(*this > it); }
146+
bool operator>=(const key_value_iterator &it) const { return !(*this < it); }
147+
148+
private:
149+
std::tuple<T1 *, T2 *> KeyValue;
150+
};
151+
152+
template <typename T> void swap(T &first, T &second) {
153+
std::swap(first, second);
154+
}
155+
156+
// Swap tuples of references.
157+
template <template <typename...> class Tuple, typename... T>
158+
void swap(Tuple<T &...> &&first, Tuple<T &...> &&second) {
159+
auto lhs = first;
160+
auto rhs = second;
161+
// Do std::swap for each element of the tuple.
162+
std::swap(lhs, rhs);
163+
}
164+
55165
// ---- merge sort implementation
56166

57167
// following two functions could be useless if std::[lower|upper]_bound worked
@@ -81,15 +191,6 @@ size_t upper_bound(Acc acc, const size_t first, const size_t last,
81191
[comp](auto x, auto y) { return !comp(y, x); });
82192
}
83193

84-
// swap for all data types including tuple-like types
85-
template <typename T> void swap_tuples(T &a, T &b) { std::swap(a, b); }
86-
87-
template <template <typename...> class TupleLike, typename T1, typename T2>
88-
void swap_tuples(TupleLike<T1, T2> &&a, TupleLike<T1, T2> &&b) {
89-
std::swap(std::get<0>(a), std::get<0>(b));
90-
std::swap(std::get<1>(a), std::get<1>(b));
91-
}
92-
93194
template <typename Iter> struct GetValueType {
94195
using type = typename std::iterator_traits<Iter>::value_type;
95196
};
@@ -205,18 +306,18 @@ void bubble_sort(Iter first, const size_t begin, const size_t end,
205306
if (begin < end) {
206307
for (size_t i = begin; i < end; ++i) {
207308
// Handle intermediate items
208-
for (size_t idx = i + 1; idx < end; ++idx) {
209-
if (comp(first[idx], first[i])) {
210-
detail::swap_tuples(first[i], first[idx]);
309+
for (size_t idx = begin; idx < begin + (end - 1 - i); ++idx) {
310+
if (comp(first[idx + 1], first[idx])) {
311+
detail::swap(first[idx], first[idx + 1]);
211312
}
212313
}
213314
}
214315
}
215316
}
216317

217-
template <typename Group, typename Iter, typename T, typename Compare>
318+
template <typename Group, typename Iter, typename ScratchIter, typename Compare>
218319
void merge_sort(Group group, Iter first, const size_t n, Compare comp,
219-
T *scratch) {
320+
ScratchIter scratch) {
220321
const size_t idx = group.get_local_linear_id();
221322
const size_t local = group.get_local_range().size();
222323
const size_t chunk = (n - 1) / local + 1;
@@ -606,15 +707,41 @@ void performRadixIterDynamicSize(
606707

607708
// The iteration of radix sort for known number of elements per work item
608709
template <size_t items_per_work_item, uint32_t radix_bits, bool is_comp_asc,
609-
bool is_key_value_sort, bool is_blocked, typename KeysT,
610-
typename ValsT, typename GroupT>
710+
bool is_key_value_sort, bool is_input_blocked, bool is_output_blocked,
711+
typename KeysT, typename ValsT, typename GroupT>
611712
void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter,
713+
const uint32_t first_iter,
612714
const uint32_t last_iter, KeysT *keys,
613715
ValsT *vals, const ScratchMemory &memory) {
614716
const uint32_t radix_states = getStatesInBits(radix_bits);
615717
const size_t wgsize = group.get_local_linear_range();
616718
const size_t idx = group.get_local_linear_id();
617719

720+
const ScratchMemory &keys_temp = memory;
721+
const ScratchMemory vals_temp =
722+
memory + wgsize * items_per_work_item * sizeof(KeysT);
723+
724+
// If input is striped, reroder items using scratch memory before sorting,
725+
// this only needs to be done at the first iteration.
726+
if constexpr (!is_input_blocked) {
727+
if (radix_iter == first_iter) {
728+
for (uint32_t i = 0; i < items_per_work_item; ++i) {
729+
size_t shift = i * wgsize + idx;
730+
keys_temp.get<KeysT>(shift) = keys[i];
731+
if constexpr (is_key_value_sort)
732+
vals_temp.get<ValsT>(shift) = vals[i];
733+
}
734+
sycl::group_barrier(group);
735+
for (uint32_t i = 0; i < items_per_work_item; ++i) {
736+
size_t shift = idx * items_per_work_item + i;
737+
keys[i] = keys_temp.get<KeysT>(shift);
738+
if constexpr (is_key_value_sort)
739+
vals[i] = vals_temp.get<ValsT>(shift);
740+
}
741+
sycl::group_barrier(group);
742+
}
743+
}
744+
618745
// 1.1. count per witem: create a private array for storing count values
619746
uint32_t count_arr[items_per_work_item] = {0};
620747
uint32_t ranks[items_per_work_item] = {0};
@@ -664,9 +791,6 @@ void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter,
664791
sycl::group_barrier(group);
665792

666793
// 3. Reorder
667-
const ScratchMemory &keys_temp = memory;
668-
const ScratchMemory vals_temp =
669-
memory + wgsize * items_per_work_item * sizeof(KeysT);
670794
for (uint32_t i = 0; i < items_per_work_item; ++i) {
671795
keys_temp.get<KeysT>(ranks[i]) = keys[i];
672796
if constexpr (is_key_value_sort)
@@ -678,7 +802,7 @@ void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter,
678802
// 4. Copy back to input
679803
for (uint32_t i = 0; i < items_per_work_item; ++i) {
680804
size_t shift = idx * items_per_work_item + i;
681-
if constexpr (!is_blocked) {
805+
if constexpr (!is_output_blocked) {
682806
if (radix_iter == last_iter - 1)
683807
shift = i * wgsize + idx;
684808
}
@@ -726,7 +850,8 @@ void privateDynamicSort(GroupT group, KeysT *keys, ValsT *values,
726850
}
727851
}
728852

729-
template <bool is_key_value_sort, bool is_blocked, bool is_comp_asc,
853+
template <bool is_key_value_sort, bool is_intput_blocked,
854+
bool is_output_blocked, bool is_comp_asc,
730855
size_t items_per_work_item = 1, uint32_t radix_bits = 4,
731856
typename GroupT, typename T, typename U>
732857
void privateStaticSort(GroupT group, T *keys, U *values, std::byte *scratch,
@@ -737,8 +862,9 @@ void privateStaticSort(GroupT group, T *keys, U *values, std::byte *scratch,
737862

738863
for (uint32_t radix_iter = first_iter; radix_iter < last_iter; ++radix_iter) {
739864
performRadixIterStaticSize<items_per_work_item, radix_bits, is_comp_asc,
740-
is_key_value_sort, is_blocked>(
741-
group, radix_iter, last_iter, keys, values, scratch);
865+
is_key_value_sort, is_intput_blocked,
866+
is_output_blocked>(
867+
group, radix_iter, first_iter, last_iter, keys, values, scratch);
742868
sycl::group_barrier(group);
743869
}
744870
}

0 commit comments

Comments
 (0)