@@ -180,8 +180,8 @@ template <typename Group, typename T, size_t NumRows, size_t NumCols,
180
180
matrix_layout Layout>
181
181
struct joint_matrix <
182
182
Group, T, NumRows, NumCols, Layout,
183
- typename std::enable_if <(NumRows <= tile_size) &&
184
- (NumCols * sizeof (T) / 4 <= tile_size)>::type > {
183
+ typename std::enable_if_t <(NumRows <= tile_size) &&
184
+ (NumCols * sizeof (T) / 4 <= tile_size)>> {
185
185
public:
186
186
static constexpr size_t trows = (NumRows + tile_size - 1 ) / tile_size;
187
187
// tcols: Num of tiles in column.
@@ -201,51 +201,51 @@ struct joint_matrix<
201
201
202
202
namespace detail {
203
203
204
+ using namespace experimental ;
205
+
204
206
template <typename Group, typename T, size_t NumRows, size_t NumCols,
205
- experimental::matrix::matrix_layout Layout>
206
- inline __SYCL_ALWAYS_INLINE static typename std::enable_if<
207
- (NumRows > experimental::matrix::tile_size) ||
208
- (NumCols * sizeof (T) / 4 > experimental::matrix::tile_size),
209
- void >::type
210
- submatrix_load (
211
- detail::submatrix<T> &sub_m,
212
- experimental::matrix::joint_matrix<Group, T, NumRows, NumCols, Layout> jm,
213
- uint32_t row, uint32_t col, size_t stride,
214
- experimental::matrix::matrix_layout layout, bool shouldreload) {
207
+ matrix::matrix_layout Layout>
208
+ inline __SYCL_ALWAYS_INLINE static
209
+ typename std::enable_if_t <(NumRows > matrix::tile_size) ||
210
+ (NumCols * sizeof (T) / 4 > matrix::tile_size),
211
+ void >
212
+ submatrix_load (detail::submatrix<T> &sub_m,
213
+ matrix::joint_matrix<Group, T, NumRows, NumCols, Layout> jm,
214
+ uint32_t row, uint32_t col, size_t stride,
215
+ matrix::matrix_layout layout, bool shouldreload) {
215
216
uint32_t offset = (row * stride + col);
216
217
T *ptr = reinterpret_cast <T *>(jm.raw_storage );
217
218
ptr += offset;
218
219
stride *= sizeof (T);
219
- sub_m.rows = experimental:: matrix::tile_size;
220
- sub_m.cols = experimental:: matrix::tile_size * 4 ;
221
- sub_m.tile = experimental:: matrix::tileloadd64_internal (
220
+ sub_m.rows = matrix::tile_size;
221
+ sub_m.cols = matrix::tile_size * 4 ;
222
+ sub_m.tile = matrix::tileloadd64_internal (
222
223
sub_m.rows , sub_m.cols , reinterpret_cast <char *>(ptr), stride);
223
224
}
224
225
225
226
template <typename Group, typename T, size_t NumRows, size_t NumCols,
226
- experimental:: matrix::matrix_layout Layout>
227
- inline __SYCL_ALWAYS_INLINE static typename std::enable_if<
228
- (NumRows <= experimental:: matrix::tile_size) &&
229
- (NumCols * sizeof (T) / 4 <= experimental::matrix::tile_size),
230
- void >::type
231
- submatrix_load (
232
- detail::submatrix<T> &sub_m,
233
- experimental:: matrix::joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
234
- uint32_t row, uint32_t col, size_t stride,
235
- experimental:: matrix::matrix_layout layout, bool shouldreload) {
227
+ matrix::matrix_layout Layout>
228
+ inline __SYCL_ALWAYS_INLINE static
229
+ typename std:: enable_if_t < (NumRows <= matrix::tile_size) &&
230
+ (NumCols * sizeof (T) / 4 <=
231
+ matrix::tile_size),
232
+ void >
233
+ submatrix_load ( detail::submatrix<T> &sub_m,
234
+ matrix::joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
235
+ uint32_t row, uint32_t col, size_t stride,
236
+ matrix::matrix_layout layout, bool shouldreload) {
236
237
if (shouldreload) {
237
- // Force sub_m.tile's shape to be experimental::matrix::tile_size *
238
- // experimental::matrix::tile_size * 4
239
- int8_t NewjmC[experimental::matrix::tile_size *
240
- experimental::matrix::tile_size * 4 ];
241
- experimental::matrix::tilestored64_internal (
242
- NumRows, NumCols * sizeof (T), reinterpret_cast <char *>(NewjmC),
243
- experimental::matrix::tile_size * 4 , jm.tile );
244
- sub_m.rows = experimental::matrix::tile_size;
245
- sub_m.cols = experimental::matrix::tile_size * 4 ;
246
- sub_m.tile = experimental::matrix::tileloadd64_internal (
247
- sub_m.rows , sub_m.cols , reinterpret_cast <char *>(NewjmC),
248
- experimental::matrix::tile_size * 4 );
238
+ // Force sub_m.tile's shape to be matrix::tile_size *
239
+ // matrix::tile_size * 4
240
+ int8_t NewjmC[matrix::tile_size * matrix::tile_size * 4 ];
241
+ matrix::tilestored64_internal (NumRows, NumCols * sizeof (T),
242
+ reinterpret_cast <char *>(NewjmC),
243
+ matrix::tile_size * 4 , jm.tile );
244
+ sub_m.rows = matrix::tile_size;
245
+ sub_m.cols = matrix::tile_size * 4 ;
246
+ sub_m.tile = matrix::tileloadd64_internal (sub_m.rows , sub_m.cols ,
247
+ reinterpret_cast <char *>(NewjmC),
248
+ matrix::tile_size * 4 );
249
249
return ;
250
250
}
251
251
sub_m.rows = NumRows;
@@ -258,62 +258,56 @@ inline __SYCL_ALWAYS_INLINE static void
258
258
submatrix_mad (detail::submatrix<int8_t > &sub_ma,
259
259
detail::submatrix<int8_t > &sub_mb,
260
260
detail::submatrix<int32_t > &sub_mc) {
261
- sub_mc.tile = experimental::matrix::tdpbssd_internal (
262
- sub_mc.rows , sub_mc.cols , sub_ma.cols , sub_mc.tile , sub_ma.tile ,
263
- sub_mb.tile );
261
+ sub_mc.tile = matrix::tdpbssd_internal (sub_mc.rows , sub_mc.cols , sub_ma.cols ,
262
+ sub_mc.tile , sub_ma.tile , sub_mb.tile );
264
263
}
265
264
266
265
// This handles cases where T1 is int16(bfloat16), T2 is float.
267
266
inline __SYCL_ALWAYS_INLINE static void
268
267
submatrix_mad (detail::submatrix<unsigned short > &sub_ma,
269
268
detail::submatrix<unsigned short > &sub_mb,
270
269
detail::submatrix<float > &sub_mc) {
271
- sub_mc.tile = experimental::matrix::tdpbf16ps_internal (
272
- sub_mc.rows , sub_mc.cols , sub_ma.cols , sub_mc. tile , sub_ma. tile ,
273
- sub_mb.tile );
270
+ sub_mc.tile =
271
+ matrix::tdpbf16ps_internal ( sub_mc.rows , sub_mc.cols , sub_ma.cols ,
272
+ sub_mc. tile , sub_ma. tile , sub_mb.tile );
274
273
}
275
274
276
275
template <typename Group, typename T, size_t NumRows, size_t NumCols>
277
276
inline __SYCL_ALWAYS_INLINE static
278
- typename std::enable_if<(NumRows > experimental::matrix::tile_size) ||
279
- (NumCols * sizeof (T) / 4 >
280
- experimental::matrix::tile_size),
281
- void >::type
282
- submatrix_store (
283
- detail::submatrix<T> &sub_m,
284
- experimental::matrix::joint_matrix<Group, T, NumRows, NumCols> &jm,
285
- uint32_t row, uint32_t col, size_t stride,
286
- experimental::matrix::matrix_layout layout, bool shouldreload) {
277
+ typename std::enable_if_t <(NumRows > matrix::tile_size) ||
278
+ (NumCols * sizeof (T) / 4 > matrix::tile_size),
279
+ void >
280
+ submatrix_store (detail::submatrix<T> &sub_m,
281
+ matrix::joint_matrix<Group, T, NumRows, NumCols> &jm,
282
+ uint32_t row, uint32_t col, size_t stride,
283
+ matrix::matrix_layout layout, bool shouldreload) {
287
284
uint32_t offset = (row * stride + col);
288
285
T *ptr = reinterpret_cast <T *>(jm.raw_storage );
289
286
ptr += offset;
290
287
stride *= sizeof (T);
291
- experimental:: matrix::tilestored64_internal (sub_m.rows , sub_m.cols ,
292
- reinterpret_cast <char *>(ptr),
293
- stride, sub_m.tile );
288
+ matrix::tilestored64_internal (sub_m.rows , sub_m.cols ,
289
+ reinterpret_cast <char *>(ptr), stride ,
290
+ sub_m.tile );
294
291
}
295
292
296
293
template <typename Group, typename T, size_t NumRows, size_t NumCols>
297
294
inline __SYCL_ALWAYS_INLINE static
298
- typename std::enable_if<(NumRows <= experimental::matrix::tile_size) &&
299
- (NumCols * sizeof (T) / 4 <=
300
- experimental::matrix::tile_size),
301
- void >::type
302
- submatrix_store (
303
- detail::submatrix<T> &sub_m,
304
- experimental::matrix::joint_matrix<Group, T, NumRows, NumCols> &jm,
305
- uint32_t row, uint32_t col, size_t stride,
306
- experimental::matrix::matrix_layout layout, bool shouldreload) {
295
+ typename std::enable_if_t <(NumRows <= matrix::tile_size) &&
296
+ (NumCols * sizeof (T) / 4 <=
297
+ matrix::tile_size),
298
+ void >
299
+ submatrix_store (detail::submatrix<T> &sub_m,
300
+ matrix::joint_matrix<Group, T, NumRows, NumCols> &jm,
301
+ uint32_t row, uint32_t col, size_t stride,
302
+ matrix::matrix_layout layout, bool shouldreload) {
307
303
if (shouldreload) {
308
- int8_t NewjmC[experimental::matrix::tile_size *
309
- experimental::matrix::tile_size * 4 ];
310
- experimental::matrix::tilestored64_internal (
311
- experimental::matrix::tile_size, experimental::matrix::tile_size * 4 ,
312
- reinterpret_cast <char *>(NewjmC), experimental::matrix::tile_size * 4 ,
313
- sub_m.tile );
314
- jm.tile = experimental::matrix::tileloadd64_internal (
315
- NumRows, NumCols * sizeof (T), reinterpret_cast <char *>(NewjmC),
316
- experimental::matrix::tile_size * 4 );
304
+ int8_t NewjmC[matrix::tile_size * matrix::tile_size * 4 ];
305
+ matrix::tilestored64_internal (matrix::tile_size, matrix::tile_size * 4 ,
306
+ reinterpret_cast <char *>(NewjmC),
307
+ matrix::tile_size * 4 , sub_m.tile );
308
+ jm.tile = matrix::tileloadd64_internal (NumRows, NumCols * sizeof (T),
309
+ reinterpret_cast <char *>(NewjmC),
310
+ matrix::tile_size * 4 );
317
311
return ;
318
312
}
319
313
jm.tile = sub_m.tile ;
@@ -326,8 +320,8 @@ namespace experimental::matrix {
326
320
// This handles cases where matrix can't be accommodated by a tile
327
321
template <typename Group, typename T, size_t NumRows, size_t NumCols,
328
322
matrix_layout Layout, access::address_space Space>
329
- inline __SYCL_ALWAYS_INLINE typename std::enable_if <
330
- (NumRows > tile_size) || (NumCols * sizeof (T) / 4 > tile_size), void >::type
323
+ inline __SYCL_ALWAYS_INLINE typename std::enable_if_t <
324
+ (NumRows > tile_size) || (NumCols * sizeof (T) / 4 > tile_size), void >
331
325
joint_matrix_load (Group sg,
332
326
joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
333
327
multi_ptr<T, Space> src, size_t stride,
@@ -347,14 +341,12 @@ joint_matrix_load(Group sg,
347
341
// This handles cases where matrix can be put into a tile
348
342
template <typename Group, typename T, size_t NumRows, size_t NumCols,
349
343
matrix_layout Layout, access::address_space Space>
350
- inline __SYCL_ALWAYS_INLINE
351
- typename std::enable_if<(NumRows <= tile_size) &&
352
- (NumCols * sizeof (T) / 4 <= tile_size),
353
- void >::type
354
- joint_matrix_load (Group sg,
355
- joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
356
- multi_ptr<T, Space> src, size_t stride,
357
- matrix_layout layout) {
344
+ inline __SYCL_ALWAYS_INLINE typename std::enable_if_t <
345
+ (NumRows <= tile_size) && (NumCols * sizeof (T) / 4 <= tile_size), void >
346
+ joint_matrix_load (Group sg,
347
+ joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
348
+ multi_ptr<T, Space> src, size_t stride,
349
+ matrix_layout layout) {
358
350
T *mem = src.get ();
359
351
// tileload happens!
360
352
jm.tile =
@@ -366,8 +358,8 @@ inline __SYCL_ALWAYS_INLINE
366
358
// This handles cases where matrix can't be accommodated by a tile
367
359
template <typename Group, typename T, size_t NumRows, size_t NumCols,
368
360
matrix_layout Layout, access::address_space Space>
369
- inline __SYCL_ALWAYS_INLINE typename std::enable_if <
370
- (NumRows > tile_size) || (NumCols * sizeof (T) / 4 > tile_size), void >::type
361
+ inline __SYCL_ALWAYS_INLINE typename std::enable_if_t <
362
+ (NumRows > tile_size) || (NumCols * sizeof (T) / 4 > tile_size), void >
371
363
joint_matrix_store (Group sg,
372
364
joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
373
365
multi_ptr<T, Space> dst, size_t stride,
@@ -387,9 +379,9 @@ joint_matrix_store(Group sg,
387
379
template <typename Group, typename T, size_t NumRows, size_t NumCols,
388
380
matrix_layout Layout, access::address_space Space>
389
381
inline __SYCL_ALWAYS_INLINE
390
- typename std::enable_if <(NumRows <= tile_size) &&
391
- (NumCols * sizeof (T) / 4 <= tile_size),
392
- void >::type
382
+ typename std::enable_if_t <(NumRows <= tile_size) &&
383
+ (NumCols * sizeof (T) / 4 <= tile_size),
384
+ void >::type
393
385
joint_matrix_store (Group sg,
394
386
joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
395
387
multi_ptr<T, Space> dst, size_t stride,
@@ -406,14 +398,14 @@ template <typename Group, typename T1, typename T2, size_t NumRowsA,
406
398
size_t NumColsA, size_t NumRowsB, size_t NumColsB, size_t NumRowsC,
407
399
size_t NumColsC, matrix_layout LayoutA, matrix_layout LayoutB,
408
400
matrix_layout LayoutC>
409
- inline __SYCL_ALWAYS_INLINE typename std::enable_if <
401
+ inline __SYCL_ALWAYS_INLINE typename std::enable_if_t <
410
402
((std::is_same<T1, int8_t >::value && std::is_same<T2, int32_t >::value) ||
411
403
(std::is_same<T1, unsigned short >::value &&
412
404
std::is_same<T2, float >::value)) &&
413
405
(LayoutA == matrix_layout::row_major) &&
414
406
(LayoutB == matrix_layout::packed_b) &&
415
407
(LayoutC == matrix_layout::row_major),
416
- void >::type
408
+ void >
417
409
joint_matrix_mad (Group sg,
418
410
joint_matrix<Group, T1, NumRowsA, NumColsA, LayoutA> &jmA,
419
411
joint_matrix<Group, T1, NumRowsB, NumColsB, LayoutB> &jmB,
0 commit comments