-
Notifications
You must be signed in to change notification settings - Fork 769
[SYCL][Group algorithms] Add group sorting algorithms implementation #4439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
a8a5b33
c27f0ac
a1d9c09
0159ff8
d8aab75
f40e9cc
6b55761
cb54021
68f2d8e
339d043
eddd7a3
b4aa9c3
768d3a3
972f6ef
8a07d59
b6ae2e5
e162168
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
//==------------ group_sort_impl.hpp ---------------------------------------==// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// This file includes some functions for group sorting algorithm implementations | ||
// | ||
|
||
#pragma once | ||
|
||
#include <CL/sycl/detail/helpers.hpp> | ||
|
||
#ifdef __SYCL_DEVICE_ONLY__ | ||
|
||
__SYCL_INLINE_NAMESPACE(cl) { | ||
namespace sycl { | ||
namespace detail { | ||
|
||
// ---- merge sort implementation | ||
|
||
// following two functions could be useless if std::[lower|upper]_bound worked | ||
// well | ||
template <typename Acc, typename Value, typename Compare> | ||
std::size_t lower_bound(Acc acc, std::size_t first, std::size_t last, | ||
const Value &value, Compare comp) { | ||
std::size_t n = last - first; | ||
std::size_t cur = n; | ||
std::size_t it; | ||
while (n > 0) { | ||
it = first; | ||
cur = n / 2; | ||
it += cur; | ||
if (comp(acc[it], value)) { | ||
n -= cur + 1, first = ++it; | ||
} else | ||
n = cur; | ||
} | ||
return first; | ||
} | ||
|
||
template <typename Acc, typename Value, typename Compare> | ||
std::size_t upper_bound(Acc acc, const std::size_t first, | ||
const std::size_t last, const Value &value, | ||
Compare comp) { | ||
return detail::lower_bound(acc, first, last, value, | ||
[comp](auto x, auto y) { return !comp(y, x); }); | ||
} | ||
|
||
// swap for all data types including tuple-like types | ||
template <typename T> void swap_tuples(T &a, T &b) { std::swap(a, b); } | ||
|
||
template <template <typename...> class TupleLike, typename T1, typename T2> | ||
void swap_tuples(TupleLike<T1, T2> &&a, TupleLike<T1, T2> &&b) { | ||
std::swap(std::get<0>(a), std::get<0>(b)); | ||
std::swap(std::get<1>(a), std::get<1>(b)); | ||
} | ||
|
||
template <typename Iter> struct GetValueType { | ||
using type = typename std::iterator_traits<Iter>::value_type; | ||
}; | ||
|
||
template <typename ElementType, access::address_space Space> | ||
struct GetValueType<sycl::multi_ptr<ElementType, Space>> { | ||
using type = ElementType; | ||
}; | ||
|
||
template <typename InAcc, typename OutAcc, typename Compare> | ||
void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1, | ||
const std::size_t start_1, const std::size_t end_1, | ||
const std::size_t end_2, const std::size_t start_out, Compare comp, | ||
const std::size_t chunk) { | ||
const std::size_t start_2 = end_1; | ||
// Borders of the sequences to merge within this call | ||
const std::size_t local_start_1 = | ||
sycl::min(static_cast<std::size_t>(offset + start_1), end_1); | ||
const std::size_t local_end_1 = | ||
sycl::min(static_cast<std::size_t>(local_start_1 + chunk), end_1); | ||
const std::size_t local_start_2 = | ||
sycl::min(static_cast<std::size_t>(offset + start_2), end_2); | ||
const std::size_t local_end_2 = | ||
sycl::min(static_cast<std::size_t>(local_start_2 + chunk), end_2); | ||
|
||
const std::size_t local_size_1 = local_end_1 - local_start_1; | ||
const std::size_t local_size_2 = local_end_2 - local_start_2; | ||
|
||
// TODO: process cases where all elements of 1st sequence > 2nd, 2nd > 1st to | ||
// improve performance | ||
|
||
// Process 1st sequence | ||
if (local_start_1 < local_end_1) { | ||
// Reduce the range for searching within the 2nd sequence and handle bound | ||
// items find left border in 2nd sequence | ||
const auto local_l_item_1 = in_acc1[local_start_1]; | ||
std::size_t l_search_bound_2 = | ||
detail::lower_bound(in_acc1, start_2, end_2, local_l_item_1, comp); | ||
const std::size_t l_shift_1 = local_start_1 - start_1; | ||
const std::size_t l_shift_2 = l_search_bound_2 - start_2; | ||
|
||
out_acc1[start_out + l_shift_1 + l_shift_2] = local_l_item_1; | ||
|
||
std::size_t r_search_bound_2{}; | ||
// find right border in 2nd sequence | ||
if (local_size_1 > 1) { | ||
const auto local_r_item_1 = in_acc1[local_end_1 - 1]; | ||
r_search_bound_2 = detail::lower_bound(in_acc1, l_search_bound_2, end_2, | ||
local_r_item_1, comp); | ||
const auto r_shift_1 = local_end_1 - 1 - start_1; | ||
const auto r_shift_2 = r_search_bound_2 - start_2; | ||
|
||
out_acc1[start_out + r_shift_1 + r_shift_2] = local_r_item_1; | ||
} | ||
|
||
// Handle intermediate items | ||
for (std::size_t idx = local_start_1 + 1; idx < local_end_1 - 1; ++idx) { | ||
const auto intermediate_item_1 = in_acc1[idx]; | ||
// we shouldn't seek in whole 2nd sequence. Just for the part where the | ||
// 1st sequence should be | ||
l_search_bound_2 = | ||
detail::lower_bound(in_acc1, l_search_bound_2, r_search_bound_2, | ||
intermediate_item_1, comp); | ||
const std::size_t shift_1 = idx - start_1; | ||
const std::size_t shift_2 = l_search_bound_2 - start_2; | ||
|
||
out_acc1[start_out + shift_1 + shift_2] = intermediate_item_1; | ||
} | ||
} | ||
// Process 2nd sequence | ||
if (local_start_2 < local_end_2) { | ||
// Reduce the range for searching within the 1st sequence and handle bound | ||
// items find left border in 1st sequence | ||
const auto local_l_item_2 = in_acc1[local_start_2]; | ||
std::size_t l_search_bound_1 = | ||
detail::upper_bound(in_acc1, start_1, end_1, local_l_item_2, comp); | ||
const std::size_t l_shift_1 = l_search_bound_1 - start_1; | ||
const std::size_t l_shift_2 = local_start_2 - start_2; | ||
|
||
out_acc1[start_out + l_shift_1 + l_shift_2] = local_l_item_2; | ||
|
||
std::size_t r_search_bound_1{}; | ||
// find right border in 1st sequence | ||
if (local_size_2 > 1) { | ||
const auto local_r_item_2 = in_acc1[local_end_2 - 1]; | ||
r_search_bound_1 = detail::upper_bound(in_acc1, l_search_bound_1, end_1, | ||
local_r_item_2, comp); | ||
const std::size_t r_shift_1 = r_search_bound_1 - start_1; | ||
const std::size_t r_shift_2 = local_end_2 - 1 - start_2; | ||
|
||
out_acc1[start_out + r_shift_1 + r_shift_2] = local_r_item_2; | ||
} | ||
|
||
// Handle intermediate items | ||
for (auto idx = local_start_2 + 1; idx < local_end_2 - 1; ++idx) { | ||
const auto intermediate_item_2 = in_acc1[idx]; | ||
// we shouldn't seek in whole 1st sequence. Just for the part where the | ||
// 2nd sequence should be | ||
l_search_bound_1 = | ||
detail::upper_bound(in_acc1, l_search_bound_1, r_search_bound_1, | ||
intermediate_item_2, comp); | ||
const std::size_t shift_1 = l_search_bound_1 - start_1; | ||
const std::size_t shift_2 = idx - start_2; | ||
|
||
out_acc1[start_out + shift_1 + shift_2] = intermediate_item_2; | ||
} | ||
} | ||
} | ||
|
||
template <typename Iter, typename Compare> | ||
void bubble_sort(Iter first, const std::size_t begin, const std::size_t end, | ||
Compare comp) { | ||
if (begin < end) { | ||
for (std::size_t i = begin; i < end; ++i) { | ||
// Handle intermediate items | ||
for (std::size_t idx = i + 1; idx < end; ++idx) { | ||
if (comp(first[idx], first[i])) { | ||
detail::swap_tuples(first[i], first[idx]); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
template <typename Group, typename Iter, typename Compare> | ||
void merge_sort(Group group, Iter first, const std::size_t n, Compare comp, | ||
std::uint8_t *scratch) { | ||
using T = typename GetValueType<Iter>::type; | ||
auto id = sycl::detail::Builder::getNDItem<Group::dimensions>(); | ||
const std::size_t idx = id.get_local_id(); | ||
const std::size_t local = group.get_local_range().size(); | ||
const std::size_t chunk = (n - 1) / local + 1; | ||
|
||
// we need to sort within work item first | ||
bubble_sort(first, idx * chunk, sycl::min((idx + 1) * chunk, n), comp); | ||
id.barrier(); | ||
|
||
T *temp = reinterpret_cast<T *>(scratch); | ||
andreyfe1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
bool data_in_temp = false; | ||
std::size_t sorted_size = 1; | ||
while (sorted_size * chunk < n) { | ||
const std::size_t start_1 = | ||
sycl::min(2 * sorted_size * chunk * (idx / sorted_size), n); | ||
const std::size_t end_1 = sycl::min(start_1 + sorted_size * chunk, n); | ||
const std::size_t end_2 = sycl::min(end_1 + sorted_size * chunk, n); | ||
const std::size_t offset = chunk * (idx % sorted_size); | ||
|
||
if (!data_in_temp) { | ||
merge(offset, first, temp, start_1, end_1, end_2, start_1, comp, chunk); | ||
} else { | ||
merge(offset, temp, first, start_1, end_1, end_2, start_1, comp, chunk); | ||
} | ||
id.barrier(); | ||
|
||
data_in_temp = !data_in_temp; | ||
sorted_size *= 2; | ||
} | ||
|
||
// copy back if data is in a temporary storage | ||
if (data_in_temp) { | ||
for (std::size_t i = 0; i < chunk; ++i) { | ||
if (idx * chunk + i < n) { | ||
first[idx * chunk + i] = temp[idx * chunk + i]; | ||
} | ||
} | ||
id.barrier(); | ||
} | ||
} | ||
|
||
} // namespace detail | ||
} // namespace sycl | ||
} // __SYCL_INLINE_NAMESPACE(cl) | ||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
//==------- group_helpers_sorters.hpp - SYCL sorters and group helpers -----==// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#pragma once | ||
|
||
#include <CL/sycl/detail/group_sort_impl.hpp> | ||
|
||
__SYCL_INLINE_NAMESPACE(cl) { | ||
namespace sycl { | ||
namespace ext { | ||
namespace oneapi { | ||
namespace experimental { | ||
|
||
// ---- group helpers | ||
template <typename Group, std::size_t Extent> class group_with_scratchpad { | ||
Group g; | ||
sycl::span<std::uint8_t, Extent> scratch; | ||
|
||
public: | ||
group_with_scratchpad(Group g_, sycl::span<std::uint8_t, Extent> scratch_) | ||
: g(g_), scratch(scratch_) {} | ||
Group get_group() const { return g; } | ||
sycl::span<std::uint8_t, Extent> get_memory() const { return scratch; } | ||
}; | ||
|
||
// ---- sorters | ||
template <typename Compare = std::less<>> class default_sorter { | ||
Compare comp; | ||
std::uint8_t *scratch; | ||
std::size_t scratch_size; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is very similar to my previous comment. A pair of a pointer to an array and its size is either range or span. Since we have span backported to SYCL 2020, can we use it here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We consider this option previously during the API discussion. The thing is that if we replaced this with sycl::span, we would need an additional template parameter for sorter (extent). However, extent is not relevant to the |
||
|
||
public: | ||
template <std::size_t Extent> | ||
default_sorter(sycl::span<std::uint8_t, Extent> scratch_, | ||
Compare comp_ = Compare()) | ||
: comp(comp_), scratch(scratch_.data()), scratch_size(scratch_.size()) {} | ||
|
||
template <typename Group, typename Ptr> | ||
void operator()(Group g, Ptr first, Ptr last) { | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
using T = typename sycl::detail::GetValueType<Ptr>::type; | ||
if (scratch_size >= memory_required<T>(Group::fence_scope, last - first)) | ||
sycl::detail::merge_sort(g, first, last - first, comp, scratch); | ||
// TODO: it's better to add else branch | ||
#else | ||
(void)g; | ||
(void)first; | ||
(void)last; | ||
throw runtime_error( | ||
"default_sorter constructor is not supported on host device.", | ||
PI_INVALID_DEVICE); | ||
#endif | ||
andreyfe1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
template <typename Group, typename T> T operator()(Group g, T val) { | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
auto range_size = g.get_local_range().size(); | ||
if (scratch_size >= memory_required<T>(Group::fence_scope, range_size)) { | ||
auto id = sycl::detail::Builder::getNDItem<Group::dimensions>(); | ||
uint32_t local_id = id.get_local_id(); | ||
andreyfe1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
T *temp = reinterpret_cast<T *>(scratch); | ||
temp[local_id] = val; | ||
sycl::detail::merge_sort(g, temp, range_size, comp, | ||
scratch + range_size * sizeof(T)); | ||
val = temp[local_id]; | ||
} | ||
// TODO: it's better to add else branch | ||
#else | ||
(void)g; | ||
(void)val; | ||
throw runtime_error( | ||
"default_sorter operator() is not supported on host device.", | ||
PI_INVALID_DEVICE); | ||
#endif | ||
return val; | ||
} | ||
|
||
template <typename T> | ||
static constexpr std::size_t memory_required(sycl::memory_scope scope, | ||
std::size_t range_size) { | ||
return range_size * sizeof(T); | ||
} | ||
|
||
template <typename T, int dim = 1> | ||
static constexpr std::size_t memory_required(sycl::memory_scope scope, | ||
sycl::range<dim> r) { | ||
return 2 * r.size() * sizeof(T); | ||
andreyfe1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
}; | ||
|
||
} // namespace experimental | ||
} // namespace oneapi | ||
} // namespace ext | ||
} // namespace sycl | ||
} // __SYCL_INLINE_NAMESPACE(cl) |
Uh oh!
There was an error while loading. Please reload this page.