Skip to content

Commit 5ec9cce

Browse files
[SYCL][E2E] Refactor Basic/built-ins/marray_common.cpp (#12875)
1 parent e151c5c commit 5ec9cce

File tree

2 files changed

+80
-110
lines changed

2 files changed

+80
-110
lines changed

sycl/test-e2e/Basic/built-ins/helpers.hpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
#include <sycl/sycl.hpp>
22

3-
template <typename T> bool equal(T x, T y) {
3+
template <typename T> bool equal(T x, T y, double delta) {
44
// Maybe should be C++20's std::equality_comparable.
55
if constexpr (std::is_scalar_v<T>) {
6-
return x == y;
6+
return std::abs(x - y) <= delta;
77
} else {
88
for (size_t i = 0; i < x.size(); ++i)
9-
if (x[i] != y[i])
9+
if (std::abs(x[i] - y[i]) > delta)
1010
return false;
1111

1212
return true;
@@ -15,10 +15,10 @@ template <typename T> bool equal(T x, T y) {
1515

1616
template <typename FuncTy, typename ExpectedTy,
1717
typename... ArgTys>
18-
void test(bool CheckDevice, FuncTy F, ExpectedTy Expected, ArgTys... Args) {
18+
void test(bool CheckDevice, double delta, FuncTy F, ExpectedTy Expected, ArgTys... Args) {
1919
auto R = F(Args...);
2020
static_assert(std::is_same_v<decltype(Expected), decltype(R)>);
21-
assert(equal(R, Expected));
21+
assert(equal(R, Expected, delta));
2222

2323
if (!CheckDevice)
2424
return;
@@ -29,15 +29,24 @@ void test(bool CheckDevice, FuncTy F, ExpectedTy Expected, ArgTys... Args) {
2929
cgh.single_task([=]() {
3030
auto R = F(Args...);
3131
static_assert(std::is_same_v<decltype(Expected), decltype(R)>);
32-
Success[0] = equal(R, Expected);
32+
Success[0] = equal(R, Expected, delta);
3333
});
3434
});
3535
assert(sycl::host_accessor{SuccessBuf}[0]);
3636
}
3737

3838
template <typename FuncTy, typename ExpectedTy, typename... ArgTys>
3939
void test(FuncTy F, ExpectedTy Expected, ArgTys... Args) {
40-
test(true /*CheckDevice*/, F, Expected, Args...);
40+
test(true /*CheckDevice*/, 0.0 /*delta*/, F, Expected, Args...);
41+
}
42+
template <typename FuncTy, typename ExpectedTy,
43+
typename... ArgTys>
44+
void test(bool CheckDevice, FuncTy F, ExpectedTy Expected, ArgTys... Args) {
45+
test(CheckDevice, 0.0 /*delta*/, F, Expected, Args...);
46+
}
47+
template <typename FuncTy, typename ExpectedTy, typename... ArgTys>
48+
void test(double delta, FuncTy F, ExpectedTy Expected, ArgTys... Args) {
49+
test(true /*CheckDevice*/, delta, F, Expected, Args...);
4150
}
4251

4352
// MSVC's STL spoils global namespace with math functions, so use explicit

sycl/test-e2e/Basic/built-ins/marray_common.cpp

Lines changed: 64 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -8,115 +8,76 @@
88
#endif
99
#include <cmath>
1010

11-
#include <sycl/sycl.hpp>
11+
#include "helpers.hpp"
1212

13-
#define TEST(FUNC, MARRAY_ELEM_TYPE, DIM, EXPECTED, DELTA, ...) \
14-
{ \
15-
{ \
16-
MARRAY_ELEM_TYPE result[DIM]; \
17-
{ \
18-
sycl::buffer<MARRAY_ELEM_TYPE> b(result, sycl::range{DIM}); \
19-
deviceQueue.submit([&](sycl::handler &cgh) { \
20-
sycl::accessor res_access{b, cgh}; \
21-
cgh.single_task([=]() { \
22-
sycl::marray<MARRAY_ELEM_TYPE, DIM> res = FUNC(__VA_ARGS__); \
23-
for (int i = 0; i < DIM; i++) \
24-
res_access[i] = res[i]; \
25-
}); \
26-
}); \
27-
} \
28-
for (int i = 0; i < DIM; i++) \
29-
assert(abs(result[i] - EXPECTED[i]) <= DELTA); \
30-
} \
31-
}
13+
int main() {
14+
using namespace sycl;
3215

33-
#define EXPECTED(TYPE, ...) ((TYPE[]){__VA_ARGS__})
16+
queue deviceQueue;
17+
device dev = deviceQueue.get_device();
3418

35-
int main() {
36-
sycl::queue deviceQueue;
37-
sycl::device dev = deviceQueue.get_device();
19+
marray<float, 2> ma1{1.0f, 2.0f};
20+
marray<float, 2> ma2{1.0f, 2.0f};
21+
marray<float, 2> ma3{3.0f, 2.0f};
22+
marray<double, 2> ma4{1.0, 2.0};
23+
marray<float, 3> ma5{M_PI, M_PI, M_PI};
24+
marray<double, 3> ma6{M_PI, M_PI, M_PI};
25+
marray<half, 3> ma7{M_PI, M_PI, M_PI};
26+
marray<float, 2> ma8{0.3f, 0.6f};
27+
marray<double, 2> ma9{5.0, 8.0};
28+
marray<float, 3> ma10{180, 180, 180};
29+
marray<double, 3> ma11{180, 180, 180};
30+
marray<half, 3> ma12{180, 180, 180};
31+
marray<half, 3> ma13{181, 179, 181};
32+
marray<float, 2> ma14{+0.0f, -0.6f};
33+
marray<double, 2> ma15{-0.0, 0.6f};
3834

39-
sycl::marray<float, 2> ma1{1.0f, 2.0f};
40-
sycl::marray<float, 2> ma2{1.0f, 2.0f};
41-
sycl::marray<float, 2> ma3{3.0f, 2.0f};
42-
sycl::marray<double, 2> ma4{1.0, 2.0};
43-
sycl::marray<float, 3> ma5{M_PI, M_PI, M_PI};
44-
sycl::marray<double, 3> ma6{M_PI, M_PI, M_PI};
45-
sycl::marray<sycl::half, 3> ma7{M_PI, M_PI, M_PI};
46-
sycl::marray<float, 2> ma8{0.3f, 0.6f};
47-
sycl::marray<double, 2> ma9{5.0, 8.0};
48-
sycl::marray<float, 3> ma10{180, 180, 180};
49-
sycl::marray<double, 3> ma11{180, 180, 180};
50-
sycl::marray<sycl::half, 3> ma12{180, 180, 180};
51-
sycl::marray<sycl::half, 3> ma13{181, 179, 181};
52-
sycl::marray<float, 2> ma14{+0.0f, -0.6f};
53-
sycl::marray<double, 2> ma15{-0.0, 0.6f};
35+
bool has_fp16 = queue{}.get_device().has(sycl::aspect::fp16);
36+
bool has_fp64 = queue{}.get_device().has(sycl::aspect::fp64);
5437

55-
// sycl::clamp
56-
TEST(sycl::clamp, float, 2, EXPECTED(float, 1.0f, 2.0f), 0, ma1, ma2, ma3);
57-
TEST(sycl::clamp, float, 2, EXPECTED(float, 1.0f, 2.0f), 0, ma1, 1.0f, 3.0f);
58-
if (dev.has(sycl::aspect::fp64))
59-
TEST(sycl::clamp, double, 2, EXPECTED(double, 1.0, 2.0), 0, ma4, 1.0, 3.0);
60-
// sycl::degrees
61-
TEST(sycl::degrees, float, 3, EXPECTED(float, 180, 180, 180), 0, ma5);
62-
if (dev.has(sycl::aspect::fp64))
63-
TEST(sycl::degrees, double, 3, EXPECTED(double, 180, 180, 180), 0, ma6);
64-
if (dev.has(sycl::aspect::fp16))
65-
TEST(sycl::degrees, sycl::half, 3, EXPECTED(sycl::half, 180, 180, 180), 0.2,
66-
ma7);
67-
// sycl::max
68-
TEST(sycl::max, float, 2, EXPECTED(float, 3.0f, 2.0f), 0, ma1, ma3);
69-
TEST(sycl::max, float, 2, EXPECTED(float, 1.5f, 2.0f), 0, ma1, 1.5f);
70-
if (dev.has(sycl::aspect::fp64))
71-
TEST(sycl::max, double, 2, EXPECTED(double, 1.5, 2.0), 0, ma4, 1.5);
72-
// sycl::min
73-
TEST(sycl::min, float, 2, EXPECTED(float, 1.0f, 2.0f), 0, ma1, ma3);
74-
TEST(sycl::min, float, 2, EXPECTED(float, 1.0f, 1.5f), 0, ma1, 1.5f);
75-
if (dev.has(sycl::aspect::fp64))
76-
TEST(sycl::min, double, 2, EXPECTED(double, 1.0, 1.5), 0, ma4, 1.5);
77-
// sycl::mix
78-
TEST(sycl::mix, float, 2, EXPECTED(float, 1.6f, 2.0f), 0, ma1, ma3, ma8);
79-
TEST(sycl::mix, float, 2, EXPECTED(float, 1.4f, 2.0f), 0, ma1, ma3, 0.2);
80-
if (dev.has(sycl::aspect::fp64))
81-
TEST(sycl::mix, double, 2, EXPECTED(double, 3.0, 5.0), 0, ma4, ma9, 0.5);
82-
// sycl::radians
83-
TEST(sycl::radians, float, 3, EXPECTED(float, M_PI, M_PI, M_PI), 0, ma10);
84-
if (dev.has(sycl::aspect::fp64))
85-
TEST(sycl::radians, double, 3, EXPECTED(double, M_PI, M_PI, M_PI), 0, ma11);
86-
if (dev.has(sycl::aspect::fp16))
87-
TEST(sycl::radians, sycl::half, 3, EXPECTED(sycl::half, M_PI, M_PI, M_PI),
88-
0.002, ma12);
89-
// sycl::step
90-
TEST(sycl::step, float, 2, EXPECTED(float, 1.0f, 1.0f), 0, ma1, ma3);
91-
if (dev.has(sycl::aspect::fp64))
92-
TEST(sycl::step, double, 2, EXPECTED(double, 1.0, 1.0), 0, ma4, ma9);
93-
if (dev.has(sycl::aspect::fp16))
94-
TEST(sycl::step, sycl::half, 3, EXPECTED(sycl::half, 1.0, 0.0, 1.0), 0,
95-
ma12, ma13);
96-
TEST(sycl::step, float, 2, EXPECTED(float, 1.0f, 0.0f), 0, 2.5f, ma3);
97-
if (dev.has(sycl::aspect::fp64))
98-
TEST(sycl::step, double, 2, EXPECTED(double, 0.0f, 1.0f), 0, 6.0f, ma9);
99-
// sycl::smoothstep
100-
TEST(sycl::smoothstep, float, 2, EXPECTED(float, 1.0f, 1.0f), 0, ma8, ma1,
101-
ma2);
102-
if (dev.has(sycl::aspect::fp64))
103-
TEST(sycl::smoothstep, double, 2, EXPECTED(double, 1.0, 1.0f), 0.00000001,
104-
ma4, ma9, ma9);
105-
if (dev.has(sycl::aspect::fp16))
106-
TEST(sycl::smoothstep, sycl::half, 3, EXPECTED(sycl::half, 1.0, 1.0, 1.0),
107-
0, ma7, ma12, ma13);
108-
TEST(sycl::smoothstep, float, 2, EXPECTED(float, 0.0553936f, 0.0f), 0.0000001,
109-
2.5f, 6.0f, ma3);
110-
if (dev.has(sycl::aspect::fp64))
111-
TEST(sycl::smoothstep, double, 2, EXPECTED(double, 0.0f, 1.0f), 0, 6.0f,
112-
8.0f, ma9);
38+
// clamp
39+
test(F(clamp), marray<float, 2>{1.0f, 2.0f}, ma1, ma2, ma3);
40+
test(F(clamp), marray<float, 2>{1.0f, 2.0f}, ma1, 1.0f, 3.0f);
41+
test(has_fp64, F(clamp), marray<double, 2>{1.0, 2.0}, ma4, 1.0, 3.0);
42+
// degrees
43+
test(F(degrees), marray<float, 3>{180, 180, 180}, ma5);
44+
test(has_fp64, F(degrees), marray<double, 3>{180, 180, 180}, ma6);
45+
test(has_fp16, 0.2, F(degrees), marray<half, 3>{180, 180, 180}, ma7);
46+
// max
47+
test(F(max), marray<float, 2>{3.0f, 2.0f}, ma1, ma3);
48+
test(F(max), marray<float, 2>{1.5f, 2.0f}, ma1, 1.5f);
49+
test(has_fp64, F(max), marray<double, 2>{1.5, 2.0}, ma4, 1.5);
50+
// min
51+
test(F(min), marray<float, 2>{1.0f, 2.0f}, ma1, ma3);
52+
test(F(min), marray<float, 2>{1.0f, 1.5f}, ma1, 1.5f);
53+
test(has_fp64, F(min), marray<double, 2>{1.0, 1.5}, ma4, 1.5);
54+
// mix
55+
test(F(mix), marray<float, 2>{1.6f, 2.0f}, ma1, ma3, ma8);
56+
test(F(mix), marray<float, 2>{1.4f, 2.0f}, ma1, ma3, 0.2);
57+
test(has_fp64, F(mix), marray<double, 2>{3.0, 5.0}, ma4, ma9, 0.5);
58+
// radians
59+
test(F(radians), marray<float, 3>{M_PI, M_PI, M_PI}, ma10);
60+
test(has_fp64, F(radians), marray<double, 3>{M_PI, M_PI, M_PI}, ma11);
61+
test(has_fp16, 0.002, F(radians), marray<half, 3>{M_PI, M_PI, M_PI}, ma12);
62+
// step
63+
test(F(step), marray<float, 2>{1.0f, 1.0f}, ma1, ma3);
64+
test(has_fp64, F(step), marray<double, 2>{1.0, 1.0}, ma4, ma9);
65+
test(has_fp16, F(step), marray<half, 3>{1.0, 0.0, 1.0}, ma12, ma13);
66+
test(F(step), marray<float, 2>{1.0f, 0.0f}, 2.5f, ma3);
67+
test(has_fp64, F(step), marray<double, 2>{0.0f, 1.0f}, 6.0f, ma9);
68+
// smoothstep
69+
test(F(smoothstep), marray<float, 2>{1.0f, 1.0f}, ma8, ma1, ma2);
70+
test(has_fp64, 0.00000001, F(smoothstep), marray<double, 2>{1.0, 1.0f}, ma4,
71+
ma9, ma9);
72+
test(has_fp16, F(smoothstep), marray<half, 3>{1.0, 1.0, 1.0}, ma7, ma12,
73+
ma13);
74+
test(0.0000001, F(smoothstep), marray<float, 2>{0.0553936f, 0.0f}, 2.5f, 6.0f,
75+
ma3);
76+
test(has_fp64, F(smoothstep), marray<double, 2>{0.0f, 1.0f}, 6.0f, 8.0f, ma9);
11377
// sign
114-
TEST(sycl::sign, float, 2, EXPECTED(float, +0.0f, -1.0f), 0, ma14);
115-
if (dev.has(sycl::aspect::fp64))
116-
TEST(sycl::sign, double, 2, EXPECTED(double, -0.0, 1.0), 0, ma15);
117-
if (dev.has(sycl::aspect::fp16))
118-
TEST(sycl::sign, sycl::half, 3, EXPECTED(sycl::half, 1.0, 1.0, 1.0), 0,
119-
ma12);
78+
test(F(sign), marray<float, 2>{+0.0f, -1.0f}, ma14);
79+
test(has_fp64, F(sign), marray<double, 2>{-0.0, 1.0}, ma15);
80+
test(has_fp16, F(sign), marray<half, 3>{1.0, 1.0, 1.0}, ma12);
12081

12182
return 0;
12283
}

0 commit comments

Comments
 (0)