-
Notifications
You must be signed in to change notification settings - Fork 768
[SYCL][CUDA] Joint_matrix elem wise ops inc bfloat16 #5964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 47 commits
025cf7e
66b4e33
2d04406
9418f74
65fddfa
4d99f3f
3982001
8d2d11f
d8bc53f
2f9b7d7
8a29c44
f1fba08
461ddb8
e433fbc
48ee8ff
4a30a27
5cb7b09
603ef6e
081008b
25877c0
4b38281
7ed380c
a53ce3d
b0badd2
a8f8041
fbd9a98
d12b6c3
003ac9e
3c62249
b18a0c7
342a83a
863450a
9106ddc
575da1c
e608f84
1393371
43c2bc5
4081add
48177cc
18c71df
a5ebf2a
e46997b
b104b30
3c46f46
a67d1fd
0f0215b
a9c2901
7141fdc
22f8650
6ee55f1
49c962d
81f8ba0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
#include <CL/sycl/detail/type_traits.hpp> | ||
|
||
#include <CL/__spirv/spirv_ops.hpp> | ||
#include <sycl/ext/oneapi/experimental/bfloat16.hpp> | ||
|
||
// TODO Decide whether to mark functions with this attribute. | ||
#define __NOEXC /*noexcept*/ | ||
|
@@ -26,10 +27,7 @@ | |
#endif | ||
|
||
__SYCL_INLINE_NAMESPACE(cl) { | ||
namespace sycl { | ||
namespace ext { | ||
namespace oneapi { | ||
namespace experimental { | ||
namespace sycl::ext::oneapi::experimental { | ||
|
||
// Provides functionality to print data from kernels in a C way: | ||
// - On non-host devices this function is directly mapped to printf from | ||
|
@@ -117,11 +115,154 @@ inline __SYCL_ALWAYS_INLINE | |
|
||
} // namespace native | ||
|
||
} // namespace experimental | ||
} // namespace oneapi | ||
} // namespace ext | ||
template <typename T> | ||
std::enable_if_t<std::is_same<T, bfloat16>::value, T> fabs(T x) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this need to be a template function or could it be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah! It's probably confused by some ambiguity between |
||
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
return bfloat16::from_bits(__clc_fabs(x.raw())); | ||
#else | ||
std::ignore = x; | ||
throw runtime_error("bfloat16 is not currently supported on the host device.", | ||
PI_ERROR_INVALID_DEVICE); | ||
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
} | ||
|
||
template <size_t N> | ||
sycl::marray<bfloat16, N> fabs(sycl::marray<bfloat16, N> x) { | ||
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
sycl::marray<bfloat16, N> res; | ||
auto x_storage = reinterpret_cast<uint32_t const *>(&x); | ||
auto res_storage = reinterpret_cast<uint32_t *>(&res); | ||
|
||
for (size_t i = 0; i < N / 2; i++) | ||
res_storage[i] = __clc_fabs(x_storage[i]); | ||
|
||
if constexpr (N % 2) { | ||
res[N - 1] = bfloat16::from_bits(__clc_fabs(x[N - 1].raw())); | ||
} | ||
return res; | ||
#else | ||
std::ignore = x; | ||
throw runtime_error("bfloat16 is not currently supported on the host device.", | ||
PI_ERROR_INVALID_DEVICE); | ||
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
} | ||
|
||
template <typename T> | ||
std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmin(T x, T y) { | ||
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
return bfloat16::from_bits(__clc_fmin(x.raw(), y.raw())); | ||
#else | ||
std::ignore = x; | ||
(void)y; | ||
throw runtime_error("bfloat16 is not currently supported on the host device.", | ||
PI_ERROR_INVALID_DEVICE); | ||
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
} | ||
|
||
template <size_t N> | ||
sycl::marray<bfloat16, N> fmin(sycl::marray<bfloat16, N> x, | ||
sycl::marray<bfloat16, N> y) { | ||
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
sycl::marray<bfloat16, N> res; | ||
auto x_storage = reinterpret_cast<uint32_t const *>(&x); | ||
auto y_storage = reinterpret_cast<uint32_t const *>(&y); | ||
auto res_storage = reinterpret_cast<uint32_t *>(&res); | ||
|
||
for (size_t i = 0; i < N / 2; i++) | ||
res_storage[i] = __clc_fmin(x_storage[i], y_storage[i]); | ||
|
||
if constexpr (N % 2) { | ||
res[N - 1] = | ||
bfloat16::from_bits(__clc_fmin(x[N - 1].raw(), y[N - 1].raw())); | ||
} | ||
|
||
return res; | ||
#else | ||
std::ignore = x; | ||
(void)y; | ||
throw runtime_error("bfloat16 is not currently supported on the host device.", | ||
PI_ERROR_INVALID_DEVICE); | ||
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
} | ||
|
||
template <typename T> | ||
std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmax(T x, T y) { | ||
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
return bfloat16::from_bits(__clc_fmax(x.raw(), y.raw())); | ||
#else | ||
std::ignore = x; | ||
(void)y; | ||
throw runtime_error("bfloat16 is not currently supported on the host device.", | ||
PI_ERROR_INVALID_DEVICE); | ||
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
} | ||
|
||
template <size_t N> | ||
sycl::marray<bfloat16, N> fmax(sycl::marray<bfloat16, N> x, | ||
sycl::marray<bfloat16, N> y) { | ||
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
sycl::marray<bfloat16, N> res; | ||
auto x_storage = reinterpret_cast<uint32_t const *>(&x); | ||
auto y_storage = reinterpret_cast<uint32_t const *>(&y); | ||
auto res_storage = reinterpret_cast<uint32_t *>(&res); | ||
|
||
for (size_t i = 0; i < N / 2; i++) | ||
res_storage[i] = __clc_fmax(x_storage[i], y_storage[i]); | ||
|
||
if constexpr (N % 2) { | ||
res[N - 1] = | ||
bfloat16::from_bits(__clc_fmax(x[N - 1].raw(), y[N - 1].raw())); | ||
} | ||
return res; | ||
#else | ||
std::ignore = x; | ||
(void)y; | ||
throw runtime_error("bfloat16 is not currently supported on the host device.", | ||
PI_ERROR_INVALID_DEVICE); | ||
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
} | ||
|
||
template <typename T> | ||
std::enable_if_t<std::is_same<T, bfloat16>::value, T> fma(T x, T y, T z) { | ||
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
return bfloat16::from_bits(__clc_fma(x.raw(), y.raw(), z.raw())); | ||
#else | ||
std::ignore = x; | ||
(void)y; | ||
(void)z; | ||
throw runtime_error("bfloat16 is not currently supported on the host device.", | ||
PI_ERROR_INVALID_DEVICE); | ||
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
} | ||
|
||
template <size_t N> | ||
sycl::marray<bfloat16, N> fma(sycl::marray<bfloat16, N> x, | ||
sycl::marray<bfloat16, N> y, | ||
sycl::marray<bfloat16, N> z) { | ||
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
sycl::marray<bfloat16, N> res; | ||
auto x_storage = reinterpret_cast<uint32_t const *>(&x); | ||
auto y_storage = reinterpret_cast<uint32_t const *>(&y); | ||
auto z_storage = reinterpret_cast<uint32_t const *>(&z); | ||
auto res_storage = reinterpret_cast<uint32_t *>(&res); | ||
|
||
for (size_t i = 0; i < N / 2; i++) | ||
res_storage[i] = __clc_fma(x_storage[i], y_storage[i], z_storage[i]); | ||
|
||
if constexpr (N % 2) { | ||
res[N - 1] = bfloat16::from_bits( | ||
__clc_fma(x[N - 1].raw(), y[N - 1].raw(), z[N - 1].raw())); | ||
} | ||
return res; | ||
#else | ||
std::ignore = x; | ||
(void)y; | ||
throw runtime_error("bfloat16 is not currently supported on the host device.", | ||
PI_ERROR_INVALID_DEVICE); | ||
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
} | ||
|
||
} // namespace sycl | ||
} // namespace sycl::ext::oneapi::experimental | ||
} // __SYCL_INLINE_NAMESPACE(cl) | ||
|
||
#undef __SYCL_CONSTANT_AS |
Uh oh!
There was an error while loading. Please reload this page.