Skip to content

Commit 2525570

Browse files
authored
[SYCL][COMPAT] Adds sycl::length wrapper to SYCLcompat (#12968)
Adds a templated wrapper to sycl::length in syclcompat, that wraps over `sycl::vec<ValueT,N>` up to N == 4, and calculates the length with N > 4. --------- Signed-off-by: Alberto Cabrera <[email protected]>
1 parent 3e14dc0 commit 2525570

File tree

3 files changed

+66
-17
lines changed

3 files changed

+66
-17
lines changed

sycl/doc/syclcompat/README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1199,7 +1199,8 @@ static kernel_function_info get_kernel_function_info(const void *function);
11991199

12001200
`syclcompat::fast_length` provides a wrapper to SYCL's
12011201
`fast_length(sycl::vec<float,N>)` that accepts arguments for a C++ array and a
1202-
length.
1202+
length. `syclcompat::length` provides a templated version that wraps over
1203+
`sycl::length`.
12031204

12041205
`vectorized_max` and `vectorized_min` are binary operations returning the
12051206
max/min of two arguments, where each argument is treated as a `sycl::vec` type.
@@ -1213,6 +1214,9 @@ which accept `sycl::vec<T,2>` arguments representing complex values.
12131214
```cpp
12141215
inline float fast_length(const float *a, int len);
12151216

1217+
template <typename ValueT>
1218+
inline ValueT length(const ValueT *a, const int len);
1219+
12161220
template <typename S, typename T> inline T vectorized_max(T a, T b);
12171221

12181222
template <typename S, typename T> inline T vectorized_min(T a, T b);

sycl/include/syclcompat/math.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,29 @@ inline float fast_length(const float *a, int len) {
7373
}
7474
}
7575

76+
/// Calculate the square root of the input array.
77+
/// \param [in] a The array pointer
78+
/// \param [in] len Length of the array
79+
/// \returns The square root
80+
template <typename ValueT>
81+
inline ValueT length(const ValueT *a, const int len) {
82+
switch (len) {
83+
case 1:
84+
return a[0];
85+
case 2:
86+
return sycl::length(sycl::vec<ValueT, 2>(a[0], a[1]));
87+
case 3:
88+
return sycl::length(sycl::vec<ValueT, 3>(a[0], a[1], a[2]));
89+
case 4:
90+
return sycl::length(sycl::vec<ValueT, 4>(a[0], a[1], a[2], a[3]));
91+
default:
92+
ValueT ret = 0;
93+
for (int i = 0; i < len; ++i)
94+
ret += a[i] * a[i];
95+
return sycl::sqrt(ret);
96+
}
97+
}
98+
7699
/// Compute vectorized max for two values, with each value treated as a vector
77100
/// type \p S
78101
/// \param [in] S The type of the vector

sycl/test-e2e/syclcompat/math/math_fast_length_test.cpp renamed to sycl/test-e2e/syclcompat/math/math_length_test.cpp

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
*
1515
* SYCLcompat API
1616
*
17-
* math_fast_length_test.cpp
17+
* math_length_test.cpp
1818
*
1919
* Description:
20-
* Fast length tests
20+
* vector length tests
2121
**************************************************************************/
2222

2323
// The original source was under the license below:
@@ -41,22 +41,26 @@
4141

4242
#define MAX_LEN 5
4343

44-
void compute_length(float *d_A, size_t n, float *ans) {
44+
void compute_fast_length(float *d_A, size_t n, float *ans) {
4545
*ans = syclcompat::fast_length(d_A, n);
4646
}
4747

48-
class FastLengthLauncher {
48+
void compute_length(float *d_A, size_t n, float *ans) {
49+
*ans = syclcompat::length(d_A, n);
50+
}
51+
52+
class LengthLauncher {
4953
protected:
5054
float *data_;
5155
float *result_;
5256
float host_result_{0.0};
5357

5458
public:
55-
FastLengthLauncher() {
59+
LengthLauncher() {
5660
data_ = (float *)syclcompat::malloc(MAX_LEN * sizeof(float));
5761
result_ = (float *)syclcompat::malloc(sizeof(float));
5862
};
59-
~FastLengthLauncher() {
63+
~LengthLauncher() {
6064
syclcompat::free(data_);
6165
syclcompat::free(result_);
6266
}
@@ -68,13 +72,13 @@ class FastLengthLauncher {
6872
assert(diff <= 1.e-5);
6973
}
7074

71-
void launch(std::vector<float> vec) {
75+
template <auto F> void launch(std::vector<float> vec) {
7276
size_t n = vec.size();
7377
syclcompat::memcpy(data_, vec.data(), sizeof(float) * n);
7478
auto data = data_;
7579
auto result = result_;
7680
syclcompat::get_default_queue().single_task(
77-
[data, result, n]() { compute_length(data, n, result); });
81+
[data, result, n]() { F(data, n, result); });
7882
syclcompat::memcpy(&host_result_, result_, sizeof(float));
7983
check_result(vec);
8084
}
@@ -83,18 +87,36 @@ class FastLengthLauncher {
8387
void test_fast_length() {
8488
std::cout << __PRETTY_FUNCTION__ << std::endl;
8589

86-
auto launcher = FastLengthLauncher();
87-
launcher.launch(std::vector<float>{0.8970062715});
88-
launcher.launch(std::vector<float>{0.8335529744, 0.7346600673});
89-
launcher.launch(std::vector<float>{0.1658983906, 0.590226484, 0.4891553616});
90-
launcher.launch(std::vector<float>{0.6041178723, 0.7760620605, 0.2944284976,
91-
0.6851913766});
92-
launcher.launch(std::vector<float>{0.6041178723, 0.7760620605, 0.2944284976,
93-
0.6851913766, 0.6851913766});
90+
auto launcher = LengthLauncher();
91+
launcher.launch<compute_fast_length>(std::vector<float>{0.8970062715});
92+
launcher.launch<compute_fast_length>(
93+
std::vector<float>{0.8335529744, 0.7346600673});
94+
launcher.launch<compute_fast_length>(
95+
std::vector<float>{0.1658983906, 0.590226484, 0.4891553616});
96+
launcher.launch<compute_fast_length>(std::vector<float>{
97+
0.6041178723, 0.7760620605, 0.2944284976, 0.6851913766});
98+
launcher.launch<compute_fast_length>(std::vector<float>{
99+
0.6041178723, 0.7760620605, 0.2944284976, 0.6851913766, 0.6851913766});
100+
}
101+
102+
void test_length() {
103+
std::cout << __PRETTY_FUNCTION__ << std::endl;
104+
105+
auto launcher = LengthLauncher();
106+
launcher.launch<compute_length>(std::vector<float>{0.8970062715});
107+
launcher.launch<compute_length>(
108+
std::vector<float>{0.8335529744, 0.7346600673});
109+
launcher.launch<compute_length>(
110+
std::vector<float>{0.1658983906, 0.590226484, 0.4891553616});
111+
launcher.launch<compute_length>(std::vector<float>{
112+
0.6041178723, 0.7760620605, 0.2944284976, 0.6851913766});
113+
launcher.launch<compute_length>(std::vector<float>{
114+
0.6041178723, 0.7760620605, 0.2944284976, 0.6851913766, 0.6851913766});
94115
}
95116

96117
int main() {
97118
test_fast_length();
119+
test_length();
98120

99121
return 0;
100122
}

0 commit comments

Comments
 (0)