Skip to content

Commit a58ee5d

Browse files
[SYCL][Matrix] Make joint_matrix_mad return A*B+C's result instead of C=A*B+C
1 parent d482ac3 commit a58ee5d

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

sycl/include/CL/sycl/ONEAPI/matrix/matrix-amx.hpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -405,26 +405,27 @@ inline __SYCL_ALWAYS_INLINE typename std::enable_if<
405405
(LayoutA == matrix_layout::row_major) &&
406406
(LayoutB == matrix_layout::packed_b) &&
407407
(LayoutC == matrix_layout::row_major),
408-
void>::type
408+
joint_matrix<Group, T2, NumRowsC, NumColsC, LayoutC>>::type
409409
joint_matrix_mad(Group sg,
410410
joint_matrix<Group, T1, NumRowsA, NumColsA, LayoutA> &jmA,
411411
joint_matrix<Group, T1, NumRowsB, NumColsB, LayoutB> &jmB,
412412
joint_matrix<Group, T2, NumRowsC, NumColsC, LayoutC> &jmC) {
413+
joint_matrix<Group, T2, NumRowsC, NumColsC, LayoutC> res(jmC);
413414
constexpr size_t epd = detail::elems_per_dword<T1>::value;
414415
// If A is large and C is small, in joint_matrix_load, we do memcpy for A, and
415416
// we do tileload for C whose shape is not tile_size*tile_size*4. In
416417
// joint_matrix_mad, we do tileload for A and shape is tile_size*tile_size*4.
417418
// So we need to reshape C before we do dpbssd.
418-
bool Cshouldreload = jmC.isSmall && !jmA.isSmall && !jmB.isSmall;
419+
bool Cshouldreload = res.isSmall && !jmA.isSmall && !jmB.isSmall;
419420
bool Ashouldreload = jmA.isSmall && !jmB.isSmall;
420421
bool Bshouldreload = jmB.isSmall && !jmA.isSmall;
421422

422-
for (int m = 0; m < jmC.trows; ++m) {
423-
for (int n = 0; n < jmC.tcols; ++n) {
423+
for (int m = 0; m < res.trows; ++m) {
424+
for (int n = 0; n < res.tcols; ++n) {
424425
detail::submatrix<T2> sub_c;
425426

426427
// AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64
427-
submatrix_load(sub_c, jmC, m * tile_size, n * tile_size, jmC.stride,
428+
submatrix_load(sub_c, res, m * tile_size, n * tile_size, res.stride,
428429
matrix_layout::row_major, Cshouldreload);
429430
for (int k = 0; k < jmA.tcols; ++k) { // K->int8_t
430431
detail::submatrix<T1> sub_a;
@@ -436,11 +437,11 @@ joint_matrix_mad(Group sg,
436437
jmB.stride, matrix_layout::packed_b, Bshouldreload);
437438
submatrix_mad(sub_a, sub_b, sub_c);
438439
}
439-
submatrix_store(sub_c, jmC, m * tile_size, n * tile_size, jmC.stride,
440+
submatrix_store(sub_c, res, m * tile_size, n * tile_size, res.stride,
440441
matrix_layout::row_major, Cshouldreload);
441442
}
442443
}
443-
return;
444+
return res;
444445
}
445446

446447
} // namespace experimental::matrix

sycl/test/matrix/matrix-amx-bf16-test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C, big_matrix<T2, N
8383
accB.get_pointer() +
8484
(k * TK / 2) * (N * 2) + sg_starty * TN * 2,
8585
N * 2, matrix_layout::packed_b);
86-
joint_matrix_mad(sg, sub_a, sub_b, sub_c);
86+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
8787
}
8888
joint_matrix_store(sg, sub_c,
8989
accC.get_pointer() + (sg_startx * TM) * N +

sycl/test/matrix/matrix-amx-int8-test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C, big_matrix<T2, N
8383
accB.get_pointer() +
8484
(k * TK / 4) * (N * 4) + sg_starty * TN * 4,
8585
N * 4, matrix_layout::packed_b);
86-
joint_matrix_mad(sg, sub_a, sub_b, sub_c);
86+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
8787
}
8888
joint_matrix_store(sg, sub_c,
8989
accC.get_pointer() + (sg_startx * TM) * N +

0 commit comments

Comments
 (0)