Skip to content

[SYCL][CUDA] Joint_matrix elem wise ops inc bfloat16 #5964

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 52 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
025cf7e
Added bfloat16 support for cuda backend.
JackAKirk Jan 25, 2022
66b4e33
deleted intel namespace bfloat16.
JackAKirk Jan 25, 2022
2d04406
Format.
JackAKirk Jan 25, 2022
9418f74
Changed extension macro name.
JackAKirk Jan 25, 2022
65fddfa
Merge branch 'sycl' into bf16-cvt-ext
JackAKirk Feb 17, 2022
4d99f3f
fixed test.
JackAKirk Feb 17, 2022
3982001
Used neg ptx7.0 builtin for unary minus
JackAKirk Mar 4, 2022
8d2d11f
Replaced SYCL_EXT_INTEL_BF16_CONVERSION.asciidoc with SYCL_EXT_ONEAPI…
JackAKirk Mar 7, 2022
d8bc53f
Merge branch 'sycl' into bf16-cvt-ext
JackAKirk Mar 8, 2022
2f9b7d7
Merge branch 'sycl' into bf16-cvt-ext
JackAKirk Mar 15, 2022
8a29c44
Renamed extension to cover all bfloat16 funct.
JackAKirk Mar 15, 2022
f1fba08
Updated macro name
JackAKirk Mar 16, 2022
461ddb8
Removed old extension doc
JackAKirk Mar 16, 2022
e433fbc
typo
JackAKirk Mar 31, 2022
48ee8ff
Initial bfloat16 function impl.
JackAKirk Apr 1, 2022
4a30a27
Added other bfloat16 scalar cases
JackAKirk Apr 4, 2022
5cb7b09
added bfloat16 device code test.
JackAKirk Apr 4, 2022
603ef6e
Merge branch 'sycl' into bfloat16-joint-matrix
JackAKirk Apr 5, 2022
081008b
Clarified error msg
JackAKirk Apr 5, 2022
25877c0
format
JackAKirk Apr 5, 2022
4b38281
format
JackAKirk Apr 5, 2022
7ed380c
removed deleted header from sycl.hpp
JackAKirk Apr 5, 2022
a53ce3d
example optimized impl using marray.
JackAKirk Apr 8, 2022
b0badd2
Array impls of bfloat16 math fcts.
JackAKirk Apr 11, 2022
a8f8041
matrix device tests use sycl::accessor.
JackAKirk Apr 11, 2022
fbd9a98
mixed float array impl for volta testing.
JackAKirk Apr 12, 2022
d12b6c3
fragments now marray for wi_data use
JackAKirk Apr 12, 2022
003ac9e
format
JackAKirk Apr 12, 2022
3c62249
format
JackAKirk Apr 12, 2022
b18a0c7
if constexpr optimisation
JackAKirk Apr 12, 2022
342a83a
commit for demonstrative purposes
JackAKirk Apr 13, 2022
863450a
std C++17 to sycl::detail C++17
JackAKirk Apr 14, 2022
9106ddc
format
JackAKirk Apr 14, 2022
575da1c
Merge branch 'sycl' into bfloat16-joint-matrix
May 9, 2022
e608f84
Merge branch 'intel:sycl' into bfloat16-joint-matrix
JackAKirk May 9, 2022
1393371
Switched back to c++14.
JackAKirk May 9, 2022
43c2bc5
Switched backed to std c++14
JackAKirk May 10, 2022
4081add
format
JackAKirk May 10, 2022
48177cc
Impls of joint_matrix_fill, wi_data, get_wi_data().
JackAKirk May 27, 2022
18c71df
Added runtime errors for host.
JackAKirk May 31, 2022
a5ebf2a
Removed storage type impls.
JackAKirk Jun 2, 2022
e46997b
wi_data constructor made private.
JackAKirk Jun 6, 2022
b104b30
Use std::ignore.
JackAKirk Jun 8, 2022
3c46f46
Merge branch 'sycl' into bfloat16-joint-matrix
JackAKirk Jun 9, 2022
a67d1fd
PI_INVALID_DEVICE -> PI_ERROR_INVALID_DEVICE
JackAKirk Jun 9, 2022
0f0215b
format
JackAKirk Jun 9, 2022
a9c2901
data -> wi_marray
JackAKirk Jun 9, 2022
7141fdc
Replaced type punning with memcpy.
JackAKirk Jun 22, 2022
22f8650
format
JackAKirk Jun 22, 2022
6ee55f1
Format
JackAKirk Jun 22, 2022
49c962d
add back partial_res decl.
JackAKirk Jun 22, 2022
81f8ba0
format
JackAKirk Jun 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion sycl/include/CL/sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
#if SYCL_EXT_ONEAPI_BACKEND_LEVEL_ZERO
#include <sycl/ext/oneapi/backend/level_zero.hpp>
#endif
#include <sycl/ext/oneapi/bf16_storage_builtins.hpp>
#include <sycl/ext/oneapi/device_global/properties.hpp>
#include <sycl/ext/oneapi/experimental/builtins.hpp>
#include <sycl/ext/oneapi/experimental/cuda/barrier.hpp>
Expand Down
157 changes: 149 additions & 8 deletions sycl/include/sycl/ext/oneapi/experimental/builtins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <CL/sycl/detail/type_traits.hpp>

#include <CL/__spirv/spirv_ops.hpp>
#include <sycl/ext/oneapi/experimental/bfloat16.hpp>

// TODO Decide whether to mark functions with this attribute.
#define __NOEXC /*noexcept*/
Expand All @@ -26,10 +27,7 @@
#endif

__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
namespace ext {
namespace oneapi {
namespace experimental {
namespace sycl::ext::oneapi::experimental {

// Provides functionality to print data from kernels in a C way:
// - On non-host devices this function is directly mapped to printf from
Expand Down Expand Up @@ -117,11 +115,154 @@ inline __SYCL_ALWAYS_INLINE

} // namespace native

} // namespace experimental
} // namespace oneapi
} // namespace ext
template <typename T>
std::enable_if_t<std::is_same<T, bfloat16>::value, T> fabs(T x) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be a template function or could it be bfloat16 fabs(bfloat16 x)? Same question applies to other similar functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remove Template <typename T> and usage of enable_if_t then the compiler sees multiple definitions of bfloat16 fabs() with the same uint16_t (bfloat16 storage type) mangled name. I'm not completely sure why this is or why the templating and use of enable_if_t resolved this but I guessed it gets confused with the other marray definition.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! It's probably confused by some ambiguity between bfloat16 and it's storage class, thinking this could be called if passed uint16_t through implicit conversion. Interesting! Thank you for clarifying. 😄

#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
return bfloat16::from_bits(__clc_fabs(x.raw()));
#else
std::ignore = x;
throw runtime_error("bfloat16 is not currently supported on the host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

template <size_t N>
sycl::marray<bfloat16, N> fabs(sycl::marray<bfloat16, N> x) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
sycl::marray<bfloat16, N> res;
auto x_storage = reinterpret_cast<uint32_t const *>(&x);
auto res_storage = reinterpret_cast<uint32_t *>(&res);

for (size_t i = 0; i < N / 2; i++)
res_storage[i] = __clc_fabs(x_storage[i]);

if constexpr (N % 2) {
res[N - 1] = bfloat16::from_bits(__clc_fabs(x[N - 1].raw()));
}
return res;
#else
std::ignore = x;
throw runtime_error("bfloat16 is not currently supported on the host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

template <typename T>
std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmin(T x, T y) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
return bfloat16::from_bits(__clc_fmin(x.raw(), y.raw()));
#else
std::ignore = x;
(void)y;
throw runtime_error("bfloat16 is not currently supported on the host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

template <size_t N>
sycl::marray<bfloat16, N> fmin(sycl::marray<bfloat16, N> x,
sycl::marray<bfloat16, N> y) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
sycl::marray<bfloat16, N> res;
auto x_storage = reinterpret_cast<uint32_t const *>(&x);
auto y_storage = reinterpret_cast<uint32_t const *>(&y);
auto res_storage = reinterpret_cast<uint32_t *>(&res);

for (size_t i = 0; i < N / 2; i++)
res_storage[i] = __clc_fmin(x_storage[i], y_storage[i]);

if constexpr (N % 2) {
res[N - 1] =
bfloat16::from_bits(__clc_fmin(x[N - 1].raw(), y[N - 1].raw()));
}

return res;
#else
std::ignore = x;
(void)y;
throw runtime_error("bfloat16 is not currently supported on the host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

template <typename T>
std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmax(T x, T y) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
return bfloat16::from_bits(__clc_fmax(x.raw(), y.raw()));
#else
std::ignore = x;
(void)y;
throw runtime_error("bfloat16 is not currently supported on the host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

template <size_t N>
sycl::marray<bfloat16, N> fmax(sycl::marray<bfloat16, N> x,
sycl::marray<bfloat16, N> y) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
sycl::marray<bfloat16, N> res;
auto x_storage = reinterpret_cast<uint32_t const *>(&x);
auto y_storage = reinterpret_cast<uint32_t const *>(&y);
auto res_storage = reinterpret_cast<uint32_t *>(&res);

for (size_t i = 0; i < N / 2; i++)
res_storage[i] = __clc_fmax(x_storage[i], y_storage[i]);

if constexpr (N % 2) {
res[N - 1] =
bfloat16::from_bits(__clc_fmax(x[N - 1].raw(), y[N - 1].raw()));
}
return res;
#else
std::ignore = x;
(void)y;
throw runtime_error("bfloat16 is not currently supported on the host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

template <typename T>
std::enable_if_t<std::is_same<T, bfloat16>::value, T> fma(T x, T y, T z) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
return bfloat16::from_bits(__clc_fma(x.raw(), y.raw(), z.raw()));
#else
std::ignore = x;
(void)y;
(void)z;
throw runtime_error("bfloat16 is not currently supported on the host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

template <size_t N>
sycl::marray<bfloat16, N> fma(sycl::marray<bfloat16, N> x,
sycl::marray<bfloat16, N> y,
sycl::marray<bfloat16, N> z) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
sycl::marray<bfloat16, N> res;
auto x_storage = reinterpret_cast<uint32_t const *>(&x);
auto y_storage = reinterpret_cast<uint32_t const *>(&y);
auto z_storage = reinterpret_cast<uint32_t const *>(&z);
auto res_storage = reinterpret_cast<uint32_t *>(&res);

for (size_t i = 0; i < N / 2; i++)
res_storage[i] = __clc_fma(x_storage[i], y_storage[i], z_storage[i]);

if constexpr (N % 2) {
res[N - 1] = bfloat16::from_bits(
__clc_fma(x[N - 1].raw(), y[N - 1].raw(), z[N - 1].raw()));
}
return res;
#else
std::ignore = x;
(void)y;
throw runtime_error("bfloat16 is not currently supported on the host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

} // namespace sycl
} // namespace sycl::ext::oneapi::experimental
} // __SYCL_INLINE_NAMESPACE(cl)

#undef __SYCL_CONSTANT_AS
Loading