|
17 | 17 | #include <ATen/native/Resize.h>
|
18 | 18 | #include <c10/util/MaybeOwned.h>
|
19 | 19 | #include <ATen/native/cuda/RowwiseScaledMM.h>
|
| 20 | +#include <ATen/cuda/tunable/GemmMxUtils.h> |
20 | 21 | #include <ATen/native/cuda/ScaledGroupMM.h>
|
21 | 22 | #include <ATen/native/cuda/GroupMM.h>
|
22 | 23 |
|
@@ -89,7 +90,8 @@ c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, b
|
89 | 90 | if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) {
|
90 | 91 | transpose_tensor = false;
|
91 | 92 | return resolve_conj_if_indicated(tensor, true);
|
92 |
| - } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) { |
| 93 | + } else if ((tensor_strides[1] == 1) && |
| 94 | + (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) { |
93 | 95 | transpose_tensor = true;
|
94 | 96 | return resolve_conj_if_indicated(tensor, true);
|
95 | 97 | } else {
|
@@ -1104,6 +1106,7 @@ ScalingType get_scaling_type(
|
1104 | 1106 |
|
1105 | 1107 | } // namespace
|
1106 | 1108 |
|
| 1109 | + |
1107 | 1110 | // Computes matrix multiply + bias while applying scaling to input and output matrices
|
1108 | 1111 | // Scales are only applicable when matrices are of Float8 type and assumed to be equal to 1.0 by default.
|
1109 | 1112 | // If output matrix type is 16 or 32-bit type, scale_result is not applied.
|
@@ -1226,17 +1229,37 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
1226 | 1229 | }
|
1227 | 1230 | #else
|
1228 | 1231 | if (scaling_choice == ScalingType::RowWise) {
|
1229 |
| - // For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes. |
| 1232 | + // For ROCm, match behavior of f8f8bf16_rowwise type checking |
1230 | 1233 | Tensor b = mat2;
|
1231 | 1234 | if (_scaled_mm_is_fnuz()) {
|
1232 | 1235 | TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fnuz);
|
1233 | 1236 | }
|
1234 | 1237 | else {
|
1235 | 1238 | TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fn);
|
1236 | 1239 | }
|
1237 |
| - // Until more than bf16 is supported. |
| 1240 | + // Until more than bf16 is supported |
1238 | 1241 | TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16,
|
1239 |
| - "hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type()); |
| 1242 | + "hipblaslt rowwise _scaled_mm only supports BFloat16 output"); |
| 1243 | + } |
| 1244 | + else if (scaling_choice == ScalingType::BlockWise) { |
| 1245 | + TORCH_CHECK(mat1.scalar_type() == at::kFloat8_e8m0fnu && |
| 1246 | + mat2.scalar_type() == at::kFloat8_e8m0fnu, |
| 1247 | + "Block-wise scaling requires both matrices to be Float8_e8m0fnu type"); |
| 1248 | + |
| 1249 | +#if ROCM_VERSION >= 60500 |
| 1250 | + TORCH_CHECK(at::cuda::tunable::IsGfx950Device(), |
| 1251 | + "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950"); |
| 1252 | + |
| 1253 | + TORCH_CHECK(mat1.size(0) % 32 == 0 && mat1.size(1) % 32 == 0 && |
| 1254 | + mat2.size(0) % 32 == 0 && mat2.size(1) % 32 == 0, |
| 1255 | + "Matrix dimensions must be multiples of 32 for block-wise scaling"); |
| 1256 | + |
| 1257 | + TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16 || |
| 1258 | + out.scalar_type() == ScalarType::Half, |
| 1259 | + "Block-wise scaling only supports BFloat16 or Half output types"); |
| 1260 | +#else |
| 1261 | + TORCH_CHECK(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 6.5 or later"); |
| 1262 | +#endif |
1240 | 1263 | }
|
1241 | 1264 | #endif
|
1242 | 1265 |
|
@@ -1315,10 +1338,12 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
1315 | 1338 | params.k = args.k;
|
1316 | 1339 | params.a = args.mata->data_ptr();
|
1317 | 1340 | params.a_scale_ptr = scale_a.data_ptr();
|
| 1341 | + params.a_scale_dtype = scale_a.scalar_type(); |
1318 | 1342 | params.lda = args.lda;
|
1319 | 1343 | params.a_dtype = args.mata->scalar_type();
|
1320 | 1344 | params.b = args.matb->data_ptr();
|
1321 | 1345 | params.b_scale_ptr = scale_b.data_ptr();
|
| 1346 | + params.b_scale_dtype = scale_b.scalar_type(); |
1322 | 1347 | params.ldb = args.ldb;
|
1323 | 1348 | params.b_dtype = args.matb->scalar_type();
|
1324 | 1349 | params.bias_ptr = bias ? bias->data_ptr(): nullptr;
|
@@ -1377,6 +1402,27 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
1377 | 1402 | scaling_choice == ScalingType::RowWise);
|
1378 | 1403 | }
|
1379 | 1404 |
|
| 1405 | + // Add MX format validation for gfx950 |
| 1406 | + if (scaling_choice == ScalingType::RowWise) { |
| 1407 | +#ifdef USE_ROCM |
| 1408 | + if (at::cuda::tunable::IsGfx950Device()) { |
| 1409 | + // Validate matrix dimensions for MX format |
| 1410 | + TORCH_CHECK(at::cuda::tunable::ValidateMXFormatRequirements(mat1.size(0), mat2.size(1), mat1.size(1)), |
| 1411 | + "For MX format on gfx950, matrix dimensions must be multiples of 32. ", |
| 1412 | + "Got dimensions: ", mat1.sizes(), " x ", mat2.sizes()); |
| 1413 | + |
| 1414 | + // Validate data types for MX format |
| 1415 | + TORCH_CHECK(mat1.scalar_type() == at::kFloat8_e8m0fnu && |
| 1416 | + mat2.scalar_type() == at::kFloat8_e8m0fnu, |
| 1417 | + "MX format requires Float8_e8m0fnu type for both input matrices"); |
| 1418 | + |
| 1419 | + TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16 || |
| 1420 | + out.scalar_type() == ScalarType::Half, |
| 1421 | + "MX format only supports BFloat16 or Half output types"); |
| 1422 | + } |
| 1423 | +#endif |
| 1424 | + } |
| 1425 | + |
1380 | 1426 | return out;
|
1381 | 1427 | }
|
1382 | 1428 |
|
|
0 commit comments