15
15
#include < climits>
16
16
17
17
#include < sycl/builtins.hpp>
18
+ #include < sycl/detail/key_value_iterator.hpp>
18
19
#include < sycl/group_algorithm.hpp>
19
20
#include < sycl/group_barrier.hpp>
20
21
#include < sycl/sycl_span.hpp>
21
22
23
+ #include < iterator>
22
24
#include < memory>
23
25
24
26
namespace sycl {
@@ -52,8 +54,46 @@ static __SYCL_ALWAYS_INLINE T *align_scratch(sycl::span<std::byte> scratch,
52
54
scratch_begin = sycl::group_broadcast (g, scratch_begin);
53
55
return scratch_begin;
54
56
}
57
+
58
+ template <typename KeyTy, typename ValueTy, typename Group>
59
+ static __SYCL_ALWAYS_INLINE std::pair<KeyTy *, ValueTy *>
60
+ align_key_value_scratch (sycl::span<std::byte> scratch, Group g,
61
+ size_t number_of_elements) {
62
+ size_t KeysSize = number_of_elements * sizeof (KeyTy);
63
+ size_t ValuesSize = number_of_elements * sizeof (ValueTy);
64
+ size_t KeysScratchSpace = KeysSize + alignof (KeyTy);
65
+ size_t ValuesScratchSpace = ValuesSize + alignof (ValueTy);
66
+
67
+ KeyTy *keys_scratch_begin = nullptr ;
68
+ ValueTy *values_scratch_begin = nullptr ;
69
+ sycl::group_barrier (g);
70
+ if (g.leader ()) {
71
+ void *scratch_ptr = scratch.data ();
72
+ scratch_ptr =
73
+ std::align (alignof (KeyTy), KeysSize, scratch_ptr, KeysScratchSpace);
74
+ keys_scratch_begin = ::new (scratch_ptr) KeyTy[number_of_elements];
75
+ scratch_ptr = scratch.data () + KeysScratchSpace;
76
+ scratch_ptr = std::align (alignof (ValueTy), ValuesSize, scratch_ptr,
77
+ ValuesScratchSpace);
78
+ values_scratch_begin = ::new (scratch_ptr) ValueTy[number_of_elements];
79
+ }
80
+ // Broadcast leader's pointer (the beginning of the scratch) to all work
81
+ // items in the group.
82
+ keys_scratch_begin = sycl::group_broadcast (g, keys_scratch_begin);
83
+ values_scratch_begin = sycl::group_broadcast (g, values_scratch_begin);
84
+ return std::make_pair (keys_scratch_begin, values_scratch_begin);
85
+ }
55
86
#endif
56
87
88
+ // Swap tuples of references.
89
+ template <template <typename ...> class Tuple , typename ... T>
90
+ void swap (Tuple<T &...> &&first, Tuple<T &...> &&second) {
91
+ auto lhs = first;
92
+ auto rhs = second;
93
+ // Do std::swap for each element of the tuple.
94
+ std::swap (lhs, rhs);
95
+ }
96
+
57
97
// ---- merge sort implementation
58
98
59
99
// following two functions could be useless if std::[lower|upper]_bound worked
@@ -83,15 +123,6 @@ size_t upper_bound(Acc acc, const size_t first, const size_t last,
83
123
[comp](auto x, auto y) { return !comp (y, x); });
84
124
}
85
125
86
- // swap for all data types including tuple-like types
87
- template <typename T> void swap_tuples (T &a, T &b) { std::swap (a, b); }
88
-
89
- template <template <typename ...> class TupleLike , typename T1, typename T2>
90
- void swap_tuples (TupleLike<T1, T2> &&a, TupleLike<T1, T2> &&b) {
91
- std::swap (std::get<0 >(a), std::get<0 >(b));
92
- std::swap (std::get<1 >(a), std::get<1 >(b));
93
- }
94
-
95
126
template <typename Iter> struct GetValueType {
96
127
using type = typename std::iterator_traits<Iter>::value_type;
97
128
};
@@ -207,18 +238,18 @@ void bubble_sort(Iter first, const size_t begin, const size_t end,
207
238
if (begin < end) {
208
239
for (size_t i = begin; i < end; ++i) {
209
240
// Handle intermediate items
210
- for (size_t idx = i + 1 ; idx < end; ++idx) {
211
- if (comp (first[idx], first[i ])) {
212
- detail::swap_tuples (first[i ], first[idx]);
241
+ for (size_t idx = begin ; idx < begin + ( end - 1 - i) ; ++idx) {
242
+ if (comp (first[idx + 1 ], first[idx ])) {
243
+ detail::swap (first[idx ], first[idx + 1 ]);
213
244
}
214
245
}
215
246
}
216
247
}
217
248
}
218
249
219
- template <typename Group, typename Iter, typename T , typename Compare>
250
+ template <typename Group, typename Iter, typename ScratchIter , typename Compare>
220
251
void merge_sort (Group group, Iter first, const size_t n, Compare comp,
221
- T * scratch) {
252
+ ScratchIter scratch) {
222
253
const size_t idx = group.get_local_linear_id ();
223
254
const size_t local = group.get_local_range ().size ();
224
255
const size_t chunk = (n - 1 ) / local + 1 ;
@@ -608,15 +639,41 @@ void performRadixIterDynamicSize(
608
639
609
640
// The iteration of radix sort for known number of elements per work item
610
641
template <size_t items_per_work_item, uint32_t radix_bits, bool is_comp_asc,
611
- bool is_key_value_sort, bool is_blocked, typename KeysT ,
612
- typename ValsT, typename GroupT>
642
+ bool is_key_value_sort, bool is_input_blocked, bool is_output_blocked ,
643
+ typename KeysT, typename ValsT, typename GroupT>
613
644
void performRadixIterStaticSize (GroupT group, const uint32_t radix_iter,
645
+ const uint32_t first_iter,
614
646
const uint32_t last_iter, KeysT *keys,
615
647
ValsT *vals, const ScratchMemory &memory) {
616
648
const uint32_t radix_states = getStatesInBits (radix_bits);
617
649
const size_t wgsize = group.get_local_linear_range ();
618
650
const size_t idx = group.get_local_linear_id ();
619
651
652
+ const ScratchMemory &keys_temp = memory;
653
+ const ScratchMemory vals_temp =
654
+ memory + wgsize * items_per_work_item * sizeof (KeysT);
655
+
656
+ // If input is striped, reroder items using scratch memory before sorting,
657
+ // this only needs to be done at the first iteration.
658
+ if constexpr (!is_input_blocked) {
659
+ if (radix_iter == first_iter) {
660
+ for (uint32_t i = 0 ; i < items_per_work_item; ++i) {
661
+ size_t shift = i * wgsize + idx;
662
+ keys_temp.get <KeysT>(shift) = keys[i];
663
+ if constexpr (is_key_value_sort)
664
+ vals_temp.get <ValsT>(shift) = vals[i];
665
+ }
666
+ sycl::group_barrier (group);
667
+ for (uint32_t i = 0 ; i < items_per_work_item; ++i) {
668
+ size_t shift = idx * items_per_work_item + i;
669
+ keys[i] = keys_temp.get <KeysT>(shift);
670
+ if constexpr (is_key_value_sort)
671
+ vals[i] = vals_temp.get <ValsT>(shift);
672
+ }
673
+ sycl::group_barrier (group);
674
+ }
675
+ }
676
+
620
677
// 1.1. count per witem: create a private array for storing count values
621
678
uint32_t count_arr[items_per_work_item] = {0 };
622
679
uint32_t ranks[items_per_work_item] = {0 };
@@ -666,9 +723,6 @@ void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter,
666
723
sycl::group_barrier (group);
667
724
668
725
// 3. Reorder
669
- const ScratchMemory &keys_temp = memory;
670
- const ScratchMemory vals_temp =
671
- memory + wgsize * items_per_work_item * sizeof (KeysT);
672
726
for (uint32_t i = 0 ; i < items_per_work_item; ++i) {
673
727
keys_temp.get <KeysT>(ranks[i]) = keys[i];
674
728
if constexpr (is_key_value_sort)
@@ -680,7 +734,7 @@ void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter,
680
734
// 4. Copy back to input
681
735
for (uint32_t i = 0 ; i < items_per_work_item; ++i) {
682
736
size_t shift = idx * items_per_work_item + i;
683
- if constexpr (!is_blocked ) {
737
+ if constexpr (!is_output_blocked ) {
684
738
if (radix_iter == last_iter - 1 )
685
739
shift = i * wgsize + idx;
686
740
}
@@ -728,7 +782,8 @@ void privateDynamicSort(GroupT group, KeysT *keys, ValsT *values,
728
782
}
729
783
}
730
784
731
- template <bool is_key_value_sort, bool is_blocked, bool is_comp_asc,
785
+ template <bool is_key_value_sort, bool is_intput_blocked,
786
+ bool is_output_blocked, bool is_comp_asc,
732
787
size_t items_per_work_item = 1 , uint32_t radix_bits = 4 ,
733
788
typename GroupT, typename T, typename U>
734
789
void privateStaticSort (GroupT group, T *keys, U *values, std::byte *scratch,
@@ -739,8 +794,9 @@ void privateStaticSort(GroupT group, T *keys, U *values, std::byte *scratch,
739
794
740
795
for (uint32_t radix_iter = first_iter; radix_iter < last_iter; ++radix_iter) {
741
796
performRadixIterStaticSize<items_per_work_item, radix_bits, is_comp_asc,
742
- is_key_value_sort, is_blocked>(
743
- group, radix_iter, last_iter, keys, values, scratch);
797
+ is_key_value_sort, is_intput_blocked,
798
+ is_output_blocked>(
799
+ group, radix_iter, first_iter, last_iter, keys, values, scratch);
744
800
sycl::group_barrier (group);
745
801
}
746
802
}
0 commit comments