Skip to content

Commit 0a1d751

Browse files
authored
[SYCL][CUDA] Joint_matrix elem wise ops inc bfloat16 (#5964)
This PR introduces full support of element wise operations in the cuda backend. `wi_data`, `get_matrix_fill`, and `joint_matrix.get_wi_data()` are introduced for portability with the Intel backend. In addition, in the CUDA backend users can call `joint_matrix.wi_marray` to access the marray that stores the WI owned elements of the matrix and perform optimized element wise operations using math functions that take marrays. bfloat16 element wise operations support is also included and this PR adds bfloat16 scalar/marray impls replacing the existing uint16_t "storage type" implementations for fma, fmax, fmin, and fabs math functions. The bfloat16 fma_relu function impl has now been added directly in #5749. The existing temporary uint16_t implementations (introduced in #5748 with unmerged tests intel/llvm-test-suite#897) have been removed, since these bfloat16 implementations replaces them. Signed-off-by: jack.kirk <[email protected]>
1 parent f4dee54 commit 0a1d751

10 files changed

+763
-229
lines changed

sycl/include/CL/sycl.hpp

-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060
#if SYCL_EXT_ONEAPI_BACKEND_LEVEL_ZERO
6161
#include <sycl/ext/oneapi/backend/level_zero.hpp>
6262
#endif
63-
#include <sycl/ext/oneapi/bf16_storage_builtins.hpp>
6463
#include <sycl/ext/oneapi/device_global/properties.hpp>
6564
#include <sycl/ext/oneapi/experimental/builtins.hpp>
6665
#include <sycl/ext/oneapi/experimental/cuda/barrier.hpp>

sycl/include/sycl/ext/oneapi/experimental/builtins.hpp

+157-8
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <CL/sycl/detail/type_traits.hpp>
1616

1717
#include <CL/__spirv/spirv_ops.hpp>
18+
#include <sycl/ext/oneapi/experimental/bfloat16.hpp>
1819

1920
// TODO Decide whether to mark functions with this attribute.
2021
#define __NOEXC /*noexcept*/
@@ -26,10 +27,15 @@
2627
#endif
2728

2829
__SYCL_INLINE_NAMESPACE(cl) {
29-
namespace sycl {
30-
namespace ext {
31-
namespace oneapi {
32-
namespace experimental {
30+
namespace sycl::ext::oneapi::experimental {
31+
namespace detail {
32+
template <size_t N>
33+
uint32_t to_uint32_t(sycl::marray<bfloat16, N> x, size_t start) {
34+
uint32_t res;
35+
std::memcpy(&res, &x[start], sizeof(uint32_t));
36+
return res;
37+
}
38+
} // namespace detail
3339

3440
// Provides functionality to print data from kernels in a C way:
3541
// - On non-host devices this function is directly mapped to printf from
@@ -117,11 +123,154 @@ inline __SYCL_ALWAYS_INLINE
117123

118124
} // namespace native
119125

120-
} // namespace experimental
121-
} // namespace oneapi
122-
} // namespace ext
126+
template <typename T>
127+
std::enable_if_t<std::is_same<T, bfloat16>::value, T> fabs(T x) {
128+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
129+
return bfloat16::from_bits(__clc_fabs(x.raw()));
130+
#else
131+
std::ignore = x;
132+
throw runtime_error("bfloat16 is not currently supported on the host device.",
133+
PI_ERROR_INVALID_DEVICE);
134+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
135+
}
136+
137+
template <size_t N>
138+
sycl::marray<bfloat16, N> fabs(sycl::marray<bfloat16, N> x) {
139+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
140+
sycl::marray<bfloat16, N> res;
141+
142+
for (size_t i = 0; i < N / 2; i++) {
143+
auto partial_res = __clc_fabs(detail::to_uint32_t(x, i * 2));
144+
std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
145+
}
146+
147+
if constexpr (N % 2) {
148+
res[N - 1] = bfloat16::from_bits(__clc_fabs(x[N - 1].raw()));
149+
}
150+
return res;
151+
#else
152+
std::ignore = x;
153+
throw runtime_error("bfloat16 is not currently supported on the host device.",
154+
PI_ERROR_INVALID_DEVICE);
155+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
156+
}
157+
158+
template <typename T>
159+
std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmin(T x, T y) {
160+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
161+
return bfloat16::from_bits(__clc_fmin(x.raw(), y.raw()));
162+
#else
163+
std::ignore = x;
164+
(void)y;
165+
throw runtime_error("bfloat16 is not currently supported on the host device.",
166+
PI_ERROR_INVALID_DEVICE);
167+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
168+
}
169+
170+
template <size_t N>
171+
sycl::marray<bfloat16, N> fmin(sycl::marray<bfloat16, N> x,
172+
sycl::marray<bfloat16, N> y) {
173+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
174+
sycl::marray<bfloat16, N> res;
175+
176+
for (size_t i = 0; i < N / 2; i++) {
177+
auto partial_res = __clc_fmin(detail::to_uint32_t(x, i * 2),
178+
detail::to_uint32_t(y, i * 2));
179+
std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
180+
}
181+
182+
if constexpr (N % 2) {
183+
res[N - 1] =
184+
bfloat16::from_bits(__clc_fmin(x[N - 1].raw(), y[N - 1].raw()));
185+
}
186+
187+
return res;
188+
#else
189+
std::ignore = x;
190+
(void)y;
191+
throw runtime_error("bfloat16 is not currently supported on the host device.",
192+
PI_ERROR_INVALID_DEVICE);
193+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
194+
}
195+
196+
template <typename T>
197+
std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmax(T x, T y) {
198+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
199+
return bfloat16::from_bits(__clc_fmax(x.raw(), y.raw()));
200+
#else
201+
std::ignore = x;
202+
(void)y;
203+
throw runtime_error("bfloat16 is not currently supported on the host device.",
204+
PI_ERROR_INVALID_DEVICE);
205+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
206+
}
207+
208+
template <size_t N>
209+
sycl::marray<bfloat16, N> fmax(sycl::marray<bfloat16, N> x,
210+
sycl::marray<bfloat16, N> y) {
211+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
212+
sycl::marray<bfloat16, N> res;
213+
214+
for (size_t i = 0; i < N / 2; i++) {
215+
auto partial_res = __clc_fmax(detail::to_uint32_t(x, i * 2),
216+
detail::to_uint32_t(y, i * 2));
217+
std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
218+
}
219+
220+
if constexpr (N % 2) {
221+
res[N - 1] =
222+
bfloat16::from_bits(__clc_fmax(x[N - 1].raw(), y[N - 1].raw()));
223+
}
224+
return res;
225+
#else
226+
std::ignore = x;
227+
(void)y;
228+
throw runtime_error("bfloat16 is not currently supported on the host device.",
229+
PI_ERROR_INVALID_DEVICE);
230+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
231+
}
232+
233+
template <typename T>
234+
std::enable_if_t<std::is_same<T, bfloat16>::value, T> fma(T x, T y, T z) {
235+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
236+
return bfloat16::from_bits(__clc_fma(x.raw(), y.raw(), z.raw()));
237+
#else
238+
std::ignore = x;
239+
(void)y;
240+
(void)z;
241+
throw runtime_error("bfloat16 is not currently supported on the host device.",
242+
PI_ERROR_INVALID_DEVICE);
243+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
244+
}
245+
246+
template <size_t N>
247+
sycl::marray<bfloat16, N> fma(sycl::marray<bfloat16, N> x,
248+
sycl::marray<bfloat16, N> y,
249+
sycl::marray<bfloat16, N> z) {
250+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
251+
sycl::marray<bfloat16, N> res;
252+
253+
for (size_t i = 0; i < N / 2; i++) {
254+
auto partial_res =
255+
__clc_fma(detail::to_uint32_t(x, i * 2), detail::to_uint32_t(y, i * 2),
256+
detail::to_uint32_t(z, i * 2));
257+
std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
258+
}
259+
260+
if constexpr (N % 2) {
261+
res[N - 1] = bfloat16::from_bits(
262+
__clc_fma(x[N - 1].raw(), y[N - 1].raw(), z[N - 1].raw()));
263+
}
264+
return res;
265+
#else
266+
std::ignore = x;
267+
(void)y;
268+
throw runtime_error("bfloat16 is not currently supported on the host device.",
269+
PI_ERROR_INVALID_DEVICE);
270+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
271+
}
123272

124-
} // namespace sycl
273+
} // namespace sycl::ext::oneapi::experimental
125274
} // __SYCL_INLINE_NAMESPACE(cl)
126275

127276
#undef __SYCL_CONSTANT_AS

0 commit comments

Comments
 (0)