@@ -405,26 +405,27 @@ inline __SYCL_ALWAYS_INLINE typename std::enable_if<
405
405
(LayoutA == matrix_layout::row_major) &&
406
406
(LayoutB == matrix_layout::packed_b) &&
407
407
(LayoutC == matrix_layout::row_major),
408
- void >::type
408
+ joint_matrix<Group, T2, NumRowsC, NumColsC, LayoutC> >::type
409
409
joint_matrix_mad (Group sg,
410
410
joint_matrix<Group, T1, NumRowsA, NumColsA, LayoutA> &jmA,
411
411
joint_matrix<Group, T1, NumRowsB, NumColsB, LayoutB> &jmB,
412
412
joint_matrix<Group, T2, NumRowsC, NumColsC, LayoutC> &jmC) {
413
+ joint_matrix<Group, T2, NumRowsC, NumColsC, LayoutC> res (jmC);
413
414
constexpr size_t epd = detail::elems_per_dword<T1>::value;
414
415
// If A is large and C is small, in joint_matrix_load, we do memcpy for A, and
415
416
// we do tileload for C whose shape is not tile_size*tile_size*4. In
416
417
// joint_matrix_mad, we do tileload for A and shape is tile_size*tile_size*4.
417
418
// So we need to reshape C before we do dpbssd.
418
- bool Cshouldreload = jmC .isSmall && !jmA.isSmall && !jmB.isSmall ;
419
+ bool Cshouldreload = res .isSmall && !jmA.isSmall && !jmB.isSmall ;
419
420
bool Ashouldreload = jmA.isSmall && !jmB.isSmall ;
420
421
bool Bshouldreload = jmB.isSmall && !jmA.isSmall ;
421
422
422
- for (int m = 0 ; m < jmC .trows ; ++m) {
423
- for (int n = 0 ; n < jmC .tcols ; ++n) {
423
+ for (int m = 0 ; m < res .trows ; ++m) {
424
+ for (int n = 0 ; n < res .tcols ; ++n) {
424
425
detail::submatrix<T2> sub_c;
425
426
426
427
// AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64
427
- submatrix_load (sub_c, jmC , m * tile_size, n * tile_size, jmC .stride ,
428
+ submatrix_load (sub_c, res , m * tile_size, n * tile_size, res .stride ,
428
429
matrix_layout::row_major, Cshouldreload);
429
430
for (int k = 0 ; k < jmA.tcols ; ++k) { // K->int8_t
430
431
detail::submatrix<T1> sub_a;
@@ -436,11 +437,11 @@ joint_matrix_mad(Group sg,
436
437
jmB.stride , matrix_layout::packed_b, Bshouldreload);
437
438
submatrix_mad (sub_a, sub_b, sub_c);
438
439
}
439
- submatrix_store (sub_c, jmC , m * tile_size, n * tile_size, jmC .stride ,
440
+ submatrix_store (sub_c, res , m * tile_size, n * tile_size, res .stride ,
440
441
matrix_layout::row_major, Cshouldreload);
441
442
}
442
443
}
443
- return ;
444
+ return res ;
444
445
}
445
446
446
447
} // namespace experimental::matrix
0 commit comments