Skip to content

[In Progress] Enable mx f8 fp4 support on ROCm #2046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: rocm6.5_internal_testing
Choose a base branch
from

Conversation

jagadish-amd
Copy link

@jagadish-amd jagadish-amd commented Apr 23, 2025

This PR enables mx data type support on ROCm.

fp8 mx data type sample test case.
PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py TestFP8MatmulCudaCUDA.test_blockwise_mxfp8_nvfp4_numerics_test_case_name_a_eye_b_eye_fast_accum_False_128_128_128_recipe_mxfp8_cuda -v

HipblasLT log hipblaslt-bench --api_method c -m 128 -n 128 -k 128 --lda 128 --ldb 128 --ldc 128 --ldd 128 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 3 --scaleB 3 --a_type f8_r --b_type f8_r --c_type bf16_r --d_type bf16_r --compute_type f32_r --algo_method index --solution_index -2147220478 --rotating 0 --cold_iters 0 --iters 0

fp4 mx data type sample test case. TBD

Commits:

  1. ROCm MX-FP8 Gemm (PR from @petrex )
    Ported the patch from ROCm MX-FP8 Gemm pytorch/pytorch#147553
    Commented few lines to avoid compilation error. (check for todo comments)

  2. Refine _platform_supports_mx_gemm check

  3. For mx fp8, A and B need not be kFloat8_e8m0fnu type

Ported the patch from pytorch#147553
Commented few lines to avoid compilation error. (check for todo comments)

Signed-off-by: Jagadish Krishnamoorthy <[email protected]>
Signed-off-by: Jagadish Krishnamoorthy <[email protected]>
@jagadish-amd jagadish-amd changed the title Enable mx f8 fp4 support on ROCm [In Progress] Enable mx f8 fp4 support on ROCm Apr 23, 2025
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Apr 23, 2025

Jenkins build for 1b8bf596817e0403f7038b90b5c656aa30b6df82 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

Detected error during base docker image building:

#31 11.26 The following packages have unmet dependencies:
#31 11.32  rocm-dev : Depends: rocm-cmake (= 0.14.0.60304-76~22.04) but 5.0.0-1 is to be installed
#31 11.32             Depends: rocm-device-libs (= 1.0.0.60304-76~22.04) but 5.0.0-1 is to be installed
#31 11.32  rocm-utils : Depends: rocminfo (= 1.0.0.60304-76~22.04) but 5.0.0-1 is to be installed
#31 11.32               Depends: rocm-cmake (= 0.14.0.60304-76~22.04) but 5.0.0-1 is to be installed
#31 11.33 E: Unable to correct problems, you have held broken packages.
#31 ERROR: process "/bin/sh -c bash ./install_rocm.sh" did not complete successfully: exit code: 100
------
 > [stage-0 23/61] RUN bash ./install_rocm.sh:
11.26 distribution that some required packages have not yet been created
11.26 or been moved out of Incoming.

Copy link

@petrex petrex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @jagadish-amd !
left few comments , lets discuss offline as well

@@ -1566,6 +1566,25 @@ void scaled_gemm(
matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
}
else if(mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) {
Copy link

@petrex petrex Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we support an usage like this in hipblaslt? : mx- format * non-mx-format

if yes this should be || instead of &&

@@ -1566,6 +1566,25 @@ void scaled_gemm(
matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
}
else if(mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) {
#if ROCM_VERSION >= 60500
if (at::cuda::tunable::IsGfx950Device()) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about this:

bool _is_gfx950_supported() {
#if ROCM_VERSION >= 60500
    return at::detail::getCUDAHooks().isGPUArch({"gfx950"});
#else
    return false;
#endif
}

maybe slightly better perf than device query...

"Matrix dimensions must be multiples of 32 for MX format. ",
"Got m=", m, ", n=", n, ", k=", k);

//todo
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hipblaslt provided this APIs but I guess we don't need to set this explicitly , at least for gfx950

@@ -513,7 +515,24 @@ class HipblasltGemmOp : public Callable<ParamsT> {
if (mat1_scale_ptr && mat2_scale_ptr) {
#ifdef HIPBLASLT_VEC_EXT
if (GetUseRowwiseFromParams<CT>(params)) {
// swapped
// For MX-FP8 on gfx950
#if ROCM_VERSION >= 60500
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems to be duplicate logic

namespace at::cuda::tunable {

#ifdef USE_ROCM
static bool IsGfx950Device() {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about

bool _is_gfx950_supported() {
#if ROCM_VERSION >= 60500
    return at::detail::getCUDAHooks().isGPUArch({"gfx950"});
#else
    return false;
#endif
}

instead of device query
and maybe evaluate once only

#endif

// Helper function to validate MX format requirements
static bool ValidateMXFormatRequirements(int64_t m, int64_t n, int64_t k) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a basic check but I think it should be ok for now. looking into hipblaslt implementation there seem to be few other shapes

@@ -7339,6 +7339,9 @@
("CUBLASLT_MATMUL_DESC_D_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)),
("CUBLASLT_MATMUL_DESC_AMAX_D_POINTER", ("HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER", CONV_MATH_FUNC, API_BLAS)),
("CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", ("HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", CONV_MATH_FUNC, API_BLAS)),
("CUBLASLT_MATMUL_DESC_A_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_A_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)),
("CUBLASLT_MATMUL_DESC_B_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_B_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)),
("CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0", ("HIPBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0", CONV_MATH_FUNC, API_BLAS)),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets focus on mx-fp8 in this PR. for mx-fp4 we have other mappings.

@rocm-repo-management-api
Copy link

Jenkins build for 1b8bf596817e0403f7038b90b5c656aa30b6df82 commit is in progress
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

Jenkins build for 1b8bf596817e0403f7038b90b5c656aa30b6df82 commit is in progress
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

Jenkins build for 1b8bf596817e0403f7038b90b5c656aa30b6df82 commit is in progress
Links: Blue Ocean view / Build artifacts

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants