Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit a94982c

Browse files
authored
[SYCL][CUDA] Test cases for bfloat16 math/elem wise joint_matrix (#975)
requires intel/llvm#5964 bfloat16_builtins.cpp covers the bfloat16 scalar math function cases introduced by intel/llvm#5964, using the tests from #897 (that cover all "storage type" uint16_t impl cases). elem_wise_all_ops_cuda.cpp covers the portable elem wise ops using `wi_data`. Since CUDA does not support `joint_matrix_store` for certain data types that are only used in a/b type matrices, such as bfloat16 and int8, it is necessary to perform a `joint_matrix_mad` operation and then call `joint_matrix_store` on the accumulator matrix in order the reach the host code check. Intel backend devices could still use this test in the future provided that a backend check is introduced. Ideally both backends could eventually use the same test code. Signed-off-by: jack.kirk <[email protected]>
1 parent 7f0ca80 commit a94982c

File tree

4 files changed

+651
-101
lines changed

4 files changed

+651
-101
lines changed

SYCL/BFloat16/bfloat16_builtins.cpp

+246
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
// REQUIRES: cuda
2+
//
3+
// Currently this test fails to compile for backends other than cuda.
4+
// Other backends could use this test when bfloat16 math function support is
5+
// added.
6+
//
7+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out -Xsycl-target-backend --cuda-gpu-arch=sm_80
8+
// RUN: %t.out
9+
10+
#include <sycl/sycl.hpp>
11+
12+
#include <cmath>
13+
#include <vector>
14+
15+
using namespace cl::sycl;
16+
using sycl::ext::oneapi::experimental::bfloat16;
17+
18+
constexpr int N = 60; // divisible by all tested array sizes
19+
constexpr float bf16_eps = 0.00390625;
20+
21+
float make_fp32(uint16_t x) {
22+
uint32_t y = x;
23+
y = y << 16;
24+
auto res = reinterpret_cast<float *>(&y);
25+
return *res;
26+
}
27+
28+
bool check(float a, float b) {
29+
return fabs(2 * (a - b) / (a + b)) > bf16_eps * 2;
30+
}
31+
32+
#define TEST_BUILTIN_1_SCAL_IMPL(NAME) \
33+
{ \
34+
buffer<float> a_buf(&a[0], N); \
35+
buffer<int> err_buf(&err, 1); \
36+
q.submit([&](handler &cgh) { \
37+
accessor<float, 1, access::mode::read_write, target::device> A(a_buf, \
38+
cgh); \
39+
accessor<int, 1, access::mode::write, target::device> ERR(err_buf, cgh); \
40+
cgh.parallel_for(N, [=](id<1> index) { \
41+
if (check(NAME(bfloat16{A[index]}), NAME(A[index]))) { \
42+
ERR[0] = 1; \
43+
} \
44+
}); \
45+
}); \
46+
} \
47+
assert(err == 0);
48+
49+
#define TEST_BUILTIN_1_ARR_IMPL(NAME, SZ) \
50+
{ \
51+
buffer<float, 2> a_buf{range<2>{N / SZ, SZ}}; \
52+
buffer<int> err_buf(&err, 1); \
53+
q.submit([&](handler &cgh) { \
54+
accessor<float, 2, access::mode::read_write, target::device> A(a_buf, \
55+
cgh); \
56+
accessor<int, 1, access::mode::write, target::device> ERR(err_buf, cgh); \
57+
cgh.parallel_for(N / SZ, [=](id<1> index) { \
58+
marray<bfloat16, SZ> arg; \
59+
for (int i = 0; i < SZ; i++) { \
60+
arg[i] = A[index][i]; \
61+
} \
62+
marray<bfloat16, SZ> res = NAME(arg); \
63+
for (int i = 0; i < SZ; i++) { \
64+
if (check(res[i], NAME(A[index][i]))) { \
65+
ERR[0] = 1; \
66+
} \
67+
} \
68+
}); \
69+
}); \
70+
} \
71+
assert(err == 0);
72+
73+
#define TEST_BUILTIN_1(NAME) \
74+
TEST_BUILTIN_1_SCAL_IMPL(NAME) \
75+
TEST_BUILTIN_1_ARR_IMPL(NAME, 1) \
76+
TEST_BUILTIN_1_ARR_IMPL(NAME, 2) \
77+
TEST_BUILTIN_1_ARR_IMPL(NAME, 3) \
78+
TEST_BUILTIN_1_ARR_IMPL(NAME, 4) \
79+
TEST_BUILTIN_1_ARR_IMPL(NAME, 5)
80+
81+
#define TEST_BUILTIN_2_SCAL_IMPL(NAME) \
82+
{ \
83+
buffer<float> a_buf(&a[0], N); \
84+
buffer<float> b_buf(&b[0], N); \
85+
buffer<int> err_buf(&err, 1); \
86+
q.submit([&](handler &cgh) { \
87+
accessor<float, 1, access::mode::read_write, target::device> A(a_buf, \
88+
cgh); \
89+
accessor<float, 1, access::mode::read_write, target::device> B(b_buf, \
90+
cgh); \
91+
accessor<int, 1, access::mode::write, target::device> ERR(err_buf, cgh); \
92+
cgh.parallel_for(N, [=](id<1> index) { \
93+
if (check(NAME(bfloat16{A[index]}, bfloat16{B[index]}), \
94+
NAME(A[index], B[index]))) { \
95+
ERR[0] = 1; \
96+
} \
97+
}); \
98+
}); \
99+
} \
100+
assert(err == 0);
101+
102+
#define TEST_BUILTIN_2_ARR_IMPL(NAME, SZ) \
103+
{ \
104+
buffer<float, 2> a_buf{range<2>{N / SZ, SZ}}; \
105+
buffer<float, 2> b_buf{range<2>{N / SZ, SZ}}; \
106+
buffer<int> err_buf(&err, 1); \
107+
q.submit([&](handler &cgh) { \
108+
accessor<float, 2, access::mode::read_write, target::device> A(a_buf, \
109+
cgh); \
110+
accessor<float, 2, access::mode::read_write, target::device> B(b_buf, \
111+
cgh); \
112+
accessor<int, 1, access::mode::write, target::device> ERR(err_buf, cgh); \
113+
cgh.parallel_for(N / SZ, [=](id<1> index) { \
114+
marray<bfloat16, SZ> arg0, arg1; \
115+
for (int i = 0; i < SZ; i++) { \
116+
arg0[i] = A[index][i]; \
117+
arg1[i] = B[index][i]; \
118+
} \
119+
marray<bfloat16, SZ> res = NAME(arg0, arg1); \
120+
for (int i = 0; i < SZ; i++) { \
121+
if (check(res[i], NAME(A[index][i], B[index][i]))) { \
122+
ERR[0] = 1; \
123+
} \
124+
} \
125+
}); \
126+
}); \
127+
} \
128+
assert(err == 0);
129+
130+
#define TEST_BUILTIN_2(NAME) \
131+
TEST_BUILTIN_2_SCAL_IMPL(NAME) \
132+
TEST_BUILTIN_2_ARR_IMPL(NAME, 1) \
133+
TEST_BUILTIN_2_ARR_IMPL(NAME, 2) \
134+
TEST_BUILTIN_2_ARR_IMPL(NAME, 3) \
135+
TEST_BUILTIN_2_ARR_IMPL(NAME, 4) \
136+
TEST_BUILTIN_2_ARR_IMPL(NAME, 5)
137+
138+
#define TEST_BUILTIN_3_SCAL_IMPL(NAME) \
139+
{ \
140+
buffer<float> a_buf(&a[0], N); \
141+
buffer<float> b_buf(&b[0], N); \
142+
buffer<float> c_buf(&c[0], N); \
143+
buffer<int> err_buf(&err, 1); \
144+
q.submit([&](handler &cgh) { \
145+
accessor<float, 1, access::mode::read_write, target::device> A(a_buf, \
146+
cgh); \
147+
accessor<float, 1, access::mode::read_write, target::device> B(b_buf, \
148+
cgh); \
149+
accessor<float, 1, access::mode::read_write, target::device> C(c_buf, \
150+
cgh); \
151+
accessor<int, 1, access::mode::write, target::device> ERR(err_buf, cgh); \
152+
cgh.parallel_for(N, [=](id<1> index) { \
153+
if (check(NAME(bfloat16{A[index]}, bfloat16{B[index]}, \
154+
bfloat16{C[index]}), \
155+
NAME(A[index], B[index], C[index]))) { \
156+
ERR[0] = 1; \
157+
} \
158+
}); \
159+
}); \
160+
} \
161+
assert(err == 0);
162+
163+
#define TEST_BUILTIN_3_ARR_IMPL(NAME, SZ) \
164+
{ \
165+
buffer<float, 2> a_buf{range<2>{N / SZ, SZ}}; \
166+
buffer<float, 2> b_buf{range<2>{N / SZ, SZ}}; \
167+
buffer<float, 2> c_buf{range<2>{N / SZ, SZ}}; \
168+
buffer<int> err_buf(&err, 1); \
169+
q.submit([&](handler &cgh) { \
170+
accessor<float, 2, access::mode::read_write, target::device> A(a_buf, \
171+
cgh); \
172+
accessor<float, 2, access::mode::read_write, target::device> B(b_buf, \
173+
cgh); \
174+
accessor<float, 2, access::mode::read_write, target::device> C(c_buf, \
175+
cgh); \
176+
accessor<int, 1, access::mode::write, target::device> ERR(err_buf, cgh); \
177+
cgh.parallel_for(N / SZ, [=](id<1> index) { \
178+
marray<bfloat16, SZ> arg0, arg1, arg2; \
179+
for (int i = 0; i < SZ; i++) { \
180+
arg0[i] = A[index][i]; \
181+
arg1[i] = B[index][i]; \
182+
arg2[i] = C[index][i]; \
183+
} \
184+
marray<bfloat16, SZ> res = NAME(arg0, arg1, arg2); \
185+
for (int i = 0; i < SZ; i++) { \
186+
if (check(res[i], NAME(A[index][i], B[index][i], C[index][i]))) { \
187+
ERR[0] = 1; \
188+
} \
189+
} \
190+
}); \
191+
}); \
192+
} \
193+
assert(err == 0);
194+
195+
#define TEST_BUILTIN_3(NAME) \
196+
TEST_BUILTIN_3_SCAL_IMPL(NAME) \
197+
TEST_BUILTIN_3_ARR_IMPL(NAME, 1) \
198+
TEST_BUILTIN_3_ARR_IMPL(NAME, 2) \
199+
TEST_BUILTIN_3_ARR_IMPL(NAME, 3) \
200+
TEST_BUILTIN_3_ARR_IMPL(NAME, 4) \
201+
TEST_BUILTIN_3_ARR_IMPL(NAME, 5)
202+
203+
#define TEST_BUILTIN_2_NAN(NAME) \
204+
{ \
205+
buffer<int> err_buf(&err, 1); \
206+
buffer<float> nan_buf(&check_nan, 1); \
207+
q.submit([&](handler &cgh) { \
208+
accessor<int, 1, access::mode::write, target::device> ERR(err_buf, cgh); \
209+
accessor<float, 1, access::mode::write, target::device> checkNAN( \
210+
nan_buf, cgh); \
211+
cgh.single_task([=]() { \
212+
checkNAN[0] = NAME(bfloat16{NAN}, bfloat16{NAN}); \
213+
if ((NAME(bfloat16{2}, bfloat16{NAN}) != 2) || \
214+
(NAME(bfloat16{NAN}, bfloat16{2}) != 2)) { \
215+
ERR[0] = 1; \
216+
} \
217+
}); \
218+
}); \
219+
} \
220+
assert(err == 0); \
221+
assert(std::isnan(check_nan));
222+
223+
int main() {
224+
queue q;
225+
226+
if (q.get_device().has(aspect::ext_oneapi_bfloat16)) {
227+
std::vector<float> a(N), b(N), c(N);
228+
int err = 0;
229+
230+
for (int i = 0; i < N; i++) {
231+
a[i] = (i - N / 2) / (float)N;
232+
b[i] = (N / 2 - i) / (float)N;
233+
c[i] = (float)(3 * i);
234+
}
235+
236+
TEST_BUILTIN_1(fabs);
237+
TEST_BUILTIN_2(fmin);
238+
TEST_BUILTIN_2(fmax);
239+
TEST_BUILTIN_3(fma);
240+
241+
float check_nan = 0;
242+
TEST_BUILTIN_2_NAN(fmin);
243+
TEST_BUILTIN_2_NAN(fmax);
244+
}
245+
return 0;
246+
}

0 commit comments

Comments
 (0)