From a58ee5d99cc36468f1469065509458e5869deb77 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Wed, 21 Apr 2021 15:02:04 +0800 Subject: [PATCH] [SYCL][Matrix] Make joint_matrix_mad return A*B+C's result instead of C=A*B+C --- sycl/include/CL/sycl/ONEAPI/matrix/matrix-amx.hpp | 15 ++++++++------- sycl/test/matrix/matrix-amx-bf16-test.cpp | 2 +- sycl/test/matrix/matrix-amx-int8-test.cpp | 2 +- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/sycl/include/CL/sycl/ONEAPI/matrix/matrix-amx.hpp b/sycl/include/CL/sycl/ONEAPI/matrix/matrix-amx.hpp index 9fd294e9a9ca4..13719454ae8e8 100644 --- a/sycl/include/CL/sycl/ONEAPI/matrix/matrix-amx.hpp +++ b/sycl/include/CL/sycl/ONEAPI/matrix/matrix-amx.hpp @@ -405,26 +405,27 @@ inline __SYCL_ALWAYS_INLINE typename std::enable_if< (LayoutA == matrix_layout::row_major) && (LayoutB == matrix_layout::packed_b) && (LayoutC == matrix_layout::row_major), - void>::type + joint_matrix>::type joint_matrix_mad(Group sg, joint_matrix &jmA, joint_matrix &jmB, joint_matrix &jmC) { + joint_matrix res(jmC); constexpr size_t epd = detail::elems_per_dword::value; // If A is large and C is small, in joint_matrix_load, we do memcpy for A, and // we do tileload for C whose shape is not tile_size*tile_size*4. In // joint_matrix_mad, we do tileload for A and shape is tile_size*tile_size*4. // So we need to reshape C before we do dpbssd. - bool Cshouldreload = jmC.isSmall && !jmA.isSmall && !jmB.isSmall; + bool Cshouldreload = res.isSmall && !jmA.isSmall && !jmB.isSmall; bool Ashouldreload = jmA.isSmall && !jmB.isSmall; bool Bshouldreload = jmB.isSmall && !jmA.isSmall; - for (int m = 0; m < jmC.trows; ++m) { - for (int n = 0; n < jmC.tcols; ++n) { + for (int m = 0; m < res.trows; ++m) { + for (int n = 0; n < res.tcols; ++n) { detail::submatrix sub_c; // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - submatrix_load(sub_c, jmC, m * tile_size, n * tile_size, jmC.stride, + submatrix_load(sub_c, res, m * tile_size, n * tile_size, res.stride, matrix_layout::row_major, Cshouldreload); for (int k = 0; k < jmA.tcols; ++k) { // K->int8_t detail::submatrix sub_a; @@ -436,11 +437,11 @@ joint_matrix_mad(Group sg, jmB.stride, matrix_layout::packed_b, Bshouldreload); submatrix_mad(sub_a, sub_b, sub_c); } - submatrix_store(sub_c, jmC, m * tile_size, n * tile_size, jmC.stride, + submatrix_store(sub_c, res, m * tile_size, n * tile_size, res.stride, matrix_layout::row_major, Cshouldreload); } } - return; + return res; } } // namespace experimental::matrix diff --git a/sycl/test/matrix/matrix-amx-bf16-test.cpp b/sycl/test/matrix/matrix-amx-bf16-test.cpp index 4ace31f53f2f4..3490cb6c5f4c7 100644 --- a/sycl/test/matrix/matrix-amx-bf16-test.cpp +++ b/sycl/test/matrix/matrix-amx-bf16-test.cpp @@ -83,7 +83,7 @@ void matrix_multiply(big_matrix &C, big_matrix &C, big_matrix