Skip to content

Commit d32f5da

Browse files
Address Alexander's comments
1 parent fd14f72 commit d32f5da

File tree

1 file changed

+82
-90
lines changed

1 file changed

+82
-90
lines changed

sycl/include/CL/sycl/ONEAPI/matrix/matrix-amx.hpp

Lines changed: 82 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,8 @@ template <typename Group, typename T, size_t NumRows, size_t NumCols,
180180
matrix_layout Layout>
181181
struct joint_matrix<
182182
Group, T, NumRows, NumCols, Layout,
183-
typename std::enable_if<(NumRows <= tile_size) &&
184-
(NumCols * sizeof(T) / 4 <= tile_size)>::type> {
183+
typename std::enable_if_t<(NumRows <= tile_size) &&
184+
(NumCols * sizeof(T) / 4 <= tile_size)>> {
185185
public:
186186
static constexpr size_t trows = (NumRows + tile_size - 1) / tile_size;
187187
// tcols: Num of tiles in column.
@@ -201,51 +201,51 @@ struct joint_matrix<
201201

202202
namespace detail {
203203

204+
using namespace experimental;
205+
204206
template <typename Group, typename T, size_t NumRows, size_t NumCols,
205-
experimental::matrix::matrix_layout Layout>
206-
inline __SYCL_ALWAYS_INLINE static typename std::enable_if<
207-
(NumRows > experimental::matrix::tile_size) ||
208-
(NumCols * sizeof(T) / 4 > experimental::matrix::tile_size),
209-
void>::type
210-
submatrix_load(
211-
detail::submatrix<T> &sub_m,
212-
experimental::matrix::joint_matrix<Group, T, NumRows, NumCols, Layout> jm,
213-
uint32_t row, uint32_t col, size_t stride,
214-
experimental::matrix::matrix_layout layout, bool shouldreload) {
207+
matrix::matrix_layout Layout>
208+
inline __SYCL_ALWAYS_INLINE static
209+
typename std::enable_if_t<(NumRows > matrix::tile_size) ||
210+
(NumCols * sizeof(T) / 4 > matrix::tile_size),
211+
void>
212+
submatrix_load(detail::submatrix<T> &sub_m,
213+
matrix::joint_matrix<Group, T, NumRows, NumCols, Layout> jm,
214+
uint32_t row, uint32_t col, size_t stride,
215+
matrix::matrix_layout layout, bool shouldreload) {
215216
uint32_t offset = (row * stride + col);
216217
T *ptr = reinterpret_cast<T *>(jm.raw_storage);
217218
ptr += offset;
218219
stride *= sizeof(T);
219-
sub_m.rows = experimental::matrix::tile_size;
220-
sub_m.cols = experimental::matrix::tile_size * 4;
221-
sub_m.tile = experimental::matrix::tileloadd64_internal(
220+
sub_m.rows = matrix::tile_size;
221+
sub_m.cols = matrix::tile_size * 4;
222+
sub_m.tile = matrix::tileloadd64_internal(
222223
sub_m.rows, sub_m.cols, reinterpret_cast<char *>(ptr), stride);
223224
}
224225

225226
template <typename Group, typename T, size_t NumRows, size_t NumCols,
226-
experimental::matrix::matrix_layout Layout>
227-
inline __SYCL_ALWAYS_INLINE static typename std::enable_if<
228-
(NumRows <= experimental::matrix::tile_size) &&
229-
(NumCols * sizeof(T) / 4 <= experimental::matrix::tile_size),
230-
void>::type
231-
submatrix_load(
232-
detail::submatrix<T> &sub_m,
233-
experimental::matrix::joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
234-
uint32_t row, uint32_t col, size_t stride,
235-
experimental::matrix::matrix_layout layout, bool shouldreload) {
227+
matrix::matrix_layout Layout>
228+
inline __SYCL_ALWAYS_INLINE static
229+
typename std::enable_if_t<(NumRows <= matrix::tile_size) &&
230+
(NumCols * sizeof(T) / 4 <=
231+
matrix::tile_size),
232+
void>
233+
submatrix_load(detail::submatrix<T> &sub_m,
234+
matrix::joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
235+
uint32_t row, uint32_t col, size_t stride,
236+
matrix::matrix_layout layout, bool shouldreload) {
236237
if (shouldreload) {
237-
// Force sub_m.tile's shape to be experimental::matrix::tile_size *
238-
// experimental::matrix::tile_size * 4
239-
int8_t NewjmC[experimental::matrix::tile_size *
240-
experimental::matrix::tile_size * 4];
241-
experimental::matrix::tilestored64_internal(
242-
NumRows, NumCols * sizeof(T), reinterpret_cast<char *>(NewjmC),
243-
experimental::matrix::tile_size * 4, jm.tile);
244-
sub_m.rows = experimental::matrix::tile_size;
245-
sub_m.cols = experimental::matrix::tile_size * 4;
246-
sub_m.tile = experimental::matrix::tileloadd64_internal(
247-
sub_m.rows, sub_m.cols, reinterpret_cast<char *>(NewjmC),
248-
experimental::matrix::tile_size * 4);
238+
// Force sub_m.tile's shape to be matrix::tile_size *
239+
// matrix::tile_size * 4
240+
int8_t NewjmC[matrix::tile_size * matrix::tile_size * 4];
241+
matrix::tilestored64_internal(NumRows, NumCols * sizeof(T),
242+
reinterpret_cast<char *>(NewjmC),
243+
matrix::tile_size * 4, jm.tile);
244+
sub_m.rows = matrix::tile_size;
245+
sub_m.cols = matrix::tile_size * 4;
246+
sub_m.tile = matrix::tileloadd64_internal(sub_m.rows, sub_m.cols,
247+
reinterpret_cast<char *>(NewjmC),
248+
matrix::tile_size * 4);
249249
return;
250250
}
251251
sub_m.rows = NumRows;
@@ -258,62 +258,56 @@ inline __SYCL_ALWAYS_INLINE static void
258258
submatrix_mad(detail::submatrix<int8_t> &sub_ma,
259259
detail::submatrix<int8_t> &sub_mb,
260260
detail::submatrix<int32_t> &sub_mc) {
261-
sub_mc.tile = experimental::matrix::tdpbssd_internal(
262-
sub_mc.rows, sub_mc.cols, sub_ma.cols, sub_mc.tile, sub_ma.tile,
263-
sub_mb.tile);
261+
sub_mc.tile = matrix::tdpbssd_internal(sub_mc.rows, sub_mc.cols, sub_ma.cols,
262+
sub_mc.tile, sub_ma.tile, sub_mb.tile);
264263
}
265264

266265
// This handles cases where T1 is int16(bfloat16), T2 is float.
267266
inline __SYCL_ALWAYS_INLINE static void
268267
submatrix_mad(detail::submatrix<unsigned short> &sub_ma,
269268
detail::submatrix<unsigned short> &sub_mb,
270269
detail::submatrix<float> &sub_mc) {
271-
sub_mc.tile = experimental::matrix::tdpbf16ps_internal(
272-
sub_mc.rows, sub_mc.cols, sub_ma.cols, sub_mc.tile, sub_ma.tile,
273-
sub_mb.tile);
270+
sub_mc.tile =
271+
matrix::tdpbf16ps_internal(sub_mc.rows, sub_mc.cols, sub_ma.cols,
272+
sub_mc.tile, sub_ma.tile, sub_mb.tile);
274273
}
275274

276275
template <typename Group, typename T, size_t NumRows, size_t NumCols>
277276
inline __SYCL_ALWAYS_INLINE static
278-
typename std::enable_if<(NumRows > experimental::matrix::tile_size) ||
279-
(NumCols * sizeof(T) / 4 >
280-
experimental::matrix::tile_size),
281-
void>::type
282-
submatrix_store(
283-
detail::submatrix<T> &sub_m,
284-
experimental::matrix::joint_matrix<Group, T, NumRows, NumCols> &jm,
285-
uint32_t row, uint32_t col, size_t stride,
286-
experimental::matrix::matrix_layout layout, bool shouldreload) {
277+
typename std::enable_if_t<(NumRows > matrix::tile_size) ||
278+
(NumCols * sizeof(T) / 4 > matrix::tile_size),
279+
void>
280+
submatrix_store(detail::submatrix<T> &sub_m,
281+
matrix::joint_matrix<Group, T, NumRows, NumCols> &jm,
282+
uint32_t row, uint32_t col, size_t stride,
283+
matrix::matrix_layout layout, bool shouldreload) {
287284
uint32_t offset = (row * stride + col);
288285
T *ptr = reinterpret_cast<T *>(jm.raw_storage);
289286
ptr += offset;
290287
stride *= sizeof(T);
291-
experimental::matrix::tilestored64_internal(sub_m.rows, sub_m.cols,
292-
reinterpret_cast<char *>(ptr),
293-
stride, sub_m.tile);
288+
matrix::tilestored64_internal(sub_m.rows, sub_m.cols,
289+
reinterpret_cast<char *>(ptr), stride,
290+
sub_m.tile);
294291
}
295292

296293
template <typename Group, typename T, size_t NumRows, size_t NumCols>
297294
inline __SYCL_ALWAYS_INLINE static
298-
typename std::enable_if<(NumRows <= experimental::matrix::tile_size) &&
299-
(NumCols * sizeof(T) / 4 <=
300-
experimental::matrix::tile_size),
301-
void>::type
302-
submatrix_store(
303-
detail::submatrix<T> &sub_m,
304-
experimental::matrix::joint_matrix<Group, T, NumRows, NumCols> &jm,
305-
uint32_t row, uint32_t col, size_t stride,
306-
experimental::matrix::matrix_layout layout, bool shouldreload) {
295+
typename std::enable_if_t<(NumRows <= matrix::tile_size) &&
296+
(NumCols * sizeof(T) / 4 <=
297+
matrix::tile_size),
298+
void>
299+
submatrix_store(detail::submatrix<T> &sub_m,
300+
matrix::joint_matrix<Group, T, NumRows, NumCols> &jm,
301+
uint32_t row, uint32_t col, size_t stride,
302+
matrix::matrix_layout layout, bool shouldreload) {
307303
if (shouldreload) {
308-
int8_t NewjmC[experimental::matrix::tile_size *
309-
experimental::matrix::tile_size * 4];
310-
experimental::matrix::tilestored64_internal(
311-
experimental::matrix::tile_size, experimental::matrix::tile_size * 4,
312-
reinterpret_cast<char *>(NewjmC), experimental::matrix::tile_size * 4,
313-
sub_m.tile);
314-
jm.tile = experimental::matrix::tileloadd64_internal(
315-
NumRows, NumCols * sizeof(T), reinterpret_cast<char *>(NewjmC),
316-
experimental::matrix::tile_size * 4);
304+
int8_t NewjmC[matrix::tile_size * matrix::tile_size * 4];
305+
matrix::tilestored64_internal(matrix::tile_size, matrix::tile_size * 4,
306+
reinterpret_cast<char *>(NewjmC),
307+
matrix::tile_size * 4, sub_m.tile);
308+
jm.tile = matrix::tileloadd64_internal(NumRows, NumCols * sizeof(T),
309+
reinterpret_cast<char *>(NewjmC),
310+
matrix::tile_size * 4);
317311
return;
318312
}
319313
jm.tile = sub_m.tile;
@@ -326,8 +320,8 @@ namespace experimental::matrix {
326320
// This handles cases where matrix can't be accommodated by a tile
327321
template <typename Group, typename T, size_t NumRows, size_t NumCols,
328322
matrix_layout Layout, access::address_space Space>
329-
inline __SYCL_ALWAYS_INLINE typename std::enable_if<
330-
(NumRows > tile_size) || (NumCols * sizeof(T) / 4 > tile_size), void>::type
323+
inline __SYCL_ALWAYS_INLINE typename std::enable_if_t<
324+
(NumRows > tile_size) || (NumCols * sizeof(T) / 4 > tile_size), void>
331325
joint_matrix_load(Group sg,
332326
joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
333327
multi_ptr<T, Space> src, size_t stride,
@@ -347,14 +341,12 @@ joint_matrix_load(Group sg,
347341
// This handles cases where matrix can be put into a tile
348342
template <typename Group, typename T, size_t NumRows, size_t NumCols,
349343
matrix_layout Layout, access::address_space Space>
350-
inline __SYCL_ALWAYS_INLINE
351-
typename std::enable_if<(NumRows <= tile_size) &&
352-
(NumCols * sizeof(T) / 4 <= tile_size),
353-
void>::type
354-
joint_matrix_load(Group sg,
355-
joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
356-
multi_ptr<T, Space> src, size_t stride,
357-
matrix_layout layout) {
344+
inline __SYCL_ALWAYS_INLINE typename std::enable_if_t<
345+
(NumRows <= tile_size) && (NumCols * sizeof(T) / 4 <= tile_size), void>
346+
joint_matrix_load(Group sg,
347+
joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
348+
multi_ptr<T, Space> src, size_t stride,
349+
matrix_layout layout) {
358350
T *mem = src.get();
359351
// tileload happens!
360352
jm.tile =
@@ -366,8 +358,8 @@ inline __SYCL_ALWAYS_INLINE
366358
// This handles cases where matrix can't be accommodated by a tile
367359
template <typename Group, typename T, size_t NumRows, size_t NumCols,
368360
matrix_layout Layout, access::address_space Space>
369-
inline __SYCL_ALWAYS_INLINE typename std::enable_if<
370-
(NumRows > tile_size) || (NumCols * sizeof(T) / 4 > tile_size), void>::type
361+
inline __SYCL_ALWAYS_INLINE typename std::enable_if_t<
362+
(NumRows > tile_size) || (NumCols * sizeof(T) / 4 > tile_size), void>
371363
joint_matrix_store(Group sg,
372364
joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
373365
multi_ptr<T, Space> dst, size_t stride,
@@ -387,9 +379,9 @@ joint_matrix_store(Group sg,
387379
template <typename Group, typename T, size_t NumRows, size_t NumCols,
388380
matrix_layout Layout, access::address_space Space>
389381
inline __SYCL_ALWAYS_INLINE
390-
typename std::enable_if<(NumRows <= tile_size) &&
391-
(NumCols * sizeof(T) / 4 <= tile_size),
392-
void>::type
382+
typename std::enable_if_t<(NumRows <= tile_size) &&
383+
(NumCols * sizeof(T) / 4 <= tile_size),
384+
void>::type
393385
joint_matrix_store(Group sg,
394386
joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
395387
multi_ptr<T, Space> dst, size_t stride,
@@ -406,14 +398,14 @@ template <typename Group, typename T1, typename T2, size_t NumRowsA,
406398
size_t NumColsA, size_t NumRowsB, size_t NumColsB, size_t NumRowsC,
407399
size_t NumColsC, matrix_layout LayoutA, matrix_layout LayoutB,
408400
matrix_layout LayoutC>
409-
inline __SYCL_ALWAYS_INLINE typename std::enable_if<
401+
inline __SYCL_ALWAYS_INLINE typename std::enable_if_t<
410402
((std::is_same<T1, int8_t>::value && std::is_same<T2, int32_t>::value) ||
411403
(std::is_same<T1, unsigned short>::value &&
412404
std::is_same<T2, float>::value)) &&
413405
(LayoutA == matrix_layout::row_major) &&
414406
(LayoutB == matrix_layout::packed_b) &&
415407
(LayoutC == matrix_layout::row_major),
416-
void>::type
408+
void>
417409
joint_matrix_mad(Group sg,
418410
joint_matrix<Group, T1, NumRowsA, NumColsA, LayoutA> &jmA,
419411
joint_matrix<Group, T1, NumRowsB, NumColsB, LayoutB> &jmB,

0 commit comments

Comments
 (0)