Skip to content

Commit 3800814

Browse files
authored
[SYCL] Key/Value sorting with fixed-size private array input (#14399)
Implementation of sort_key_value_over_group APIs from https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/proposed/sycl_ext_oneapi_group_sort.asciidoc#functions-with-fixed-size-arrays
1 parent 1207b15 commit 3800814

File tree

7 files changed

+807
-165
lines changed

7 files changed

+807
-165
lines changed

sycl/include/sycl/detail/group_sort_impl.hpp

Lines changed: 79 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
#include <climits>
1616

1717
#include <sycl/builtins.hpp>
18+
#include <sycl/detail/key_value_iterator.hpp>
1819
#include <sycl/group_algorithm.hpp>
1920
#include <sycl/group_barrier.hpp>
2021
#include <sycl/sycl_span.hpp>
2122

23+
#include <iterator>
2224
#include <memory>
2325

2426
namespace sycl {
@@ -52,8 +54,46 @@ static __SYCL_ALWAYS_INLINE T *align_scratch(sycl::span<std::byte> scratch,
5254
scratch_begin = sycl::group_broadcast(g, scratch_begin);
5355
return scratch_begin;
5456
}
57+
58+
template <typename KeyTy, typename ValueTy, typename Group>
59+
static __SYCL_ALWAYS_INLINE std::pair<KeyTy *, ValueTy *>
60+
align_key_value_scratch(sycl::span<std::byte> scratch, Group g,
61+
size_t number_of_elements) {
62+
size_t KeysSize = number_of_elements * sizeof(KeyTy);
63+
size_t ValuesSize = number_of_elements * sizeof(ValueTy);
64+
size_t KeysScratchSpace = KeysSize + alignof(KeyTy);
65+
size_t ValuesScratchSpace = ValuesSize + alignof(ValueTy);
66+
67+
KeyTy *keys_scratch_begin = nullptr;
68+
ValueTy *values_scratch_begin = nullptr;
69+
sycl::group_barrier(g);
70+
if (g.leader()) {
71+
void *scratch_ptr = scratch.data();
72+
scratch_ptr =
73+
std::align(alignof(KeyTy), KeysSize, scratch_ptr, KeysScratchSpace);
74+
keys_scratch_begin = ::new (scratch_ptr) KeyTy[number_of_elements];
75+
scratch_ptr = scratch.data() + KeysScratchSpace;
76+
scratch_ptr = std::align(alignof(ValueTy), ValuesSize, scratch_ptr,
77+
ValuesScratchSpace);
78+
values_scratch_begin = ::new (scratch_ptr) ValueTy[number_of_elements];
79+
}
80+
// Broadcast leader's pointer (the beginning of the scratch) to all work
81+
// items in the group.
82+
keys_scratch_begin = sycl::group_broadcast(g, keys_scratch_begin);
83+
values_scratch_begin = sycl::group_broadcast(g, values_scratch_begin);
84+
return std::make_pair(keys_scratch_begin, values_scratch_begin);
85+
}
5586
#endif
5687

88+
// Swap tuples of references.
89+
template <template <typename...> class Tuple, typename... T>
90+
void swap(Tuple<T &...> &&first, Tuple<T &...> &&second) {
91+
auto lhs = first;
92+
auto rhs = second;
93+
// Do std::swap for each element of the tuple.
94+
std::swap(lhs, rhs);
95+
}
96+
5797
// ---- merge sort implementation
5898

5999
// following two functions could be useless if std::[lower|upper]_bound worked
@@ -83,15 +123,6 @@ size_t upper_bound(Acc acc, const size_t first, const size_t last,
83123
[comp](auto x, auto y) { return !comp(y, x); });
84124
}
85125

86-
// swap for all data types including tuple-like types
87-
template <typename T> void swap_tuples(T &a, T &b) { std::swap(a, b); }
88-
89-
template <template <typename...> class TupleLike, typename T1, typename T2>
90-
void swap_tuples(TupleLike<T1, T2> &&a, TupleLike<T1, T2> &&b) {
91-
std::swap(std::get<0>(a), std::get<0>(b));
92-
std::swap(std::get<1>(a), std::get<1>(b));
93-
}
94-
95126
template <typename Iter> struct GetValueType {
96127
using type = typename std::iterator_traits<Iter>::value_type;
97128
};
@@ -207,18 +238,18 @@ void bubble_sort(Iter first, const size_t begin, const size_t end,
207238
if (begin < end) {
208239
for (size_t i = begin; i < end; ++i) {
209240
// Handle intermediate items
210-
for (size_t idx = i + 1; idx < end; ++idx) {
211-
if (comp(first[idx], first[i])) {
212-
detail::swap_tuples(first[i], first[idx]);
241+
for (size_t idx = begin; idx < begin + (end - 1 - i); ++idx) {
242+
if (comp(first[idx + 1], first[idx])) {
243+
detail::swap(first[idx], first[idx + 1]);
213244
}
214245
}
215246
}
216247
}
217248
}
218249

219-
template <typename Group, typename Iter, typename T, typename Compare>
250+
template <typename Group, typename Iter, typename ScratchIter, typename Compare>
220251
void merge_sort(Group group, Iter first, const size_t n, Compare comp,
221-
T *scratch) {
252+
ScratchIter scratch) {
222253
const size_t idx = group.get_local_linear_id();
223254
const size_t local = group.get_local_range().size();
224255
const size_t chunk = (n - 1) / local + 1;
@@ -608,15 +639,41 @@ void performRadixIterDynamicSize(
608639

609640
// The iteration of radix sort for known number of elements per work item
610641
template <size_t items_per_work_item, uint32_t radix_bits, bool is_comp_asc,
611-
bool is_key_value_sort, bool is_blocked, typename KeysT,
612-
typename ValsT, typename GroupT>
642+
bool is_key_value_sort, bool is_input_blocked, bool is_output_blocked,
643+
typename KeysT, typename ValsT, typename GroupT>
613644
void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter,
645+
const uint32_t first_iter,
614646
const uint32_t last_iter, KeysT *keys,
615647
ValsT *vals, const ScratchMemory &memory) {
616648
const uint32_t radix_states = getStatesInBits(radix_bits);
617649
const size_t wgsize = group.get_local_linear_range();
618650
const size_t idx = group.get_local_linear_id();
619651

652+
const ScratchMemory &keys_temp = memory;
653+
const ScratchMemory vals_temp =
654+
memory + wgsize * items_per_work_item * sizeof(KeysT);
655+
656+
// If input is striped, reroder items using scratch memory before sorting,
657+
// this only needs to be done at the first iteration.
658+
if constexpr (!is_input_blocked) {
659+
if (radix_iter == first_iter) {
660+
for (uint32_t i = 0; i < items_per_work_item; ++i) {
661+
size_t shift = i * wgsize + idx;
662+
keys_temp.get<KeysT>(shift) = keys[i];
663+
if constexpr (is_key_value_sort)
664+
vals_temp.get<ValsT>(shift) = vals[i];
665+
}
666+
sycl::group_barrier(group);
667+
for (uint32_t i = 0; i < items_per_work_item; ++i) {
668+
size_t shift = idx * items_per_work_item + i;
669+
keys[i] = keys_temp.get<KeysT>(shift);
670+
if constexpr (is_key_value_sort)
671+
vals[i] = vals_temp.get<ValsT>(shift);
672+
}
673+
sycl::group_barrier(group);
674+
}
675+
}
676+
620677
// 1.1. count per witem: create a private array for storing count values
621678
uint32_t count_arr[items_per_work_item] = {0};
622679
uint32_t ranks[items_per_work_item] = {0};
@@ -666,9 +723,6 @@ void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter,
666723
sycl::group_barrier(group);
667724

668725
// 3. Reorder
669-
const ScratchMemory &keys_temp = memory;
670-
const ScratchMemory vals_temp =
671-
memory + wgsize * items_per_work_item * sizeof(KeysT);
672726
for (uint32_t i = 0; i < items_per_work_item; ++i) {
673727
keys_temp.get<KeysT>(ranks[i]) = keys[i];
674728
if constexpr (is_key_value_sort)
@@ -680,7 +734,7 @@ void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter,
680734
// 4. Copy back to input
681735
for (uint32_t i = 0; i < items_per_work_item; ++i) {
682736
size_t shift = idx * items_per_work_item + i;
683-
if constexpr (!is_blocked) {
737+
if constexpr (!is_output_blocked) {
684738
if (radix_iter == last_iter - 1)
685739
shift = i * wgsize + idx;
686740
}
@@ -728,7 +782,8 @@ void privateDynamicSort(GroupT group, KeysT *keys, ValsT *values,
728782
}
729783
}
730784

731-
template <bool is_key_value_sort, bool is_blocked, bool is_comp_asc,
785+
template <bool is_key_value_sort, bool is_intput_blocked,
786+
bool is_output_blocked, bool is_comp_asc,
732787
size_t items_per_work_item = 1, uint32_t radix_bits = 4,
733788
typename GroupT, typename T, typename U>
734789
void privateStaticSort(GroupT group, T *keys, U *values, std::byte *scratch,
@@ -739,8 +794,9 @@ void privateStaticSort(GroupT group, T *keys, U *values, std::byte *scratch,
739794

740795
for (uint32_t radix_iter = first_iter; radix_iter < last_iter; ++radix_iter) {
741796
performRadixIterStaticSize<items_per_work_item, radix_bits, is_comp_asc,
742-
is_key_value_sort, is_blocked>(
743-
group, radix_iter, last_iter, keys, values, scratch);
797+
is_key_value_sort, is_intput_blocked,
798+
is_output_blocked>(
799+
group, radix_iter, first_iter, last_iter, keys, values, scratch);
744800
sycl::group_barrier(group);
745801
}
746802
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
//==------------ key_value_iterator.hpp ------------------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// This file includes key/value iterator implementation used for group_sort
9+
// algorithms.
10+
//
11+
12+
#pragma once
13+
#include <iterator>
14+
#include <tuple>
15+
#include <utility>
16+
17+
namespace sycl {
18+
inline namespace _V1 {
19+
namespace detail {
20+
21+
template <typename T1, typename T2> class key_value_iterator {
22+
public:
23+
key_value_iterator(T1 *Keys, T2 *Values) : KeyValue{Keys, Values} {}
24+
25+
using difference_type = std::ptrdiff_t;
26+
using value_type = std::tuple<T1, T2>;
27+
using reference = std::tuple<T1 &, T2 &>;
28+
using pointer = std::tuple<T1 *, T2 *>;
29+
using iterator_category = std::random_access_iterator_tag;
30+
31+
reference operator*() const {
32+
return std::tie(*(std::get<0>(KeyValue)), *(std::get<1>(KeyValue)));
33+
}
34+
35+
reference operator[](difference_type i) const { return *(*this + i); }
36+
37+
difference_type operator-(const key_value_iterator &it) const {
38+
return std::get<0>(KeyValue) - std::get<0>(it.KeyValue);
39+
}
40+
41+
key_value_iterator &operator+=(difference_type i) {
42+
KeyValue =
43+
std::make_tuple(std::get<0>(KeyValue) + i, std::get<1>(KeyValue) + i);
44+
return *this;
45+
}
46+
key_value_iterator &operator-=(difference_type i) { return *this += -i; }
47+
key_value_iterator &operator++() { return *this += 1; }
48+
key_value_iterator &operator--() { return *this -= 1; }
49+
std::tuple<T1 *, T2 *> base() const { return KeyValue; }
50+
key_value_iterator operator++(int) {
51+
key_value_iterator it(*this);
52+
++(*this);
53+
return it;
54+
}
55+
key_value_iterator operator--(int) {
56+
key_value_iterator it(*this);
57+
--(*this);
58+
return it;
59+
}
60+
61+
key_value_iterator operator-(difference_type i) const {
62+
key_value_iterator it(*this);
63+
return it -= i;
64+
}
65+
key_value_iterator operator+(difference_type i) const {
66+
key_value_iterator it(*this);
67+
return it += i;
68+
}
69+
friend key_value_iterator operator+(difference_type i,
70+
const key_value_iterator &it) {
71+
return it + i;
72+
}
73+
74+
bool operator==(const key_value_iterator &it) const {
75+
return *this - it == 0;
76+
}
77+
78+
bool operator!=(const key_value_iterator &it) const { return !(*this == it); }
79+
bool operator<(const key_value_iterator &it) const { return *this - it < 0; }
80+
bool operator>(const key_value_iterator &it) const { return it < *this; }
81+
bool operator<=(const key_value_iterator &it) const { return !(*this > it); }
82+
bool operator>=(const key_value_iterator &it) const { return !(*this < it); }
83+
84+
private:
85+
std::tuple<T1 *, T2 *> KeyValue;
86+
};
87+
88+
template <typename T> void swap(T &first, T &second) {
89+
std::swap(first, second);
90+
}
91+
92+
} // namespace detail
93+
} // namespace _V1
94+
} // namespace sycl

0 commit comments

Comments
 (0)