-
Notifications
You must be signed in to change notification settings - Fork 66
[In Progress] Enable mx f8 fp4 support on ROCm #2046
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: rocm6.5_internal_testing
Are you sure you want to change the base?
[In Progress] Enable mx f8 fp4 support on ROCm #2046
Conversation
Ported the patch from pytorch#147553 Commented few lines to avoid compilation error. (check for todo comments) Signed-off-by: Jagadish Krishnamoorthy <[email protected]>
Signed-off-by: Jagadish Krishnamoorthy <[email protected]>
Signed-off-by: Jagadish Krishnamoorthy <[email protected]>
Jenkins build for 1b8bf596817e0403f7038b90b5c656aa30b6df82 commit finished as FAILURE Detected error during base docker image building:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @jagadish-amd !
left few comments , lets discuss offline as well
@@ -1566,6 +1566,25 @@ void scaled_gemm( | |||
matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT; | |||
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT; | |||
} | |||
else if(mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we support an usage like this in hipblaslt? : mx- format
* non-mx-format
if yes this should be ||
instead of &&
@@ -1566,6 +1566,25 @@ void scaled_gemm( | |||
matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT; | |||
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT; | |||
} | |||
else if(mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) { | |||
#if ROCM_VERSION >= 60500 | |||
if (at::cuda::tunable::IsGfx950Device()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about this:
bool _is_gfx950_supported() {
#if ROCM_VERSION >= 60500
return at::detail::getCUDAHooks().isGPUArch({"gfx950"});
#else
return false;
#endif
}
maybe slightly better perf than device query...
"Matrix dimensions must be multiples of 32 for MX format. ", | ||
"Got m=", m, ", n=", n, ", k=", k); | ||
|
||
//todo |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hipblaslt provided this APIs but I guess we don't need to set this explicitly , at least for gfx950
@@ -513,7 +515,24 @@ class HipblasltGemmOp : public Callable<ParamsT> { | |||
if (mat1_scale_ptr && mat2_scale_ptr) { | |||
#ifdef HIPBLASLT_VEC_EXT | |||
if (GetUseRowwiseFromParams<CT>(params)) { | |||
// swapped | |||
// For MX-FP8 on gfx950 | |||
#if ROCM_VERSION >= 60500 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems to be duplicate logic
namespace at::cuda::tunable { | ||
|
||
#ifdef USE_ROCM | ||
static bool IsGfx950Device() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about
bool _is_gfx950_supported() {
#if ROCM_VERSION >= 60500
return at::detail::getCUDAHooks().isGPUArch({"gfx950"});
#else
return false;
#endif
}
instead of device query
and maybe evaluate once only
#endif | ||
|
||
// Helper function to validate MX format requirements | ||
static bool ValidateMXFormatRequirements(int64_t m, int64_t n, int64_t k) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a basic check but I think it should be ok for now. looking into hipblaslt implementation there seem to be few other shapes
@@ -7339,6 +7339,9 @@ | |||
("CUBLASLT_MATMUL_DESC_D_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)), | |||
("CUBLASLT_MATMUL_DESC_AMAX_D_POINTER", ("HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER", CONV_MATH_FUNC, API_BLAS)), | |||
("CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", ("HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", CONV_MATH_FUNC, API_BLAS)), | |||
("CUBLASLT_MATMUL_DESC_A_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_A_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)), | |||
("CUBLASLT_MATMUL_DESC_B_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_B_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)), | |||
("CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0", ("HIPBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0", CONV_MATH_FUNC, API_BLAS)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets focus on mx-fp8 in this PR. for mx-fp4 we have other mappings.
Jenkins build for 1b8bf596817e0403f7038b90b5c656aa30b6df82 commit is in progress |
Jenkins build for 1b8bf596817e0403f7038b90b5c656aa30b6df82 commit is in progress |
Jenkins build for 1b8bf596817e0403f7038b90b5c656aa30b6df82 commit is in progress |
This PR enables mx data type support on ROCm.
fp8 mx data type sample test case.
PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py TestFP8MatmulCudaCUDA.test_blockwise_mxfp8_nvfp4_numerics_test_case_name_a_eye_b_eye_fast_accum_False_128_128_128_recipe_mxfp8_cuda -v
HipblasLT log hipblaslt-bench --api_method c -m 128 -n 128 -k 128 --lda 128 --ldb 128 --ldc 128 --ldd 128 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 3 --scaleB 3 --a_type f8_r --b_type f8_r --c_type bf16_r --d_type bf16_r --compute_type f32_r --algo_method index --solution_index -2147220478 --rotating 0 --cold_iters 0 --iters 0
fp4 mx data type sample test case. TBD
Commits:
ROCm MX-FP8 Gemm (PR from @petrex )
Ported the patch from ROCm MX-FP8 Gemm pytorch/pytorch#147553
Commented few lines to avoid compilation error. (check for todo comments)
Refine _platform_supports_mx_gemm check
For mx fp8, A and B need not be kFloat8_e8m0fnu type