Skip to content

Commit b6e6e37

Browse files
committed
ocp fp8 test
1 parent 46397f5 commit b6e6e37

17 files changed

+467
-194
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#!/usr/bin/env bash
2+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
# ==============================================================================
17+
18+
set -e
19+
set -x
20+
21+
N_BUILD_JOBS=$(grep -c ^processor /proc/cpuinfo)
22+
# If rocm-smi exists locally (it should) use it to find
23+
# out how many GPUs we have to test with.
24+
rocm-smi -i
25+
STATUS=$?
26+
if [ $STATUS -ne 0 ]; then TF_GPU_COUNT=1; else
27+
TF_GPU_COUNT=$(rocm-smi -i|grep 'Device ID' |grep 'GPU' |wc -l)
28+
fi
29+
TF_TESTS_PER_GPU=1
30+
N_TEST_JOBS=$(expr ${TF_GPU_COUNT} \* ${TF_TESTS_PER_GPU})
31+
32+
echo ""
33+
echo "Bazel will use ${N_BUILD_JOBS} concurrent build job(s) and ${N_TEST_JOBS} concurrent test job(s)."
34+
echo ""
35+
36+
# First positional argument (if any) specifies the ROCM_INSTALL_DIR
37+
if [[ -n $1 ]]; then
38+
ROCM_INSTALL_DIR=$1
39+
else
40+
if [[ -z "${ROCM_PATH}" ]]; then
41+
ROCM_INSTALL_DIR=/opt/rocm/
42+
else
43+
ROCM_INSTALL_DIR=$ROCM_PATH
44+
fi
45+
fi
46+
47+
export PYTHON_BIN_PATH=`which python3`
48+
PYTHON_VERSION=`python3 -c "import sys;print(f'{sys.version_info.major}.{sys.version_info.minor}')"`
49+
export TF_PYTHON_VERSION=$PYTHON_VERSION
50+
export TF_NEED_ROCM=1
51+
export ROCM_PATH=$ROCM_INSTALL_DIR
52+
53+
if [ -f /usertools/rocm.bazelrc ]; then
54+
# Use the bazelrc files in /usertools if available
55+
if [ ! -d /tf ];then
56+
# The bazelrc files in /usertools expect /tf to exist
57+
mkdir /tf
58+
fi
59+
60+
bazel \
61+
--bazelrc=/usertools/rocm.bazelrc \
62+
test \
63+
--config=sigbuild_local_cache \
64+
--config=rocm \
65+
--config=xla_cpp_filters \
66+
--test_output=errors \
67+
--local_test_jobs=${N_TEST_JOBS} \
68+
--test_env=TF_TESTS_PER_GPU=$TF_TESTS_PER_GPU \
69+
--test_env=TF_GPU_COUNT=$TF_GPU_COUNT \
70+
--test_output=streamed \
71+
--test_env=TF_CPP_VMODULE="gemm_rewriter=3" \
72+
--test_env=XLA_FLAGS="--xla_dump_to=/tmp/generated --xla_dump_hlo_as_text --xla_dump_hlo_as_html --xla_gpu_enable_cublaslt=true --xla_gpu_autotune_level=4" \
73+
--action_env=XLA_FLAGS=--xla_gpu_force_compilation_parallelism=16 \
74+
-- @local_xla//xla/service/gpu/transforms:gemm_rewriter_test_gpu_amd_any
75+
fi

third_party/xla/xla/debug_options_flags.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
110110

111111
opts.set_xla_gpu_enable_cudnn_frontend(true);
112112

113-
opts.set_xla_gpu_enable_cublaslt(false);
113+
opts.set_xla_gpu_enable_cublaslt(true);
114114

115115
opts.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION);
116116
opts.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLAS);

third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc

+13-1
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,21 @@ class GemmAutotuner {
168168
se::DeviceMemoryBase a_scale_buffer, b_scale_buffer, c_scale_buffer,
169169
d_scale_buffer, d_amax_buffer, bias_buffer, aux_buffer;
170170

171+
int input_buffer_idx = 2; // lhs is at 0, rhs is at 1
171172
if (has_vector_bias) {
172-
bias_buffer = rz_buffers_.input_buffers().at(has_matrix_bias ? 3 : 2);
173+
if (has_matrix_bias) {
174+
input_buffer_idx++;
175+
}
176+
bias_buffer = rz_buffers_.input_buffers().at(input_buffer_idx++);
177+
}
178+
// In the current GemmRewriter design for FP8, the a/b scales remain active
179+
// even when they are not used. Consequently, we must inform the autotuner
180+
// so it can choose algorithms that properly support a/b scales.
181+
if (gemm_config.is_fp8) {
182+
a_scale_buffer = rz_buffers_.input_buffers().at(input_buffer_idx++);
183+
b_scale_buffer = rz_buffers_.input_buffers().at(input_buffer_idx++);
173184
}
185+
174186
if (has_aux_output) {
175187
aux_buffer = rz_buffers_.output_buffers().at(1);
176188
}

third_party/xla/xla/service/gpu/backend_configs.proto

+2
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ message GemmBackendConfig {
105105
optional bool grad_x = 16;
106106
optional bool grad_y = 17;
107107
bool damax_output = 18;
108+
109+
bool is_fp8 = 19;
108110
}
109111

110112
// Backend config for bitcast operation generated from MLIR MHLO dialect.

third_party/xla/xla/service/gpu/buffer_comparator.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ absl::StatusOr<bool> BufferComparator::CompareEqual(
187187
stream, current, expected};
188188

189189
switch (shape_.element_type()) {
190-
#if GOOGLE_CUDA // not available for ROCm yet..
190+
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60300
191191
case xla::F8E4M3FN:
192192
return CompareEqualParameterized<tsl::float8_e4m3fn, float>(
193193
"fp8_e4m3fn_comparison", buffer_comparator::fp8_e4m3fn_comparison(),
@@ -196,7 +196,7 @@ absl::StatusOr<bool> BufferComparator::CompareEqual(
196196
return CompareEqualParameterized<tsl::float8_e5m2, float>(
197197
"fp8_e5m2_comparison", buffer_comparator::fp8_e5m2_comparison(),
198198
params);
199-
#endif // GOOGLE_CUDA
199+
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60300
200200
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
201201
case xla::F8E4M3FNUZ:
202202
return CompareEqualParameterized<tsl::float8_e4m3fnuz, float>(

third_party/xla/xla/service/gpu/buffer_comparator.cu.cc

+33-15
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,29 @@ __device__ __inline__ float Canonicalize(float input) {
5454
return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f));
5555
}
5656

57+
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60300
58+
__global__ void xla_fp8_e4m3fn_comparison(
5759
#if GOOGLE_CUDA
58-
__global__ void xla_fp8_e4m3fn_comparison(__nv_fp8_storage_t* buffer_a,
59-
__nv_fp8_storage_t* buffer_b,
60-
float rel_error_threshold,
61-
uint64_t buffer_length,
62-
int* mismatch_count) {
60+
__nv_fp8_storage_t* buffer_a, __nv_fp8_storage_t* buffer_b,
61+
#else // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60300
62+
__hip_fp8_storage_t* buffer_a, __hip_fp8_storage_t* buffer_b,
63+
#endif
64+
float rel_error_threshold, uint64_t buffer_length, int* mismatch_count) {
6365
int idx = threadIdx.x + blockIdx.x * blockDim.x;
6466
if (idx >= buffer_length) return;
6567
// TODO(philipphack): Replace with direct conversion to float when this
6668
// functionality becomes available.
69+
#if GOOGLE_CUDA
6770
float elem_a =
6871
__half2float(__nv_cvt_fp8_to_halfraw(buffer_a[idx], __NV_E4M3));
6972
float elem_b =
7073
__half2float(__nv_cvt_fp8_to_halfraw(buffer_b[idx], __NV_E4M3));
74+
#else // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60300
75+
float elem_a =
76+
__half2float(__hip_cvt_fp8_to_halfraw(buffer_a[idx], __HIP_E4M3));
77+
float elem_b =
78+
__half2float(__hip_cvt_fp8_to_halfraw(buffer_b[idx], __HIP_E4M3));
79+
#endif
7180
elem_a = Canonicalize(elem_a);
7281
elem_b = Canonicalize(elem_b);
7382
if (isnan(elem_a) && isnan(elem_b)) return;
@@ -78,19 +87,28 @@ __global__ void xla_fp8_e4m3fn_comparison(__nv_fp8_storage_t* buffer_a,
7887
atomicAdd(mismatch_count, 1);
7988
}
8089

81-
__global__ void xla_fp8_e5m2_comparison(__nv_fp8_storage_t* buffer_a,
82-
__nv_fp8_storage_t* buffer_b,
83-
float rel_error_threshold,
84-
uint64_t buffer_length,
85-
int* mismatch_count) {
90+
__global__ void xla_fp8_e5m2_comparison(
91+
#if GOOGLE_CUDA
92+
__nv_fp8_storage_t* buffer_a, __nv_fp8_storage_t* buffer_b,
93+
#else // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60300
94+
__hip_fp8_storage_t* buffer_a, __hip_fp8_storage_t* buffer_b,
95+
#endif
96+
float rel_error_threshold, uint64_t buffer_length, int* mismatch_count) {
8697
int idx = threadIdx.x + blockIdx.x * blockDim.x;
8798
if (idx >= buffer_length) return;
88-
// TODO(philipphack): Replace with direct conversion to float when this
89-
// functionality becomes available.
99+
// TODO(philipphack): Replace with direct conversion to float when this
100+
// functionality becomes available.
101+
#if GOOGLE_CUDA
90102
float elem_a =
91103
__half2float(__nv_cvt_fp8_to_halfraw(buffer_a[idx], __NV_E5M2));
92104
float elem_b =
93105
__half2float(__nv_cvt_fp8_to_halfraw(buffer_b[idx], __NV_E5M2));
106+
#else // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60300
107+
float elem_a =
108+
__half2float(__hip_cvt_fp8_to_halfraw(buffer_a[idx], __HIP_E5M2));
109+
float elem_b =
110+
__half2float(__hip_cvt_fp8_to_halfraw(buffer_b[idx], __HIP_E5M2));
111+
#endif
94112
elem_a = Canonicalize(elem_a);
95113
elem_b = Canonicalize(elem_b);
96114
if (isnan(elem_a) && isnan(elem_b)) return;
@@ -100,7 +118,7 @@ __global__ void xla_fp8_e5m2_comparison(__nv_fp8_storage_t* buffer_a,
100118
if (rel_error > rel_error_threshold || isnan(rel_error))
101119
atomicAdd(mismatch_count, 1);
102120
}
103-
#endif // GOOGLE_CUDA
121+
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60300
104122

105123
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
106124

@@ -262,15 +280,15 @@ __global__ void xla_int32_comparison(int* buffer_a, int* buffer_b,
262280

263281
} // namespace
264282

265-
#if GOOGLE_CUDA
283+
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60300
266284
void* fp8_e4m3fn_comparison() {
267285
return reinterpret_cast<void*>(&xla_fp8_e4m3fn_comparison);
268286
}
269287

270288
void* fp8_e5m2_comparison() {
271289
return reinterpret_cast<void*>(&xla_fp8_e5m2_comparison);
272290
}
273-
#endif
291+
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60300
274292

275293
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
276294
void* fp8_e4m3fnuz_comparison() {

third_party/xla/xla/service/gpu/matmul_utils.cc

+6-5
Original file line numberDiff line numberDiff line change
@@ -301,13 +301,13 @@ absl::StatusOr<bool> CanFoldTransposeOperandIntoDot(const HloInstruction& dot,
301301
double alpha_real, double alpha_imag, double beta,
302302
PrecisionConfig::Algorithm precision_algorithm,
303303
std::optional<int64_t> algorithm, int64_t compute_precision, bool grad_x,
304-
bool grad_y) {
304+
bool grad_y, bool is_fp8) {
305305
return GemmConfig::For(lhs_shape, lhs_batch_dims, lhs_contracting_dims,
306306
rhs_shape, rhs_batch_dims, rhs_contracting_dims,
307307
/*c_shape=*/output_shape, /*bias_shape_ptr=*/nullptr,
308308
output_shape, alpha_real, alpha_imag, beta,
309309
precision_algorithm, algorithm, compute_precision,
310-
grad_x, grad_y);
310+
grad_x, grad_y, is_fp8);
311311
}
312312

313313
/*static*/ absl::StatusOr<GemmConfig> GemmConfig::For(
@@ -319,7 +319,7 @@ absl::StatusOr<bool> CanFoldTransposeOperandIntoDot(const HloInstruction& dot,
319319
double alpha_imag, double beta,
320320
PrecisionConfig::Algorithm precision_algorithm,
321321
std::optional<int64_t> algorithm, int64_t compute_precision, bool grad_x,
322-
bool grad_y) {
322+
bool grad_y, bool is_fp8) {
323323
absl::Span<const int64_t> lhs_col_dims = lhs_contracting_dims;
324324
TF_ASSIGN_OR_RETURN(
325325
std::vector<int64_t> lhs_row_dims,
@@ -436,7 +436,8 @@ absl::StatusOr<bool> CanFoldTransposeOperandIntoDot(const HloInstruction& dot,
436436
precision_algorithm,
437437
algorithm,
438438
grad_x,
439-
grad_y};
439+
grad_y,
440+
is_fp8};
440441
}
441442

442443
namespace {
@@ -509,7 +510,7 @@ bool IsTf32Allowed(PrecisionConfig::Algorithm algorithm,
509510
/*bias_shape_ptr=*/
510511
vector_bias_shape ? &vector_bias_shape.value() : nullptr, output_shape,
511512
config.alpha_real(), config.alpha_imag(), config.beta(),
512-
precision_algorithm, algorithm, precision, grad_x, grad_y);
513+
precision_algorithm, algorithm, precision, grad_x, grad_y, config.is_fp8());
513514
}
514515

515516
absl::StatusOr<GemmConfig::DescriptorsTuple> GemmConfig::GetMatrixDescriptors(

third_party/xla/xla/service/gpu/matmul_utils.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ struct GemmConfig : public se::gpu::GemmConfig {
121121
double alpha_real, double alpha_imag, double beta,
122122
PrecisionConfig::Algorithm precision_algorithm,
123123
std::optional<int64_t> algorithm, int64_t compute_precision, bool grad_x,
124-
bool grad_y);
124+
bool grad_y, bool is_fp8);
125125

126126
// As above with additional `c_shape` and `bias_shape_ptr` parameter, both
127127
// which are only necessarily for F8 gemms.
@@ -134,7 +134,7 @@ struct GemmConfig : public se::gpu::GemmConfig {
134134
double alpha_imag, double beta,
135135
PrecisionConfig::Algorithm precision_algorithm,
136136
std::optional<int64_t> algorithm, int64_t compute_precision, bool grad_x,
137-
bool grad_y);
137+
bool grad_y, bool is_fp8);
138138

139139
struct DescriptorsTuple {
140140
se::gpu::MatrixDescriptor lhs;

third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ TEST(CommandBufferThunkTest, GemmCmd) {
642642
ShapeUtil::MakeShape(PrimitiveType::F32, {4, 3}), {}, {0},
643643
ShapeUtil::MakeShape(PrimitiveType::F32, {2, 3}), 1.0,
644644
0.0, 0.0, PrecisionConfig::ALG_UNSET, std::nullopt,
645-
se::blas::kDefaultComputePrecision, false, false);
645+
se::blas::kDefaultComputePrecision, false, false, false);
646646
ASSERT_TRUE(config.ok());
647647

648648
// Prepare commands sequence for constructing command buffer.
@@ -750,7 +750,7 @@ TEST(CommandBufferThunkTest, CublasLtCmd) {
750750
/*precision_algorithm*/ PrecisionConfig::ALG_UNSET,
751751
/*algorithm*/ std::nullopt,
752752
/*compute_precision*/ se::blas::kDefaultComputePrecision,
753-
/*grad_x*/ false, /*grad_y*/ false);
753+
/*grad_x*/ false, /*grad_y*/ false, /*is_fp8*/ false);
754754
ASSERT_TRUE(config.ok());
755755

756756
// Prepare commands sequence for constructing command buffer.

third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk_test.cc

+6-6
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ TEST(DynamicSliceThunkTest, SlicedGemm) {
126126
ShapeUtil::MakeShape(PrimitiveType::F32, {3, 1}), {}, {0},
127127
ShapeUtil::MakeShape(PrimitiveType::F32, {1, 1}), 1.0,
128128
0.0, 0.0, PrecisionConfig::ALG_UNSET, std::nullopt,
129-
se::blas::kDefaultComputePrecision, false, false);
129+
se::blas::kDefaultComputePrecision, false, false, false);
130130
ASSERT_TRUE(config.ok());
131131

132132
// Creating embedded GEMM thunk.
@@ -278,7 +278,7 @@ TEST(DynamicSliceThunkTest, MulipleSlicedOperandsGemm) {
278278
ShapeUtil::MakeShape(PrimitiveType::F32, {3, 1}), {}, {0},
279279
ShapeUtil::MakeShape(PrimitiveType::F32, {1, 1}), 1.0,
280280
0.0, 0.0, PrecisionConfig::ALG_UNSET, std::nullopt,
281-
se::blas::kDefaultComputePrecision, false, false);
281+
se::blas::kDefaultComputePrecision, false, false, false);
282282
ASSERT_TRUE(config.ok());
283283

284284
// Creating embedded GEMM thunk.
@@ -797,7 +797,7 @@ TEST(DynamicSliceThunkTest, SlicedGemmArbitraryArgumentOrder) {
797797
ShapeUtil::MakeShape(PrimitiveType::F32, {3, 1}), {}, {0},
798798
ShapeUtil::MakeShape(PrimitiveType::F32, {1, 1}), 1.0,
799799
0.0, 0.0, PrecisionConfig::ALG_UNSET, std::nullopt,
800-
se::blas::kDefaultComputePrecision, false, false);
800+
se::blas::kDefaultComputePrecision, false, false, false);
801801
ASSERT_TRUE(config.ok());
802802

803803
// Creating embedded GEMM thunk.
@@ -945,7 +945,7 @@ TEST(DynamicSliceThunkTest, SlicedGemmArbitraryNumberOfArguments) {
945945
ShapeUtil::MakeShape(PrimitiveType::F32, {3, 1}), {}, {0},
946946
ShapeUtil::MakeShape(PrimitiveType::F32, {1, 1}), 1.0,
947947
0.0, 0.0, PrecisionConfig::ALG_UNSET, std::nullopt,
948-
se::blas::kDefaultComputePrecision, false, false);
948+
se::blas::kDefaultComputePrecision, false, false, false);
949949
ASSERT_TRUE(config.ok());
950950

951951
// Creating embedded GEMM thunk.
@@ -1086,7 +1086,7 @@ TEST(DynamicSliceThunkTest, SlicedTupledOperandGemm) {
10861086
ShapeUtil::MakeShape(PrimitiveType::F32, {3, 1}), {}, {0},
10871087
ShapeUtil::MakeShape(PrimitiveType::F32, {1, 1}), 1.0,
10881088
0.0, 0.0, PrecisionConfig::ALG_UNSET, std::nullopt,
1089-
se::blas::kDefaultComputePrecision, false, false);
1089+
se::blas::kDefaultComputePrecision, false, false, false);
10901090
ASSERT_TRUE(config.ok());
10911091

10921092
// Creating embedded GEMM thunk.
@@ -1439,7 +1439,7 @@ TEST(DynamicSliceThunkTest, SlicedOperandsSameBufferGemm) {
14391439
ShapeUtil::MakeShape(PrimitiveType::F32, {3, 1}), {}, {0},
14401440
ShapeUtil::MakeShape(PrimitiveType::F32, {1, 1}), 1.0,
14411441
0.0, 0.0, PrecisionConfig::ALG_UNSET, std::nullopt,
1442-
se::blas::kDefaultComputePrecision, false, false);
1442+
se::blas::kDefaultComputePrecision, false, false, false);
14431443
ASSERT_TRUE(config.ok());
14441444

14451445
// Creating embedded GEMM thunk.

0 commit comments

Comments
 (0)