Skip to content

Commit e58f156

Browse files
Revert "[SYCL][E2E] Refactor Basic/built-ins/marray_common.cpp (#12875)"
This reverts commit 5ec9cce.
1 parent 897b270 commit e58f156

File tree

2 files changed

+110
-80
lines changed

2 files changed

+110
-80
lines changed

sycl/test-e2e/Basic/built-ins/helpers.hpp

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
#include <sycl/sycl.hpp>
22

3-
template <typename T> bool equal(T x, T y, double delta) {
3+
template <typename T> bool equal(T x, T y) {
44
// Maybe should be C++20's std::equality_comparable.
55
if constexpr (std::is_scalar_v<T>) {
6-
return std::abs(x - y) <= delta;
6+
return x == y;
77
} else {
88
for (size_t i = 0; i < x.size(); ++i)
9-
if (std::abs(x[i] - y[i]) > delta)
9+
if (x[i] != y[i])
1010
return false;
1111

1212
return true;
@@ -15,10 +15,10 @@ template <typename T> bool equal(T x, T y, double delta) {
1515

1616
template <typename FuncTy, typename ExpectedTy,
1717
typename... ArgTys>
18-
void test(bool CheckDevice, double delta, FuncTy F, ExpectedTy Expected, ArgTys... Args) {
18+
void test(bool CheckDevice, FuncTy F, ExpectedTy Expected, ArgTys... Args) {
1919
auto R = F(Args...);
2020
static_assert(std::is_same_v<decltype(Expected), decltype(R)>);
21-
assert(equal(R, Expected, delta));
21+
assert(equal(R, Expected));
2222

2323
if (!CheckDevice)
2424
return;
@@ -29,24 +29,15 @@ void test(bool CheckDevice, double delta, FuncTy F, ExpectedTy Expected, ArgTys.
2929
cgh.single_task([=]() {
3030
auto R = F(Args...);
3131
static_assert(std::is_same_v<decltype(Expected), decltype(R)>);
32-
Success[0] = equal(R, Expected, delta);
32+
Success[0] = equal(R, Expected);
3333
});
3434
});
3535
assert(sycl::host_accessor{SuccessBuf}[0]);
3636
}
3737

3838
template <typename FuncTy, typename ExpectedTy, typename... ArgTys>
3939
void test(FuncTy F, ExpectedTy Expected, ArgTys... Args) {
40-
test(true /*CheckDevice*/, 0.0 /*delta*/, F, Expected, Args...);
41-
}
42-
template <typename FuncTy, typename ExpectedTy,
43-
typename... ArgTys>
44-
void test(bool CheckDevice, FuncTy F, ExpectedTy Expected, ArgTys... Args) {
45-
test(CheckDevice, 0.0 /*delta*/, F, Expected, Args...);
46-
}
47-
template <typename FuncTy, typename ExpectedTy, typename... ArgTys>
48-
void test(double delta, FuncTy F, ExpectedTy Expected, ArgTys... Args) {
49-
test(true /*CheckDevice*/, delta, F, Expected, Args...);
40+
test(true /*CheckDevice*/, F, Expected, Args...);
5041
}
5142

5243
// MSVC's STL spoils global namespace with math functions, so use explicit

sycl/test-e2e/Basic/built-ins/marray_common.cpp

Lines changed: 103 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -8,76 +8,115 @@
88
#endif
99
#include <cmath>
1010

11-
#include "helpers.hpp"
11+
#include <sycl/sycl.hpp>
1212

13-
int main() {
14-
using namespace sycl;
13+
#define TEST(FUNC, MARRAY_ELEM_TYPE, DIM, EXPECTED, DELTA, ...) \
14+
{ \
15+
{ \
16+
MARRAY_ELEM_TYPE result[DIM]; \
17+
{ \
18+
sycl::buffer<MARRAY_ELEM_TYPE> b(result, sycl::range{DIM}); \
19+
deviceQueue.submit([&](sycl::handler &cgh) { \
20+
sycl::accessor res_access{b, cgh}; \
21+
cgh.single_task([=]() { \
22+
sycl::marray<MARRAY_ELEM_TYPE, DIM> res = FUNC(__VA_ARGS__); \
23+
for (int i = 0; i < DIM; i++) \
24+
res_access[i] = res[i]; \
25+
}); \
26+
}); \
27+
} \
28+
for (int i = 0; i < DIM; i++) \
29+
assert(abs(result[i] - EXPECTED[i]) <= DELTA); \
30+
} \
31+
}
1532

16-
queue deviceQueue;
17-
device dev = deviceQueue.get_device();
33+
#define EXPECTED(TYPE, ...) ((TYPE[]){__VA_ARGS__})
1834

19-
marray<float, 2> ma1{1.0f, 2.0f};
20-
marray<float, 2> ma2{1.0f, 2.0f};
21-
marray<float, 2> ma3{3.0f, 2.0f};
22-
marray<double, 2> ma4{1.0, 2.0};
23-
marray<float, 3> ma5{M_PI, M_PI, M_PI};
24-
marray<double, 3> ma6{M_PI, M_PI, M_PI};
25-
marray<half, 3> ma7{M_PI, M_PI, M_PI};
26-
marray<float, 2> ma8{0.3f, 0.6f};
27-
marray<double, 2> ma9{5.0, 8.0};
28-
marray<float, 3> ma10{180, 180, 180};
29-
marray<double, 3> ma11{180, 180, 180};
30-
marray<half, 3> ma12{180, 180, 180};
31-
marray<half, 3> ma13{181, 179, 181};
32-
marray<float, 2> ma14{+0.0f, -0.6f};
33-
marray<double, 2> ma15{-0.0, 0.6f};
35+
int main() {
36+
sycl::queue deviceQueue;
37+
sycl::device dev = deviceQueue.get_device();
3438

35-
bool has_fp16 = queue{}.get_device().has(sycl::aspect::fp16);
36-
bool has_fp64 = queue{}.get_device().has(sycl::aspect::fp64);
39+
sycl::marray<float, 2> ma1{1.0f, 2.0f};
40+
sycl::marray<float, 2> ma2{1.0f, 2.0f};
41+
sycl::marray<float, 2> ma3{3.0f, 2.0f};
42+
sycl::marray<double, 2> ma4{1.0, 2.0};
43+
sycl::marray<float, 3> ma5{M_PI, M_PI, M_PI};
44+
sycl::marray<double, 3> ma6{M_PI, M_PI, M_PI};
45+
sycl::marray<sycl::half, 3> ma7{M_PI, M_PI, M_PI};
46+
sycl::marray<float, 2> ma8{0.3f, 0.6f};
47+
sycl::marray<double, 2> ma9{5.0, 8.0};
48+
sycl::marray<float, 3> ma10{180, 180, 180};
49+
sycl::marray<double, 3> ma11{180, 180, 180};
50+
sycl::marray<sycl::half, 3> ma12{180, 180, 180};
51+
sycl::marray<sycl::half, 3> ma13{181, 179, 181};
52+
sycl::marray<float, 2> ma14{+0.0f, -0.6f};
53+
sycl::marray<double, 2> ma15{-0.0, 0.6f};
3754

38-
// clamp
39-
test(F(clamp), marray<float, 2>{1.0f, 2.0f}, ma1, ma2, ma3);
40-
test(F(clamp), marray<float, 2>{1.0f, 2.0f}, ma1, 1.0f, 3.0f);
41-
test(has_fp64, F(clamp), marray<double, 2>{1.0, 2.0}, ma4, 1.0, 3.0);
42-
// degrees
43-
test(F(degrees), marray<float, 3>{180, 180, 180}, ma5);
44-
test(has_fp64, F(degrees), marray<double, 3>{180, 180, 180}, ma6);
45-
test(has_fp16, 0.2, F(degrees), marray<half, 3>{180, 180, 180}, ma7);
46-
// max
47-
test(F(max), marray<float, 2>{3.0f, 2.0f}, ma1, ma3);
48-
test(F(max), marray<float, 2>{1.5f, 2.0f}, ma1, 1.5f);
49-
test(has_fp64, F(max), marray<double, 2>{1.5, 2.0}, ma4, 1.5);
50-
// min
51-
test(F(min), marray<float, 2>{1.0f, 2.0f}, ma1, ma3);
52-
test(F(min), marray<float, 2>{1.0f, 1.5f}, ma1, 1.5f);
53-
test(has_fp64, F(min), marray<double, 2>{1.0, 1.5}, ma4, 1.5);
54-
// mix
55-
test(F(mix), marray<float, 2>{1.6f, 2.0f}, ma1, ma3, ma8);
56-
test(F(mix), marray<float, 2>{1.4f, 2.0f}, ma1, ma3, 0.2);
57-
test(has_fp64, F(mix), marray<double, 2>{3.0, 5.0}, ma4, ma9, 0.5);
58-
// radians
59-
test(F(radians), marray<float, 3>{M_PI, M_PI, M_PI}, ma10);
60-
test(has_fp64, F(radians), marray<double, 3>{M_PI, M_PI, M_PI}, ma11);
61-
test(has_fp16, 0.002, F(radians), marray<half, 3>{M_PI, M_PI, M_PI}, ma12);
62-
// step
63-
test(F(step), marray<float, 2>{1.0f, 1.0f}, ma1, ma3);
64-
test(has_fp64, F(step), marray<double, 2>{1.0, 1.0}, ma4, ma9);
65-
test(has_fp16, F(step), marray<half, 3>{1.0, 0.0, 1.0}, ma12, ma13);
66-
test(F(step), marray<float, 2>{1.0f, 0.0f}, 2.5f, ma3);
67-
test(has_fp64, F(step), marray<double, 2>{0.0f, 1.0f}, 6.0f, ma9);
68-
// smoothstep
69-
test(F(smoothstep), marray<float, 2>{1.0f, 1.0f}, ma8, ma1, ma2);
70-
test(has_fp64, 0.00000001, F(smoothstep), marray<double, 2>{1.0, 1.0f}, ma4,
71-
ma9, ma9);
72-
test(has_fp16, F(smoothstep), marray<half, 3>{1.0, 1.0, 1.0}, ma7, ma12,
73-
ma13);
74-
test(0.0000001, F(smoothstep), marray<float, 2>{0.0553936f, 0.0f}, 2.5f, 6.0f,
75-
ma3);
76-
test(has_fp64, F(smoothstep), marray<double, 2>{0.0f, 1.0f}, 6.0f, 8.0f, ma9);
55+
// sycl::clamp
56+
TEST(sycl::clamp, float, 2, EXPECTED(float, 1.0f, 2.0f), 0, ma1, ma2, ma3);
57+
TEST(sycl::clamp, float, 2, EXPECTED(float, 1.0f, 2.0f), 0, ma1, 1.0f, 3.0f);
58+
if (dev.has(sycl::aspect::fp64))
59+
TEST(sycl::clamp, double, 2, EXPECTED(double, 1.0, 2.0), 0, ma4, 1.0, 3.0);
60+
// sycl::degrees
61+
TEST(sycl::degrees, float, 3, EXPECTED(float, 180, 180, 180), 0, ma5);
62+
if (dev.has(sycl::aspect::fp64))
63+
TEST(sycl::degrees, double, 3, EXPECTED(double, 180, 180, 180), 0, ma6);
64+
if (dev.has(sycl::aspect::fp16))
65+
TEST(sycl::degrees, sycl::half, 3, EXPECTED(sycl::half, 180, 180, 180), 0.2,
66+
ma7);
67+
// sycl::max
68+
TEST(sycl::max, float, 2, EXPECTED(float, 3.0f, 2.0f), 0, ma1, ma3);
69+
TEST(sycl::max, float, 2, EXPECTED(float, 1.5f, 2.0f), 0, ma1, 1.5f);
70+
if (dev.has(sycl::aspect::fp64))
71+
TEST(sycl::max, double, 2, EXPECTED(double, 1.5, 2.0), 0, ma4, 1.5);
72+
// sycl::min
73+
TEST(sycl::min, float, 2, EXPECTED(float, 1.0f, 2.0f), 0, ma1, ma3);
74+
TEST(sycl::min, float, 2, EXPECTED(float, 1.0f, 1.5f), 0, ma1, 1.5f);
75+
if (dev.has(sycl::aspect::fp64))
76+
TEST(sycl::min, double, 2, EXPECTED(double, 1.0, 1.5), 0, ma4, 1.5);
77+
// sycl::mix
78+
TEST(sycl::mix, float, 2, EXPECTED(float, 1.6f, 2.0f), 0, ma1, ma3, ma8);
79+
TEST(sycl::mix, float, 2, EXPECTED(float, 1.4f, 2.0f), 0, ma1, ma3, 0.2);
80+
if (dev.has(sycl::aspect::fp64))
81+
TEST(sycl::mix, double, 2, EXPECTED(double, 3.0, 5.0), 0, ma4, ma9, 0.5);
82+
// sycl::radians
83+
TEST(sycl::radians, float, 3, EXPECTED(float, M_PI, M_PI, M_PI), 0, ma10);
84+
if (dev.has(sycl::aspect::fp64))
85+
TEST(sycl::radians, double, 3, EXPECTED(double, M_PI, M_PI, M_PI), 0, ma11);
86+
if (dev.has(sycl::aspect::fp16))
87+
TEST(sycl::radians, sycl::half, 3, EXPECTED(sycl::half, M_PI, M_PI, M_PI),
88+
0.002, ma12);
89+
// sycl::step
90+
TEST(sycl::step, float, 2, EXPECTED(float, 1.0f, 1.0f), 0, ma1, ma3);
91+
if (dev.has(sycl::aspect::fp64))
92+
TEST(sycl::step, double, 2, EXPECTED(double, 1.0, 1.0), 0, ma4, ma9);
93+
if (dev.has(sycl::aspect::fp16))
94+
TEST(sycl::step, sycl::half, 3, EXPECTED(sycl::half, 1.0, 0.0, 1.0), 0,
95+
ma12, ma13);
96+
TEST(sycl::step, float, 2, EXPECTED(float, 1.0f, 0.0f), 0, 2.5f, ma3);
97+
if (dev.has(sycl::aspect::fp64))
98+
TEST(sycl::step, double, 2, EXPECTED(double, 0.0f, 1.0f), 0, 6.0f, ma9);
99+
// sycl::smoothstep
100+
TEST(sycl::smoothstep, float, 2, EXPECTED(float, 1.0f, 1.0f), 0, ma8, ma1,
101+
ma2);
102+
if (dev.has(sycl::aspect::fp64))
103+
TEST(sycl::smoothstep, double, 2, EXPECTED(double, 1.0, 1.0f), 0.00000001,
104+
ma4, ma9, ma9);
105+
if (dev.has(sycl::aspect::fp16))
106+
TEST(sycl::smoothstep, sycl::half, 3, EXPECTED(sycl::half, 1.0, 1.0, 1.0),
107+
0, ma7, ma12, ma13);
108+
TEST(sycl::smoothstep, float, 2, EXPECTED(float, 0.0553936f, 0.0f), 0.0000001,
109+
2.5f, 6.0f, ma3);
110+
if (dev.has(sycl::aspect::fp64))
111+
TEST(sycl::smoothstep, double, 2, EXPECTED(double, 0.0f, 1.0f), 0, 6.0f,
112+
8.0f, ma9);
77113
// sign
78-
test(F(sign), marray<float, 2>{+0.0f, -1.0f}, ma14);
79-
test(has_fp64, F(sign), marray<double, 2>{-0.0, 1.0}, ma15);
80-
test(has_fp16, F(sign), marray<half, 3>{1.0, 1.0, 1.0}, ma12);
114+
TEST(sycl::sign, float, 2, EXPECTED(float, +0.0f, -1.0f), 0, ma14);
115+
if (dev.has(sycl::aspect::fp64))
116+
TEST(sycl::sign, double, 2, EXPECTED(double, -0.0, 1.0), 0, ma15);
117+
if (dev.has(sycl::aspect::fp16))
118+
TEST(sycl::sign, sycl::half, 3, EXPECTED(sycl::half, 1.0, 1.0, 1.0), 0,
119+
ma12);
81120

82121
return 0;
83122
}

0 commit comments

Comments
 (0)