diff --git a/sycl/include/CL/sycl.hpp b/sycl/include/CL/sycl.hpp index 47dbaf6f05bc6..9b517ced4b659 100644 --- a/sycl/include/CL/sycl.hpp +++ b/sycl/include/CL/sycl.hpp @@ -60,7 +60,6 @@ #if SYCL_EXT_ONEAPI_BACKEND_LEVEL_ZERO #include #endif -#include #include #include #include diff --git a/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp b/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp index c8fa033d8c79e..c2326c6563efc 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/builtins.hpp @@ -15,6 +15,7 @@ #include #include +#include // TODO Decide whether to mark functions with this attribute. #define __NOEXC /*noexcept*/ @@ -26,10 +27,15 @@ #endif __SYCL_INLINE_NAMESPACE(cl) { -namespace sycl { -namespace ext { -namespace oneapi { -namespace experimental { +namespace sycl::ext::oneapi::experimental { +namespace detail { +template +uint32_t to_uint32_t(sycl::marray x, size_t start) { + uint32_t res; + std::memcpy(&res, &x[start], sizeof(uint32_t)); + return res; +} +} // namespace detail // Provides functionality to print data from kernels in a C way: // - On non-host devices this function is directly mapped to printf from @@ -117,11 +123,154 @@ inline __SYCL_ALWAYS_INLINE } // namespace native -} // namespace experimental -} // namespace oneapi -} // namespace ext +template +std::enable_if_t::value, T> fabs(T x) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + return bfloat16::from_bits(__clc_fabs(x.raw())); +#else + std::ignore = x; + throw runtime_error("bfloat16 is not currently supported on the host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + +template +sycl::marray fabs(sycl::marray x) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + sycl::marray res; + + for (size_t i = 0; i < N / 2; i++) { + auto partial_res = __clc_fabs(detail::to_uint32_t(x, i * 2)); + std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); + } + + if constexpr (N % 2) { + res[N - 1] = bfloat16::from_bits(__clc_fabs(x[N - 1].raw())); + } + return res; +#else + std::ignore = x; + throw runtime_error("bfloat16 is not currently supported on the host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + +template +std::enable_if_t::value, T> fmin(T x, T y) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + return bfloat16::from_bits(__clc_fmin(x.raw(), y.raw())); +#else + std::ignore = x; + (void)y; + throw runtime_error("bfloat16 is not currently supported on the host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + +template +sycl::marray fmin(sycl::marray x, + sycl::marray y) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + sycl::marray res; + + for (size_t i = 0; i < N / 2; i++) { + auto partial_res = __clc_fmin(detail::to_uint32_t(x, i * 2), + detail::to_uint32_t(y, i * 2)); + std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); + } + + if constexpr (N % 2) { + res[N - 1] = + bfloat16::from_bits(__clc_fmin(x[N - 1].raw(), y[N - 1].raw())); + } + + return res; +#else + std::ignore = x; + (void)y; + throw runtime_error("bfloat16 is not currently supported on the host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + +template +std::enable_if_t::value, T> fmax(T x, T y) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + return bfloat16::from_bits(__clc_fmax(x.raw(), y.raw())); +#else + std::ignore = x; + (void)y; + throw runtime_error("bfloat16 is not currently supported on the host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + +template +sycl::marray fmax(sycl::marray x, + sycl::marray y) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + sycl::marray res; + + for (size_t i = 0; i < N / 2; i++) { + auto partial_res = __clc_fmax(detail::to_uint32_t(x, i * 2), + detail::to_uint32_t(y, i * 2)); + std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); + } + + if constexpr (N % 2) { + res[N - 1] = + bfloat16::from_bits(__clc_fmax(x[N - 1].raw(), y[N - 1].raw())); + } + return res; +#else + std::ignore = x; + (void)y; + throw runtime_error("bfloat16 is not currently supported on the host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + +template +std::enable_if_t::value, T> fma(T x, T y, T z) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + return bfloat16::from_bits(__clc_fma(x.raw(), y.raw(), z.raw())); +#else + std::ignore = x; + (void)y; + (void)z; + throw runtime_error("bfloat16 is not currently supported on the host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} + +template +sycl::marray fma(sycl::marray x, + sycl::marray y, + sycl::marray z) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + sycl::marray res; + + for (size_t i = 0; i < N / 2; i++) { + auto partial_res = + __clc_fma(detail::to_uint32_t(x, i * 2), detail::to_uint32_t(y, i * 2), + detail::to_uint32_t(z, i * 2)); + std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t)); + } + + if constexpr (N % 2) { + res[N - 1] = bfloat16::from_bits( + __clc_fma(x[N - 1].raw(), y[N - 1].raw(), z[N - 1].raw())); + } + return res; +#else + std::ignore = x; + (void)y; + throw runtime_error("bfloat16 is not currently supported on the host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} -} // namespace sycl +} // namespace sycl::ext::oneapi::experimental } // __SYCL_INLINE_NAMESPACE(cl) #undef __SYCL_CONSTANT_AS diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp index 7a7ecbf948907..6de66d0f8590c 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp @@ -7,11 +7,10 @@ // ===--------------------------------------------------------------------=== // #pragma once +#include __SYCL_INLINE_NAMESPACE(cl) { -namespace sycl { -namespace ext { -namespace oneapi { +namespace sycl::ext::oneapi { namespace experimental::matrix { enum class matrix_use { a, b, accumulator }; @@ -28,68 +27,148 @@ template struct joint_matrix; -#define __SYCL_JOINT_MATRIX_OVERLOAD(type, use, M, N, frag_type, frag_size) \ +template class wi_data { + marray &data; + wi_data(marray &wi_data) : data(wi_data){}; + template + friend struct joint_matrix; + +public: + size_t length() { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + return data.size(); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + }; + + type &operator[](size_t i) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + return data[i]; +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + }; +}; + +#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR(type, use, M, N, size) \ template \ struct joint_matrix< \ type, matrix_use::use, M, N, Layout, sycl::sub_group, \ typename std::enable_if_t> { \ - frag_type data[frag_size]; \ + marray wi_marray; \ + inline __SYCL_ALWAYS_INLINE wi_data get_wi_data() { \ + return wi_data(wi_marray); \ + }; \ }; +// m8n32k16 +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 8, 16, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 16, 32, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 8, 16, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 32, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, accumulator, 8, 32, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(float, accumulator, 8, 32, 8) + +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 8, 16, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 32, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 8, 16, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 32, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int32_t, accumulator, 8, 32, 8) +// m32n8k16 +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 32, 16, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 16, 8, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 32, 16, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 8, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, accumulator, 32, 8, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(float, accumulator, 32, 8, 8) + +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 32, 16, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 8, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 32, 16, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 8, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int32_t, accumulator, 32, 8, 8) +// m16n16k16 +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 16, 16, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 16, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, accumulator, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(float, accumulator, 16, 16, 8) + +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 16, 8) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int32_t, accumulator, 16, 16, 8) // m8n8k4 double only -__SYCL_JOINT_MATRIX_OVERLOAD(double, a, 8, 4, double, 1) -__SYCL_JOINT_MATRIX_OVERLOAD(double, b, 4, 8, double, 1) -__SYCL_JOINT_MATRIX_OVERLOAD(double, accumulator, 8, 8, double, 2) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, a, 8, 4, 1) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, b, 4, 8, 1) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, accumulator, 8, 8, 2) + +#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR + +#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(precision, use, M, N, type, \ + size) \ + template \ + struct joint_matrix< \ + precision, matrix_use::use, M, N, Layout, sycl::sub_group, \ + typename std::enable_if_t> { \ + marray wi_marray; \ + inline __SYCL_ALWAYS_INLINE wi_data get_wi_data() { \ + return wi_data(wi_marray); \ + }; \ + }; +// m16n16k8 tf32 only +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(precision::tf32, a, 16, 8, float, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(precision::tf32, b, 8, 16, float, 4) +#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION + +#define __SYCL_JOINT_MATRIX_OVERLOAD(type, use, M, N, frag_type, frag_size) \ + template \ + struct joint_matrix< \ + type, matrix_use::use, M, N, Layout, sycl::sub_group, \ + typename std::enable_if_t> { \ + frag_type wi_marray[frag_size]; \ + }; + +// bf16 data format uint16_t implementation is deprecated // m8n32k16 -// bf16 data format uses uint16_t data type __SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 8, 16, int32_t, 2) __SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 32, int32_t, 8) -__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 8, 16, int32_t, 8) -__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 32, int32_t, 8) -__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 8, 32, float, 8) -__SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 8, 32, int32_t, 4) - -__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 8, 16, int32_t, 1) -__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 32, int32_t, 4) -__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 8, 16, int32_t, 1) -__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 32, int32_t, 4) -__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 8, 32, int32_t, 8) - // m32n8k16 __SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 32, 16, int32_t, 8) __SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 8, int32_t, 2) -__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 32, 16, int32_t, 8) -__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 8, int32_t, 8) -__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 32, 8, float, 8) -__SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 32, 8, int32_t, 4) - -__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 32, 16, int32_t, 4) -__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 8, int32_t, 1) -__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 32, 16, int32_t, 4) -__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 8, int32_t, 1) -__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 32, 8, int32_t, 8) - // m16n16k16 __SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 16, 16, int32_t, 4) __SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 16, int32_t, 4) -__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 16, 16, int32_t, 8) -__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 16, int32_t, 8) -__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 16, 16, float, 8) -__SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 16, 16, int32_t, 4) -__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 16, 16, int32_t, 2) -__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 16, int32_t, 2) -__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 16, 16, int32_t, 2) -__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 16, int32_t, 2) -__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 16, 16, int32_t, 8) +#undef __SYCL_JOINT_MATRIX_OVERLOAD -// m16n16k8 tf32 -__SYCL_JOINT_MATRIX_OVERLOAD(precision::tf32, a, 16, 8, float, 4) -__SYCL_JOINT_MATRIX_OVERLOAD(precision::tf32, b, 8, 16, float, 4) +template +inline __SYCL_ALWAYS_INLINE void +joint_matrix_fill(Group sg, + joint_matrix &res, + const T2 v) { + // We kept the unused "sg" in joint_matrix_fill to match the other DPC++ + // functions + std::ignore = sg; +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + res.wi_marray = v; +#else + std::ignore = res; + std::ignore = v; +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +} -#undef __SYCL_JOINT_MATRIX_OVERLOAD } // namespace experimental::matrix namespace detail { @@ -134,164 +213,168 @@ struct joint_matrix_load_impl< void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, multi_ptr src, size_t stride) { - if constexpr (std::is_same::value) { - int32_t *tileptr = reinterpret_cast(src.get()); + if constexpr (std::is_same::value || + std::is_same< + T, sycl::ext::oneapi::experimental::bfloat16>::value) { + auto tileptr = reinterpret_cast(src.get()); + auto destptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (Use == sycl::ext::oneapi::experimental::matrix::matrix_use::a) { - __mma_bf16_m16n16k16_ld_a(res.data, tileptr, stride, + __mma_bf16_m16n16k16_ld_a(destptr, tileptr, stride, get_layout_id()); } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::b) { - __mma_bf16_m16n16k16_ld_b(res.data, tileptr, stride, + __mma_bf16_m16n16k16_ld_b(destptr, tileptr, stride, get_layout_id()); } } else if constexpr (NumRows == 8 && NumCols == 16) { - __mma_bf16_m8n32k16_ld_a(res.data, tileptr, stride, + __mma_bf16_m8n32k16_ld_a(destptr, tileptr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 32) { - __mma_bf16_m8n32k16_ld_b(res.data, tileptr, stride, + __mma_bf16_m8n32k16_ld_b(destptr, tileptr, stride, get_layout_id()); } else if constexpr (NumRows == 32 && NumCols == 16) { - __mma_bf16_m32n8k16_ld_a(res.data, tileptr, stride, + __mma_bf16_m32n8k16_ld_a(destptr, tileptr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 8) { - __mma_bf16_m32n8k16_ld_b(res.data, tileptr, stride, + __mma_bf16_m32n8k16_ld_b(destptr, tileptr, stride, get_layout_id()); } } else if constexpr (std::is_same::value) { - int32_t *tileptr = reinterpret_cast(src.get()); + auto tileptr = reinterpret_cast(src.get()); + auto destptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (Use == sycl::ext::oneapi::experimental::matrix::matrix_use::a) { - __imma_m16n16k16_ld_a_u8(res.data, tileptr, stride, + __imma_m16n16k16_ld_a_u8(destptr, tileptr, stride, get_layout_id()); } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::b) { - __imma_m16n16k16_ld_b_u8(res.data, tileptr, stride, + __imma_m16n16k16_ld_b_u8(destptr, tileptr, stride, get_layout_id()); } } else if constexpr (NumRows == 8 && NumCols == 16) { - __imma_m8n32k16_ld_a_u8(res.data, tileptr, stride, + __imma_m8n32k16_ld_a_u8(destptr, tileptr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 32) { - __imma_m8n32k16_ld_b_u8(res.data, tileptr, stride, + __imma_m8n32k16_ld_b_u8(destptr, tileptr, stride, get_layout_id()); } else if constexpr (NumRows == 32 && NumCols == 16) { - __imma_m32n8k16_ld_a_u8(res.data, tileptr, stride, + __imma_m32n8k16_ld_a_u8(destptr, tileptr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 8) { - __imma_m32n8k16_ld_b_u8(res.data, tileptr, stride, + __imma_m32n8k16_ld_b_u8(destptr, tileptr, stride, get_layout_id()); } } else if constexpr (std::is_same::value) { - int32_t *tileptr = reinterpret_cast(src.get()); + auto tileptr = reinterpret_cast(src.get()); + auto destptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (Use == sycl::ext::oneapi::experimental::matrix::matrix_use::a) { - __imma_m16n16k16_ld_a_s8(res.data, tileptr, stride, + __imma_m16n16k16_ld_a_s8(destptr, tileptr, stride, get_layout_id()); } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::b) { - __imma_m16n16k16_ld_b_s8(res.data, tileptr, stride, + __imma_m16n16k16_ld_b_s8(destptr, tileptr, stride, get_layout_id()); } } else if constexpr (NumRows == 8 && NumCols == 16) { - __imma_m8n32k16_ld_a_s8(res.data, tileptr, stride, + __imma_m8n32k16_ld_a_s8(destptr, tileptr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 32) { - __imma_m8n32k16_ld_b_s8(res.data, tileptr, stride, + __imma_m8n32k16_ld_b_s8(destptr, tileptr, stride, get_layout_id()); } else if constexpr (NumRows == 32 && NumCols == 16) { - __imma_m32n8k16_ld_a_s8(res.data, tileptr, stride, + __imma_m32n8k16_ld_a_s8(destptr, tileptr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 8) { - __imma_m32n8k16_ld_b_s8(res.data, tileptr, stride, + __imma_m32n8k16_ld_b_s8(destptr, tileptr, stride, get_layout_id()); } } else if constexpr (std::is_same::value) { - int32_t *tileptr = reinterpret_cast(src.get()); + auto tileptr = reinterpret_cast(src.get()); + auto dstptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (Use == sycl::ext::oneapi::experimental::matrix::matrix_use::a) { - __hmma_m16n16k16_ld_a(res.data, tileptr, stride, + __hmma_m16n16k16_ld_a(dstptr, tileptr, stride, get_layout_id()); } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::b) { - __hmma_m16n16k16_ld_b(res.data, tileptr, stride, + __hmma_m16n16k16_ld_b(dstptr, tileptr, stride, get_layout_id()); } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::accumulator) { - __hmma_m16n16k16_ld_c_f16(res.data, tileptr, stride, + __hmma_m16n16k16_ld_c_f16(dstptr, tileptr, stride, get_layout_id()); } } else if constexpr (NumRows == 8 && NumCols == 16) { - __hmma_m8n32k16_ld_a(res.data, tileptr, stride, - get_layout_id()); + __hmma_m8n32k16_ld_a(dstptr, tileptr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 32) { - __hmma_m8n32k16_ld_b(res.data, tileptr, stride, - get_layout_id()); + __hmma_m8n32k16_ld_b(dstptr, tileptr, stride, get_layout_id()); } else if constexpr (NumRows == 32 && NumCols == 16) { - __hmma_m32n8k16_ld_a(res.data, tileptr, stride, - get_layout_id()); + __hmma_m32n8k16_ld_a(dstptr, tileptr, stride, get_layout_id()); } else if constexpr (NumRows == 16 && NumCols == 8) { - __hmma_m32n8k16_ld_b(res.data, tileptr, stride, - get_layout_id()); + __hmma_m32n8k16_ld_b(dstptr, tileptr, stride, get_layout_id()); } else if constexpr (NumRows == 32 && NumCols == 8) { - __hmma_m32n8k16_ld_c_f16(res.data, tileptr, stride, + __hmma_m32n8k16_ld_c_f16(dstptr, tileptr, stride, get_layout_id()); } else if constexpr (NumRows == 8 && NumCols == 32) { - __hmma_m8n32k16_ld_c_f16(res.data, tileptr, stride, + __hmma_m8n32k16_ld_c_f16(dstptr, tileptr, stride, get_layout_id()); } } else if constexpr (std::is_same::value) { + auto destptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { - __imma_m16n16k16_ld_c(res.data, src.get(), stride, + __imma_m16n16k16_ld_c(destptr, src.get(), stride, get_layout_id()); } else if constexpr (NumRows == 8 && NumCols == 32) { - __imma_m8n32k16_ld_c(res.data, src.get(), stride, + __imma_m8n32k16_ld_c(destptr, src.get(), stride, get_layout_id()); } else if constexpr (NumRows == 32 && NumCols == 8) { - __imma_m32n8k16_ld_c(res.data, src.get(), stride, + __imma_m32n8k16_ld_c(destptr, src.get(), stride, get_layout_id()); } } else if constexpr (std::is_same::value) { - if (std::is_same::value) { + if constexpr (std::is_same::value) { + auto dstptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 16) { - __hmma_m16n16k16_ld_c_f32(res.data, src.get(), stride, + __hmma_m16n16k16_ld_c_f32(dstptr, src.get(), stride, get_layout_id()); } else if constexpr (NumRows == 8 && NumCols == 32) { - __hmma_m8n32k16_ld_c_f32(res.data, src.get(), stride, + __hmma_m8n32k16_ld_c_f32(dstptr, src.get(), stride, get_layout_id()); } else if constexpr (NumRows == 32 && NumCols == 8) { - __hmma_m32n8k16_ld_c_f32(res.data, src.get(), stride, + __hmma_m32n8k16_ld_c_f32(dstptr, src.get(), stride, get_layout_id()); } - } else if (std::is_same::value) { - int32_t *tileptr = reinterpret_cast(src.get()); + } else if constexpr (std::is_same::value) { + auto tileptr = reinterpret_cast(src.get()); + auto dstptr = reinterpret_cast(&res.wi_marray); if constexpr (NumRows == 16 && NumCols == 8) { - __mma_tf32_m16n16k8_ld_a(reinterpret_cast(res.data), - tileptr, stride, get_layout_id()); + __mma_tf32_m16n16k8_ld_a(dstptr, tileptr, stride, + get_layout_id()); } else if constexpr (NumRows == 8 && NumCols == 16) { - __mma_tf32_m16n16k8_ld_b(reinterpret_cast(res.data), - tileptr, stride, get_layout_id()); + __mma_tf32_m16n16k8_ld_b(dstptr, tileptr, stride, + get_layout_id()); } } } else if constexpr (std::is_same::value) { + auto dstptr = reinterpret_cast(&res.wi_marray); if constexpr (Use == sycl::ext::oneapi::experimental::matrix::matrix_use::a) { - __dmma_m8n8k4_ld_a(res.data, src.get(), stride, - get_layout_id()); + __dmma_m8n8k4_ld_a(dstptr, src.get(), stride, get_layout_id()); } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::b) { - __dmma_m8n8k4_ld_b(res.data, src.get(), stride, - get_layout_id()); + __dmma_m8n8k4_ld_b(dstptr, src.get(), stride, get_layout_id()); } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: matrix_use::accumulator) { - __dmma_m8n8k4_ld_c(res.data, src.get(), stride, - get_layout_id()); + __dmma_m8n8k4_ld_c(dstptr, src.get(), stride, get_layout_id()); } } } @@ -322,44 +405,51 @@ struct joint_matrix_store_impl< T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, NumRows, NumCols, Layout, sycl::sub_group> &src, multi_ptr dst, size_t stride) { - if (NumRows == 16 && NumCols == 16) { + if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (std::is_same::value) { - __hmma_m16n16k16_st_c_f32(dst.get(), src.data, stride, - get_layout_id()); + __hmma_m16n16k16_st_c_f32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); } else if constexpr (std::is_same::value) { - __imma_m16n16k16_st_c_i32(dst.get(), src.data, stride, - get_layout_id()); + __imma_m16n16k16_st_c_i32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); } else if constexpr (std::is_same::value) { - int32_t *tileptr = reinterpret_cast(dst.get()); - __hmma_m16n16k16_st_c_f16(tileptr, src.data, stride, - get_layout_id()); + __hmma_m16n16k16_st_c_f16(reinterpret_cast(dst.get()), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); } - } else if (NumRows == 8 && NumCols == 32) { + } else if constexpr (NumRows == 8 && NumCols == 32) { if constexpr (std::is_same::value) { - __hmma_m8n32k16_st_c_f32(dst.get(), src.data, stride, - get_layout_id()); + __hmma_m8n32k16_st_c_f32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); } else if constexpr (std::is_same::value) { - __imma_m8n32k16_st_c_i32(dst.get(), src.data, stride, - get_layout_id()); + __imma_m8n32k16_st_c_i32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); } else if constexpr (std::is_same::value) { - int32_t *tileptr = reinterpret_cast(dst.get()); - __hmma_m8n32k16_st_c_f16(tileptr, src.data, stride, - get_layout_id()); + __hmma_m8n32k16_st_c_f16(reinterpret_cast(dst.get()), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); } - } else if (NumRows == 32 && NumCols == 8) { + } else if constexpr (NumRows == 32 && NumCols == 8) { if constexpr (std::is_same::value) { - __hmma_m32n8k16_st_c_f32(dst.get(), src.data, stride, - get_layout_id()); + __hmma_m32n8k16_st_c_f32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); } else if constexpr (std::is_same::value) { - __imma_m32n8k16_st_c_i32(dst.get(), src.data, stride, - get_layout_id()); + __imma_m32n8k16_st_c_i32(dst.get(), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); } else if constexpr (std::is_same::value) { - int32_t *tileptr = reinterpret_cast(dst.get()); - __hmma_m32n8k16_st_c_f16(tileptr, src.data, stride, - get_layout_id()); + __hmma_m32n8k16_st_c_f16(reinterpret_cast(dst.get()), + reinterpret_cast(&src.wi_marray), + stride, get_layout_id()); } } else if constexpr (std::is_same::value) { - __dmma_m8n8k4_st_c_f64(dst.get(), src.data, stride, + __dmma_m8n8k4_st_c_f64(dst.get(), + reinterpret_cast(&src.wi_marray), stride, get_layout_id()); } } @@ -459,70 +549,127 @@ struct joint_matrix_mad_impl< N, LayoutC, sycl::sub_group> D; if constexpr (M == 16 && N == 16 && K == 16) { - if constexpr (std::is_same::value) { - __imma_m16n16k16_mma_s8(D.data, A.data, B.data, C.data, - get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { - __imma_m16n16k16_mma_u8(D.data, A.data, B.data, C.data, - get_layout_pair_id(), 0); + if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrC = reinterpret_cast(&C.wi_marray); + auto ptrD = reinterpret_cast(&D.wi_marray); + if constexpr (std::is_same::value) { + __imma_m16n16k16_mma_s8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __imma_m16n16k16_mma_u8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } } else if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); if constexpr (std::is_same::value) { - __hmma_m16n16k16_mma_f32f32(D.data, A.data, B.data, C.data, - get_layout_pair_id(), - 0); + __hmma_m16n16k16_mma_f32f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); } else if constexpr (std::is_same::value) { - __hmma_m16n16k16_mma_f16f16(D.data, A.data, B.data, C.data, - get_layout_pair_id(), - 0); + __hmma_m16n16k16_mma_f16f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); } - } else if constexpr (std::is_same::value) { - __mma_bf16_m16n16k16_mma_f32(D.data, A.data, B.data, C.data, - get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value || + std::is_same::value) { + __mma_bf16_m16n16k16_mma_f32( + reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); } } else if constexpr (M == 8 && N == 32 && K == 16) { - if constexpr (std::is_same::value) { - __imma_m8n32k16_mma_s8(D.data, A.data, B.data, C.data, - get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { - __imma_m8n32k16_mma_u8(D.data, A.data, B.data, C.data, - get_layout_pair_id(), 0); + if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrC = reinterpret_cast(&C.wi_marray); + auto ptrD = reinterpret_cast(&D.wi_marray); + if constexpr (std::is_same::value) { + __imma_m8n32k16_mma_s8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __imma_m8n32k16_mma_u8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } } else if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); if constexpr (std::is_same::value) { - __hmma_m8n32k16_mma_f32f32(D.data, A.data, B.data, C.data, - get_layout_pair_id(), 0); + __hmma_m8n32k16_mma_f32f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); } else if constexpr (std::is_same::value) { - __hmma_m8n32k16_mma_f16f16(D.data, A.data, B.data, C.data, - get_layout_pair_id(), 0); + __hmma_m8n32k16_mma_f16f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); } - } else if constexpr (std::is_same::value) { - __mma_bf16_m8n32k16_mma_f32(D.data, A.data, B.data, C.data, - get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value || + std::is_same::value) { + __mma_bf16_m8n32k16_mma_f32( + reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); } } else if constexpr (M == 32 && N == 8 && K == 16) { - if constexpr (std::is_same::value) { - __imma_m32n8k16_mma_s8(D.data, A.data, B.data, C.data, - get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { - __imma_m32n8k16_mma_u8(D.data, A.data, B.data, C.data, - get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { - __mma_bf16_m32n8k16_mma_f32(D.data, A.data, B.data, C.data, - get_layout_pair_id(), 0); + if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); + auto ptrC = reinterpret_cast(&C.wi_marray); + auto ptrD = reinterpret_cast(&D.wi_marray); + if constexpr (std::is_same::value) { + __imma_m32n8k16_mma_s8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __imma_m32n8k16_mma_u8(ptrD, ptrA, ptrB, ptrC, + get_layout_pair_id(), 0); + } + } else if constexpr (std::is_same::value || + std::is_same::value) { + __mma_bf16_m32n8k16_mma_f32( + reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); } else if constexpr (std::is_same::value) { + auto ptrA = reinterpret_cast(&A.wi_marray); + auto ptrB = reinterpret_cast(&B.wi_marray); if constexpr (std::is_same::value) { - __hmma_m32n8k16_mma_f32f32(D.data, A.data, B.data, C.data, - get_layout_pair_id(), 0); + __hmma_m32n8k16_mma_f32f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); } else if constexpr (std::is_same::value) { - __hmma_m32n8k16_mma_f16f16(D.data, A.data, B.data, C.data, - get_layout_pair_id(), 0); + __hmma_m32n8k16_mma_f16f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); } } } else if constexpr (M == 16 && N == 16 && K == 8) { - __mma_tf32_m16n16k8_mma_f32(D.data, reinterpret_cast(A.data), - reinterpret_cast(B.data), C.data, + __mma_tf32_m16n16k8_mma_f32(reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } else if constexpr (std::is_same::value) { - __dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data, + __dmma_m8n8k4_mma_f64(reinterpret_cast(&D.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } return D; @@ -548,10 +695,10 @@ void joint_matrix_load( Layout, Space>{} .load(res, src, stride); #else - (void)sg; - (void)res; - (void)src; - (void)stride; + std::ignore = sg; + std::ignore = res; + std::ignore = src; + std::ignore = stride; throw runtime_error( "When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_load is " "only supported by CUDA devices", @@ -570,10 +717,10 @@ void joint_matrix_store(Group sg, Layout, Space>{} .store(src, dst, stride); #else - (void)sg; - (void)src; - (void)dst; - (void)stride; + std::ignore = sg; + std::ignore = src; + std::ignore = dst; + std::ignore = stride; throw runtime_error( "When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_store is " "only supported by CUDA devices", @@ -594,10 +741,10 @@ joint_matrix_mad( T1, T2, M, K, N, LayoutA, LayoutB, LayoutC>{} .mad(A, B, C); #else - (void)sg; - (void)A; - (void)B; - (void)C; + std::ignore = sg; + std::ignore = A; + std::ignore = B; + std::ignore = C; throw runtime_error("When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_mad is " "only supported by CUDA devices", PI_ERROR_INVALID_DEVICE); @@ -620,7 +767,5 @@ float round_to_tf32(float a) { } } // namespace experimental::matrix -} // namespace oneapi -} // namespace ext -} // namespace sycl +} // namespace sycl::ext::oneapi } // __SYCL_INLINE_NAMESPACE(cl) diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-bfloat16-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-bfloat16-test.cpp new file mode 100644 index 0000000000000..724093bf13e77 --- /dev/null +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-bfloat16-test.cpp @@ -0,0 +1,208 @@ +// REQUIRES: cuda + +// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s + +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; +using sycl::ext::oneapi::experimental::bfloat16; + +constexpr int stride = 16; + +int main() { + + buffer bufA(nullptr, range<1>(1)); + buffer bufB(nullptr, range<1>(1)); + buffer bufC(nullptr, range<1>(1)); + buffer bufD(nullptr, range<1>(1)); + + queue q; + + q.submit([&](handler &cgh) { + sycl::accessor + accA(bufA, cgh); + sycl::accessor + accB(bufB, cgh); + sycl::accessor + accC(bufC, cgh); + sycl::accessor + accD(bufD, cgh); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %17, i32 %18, i32 %19, i32 %20, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %22, float %23, float %24, float %25, float %26, float %27, float %28, float %29, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %17, i32 %18, i32 %19, i32 %20, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %22, float %23, float %24, float %25, float %26, float %27, float %28, float %29, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %15, i32 %16, i32 %17, i32 %18, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %15, i32 %16, i32 %17, i32 %18, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 %11, i32 %12, i32 %15, i32 %16, i32 %17, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 %11, i32 %12, i32 %15, i32 %16, i32 %17, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + }); + + return 0; +}; diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp index 7d9ba2966d866..f5a3fdfbca6af 100644 --- a/sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp @@ -2,7 +2,7 @@ // RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s -#include +#include using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; @@ -30,10 +30,18 @@ int main() { queue q; q.submit([&](handler &cgh) { - auto accC = bufC.get_access(cgh); - auto accA = bufA.get_access(cgh); - auto accB = bufB.get_access(cgh); - auto accD = bufD.get_access(cgh); + sycl::accessor + accA(bufA, cgh); + sycl::accessor + accB(bufB, cgh); + sycl::accessor + accC(bufC, cgh); + sycl::accessor + accD(bufD, cgh); cgh.parallel_for( nd_range<2>({1, 32}, {1, 32}), @@ -61,13 +69,6 @@ int main() { //CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1f64(double addrspace(1)* %_arg_accD, double %6, double %7, i32 8) #{{.*}} joint_matrix_store(sg, sub_c, accD.get_pointer(), N); }); - }); - - q.submit([&](handler &cgh) { - auto accC = bufC.get_access(cgh); - auto accA = bufA.get_access(cgh); - auto accB = bufB.get_access(cgh); - auto accD = bufD.get_access(cgh); cgh.parallel_for( nd_range<2>({1, 32}, {1, 32}), diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-half-float-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-half-float-test.cpp index 538c04bb783e0..d53e6ad37b2ca 100644 --- a/sycl/test/check_device_code/matrix/matrix-nvptx-half-float-test.cpp +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-half-float-test.cpp @@ -2,7 +2,7 @@ // RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_70 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s -#include +#include using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; @@ -19,10 +19,18 @@ int main() { queue q; q.submit([&](handler &cgh) { - auto accC = bufC.get_access(cgh); - auto accA = bufA.get_access(cgh); - auto accB = bufB.get_access(cgh); - auto accD = bufD.get_access(cgh); + sycl::accessor + accA(bufA, cgh); + sycl::accessor + accB(bufB, cgh); + sycl::accessor + accC(bufC, cgh); + sycl::accessor + accD(bufD, cgh); cgh.parallel_for( nd_range<2>({1, 32}, {1, 32}), diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-half-half-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-half-half-test.cpp index 54d7c427d917f..ec425cb228ed8 100644 --- a/sycl/test/check_device_code/matrix/matrix-nvptx-half-half-test.cpp +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-half-half-test.cpp @@ -2,7 +2,7 @@ // RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_70 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s -#include +#include using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; @@ -19,10 +19,18 @@ int main() { queue q; q.submit([&](handler &cgh) { - auto accC = bufC.get_access(cgh); - auto accA = bufA.get_access(cgh); - auto accB = bufB.get_access(cgh); - auto accD = bufD.get_access(cgh); + sycl::accessor + accA(bufA, cgh); + sycl::accessor + accB(bufB, cgh); + sycl::accessor + accC(bufC, cgh); + sycl::accessor + accD(bufD, cgh); cgh.parallel_for( nd_range<2>({1, 32}, {1, 32}), diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-int8-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-int8-test.cpp index 805c7fe02eff5..28166db16b3e3 100644 --- a/sycl/test/check_device_code/matrix/matrix-nvptx-int8-test.cpp +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-int8-test.cpp @@ -2,7 +2,7 @@ // RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_72 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s -#include +#include using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; @@ -19,10 +19,18 @@ int main() { queue q; q.submit([&](handler &cgh) { - auto accC = bufC.get_access(cgh); - auto accA = bufA.get_access(cgh); - auto accB = bufB.get_access(cgh); - auto accD = bufD.get_access(cgh); + sycl::accessor + accA(bufA, cgh); + sycl::accessor + accB(bufB, cgh); + sycl::accessor + accC(bufC, cgh); + sycl::accessor + accD(bufD, cgh); cgh.parallel_for( nd_range<2>({1, 32}, {1, 32}), diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-tf32-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-tf32-test.cpp index 9cdd5e739b00a..9381bf709c2d3 100644 --- a/sycl/test/check_device_code/matrix/matrix-nvptx-tf32-test.cpp +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-tf32-test.cpp @@ -80,10 +80,10 @@ int main() { // CHECK: tail call i32 @llvm.nvvm.f2tf32.rna(float {{.*}} // Round a, b to tf32 for (auto i = 0; i < 4; ++i) - sub_a.data[i] = round_to_tf32(sub_a.data[i]); + sub_a.wi_marray[i] = round_to_tf32(sub_a.wi_marray[i]); for (auto i = 0; i < 4; ++i) - sub_b.data[i] = round_to_tf32(sub_b.data[i]); + sub_b.wi_marray[i] = round_to_tf32(sub_b.wi_marray[i]); //CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.row.row.tf32(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 %{{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) #{{.*}} sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); @@ -125,10 +125,10 @@ int main() { // CHECK: tail call i32 @llvm.nvvm.f2tf32.rna(float {{.*}} // Round a, b to tf32 for (auto i = 0; i < 4; ++i) - sub_a.data[i] = round_to_tf32(sub_a.data[i]); + sub_a.wi_marray[i] = round_to_tf32(sub_a.wi_marray[i]); for (auto i = 0; i < 4; ++i) - sub_b.data[i] = round_to_tf32(sub_b.data[i]); + sub_b.wi_marray[i] = round_to_tf32(sub_b.wi_marray[i]); //CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.col.col.tf32(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) #{{.*}} sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-uint8-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-uint8-test.cpp index 2962810436716..aee19c4fc0ce8 100644 --- a/sycl/test/check_device_code/matrix/matrix-nvptx-uint8-test.cpp +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-uint8-test.cpp @@ -2,7 +2,7 @@ // RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_72 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s -#include +#include using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; @@ -19,10 +19,18 @@ int main() { queue q; q.submit([&](handler &cgh) { - auto accC = bufC.get_access(cgh); - auto accA = bufA.get_access(cgh); - auto accB = bufB.get_access(cgh); - auto accD = bufD.get_access(cgh); + sycl::accessor + accA(bufA, cgh); + sycl::accessor + accB(bufB, cgh); + sycl::accessor + accC(bufC, cgh); + sycl::accessor + accD(bufD, cgh); cgh.parallel_for( nd_range<2>({1, 32}, {1, 32}),