17
17
#include < sycl/group_barrier.hpp>
18
18
#include < sycl/sycl_span.hpp>
19
19
20
+ #include < iterator>
20
21
#include < memory>
21
22
22
23
namespace sycl {
@@ -50,8 +51,117 @@ static __SYCL_ALWAYS_INLINE T *align_scratch(sycl::span<std::byte> scratch,
50
51
scratch_begin = sycl::group_broadcast (g, scratch_begin);
51
52
return scratch_begin;
52
53
}
54
+
55
+ template <typename KeyTy, typename ValueTy, typename Group>
56
+ static __SYCL_ALWAYS_INLINE std::pair<KeyTy *, ValueTy *>
57
+ align_key_value_scratch (sycl::span<std::byte> scratch, Group g,
58
+ size_t number_of_elements) {
59
+ size_t KeysSize = number_of_elements * sizeof (KeyTy);
60
+ size_t ValuesSize = number_of_elements * sizeof (ValueTy);
61
+ size_t KeysScratchSpace = KeysSize + alignof (KeyTy);
62
+ size_t ValuesScratchSpace = ValuesSize + alignof (ValueTy);
63
+
64
+ KeyTy *keys_scratch_begin = nullptr ;
65
+ ValueTy *values_scratch_begin = nullptr ;
66
+ sycl::group_barrier (g);
67
+ if (g.leader ()) {
68
+ void *scratch_ptr = scratch.data ();
69
+ scratch_ptr =
70
+ std::align (alignof (KeyTy), KeysSize, scratch_ptr, KeysScratchSpace);
71
+ keys_scratch_begin = ::new (scratch_ptr) KeyTy[number_of_elements];
72
+ scratch_ptr = scratch.data () + KeysScratchSpace;
73
+ scratch_ptr = std::align (alignof (ValueTy), ValuesSize, scratch_ptr,
74
+ ValuesScratchSpace);
75
+ values_scratch_begin = ::new (scratch_ptr) ValueTy[number_of_elements];
76
+ }
77
+ // Broadcast leader's pointer (the beginning of the scratch) to all work
78
+ // items in the group.
79
+ keys_scratch_begin = sycl::group_broadcast (g, keys_scratch_begin);
80
+ values_scratch_begin = sycl::group_broadcast (g, values_scratch_begin);
81
+ return std::make_pair (keys_scratch_begin, values_scratch_begin);
82
+ }
53
83
#endif
54
84
85
+ template <typename T1, typename T2> class key_value_iterator {
86
+ public:
87
+ key_value_iterator (T1 *Keys, T2 *Values) : KeyValue{Keys, Values} {}
88
+
89
+ using difference_type = std::ptrdiff_t ;
90
+ using value_type = std::tuple<T1, T2>;
91
+ using reference = std::tuple<T1 &, T2 &>;
92
+ using pointer = std::tuple<T1 *, T2 *>;
93
+ using iterator_category = std::random_access_iterator_tag;
94
+
95
+ reference operator *() const {
96
+ return std::tie (*(std::get<0 >(KeyValue)), *(std::get<1 >(KeyValue)));
97
+ }
98
+
99
+ reference operator [](difference_type i) const { return *(*this + i); }
100
+
101
+ difference_type operator -(const key_value_iterator &it) const {
102
+ return std::get<0 >(KeyValue) - std::get<0 >(it.KeyValue );
103
+ }
104
+
105
+ key_value_iterator &operator +=(difference_type i) {
106
+ KeyValue =
107
+ std::make_tuple (std::get<0 >(KeyValue) + i, std::get<1 >(KeyValue) + i);
108
+ return *this ;
109
+ }
110
+ key_value_iterator &operator -=(difference_type i) { return *this += -i; }
111
+ key_value_iterator &operator ++() { return *this += 1 ; }
112
+ key_value_iterator &operator --() { return *this -= 1 ; }
113
+ std::tuple<T1 *, T2 *> base () const { return KeyValue; }
114
+ key_value_iterator operator ++(int ) {
115
+ key_value_iterator it (*this );
116
+ ++(*this );
117
+ return it;
118
+ }
119
+ key_value_iterator operator --(int ) {
120
+ key_value_iterator it (*this );
121
+ --(*this );
122
+ return it;
123
+ }
124
+
125
+ key_value_iterator operator -(difference_type i) const {
126
+ key_value_iterator it (*this );
127
+ return it -= i;
128
+ }
129
+ key_value_iterator operator +(difference_type i) const {
130
+ key_value_iterator it (*this );
131
+ return it += i;
132
+ }
133
+ friend key_value_iterator operator +(difference_type i,
134
+ const key_value_iterator &it) {
135
+ return it + i;
136
+ }
137
+
138
+ bool operator ==(const key_value_iterator &it) const {
139
+ return *this - it == 0 ;
140
+ }
141
+
142
+ bool operator !=(const key_value_iterator &it) const { return !(*this == it); }
143
+ bool operator <(const key_value_iterator &it) const { return *this - it < 0 ; }
144
+ bool operator >(const key_value_iterator &it) const { return it < *this ; }
145
+ bool operator <=(const key_value_iterator &it) const { return !(*this > it); }
146
+ bool operator >=(const key_value_iterator &it) const { return !(*this < it); }
147
+
148
+ private:
149
+ std::tuple<T1 *, T2 *> KeyValue;
150
+ };
151
+
152
+ template <typename T> void swap (T &first, T &second) {
153
+ std::swap (first, second);
154
+ }
155
+
156
+ // Swap tuples of references.
157
+ template <template <typename ...> class Tuple , typename ... T>
158
+ void swap (Tuple<T &...> &&first, Tuple<T &...> &&second) {
159
+ auto lhs = first;
160
+ auto rhs = second;
161
+ // Do std::swap for each element of the tuple.
162
+ std::swap (lhs, rhs);
163
+ }
164
+
55
165
// ---- merge sort implementation
56
166
57
167
// following two functions could be useless if std::[lower|upper]_bound worked
@@ -81,15 +191,6 @@ size_t upper_bound(Acc acc, const size_t first, const size_t last,
81
191
[comp](auto x, auto y) { return !comp (y, x); });
82
192
}
83
193
84
- // swap for all data types including tuple-like types
85
- template <typename T> void swap_tuples (T &a, T &b) { std::swap (a, b); }
86
-
87
- template <template <typename ...> class TupleLike , typename T1, typename T2>
88
- void swap_tuples (TupleLike<T1, T2> &&a, TupleLike<T1, T2> &&b) {
89
- std::swap (std::get<0 >(a), std::get<0 >(b));
90
- std::swap (std::get<1 >(a), std::get<1 >(b));
91
- }
92
-
93
194
template <typename Iter> struct GetValueType {
94
195
using type = typename std::iterator_traits<Iter>::value_type;
95
196
};
@@ -205,18 +306,18 @@ void bubble_sort(Iter first, const size_t begin, const size_t end,
205
306
if (begin < end) {
206
307
for (size_t i = begin; i < end; ++i) {
207
308
// Handle intermediate items
208
- for (size_t idx = i + 1 ; idx < end; ++idx) {
209
- if (comp (first[idx], first[i ])) {
210
- detail::swap_tuples (first[i ], first[idx]);
309
+ for (size_t idx = begin ; idx < begin + ( end - 1 - i) ; ++idx) {
310
+ if (comp (first[idx + 1 ], first[idx ])) {
311
+ detail::swap (first[idx ], first[idx + 1 ]);
211
312
}
212
313
}
213
314
}
214
315
}
215
316
}
216
317
217
- template <typename Group, typename Iter, typename T , typename Compare>
318
+ template <typename Group, typename Iter, typename ScratchIter , typename Compare>
218
319
void merge_sort (Group group, Iter first, const size_t n, Compare comp,
219
- T * scratch) {
320
+ ScratchIter scratch) {
220
321
const size_t idx = group.get_local_linear_id ();
221
322
const size_t local = group.get_local_range ().size ();
222
323
const size_t chunk = (n - 1 ) / local + 1 ;
@@ -606,15 +707,41 @@ void performRadixIterDynamicSize(
606
707
607
708
// The iteration of radix sort for known number of elements per work item
608
709
template <size_t items_per_work_item, uint32_t radix_bits, bool is_comp_asc,
609
- bool is_key_value_sort, bool is_blocked, typename KeysT ,
610
- typename ValsT, typename GroupT>
710
+ bool is_key_value_sort, bool is_input_blocked, bool is_output_blocked ,
711
+ typename KeysT, typename ValsT, typename GroupT>
611
712
void performRadixIterStaticSize (GroupT group, const uint32_t radix_iter,
713
+ const uint32_t first_iter,
612
714
const uint32_t last_iter, KeysT *keys,
613
715
ValsT *vals, const ScratchMemory &memory) {
614
716
const uint32_t radix_states = getStatesInBits (radix_bits);
615
717
const size_t wgsize = group.get_local_linear_range ();
616
718
const size_t idx = group.get_local_linear_id ();
617
719
720
+ const ScratchMemory &keys_temp = memory;
721
+ const ScratchMemory vals_temp =
722
+ memory + wgsize * items_per_work_item * sizeof (KeysT);
723
+
724
+ // If input is striped, reroder items using scratch memory before sorting,
725
+ // this only needs to be done at the first iteration.
726
+ if constexpr (!is_input_blocked) {
727
+ if (radix_iter == first_iter) {
728
+ for (uint32_t i = 0 ; i < items_per_work_item; ++i) {
729
+ size_t shift = i * wgsize + idx;
730
+ keys_temp.get <KeysT>(shift) = keys[i];
731
+ if constexpr (is_key_value_sort)
732
+ vals_temp.get <ValsT>(shift) = vals[i];
733
+ }
734
+ sycl::group_barrier (group);
735
+ for (uint32_t i = 0 ; i < items_per_work_item; ++i) {
736
+ size_t shift = idx * items_per_work_item + i;
737
+ keys[i] = keys_temp.get <KeysT>(shift);
738
+ if constexpr (is_key_value_sort)
739
+ vals[i] = vals_temp.get <ValsT>(shift);
740
+ }
741
+ sycl::group_barrier (group);
742
+ }
743
+ }
744
+
618
745
// 1.1. count per witem: create a private array for storing count values
619
746
uint32_t count_arr[items_per_work_item] = {0 };
620
747
uint32_t ranks[items_per_work_item] = {0 };
@@ -664,9 +791,6 @@ void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter,
664
791
sycl::group_barrier (group);
665
792
666
793
// 3. Reorder
667
- const ScratchMemory &keys_temp = memory;
668
- const ScratchMemory vals_temp =
669
- memory + wgsize * items_per_work_item * sizeof (KeysT);
670
794
for (uint32_t i = 0 ; i < items_per_work_item; ++i) {
671
795
keys_temp.get <KeysT>(ranks[i]) = keys[i];
672
796
if constexpr (is_key_value_sort)
@@ -678,7 +802,7 @@ void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter,
678
802
// 4. Copy back to input
679
803
for (uint32_t i = 0 ; i < items_per_work_item; ++i) {
680
804
size_t shift = idx * items_per_work_item + i;
681
- if constexpr (!is_blocked ) {
805
+ if constexpr (!is_output_blocked ) {
682
806
if (radix_iter == last_iter - 1 )
683
807
shift = i * wgsize + idx;
684
808
}
@@ -726,7 +850,8 @@ void privateDynamicSort(GroupT group, KeysT *keys, ValsT *values,
726
850
}
727
851
}
728
852
729
- template <bool is_key_value_sort, bool is_blocked, bool is_comp_asc,
853
+ template <bool is_key_value_sort, bool is_intput_blocked,
854
+ bool is_output_blocked, bool is_comp_asc,
730
855
size_t items_per_work_item = 1 , uint32_t radix_bits = 4 ,
731
856
typename GroupT, typename T, typename U>
732
857
void privateStaticSort (GroupT group, T *keys, U *values, std::byte *scratch,
@@ -737,8 +862,9 @@ void privateStaticSort(GroupT group, T *keys, U *values, std::byte *scratch,
737
862
738
863
for (uint32_t radix_iter = first_iter; radix_iter < last_iter; ++radix_iter) {
739
864
performRadixIterStaticSize<items_per_work_item, radix_bits, is_comp_asc,
740
- is_key_value_sort, is_blocked>(
741
- group, radix_iter, last_iter, keys, values, scratch);
865
+ is_key_value_sort, is_intput_blocked,
866
+ is_output_blocked>(
867
+ group, radix_iter, first_iter, last_iter, keys, values, scratch);
742
868
sycl::group_barrier (group);
743
869
}
744
870
}
0 commit comments