Skip to content

Commit a264d06

Browse files
committed
ROCm MX-FP8 Gemm
Ported the patch from pytorch#147553 Commented few lines to avoid compilation error. (check for todo comments) Signed-off-by: Jagadish Krishnamoorthy <[email protected]>
1 parent 0a6e1d6 commit a264d06

File tree

6 files changed

+131
-8
lines changed

6 files changed

+131
-8
lines changed

aten/src/ATen/cuda/CUDABlas.cpp

+21-2
Original file line numberDiff line numberDiff line change
@@ -1566,6 +1566,25 @@ void scaled_gemm(
15661566
matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
15671567
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
15681568
}
1569+
else if(mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) {
1570+
#if ROCM_VERSION >= 60500
1571+
if (at::cuda::tunable::IsGfx950Device()) {
1572+
// Validate matrix dimensions for MX format
1573+
TORCH_CHECK(at::cuda::tunable::ValidateMXFormatRequirements(m, n, k),
1574+
"Matrix dimensions must be multiples of 32 for MX format. ",
1575+
"Got m=", m, ", n=", n, ", k=", k);
1576+
1577+
//todo
1578+
// Set block sizes for MX format
1579+
// TODO: Check if we need to set these explicitly for hipblaslt
1580+
//constexpr int32_t block_size = 32;
1581+
//computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_BLOCK_SIZE_ROWS_VEC_EXT, block_size);
1582+
//computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_BLOCK_SIZE_COLS_VEC_EXT, block_size);
1583+
//computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_BLOCK_SIZE_ROWS_VEC_EXT, block_size);
1584+
//computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_BLOCK_SIZE_COLS_VEC_EXT, block_size);
1585+
}
1586+
#endif
1587+
}
15691588
#else
15701589
// rowwise isn't supported using cublaslt or older hipblaslt
15711590
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
@@ -1603,11 +1622,11 @@ void scaled_gemm(
16031622
}
16041623

16051624
if (mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) {
1606-
#if CUDA_VERSION >= 12080
1625+
#if (!defined(USE_ROCM) && CUDA_VERSION >= 12080) || (defined(USE_ROCM) && ROCM_VERSION >= 60500)
16071626
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
16081627
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
16091628
#else
1610-
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 and above");
1629+
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 or ROCm 6.5(with gfx950) and above");
16111630
#endif // if CUDA_VERSION >= 12080
16121631
} else if (mat1_scale_dtype == kFloat8_e4m3fn && mat2_scale_dtype == kFloat8_e4m3fn) {
16131632
#if CUDA_VERSION >= 12080

aten/src/ATen/cuda/tunable/GemmHipblaslt.h

+20-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include <hipblaslt/hipblaslt.h>
1515
#include <hipblaslt/hipblaslt-ext.hpp>
1616

17+
#include <ATen/cuda/tunable/GemmMxUtils.h>
18+
1719
#define TORCH_HIPBLASLT_CHECK(EXPR) \
1820
do { \
1921
hipblasStatus_t __err = EXPR; \
@@ -513,7 +515,24 @@ class HipblasltGemmOp : public Callable<ParamsT> {
513515
if (mat1_scale_ptr && mat2_scale_ptr) {
514516
#ifdef HIPBLASLT_VEC_EXT
515517
if (GetUseRowwiseFromParams<CT>(params)) {
516-
// swapped
518+
// For MX-FP8 on gfx950
519+
#if ROCM_VERSION >= 60500
520+
if (IsGfx950Device()) {
521+
// Validate matrix dimensions for MX format
522+
TORCH_CHECK(ValidateMXFormatRequirements(params->m, params->n, params->k),
523+
"Matrix dimensions must be multiples of 32 for MX format. ",
524+
"Got m=", params->m, ", n=", params->n, ", k=", params->k);
525+
526+
//todo
527+
// Set block sizes for MX format
528+
//constexpr int32_t block_size = 32;
529+
//matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_BLOCK_SIZE_ROWS_VEC_EXT, block_size);
530+
//matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_BLOCK_SIZE_COLS_VEC_EXT, block_size);
531+
//matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_BLOCK_SIZE_ROWS_VEC_EXT, block_size);
532+
//matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_BLOCK_SIZE_COLS_VEC_EXT, block_size);
533+
}
534+
#endif
535+
// Set scale pointers (swapped as before)
517536
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT, mat2_scale_ptr);
518537
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT, mat1_scale_ptr);
519538
}
+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include <ATen/cuda/CUDAContext.h>
4+
#include <string>
5+
6+
namespace at::cuda::tunable {
7+
8+
#ifdef USE_ROCM
9+
static bool IsGfx950Device() {
10+
// Single static check - only evaluated once
11+
static bool is_gfx950 = []() {
12+
auto device = at::cuda::current_device();
13+
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(device);
14+
return (std::string(prop->gcnArchName) == "gfx950");
15+
}();
16+
return is_gfx950;
17+
}
18+
#endif
19+
20+
// Helper function to validate MX format requirements
21+
static bool ValidateMXFormatRequirements(int64_t m, int64_t n, int64_t k) {
22+
constexpr int32_t required_block_size = 32;
23+
return (m % required_block_size == 0) &&
24+
(n % required_block_size == 0) &&
25+
(k % required_block_size == 0);
26+
}
27+
28+
} // namespace at::cuda::tunable
29+

aten/src/ATen/native/cuda/Blas.cpp

+50-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <ATen/native/Resize.h>
1818
#include <c10/util/MaybeOwned.h>
1919
#include <ATen/native/cuda/RowwiseScaledMM.h>
20+
#include <ATen/cuda/tunable/GemmMxUtils.h>
2021
#include <ATen/native/cuda/ScaledGroupMM.h>
2122
#include <ATen/native/cuda/GroupMM.h>
2223

@@ -89,7 +90,8 @@ c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, b
8990
if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) {
9091
transpose_tensor = false;
9192
return resolve_conj_if_indicated(tensor, true);
92-
} else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) {
93+
} else if ((tensor_strides[1] == 1) &&
94+
(tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) {
9395
transpose_tensor = true;
9496
return resolve_conj_if_indicated(tensor, true);
9597
} else {
@@ -1104,6 +1106,7 @@ ScalingType get_scaling_type(
11041106

11051107
} // namespace
11061108

1109+
11071110
// Computes matrix multiply + bias while applying scaling to input and output matrices
11081111
// Scales are only applicable when matrices are of Float8 type and assumed to be equal to 1.0 by default.
11091112
// If output matrix type is 16 or 32-bit type, scale_result is not applied.
@@ -1226,17 +1229,37 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
12261229
}
12271230
#else
12281231
if (scaling_choice == ScalingType::RowWise) {
1229-
// For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes.
1232+
// For ROCm, match behavior of f8f8bf16_rowwise type checking
12301233
Tensor b = mat2;
12311234
if (_scaled_mm_is_fnuz()) {
12321235
TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fnuz);
12331236
}
12341237
else {
12351238
TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fn);
12361239
}
1237-
// Until more than bf16 is supported.
1240+
// Until more than bf16 is supported
12381241
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16,
1239-
"hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type());
1242+
"hipblaslt rowwise _scaled_mm only supports BFloat16 output");
1243+
}
1244+
else if (scaling_choice == ScalingType::BlockWise) {
1245+
TORCH_CHECK(mat1.scalar_type() == at::kFloat8_e8m0fnu &&
1246+
mat2.scalar_type() == at::kFloat8_e8m0fnu,
1247+
"Block-wise scaling requires both matrices to be Float8_e8m0fnu type");
1248+
1249+
#if ROCM_VERSION >= 60500
1250+
TORCH_CHECK(at::cuda::tunable::IsGfx950Device(),
1251+
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950");
1252+
1253+
TORCH_CHECK(mat1.size(0) % 32 == 0 && mat1.size(1) % 32 == 0 &&
1254+
mat2.size(0) % 32 == 0 && mat2.size(1) % 32 == 0,
1255+
"Matrix dimensions must be multiples of 32 for block-wise scaling");
1256+
1257+
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16 ||
1258+
out.scalar_type() == ScalarType::Half,
1259+
"Block-wise scaling only supports BFloat16 or Half output types");
1260+
#else
1261+
TORCH_CHECK(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 6.5 or later");
1262+
#endif
12401263
}
12411264
#endif
12421265

@@ -1315,10 +1338,12 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
13151338
params.k = args.k;
13161339
params.a = args.mata->data_ptr();
13171340
params.a_scale_ptr = scale_a.data_ptr();
1341+
params.a_scale_dtype = scale_a.scalar_type();
13181342
params.lda = args.lda;
13191343
params.a_dtype = args.mata->scalar_type();
13201344
params.b = args.matb->data_ptr();
13211345
params.b_scale_ptr = scale_b.data_ptr();
1346+
params.b_scale_dtype = scale_b.scalar_type();
13221347
params.ldb = args.ldb;
13231348
params.b_dtype = args.matb->scalar_type();
13241349
params.bias_ptr = bias ? bias->data_ptr(): nullptr;
@@ -1377,6 +1402,27 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
13771402
scaling_choice == ScalingType::RowWise);
13781403
}
13791404

1405+
// Add MX format validation for gfx950
1406+
if (scaling_choice == ScalingType::RowWise) {
1407+
#ifdef USE_ROCM
1408+
if (at::cuda::tunable::IsGfx950Device()) {
1409+
// Validate matrix dimensions for MX format
1410+
TORCH_CHECK(at::cuda::tunable::ValidateMXFormatRequirements(mat1.size(0), mat2.size(1), mat1.size(1)),
1411+
"For MX format on gfx950, matrix dimensions must be multiples of 32. ",
1412+
"Got dimensions: ", mat1.sizes(), " x ", mat2.sizes());
1413+
1414+
// Validate data types for MX format
1415+
TORCH_CHECK(mat1.scalar_type() == at::kFloat8_e8m0fnu &&
1416+
mat2.scalar_type() == at::kFloat8_e8m0fnu,
1417+
"MX format requires Float8_e8m0fnu type for both input matrices");
1418+
1419+
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16 ||
1420+
out.scalar_type() == ScalarType::Half,
1421+
"MX format only supports BFloat16 or Half output types");
1422+
}
1423+
#endif
1424+
}
1425+
13801426
return out;
13811427
}
13821428

torch/testing/_internal/common_cuda.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,14 @@ def evaluate_platform_supports_fp8():
104104

105105
PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8())
106106

107-
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: TEST_CUDA and SM100OrLater)
107+
def _platform_supports_mx_gemm():
108+
if TEST_CUDA:
109+
return SM100OrLater
110+
if TEST_WITH_ROCM:
111+
return torch.cuda.get_device_properties(torch.cuda.current_device(0)).name.startswith('gfx950')
112+
return False
113+
114+
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: _platform_supports_mx_gemm())
108115

109116
if TEST_NUMBA:
110117
try:

torch/utils/hipify/cuda_to_hip_mappings.py

+3
Original file line numberDiff line numberDiff line change
@@ -7339,6 +7339,9 @@
73397339
("CUBLASLT_MATMUL_DESC_D_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)),
73407340
("CUBLASLT_MATMUL_DESC_AMAX_D_POINTER", ("HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER", CONV_MATH_FUNC, API_BLAS)),
73417341
("CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", ("HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", CONV_MATH_FUNC, API_BLAS)),
7342+
("CUBLASLT_MATMUL_DESC_A_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_A_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)),
7343+
("CUBLASLT_MATMUL_DESC_B_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_B_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)),
7344+
("CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0", ("HIPBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0", CONV_MATH_FUNC, API_BLAS)),
73427345
("cublasLtMatrixLayout_t", ("hipblasLtMatrixLayout_t", CONV_MATH_FUNC, API_BLAS)),
73437346
("cublasLtMatrixLayoutOpaque_t", ("hipblasLtMatrixLayoutOpaque_t", CONV_MATH_FUNC, API_BLAS)),
73447347
("cublasLtMatrixLayoutAttribute_t", ("hipblasLtMatrixLayoutAttribute_t", CONV_MATH_FUNC, API_BLAS)),

0 commit comments

Comments
 (0)