Skip to content

Commit ec80987

Browse files
haonanyakeryell
andauthored
[SYCL][Test] Add Devicelib tests (#1256)
Test for all math function. Signed-off-by: haonanya <[email protected]> Co-Authored-By: Ronan Keryell <[email protected]>
1 parent 2438f61 commit ec80987

11 files changed

+436
-296
lines changed

sycl/test/devicelib/c99_complex_math_fp64_test.cpp

+10-10
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
#define CMPLX(r, i) ((double __complex__){ (double)r, (double)i })
1010
#endif
1111

12-
bool is_about_C99_CMPLX(double __complex__ x, double __complex__ y) {
13-
return is_about_FP(creal(x), creal(y)) && is_about_FP(cimag(x), cimag(y));
12+
bool approx_equal_c99_cmplx(double __complex__ x, double __complex__ y) {
13+
return approx_equal_fp(creal(x), creal(y)) && approx_equal_fp(cimag(x), cimag(y));
1414
}
1515

1616
namespace s = cl::sycl;
@@ -44,7 +44,7 @@ void device_c99_complex_times(s::queue &deviceQueue) {
4444
}
4545

4646
for (size_t idx = 0; idx < 4; ++idx) {
47-
assert(is_about_C99_CMPLX(buf_out2[idx], ref_results2[idx]));
47+
assert(approx_equal_c99_cmplx(buf_out2[idx], ref_results2[idx]));
4848
}
4949
}
5050

@@ -81,7 +81,7 @@ void device_c99_complex_divides(s::queue &deviceQueue) {
8181
}
8282

8383
for (size_t idx = 0; idx < 8; ++idx) {
84-
assert(is_about_C99_CMPLX(buf_out2[idx], ref_results2[idx]));
84+
assert(approx_equal_c99_cmplx(buf_out2[idx], ref_results2[idx]));
8585
}
8686
}
8787

@@ -107,7 +107,7 @@ void device_c99_complex_sqrt(s::queue &deviceQueue) {
107107
}
108108

109109
for (size_t idx = 0; idx < 4; ++idx) {
110-
assert(is_about_C99_CMPLX(buf_out2[idx], ref_results2[idx]));
110+
assert(approx_equal_c99_cmplx(buf_out2[idx], ref_results2[idx]));
111111
}
112112
}
113113

@@ -132,7 +132,7 @@ void device_c99_complex_abs(s::queue &deviceQueue) {
132132
}
133133

134134
for (size_t idx = 0; idx < 4; ++idx) {
135-
assert(is_about_FP(buf_out2[idx], ref_results2[idx]));
135+
assert(approx_equal_fp(buf_out2[idx], ref_results2[idx]));
136136
}
137137
}
138138

@@ -158,7 +158,7 @@ void device_c99_complex_exp(s::queue &deviceQueue) {
158158
}
159159

160160
for (size_t idx = 0; idx < 4; ++idx) {
161-
assert(is_about_C99_CMPLX(buf_out2[idx], ref_results2[idx]));
161+
assert(approx_equal_c99_cmplx(buf_out2[idx], ref_results2[idx]));
162162
}
163163
}
164164

@@ -184,7 +184,7 @@ void device_c99_complex_log(s::queue &deviceQueue) {
184184
}
185185

186186
for (size_t idx = 0; idx < 4; ++idx) {
187-
assert(is_about_C99_CMPLX(buf_out2[idx], ref_results2[idx]));
187+
assert(approx_equal_c99_cmplx(buf_out2[idx], ref_results2[idx]));
188188
}
189189
}
190190

@@ -208,7 +208,7 @@ void device_c99_complex_sin(s::queue &deviceQueue) {
208208
}
209209

210210
for (size_t idx = 0; idx < 2; ++idx) {
211-
assert(is_about_C99_CMPLX(buf_out2[idx], ref_results2[idx]));
211+
assert(approx_equal_c99_cmplx(buf_out2[idx], ref_results2[idx]));
212212
}
213213
}
214214

@@ -232,7 +232,7 @@ void device_c99_complex_cos(s::queue &deviceQueue) {
232232
}
233233

234234
for (size_t idx = 0; idx < 2; ++idx) {
235-
assert(is_about_C99_CMPLX(buf_out2[idx], ref_results2[idx]));
235+
assert(approx_equal_c99_cmplx(buf_out2[idx], ref_results2[idx]));
236236
}
237237
}
238238

sycl/test/devicelib/c99_complex_math_test.cpp

+10-10
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
#define CMPLXF(r, i) ((float __complex__){ (float)r, (float)i })
1111
#endif
1212

13-
bool is_about_C99_CMPLXF(float __complex__ x, float __complex__ y) {
14-
return is_about_FP(crealf(x), crealf(y)) && is_about_FP(cimagf(x), cimagf(y));
13+
bool approx_equal_c99_cmplxf(float __complex__ x, float __complex__ y) {
14+
return approx_equal_fp(crealf(x), crealf(y)) && approx_equal_fp(cimagf(x), cimagf(y));
1515
}
1616

1717
namespace s = cl::sycl;
@@ -46,7 +46,7 @@ void device_c99_complex_times(s::queue &deviceQueue) {
4646
}
4747

4848
for (size_t idx = 0; idx < 4; ++idx) {
49-
assert(is_about_C99_CMPLXF(buf_out1[idx], ref_results1[idx]));
49+
assert(approx_equal_c99_cmplxf(buf_out1[idx], ref_results1[idx]));
5050
}
5151
}
5252

@@ -83,7 +83,7 @@ void device_c99_complex_divides(s::queue &deviceQueue) {
8383
}
8484

8585
for (size_t idx = 0; idx < 8; ++idx) {
86-
assert(is_about_C99_CMPLXF(buf_out1[idx], ref_results1[idx]));
86+
assert(approx_equal_c99_cmplxf(buf_out1[idx], ref_results1[idx]));
8787
}
8888
}
8989

@@ -110,7 +110,7 @@ void device_c99_complex_sqrt(s::queue &deviceQueue) {
110110
}
111111

112112
for (size_t idx = 0; idx < 4; ++idx) {
113-
assert(is_about_C99_CMPLXF(buf_out1[idx], ref_results1[idx]));
113+
assert(approx_equal_c99_cmplxf(buf_out1[idx], ref_results1[idx]));
114114
}
115115
}
116116

@@ -136,7 +136,7 @@ void device_c99_complex_abs(s::queue &deviceQueue) {
136136
}
137137

138138
for (size_t idx = 0; idx < 4; ++idx) {
139-
assert(is_about_FP(buf_out1[idx], ref_results1[idx]));
139+
assert(approx_equal_fp(buf_out1[idx], ref_results1[idx]));
140140
}
141141
}
142142

@@ -162,7 +162,7 @@ void device_c99_complex_exp(s::queue &deviceQueue) {
162162
}
163163

164164
for (size_t idx = 0; idx < 4; ++idx) {
165-
assert(is_about_C99_CMPLXF(buf_out1[idx], ref_results1[idx]));
165+
assert(approx_equal_c99_cmplxf(buf_out1[idx], ref_results1[idx]));
166166
}
167167
}
168168

@@ -188,7 +188,7 @@ void device_c99_complex_log(s::queue &deviceQueue) {
188188
}
189189

190190
for (size_t idx = 0; idx < 4; ++idx) {
191-
assert(is_about_C99_CMPLXF(buf_out1[idx], ref_results1[idx]));
191+
assert(approx_equal_c99_cmplxf(buf_out1[idx], ref_results1[idx]));
192192
}
193193
}
194194

@@ -212,7 +212,7 @@ void device_c99_complex_sin(s::queue &deviceQueue) {
212212
}
213213

214214
for (size_t idx = 0; idx < 2; ++idx) {
215-
assert(is_about_C99_CMPLXF(buf_out1[idx], ref_results1[idx]));
215+
assert(approx_equal_c99_cmplxf(buf_out1[idx], ref_results1[idx]));
216216
}
217217
}
218218

@@ -236,7 +236,7 @@ void device_c99_complex_cos(s::queue &deviceQueue) {
236236
}
237237

238238
for (size_t idx = 0; idx < 2; ++idx) {
239-
assert(is_about_C99_CMPLXF(buf_out1[idx], ref_results1[idx]));
239+
assert(approx_equal_c99_cmplxf(buf_out1[idx], ref_results1[idx]));
240240
}
241241
}
242242

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// UNSUPPORTED: windows
2+
// RUN: %clangxx -fsycl -c %s -o %t.o
3+
// RUN: %clangxx -fsycl %t.o %sycl_libs_dir/libsycl-cmath-fp64.o -o %t.out
4+
// RUN: env SYCL_DEVICE_TYPE=HOST %t.out
5+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
6+
// RUN: %ACC_RUN_PLACEHOLDER %t.out
7+
#include <CL/sycl.hpp>
8+
#include <cmath>
9+
#include <iostream>
10+
#include "math_utils.hpp"
11+
12+
namespace s = cl::sycl;
13+
constexpr s::access::mode sycl_read = s::access::mode::read;
14+
constexpr s::access::mode sycl_write = s::access::mode::write;
15+
16+
#define TEST_NUM 38
17+
18+
double ref[TEST_NUM] = {
19+
1, 0, 0, 0, 0, 0, 0, 1, 1, 0.5,
20+
0, 2, 0, 0, 1, 0, 2, 0, 0, 0,
21+
0, 0, 1, 0, 1, 2, 0, 1, 2, 5,
22+
0, 0, 0, 0, 0.5, 0.5, NAN, NAN,};
23+
24+
double refIptr = 1;
25+
26+
template <class T>
27+
void device_cmath_test(s::queue &deviceQueue) {
28+
s::range<1> numOfItems{TEST_NUM};
29+
T result[TEST_NUM] = {-1};
30+
31+
// Variable exponent is an integer value to store the exponent in frexp function
32+
int exponent = -1;
33+
34+
// Variable iptr stores the integral part of float point in modf function
35+
T iptr = -1;
36+
37+
// Variable quo stores the sign and some bits of x/y in remquo function
38+
int quo = -1;
39+
{
40+
s::buffer<T, 1> buffer1(result, numOfItems);
41+
s::buffer<int, 1> buffer2(&exponent, s::range<1>{1});
42+
s::buffer<T, 1> buffer3(&iptr, s::range<1>{1});
43+
s::buffer<int, 1> buffer4(&quo, s::range<1>{1});
44+
deviceQueue.submit([&](cl::sycl::handler &cgh) {
45+
auto res_access = buffer1.template get_access<sycl_write>(cgh);
46+
auto exp_access = buffer2.template get_access<sycl_write>(cgh);
47+
auto iptr_access = buffer3.template get_access<sycl_write>(cgh);
48+
auto quo_access = buffer4.template get_access<sycl_write>(cgh);
49+
cgh.single_task<class DeviceMathTest>([=]() {
50+
int i = 0;
51+
res_access[i++] = std::cos(0.0);
52+
res_access[i++] = std::sin(0.0);
53+
res_access[i++] = std::log(1.0);
54+
res_access[i++] = std::acos(1.0);
55+
res_access[i++] = std::asin(0.0);
56+
res_access[i++] = std::atan(0.0);
57+
res_access[i++] = std::atan2(0.0, 1.0);
58+
res_access[i++] = std::cosh(0.0);
59+
res_access[i++] = std::exp(0.0);
60+
res_access[i++] = std::fmod(1.5, 1.0);
61+
res_access[i++] = std::frexp(0.0, &exp_access[0]);
62+
res_access[i++] = std::ldexp(1.0, 1);
63+
res_access[i++] = std::log10(1.0);
64+
res_access[i++] = std::modf(1.0, &iptr_access[0]);
65+
res_access[i++] = std::pow(1.0, 1.0);
66+
res_access[i++] = std::sinh(0.0);
67+
res_access[i++] = std::sqrt(4.0);
68+
res_access[i++] = std::tan(0.0);
69+
res_access[i++] = std::tanh(0.0);
70+
res_access[i++] = std::acosh(1.0);
71+
res_access[i++] = std::asinh(0.0);
72+
res_access[i++] = std::atanh(0.0);
73+
res_access[i++] = std::cbrt(1.0);
74+
res_access[i++] = std::erf(0.0);
75+
res_access[i++] = std::erfc(0.0);
76+
res_access[i++] = std::exp2(1.0);
77+
res_access[i++] = std::expm1(0.0);
78+
res_access[i++] = std::fdim(1.0, 0.0);
79+
res_access[i++] = std::fma(1.0, 1.0, 1.0);
80+
res_access[i++] = std::hypot(3.0, 4.0);
81+
res_access[i++] = std::ilogb(1.0);
82+
res_access[i++] = std::log1p(0.0);
83+
res_access[i++] = std::log2(1.0);
84+
res_access[i++] = std::logb(1.0);
85+
res_access[i++] = std::remainder(0.5, 1.0);
86+
res_access[i++] = std::remquo(0.5, 1.0, &quo_access[0]);
87+
T a = NAN;
88+
res_access[i++] = std::tgamma(a);
89+
res_access[i++] = std::lgamma(a);
90+
});
91+
});
92+
}
93+
94+
// Compare result with reference
95+
for (int i = 0; i < TEST_NUM; ++i) {
96+
assert(approx_equal_fp(result[i], ref[i]));
97+
}
98+
99+
// Test modf integral part
100+
assert(approx_equal_fp(iptr, refIptr));
101+
102+
// Test frexp exponent
103+
assert(exponent == 0);
104+
105+
// Test remquo sign
106+
assert(quo == 0);
107+
}
108+
109+
int main() {
110+
s::queue deviceQueue;
111+
if (deviceQueue.get_device().has_extension("cl_khr_fp64")) {
112+
device_cmath_test<double>(deviceQueue);
113+
std::cout << "Pass" << std::endl;
114+
}
115+
return 0;
116+
}

0 commit comments

Comments
 (0)