Skip to content

Commit 5e84717

Browse files
majnemeramd-jianli12
authored andcommitted
Use DeviceDescription instead of hard-coding warp size as 32
tensorflow/tf-build-actions@600513b [ROCm] Fix flaky gpu compiler test when building with rocm tensorflow/tf-build-actions@a35cf48 [XLA:GPU] Use DeviceDescription instead of hard-coding warp size as 32 xla@e849446 [ROCm] Pass correct warp size to Triton pipeline xla@3e7b0fe cherry-picked warp size passing to triton calls, and globally enabled warpsize=64 xla@750ad89 Fixes.
1 parent 25a5adf commit 5e84717

File tree

86 files changed

+1027
-501
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+1027
-501
lines changed

third_party/xla/xla/service/gpu/BUILD

+11-1
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,10 @@ xla_cc_test(
191191
srcs = ["gpu_copy_insertion_test.cc"],
192192
deps = [
193193
":buffer_sharing",
194+
":gpu_device_info_for_tests",
194195
"//xla:test",
195196
"//xla:test_helpers",
197+
"//xla/hlo/analysis:hlo_dataflow_analysis",
196198
"//xla/hlo/ir:hlo",
197199
"//xla/service:copy_insertion",
198200
"//xla/tests:hlo_test_base",
@@ -263,7 +265,9 @@ xla_cc_test(
263265

264266
cc_library(
265267
name = "gpu_device_info_for_tests",
266-
testonly = 1,
268+
# This is *not* a test library because it is used in a cc_binary which is used for testing but
269+
# test_only libraries are not allowed in cc_binaries.
270+
testonly = 0,
267271
srcs = ["gpu_device_info_for_tests.cc"],
268272
hdrs = ["gpu_device_info_for_tests.h"],
269273
compatible_with = get_compatible_with_portable(),
@@ -704,6 +708,7 @@ cc_library(
704708
"//xla/hlo/ir:hlo",
705709
"//xla/service:hlo_module_config",
706710
"//xla/stream_executor:semantic_version",
711+
"//xla/stream_executor:device_description",
707712
"@com_google_absl//absl/algorithm:container",
708713
"@com_google_absl//absl/base:core_headers",
709714
"@com_google_absl//absl/container:inlined_vector",
@@ -1345,6 +1350,7 @@ cc_library(
13451350
"//xla/service/gpu/transforms:copy_fusion",
13461351
"//xla/service/gpu/transforms:horizontal_loop_fusion",
13471352
"//xla/service/gpu/transforms:sanitize_constant_names",
1353+
"//xla/stream_executor:device_description",
13481354
],
13491355
)
13501356

@@ -2521,6 +2527,10 @@ xla_cc_test(
25212527
"//xla/hlo/ir:hlo",
25222528
"//xla/service:hlo_parser",
25232529
"//xla/tests:hlo_test_base",
2530+
"//xla/service:hlo_runner",
2531+
"//xla/service:instruction_fusion",
2532+
"//xla/service:platform_util",
2533+
"//xla/stream_executor:device_description",
25242534
"//xla/tests:xla_internal_test_main",
25252535
"@com_google_absl//absl/strings",
25262536
"@com_google_googletest//:gtest_main",

third_party/xla/xla/service/gpu/autotuning/autotuner_util.h

+8-11
Original file line numberDiff line numberDiff line change
@@ -141,14 +141,8 @@ class AutotuneConfig {
141141
debug_options.xla_gpu_experimental_autotune_cache_mode()) {}
142142

143143
std::string GetModelStr() const {
144-
if (auto deviceless_config = std::get_if<DevicelessConfig>(&config_)) {
145-
return AutotuneCacheKey::DeviceDescriptionToCacheKey(
146-
deviceless_config->device_description);
147-
}
148-
149-
const auto& device_config = std::get<DeviceConfig>(config_);
150144
return AutotuneCacheKey::DeviceDescriptionToCacheKey(
151-
device_config.stream_exec->GetDeviceDescription());
145+
GetDeviceDescription());
152146
}
153147

154148
se::StreamExecutor* GetExecutor() const {
@@ -175,11 +169,14 @@ class AutotuneConfig {
175169
}
176170

177171
const se::GpuComputeCapability& GetGpuComputeCapability() const {
178-
if (auto c = std::get_if<DeviceConfig>(&config_)) {
179-
return c->stream_exec->GetDeviceDescription().gpu_compute_capability();
172+
return GetDeviceDescription().gpu_compute_capability();
173+
}
174+
175+
const se::DeviceDescription& GetDeviceDescription() const {
176+
if (auto* device_config = std::get_if<DeviceConfig>(&config_)) {
177+
return device_config->stream_exec->GetDeviceDescription();
180178
}
181-
return std::get<DevicelessConfig>(config_)
182-
.device_description.gpu_compute_capability();
179+
return std::get<DevicelessConfig>(config_).device_description;
183180
}
184181

185182
bool IsDeviceless() const {

third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc

+1-2
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,7 @@ GpuConvAlgorithmPicker::AutotuneRuntimeArguments::FromInstruction(
459459

460460
// Get canonical HLO.
461461
std::string canonical_hlo(
462-
AutotuneCacheKey(config.GetExecutor()->GetDeviceDescription(), *instr)
463-
.GetHlo());
462+
AutotuneCacheKey(config.GetDeviceDescription(), *instr).GetHlo());
464463

465464
TF_ASSIGN_OR_RETURN(GpuConvConfig gpu_conv_config, GetGpuConvConfig(instr));
466465

third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc

+21-18
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ absl::StatusOr<std::unique_ptr<HloModule>> TritonGemmAutotuneExtractor(
380380

381381
// If the priority fusion pass above skipped some instructions, turn them
382382
// into fusions.
383-
FusionWrapper fusion_wrapper;
383+
FusionWrapper fusion_wrapper(gpu_device_info);
384384
TF_RETURN_IF_ERROR(fusion_wrapper.Run(new_module.get()).status());
385385
}
386386
return new_module;
@@ -528,7 +528,7 @@ absl::Status DumpAutotunedFusion(const AutotuneConfig& autotune_config,
528528
TritonGemmConfig::FromProto(result.triton()));
529529
}
530530
const se::DeviceDescription& device_desc =
531-
autotune_config.GetExecutor()->GetDeviceDescription();
531+
autotune_config.GetDeviceDescription();
532532
TF_ASSIGN_OR_RETURN(
533533
std::unique_ptr<HloModule> module,
534534
util.ExtractModule([&](const DebugOptions& debug_opts) {
@@ -693,12 +693,12 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) {
693693
// a sufficient number of thread block programs to occupy all available cores.
694694
// Around 5 full waves completely avoid the need for split-K.
695695
// n_tiles = split_k * (M * N) / (block_m * block_n)
696-
const int kCoreCount =
697-
!config_.IsDeviceless()
698-
? config_.GetExecutor()->GetDeviceDescription().core_count()
699-
: 100; // some sensible default
696+
const int kCoreCount = config_.GetDeviceDescription().core_count();
697+
CHECK_GE(kCoreCount, 1);
700698
const int64_t kSufficientNumberOfTiles = kMaxWavesForSplitK * kCoreCount;
701699
const int64_t result_size = ShapeUtil::ElementsIn(dot.shape());
700+
const int64_t threads_per_warp =
701+
config_.GetDeviceDescription().threads_per_warp();
702702

703703
// Triton configurations are adjusted and deduplicated.
704704
absl::flat_hash_set<TritonGemmConfig> added;
@@ -735,7 +735,7 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) {
735735
2 * std::max(kMinTileSize, kLdmatrixGranularity / minBitWidth));
736736
int meta_elements = config.block_m * config.block_k / 16;
737737
config.num_warps =
738-
std::min<int>(config.num_warps, meta_elements / WarpSize());
738+
std::min<int>(config.num_warps, meta_elements / threads_per_warp);
739739
}
740740

741741
if (added.insert(config).second) {
@@ -783,13 +783,13 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util,
783783
-> absl::StatusOr<bool> {
784784
std::unique_ptr<Executable> executable;
785785
if (std::holds_alternative<TritonGemmConfig>(config)) {
786-
TF_ASSIGN_OR_RETURN(
787-
executable, compile_util.Compile([&](const DebugOptions& opts) {
788-
return TritonGemmAutotuneExtractor(
789-
std::get<TritonGemmConfig>(config),
790-
config_.GetExecutor()->GetDeviceDescription(), fusion, opts,
791-
allow_filtering_kernels_spilling_registers);
792-
}));
786+
TF_ASSIGN_OR_RETURN(executable,
787+
compile_util.Compile([&](const DebugOptions& opts) {
788+
return TritonGemmAutotuneExtractor(
789+
std::get<TritonGemmConfig>(config),
790+
config_.GetDeviceDescription(), fusion, opts,
791+
allow_filtering_kernels_spilling_registers);
792+
}));
793793
} else if (std::holds_alternative<CuDnnConfig>(config)) {
794794
executable =
795795
compile_util
@@ -801,9 +801,9 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util,
801801
} else if (std::holds_alternative<CuBlasConfig>(config)) {
802802
TF_ASSIGN_OR_RETURN(
803803
executable, compile_util.Compile([&](const DebugOptions& opts) {
804-
return CublasGemmAutotuneExtractor(
805-
config_, config_.GetExecutor()->GetDeviceDescription(),
806-
toolkit_version_, fusion, opts);
804+
return CublasGemmAutotuneExtractor(config_,
805+
config_.GetDeviceDescription(),
806+
toolkit_version_, fusion, opts);
807807
}));
808808
} else {
809809
LOG(FATAL) << "Unsupported config type: " << config.index();
@@ -1005,6 +1005,9 @@ GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const {
10051005
bool tune_ctas =
10061006
debug_options_.xla_gpu_enable_triton_hopper() && cc.IsAtLeastHopper();
10071007

1008+
const int64_t threads_per_warp =
1009+
config_.GetDeviceDescription().threads_per_warp();
1010+
10081011
for (int num_stages : kNumStages) {
10091012
// Volta doesn't support num_stages > 2.
10101013
if (!cc.IsAtLeastAmpere() && num_stages > 2) {
@@ -1017,7 +1020,7 @@ GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const {
10171020
const int tile_rhs = tile_k * tile_n;
10181021
for (int num_warps : kNumWarps) {
10191022
// Each thread should read at least one input element.
1020-
if (num_warps * WarpSize() > std::min(tile_lhs, tile_rhs)) {
1023+
if (num_warps * threads_per_warp > std::min(tile_lhs, tile_rhs)) {
10211024
break;
10221025
}
10231026
for (int split_k : kSplitK) {

third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc

+7-1
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ absl::StatusOr<std::vector<TritonGemmConfig>> GetPossibleMatmulAutotuneConfigs(
256256
auto ccc = deviceless_proto.mutable_cuda_compute_capability();
257257
ccc->set_major(compute_capability.major);
258258
ccc->set_minor(compute_capability.minor);
259+
deviceless_proto.set_core_count(100);
260+
deviceless_proto.set_threads_per_warp(32);
259261
DevicelessConfig test_config{se::DeviceDescription{deviceless_proto}};
260262
AutotuneConfig autotune_config{test_config, debug_options};
261263
GemmFusionAutotunerImpl autotuner(autotune_config, toolkit_version,
@@ -941,7 +943,11 @@ ENTRY wais {
941943
compute_capability, GetToolkitVersion(), debug_options));
942944
for (const auto& config : configs) {
943945
int metadata_size = config.block_m * config.block_k / 16;
944-
EXPECT_LE(config.num_warps * WarpSize(), metadata_size);
946+
EXPECT_LE(
947+
config.num_warps *
948+
WarpSize(
949+
backend().default_stream_executor()->GetDeviceDescription()),
950+
metadata_size);
945951
EXPECT_GT(config.block_k, 16); // kMinTileSize
946952
}
947953
}

third_party/xla/xla/service/gpu/buffer_sharing.cc

+10-9
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,10 @@ limitations under the License.
4040
namespace xla {
4141
namespace gpu {
4242

43-
std::optional<bool> FusionCanShareBufferHint(const HloInstruction* user,
44-
const HloInstruction* operand,
45-
const ShapeIndex& user_index) {
43+
std::optional<bool> FusionCanShareBufferHint(
44+
const HloInstruction* user, const HloInstruction* operand,
45+
const ShapeIndex& user_index,
46+
const se::DeviceDescription& device_description) {
4647
const HloFusionInstruction* fusion = DynCast<HloFusionInstruction>(user);
4748
if (fusion == nullptr) {
4849
return std::nullopt;
@@ -77,8 +78,6 @@ std::optional<bool> FusionCanShareBufferHint(const HloInstruction* user,
7778
// Allow multiple output users, if they end in reductions.
7879
// This only works for the reduction emitter, as it calculates the reduction
7980
// first, i.e. before processing other outputs (that may overwrite the input).
80-
stream_executor::GpuDeviceInfoProto device_info;
81-
stream_executor::DeviceDescription device_description(device_info);
8281
auto analysis = HloFusionAnalysis::Create(*user, device_description);
8382
bool is_reduction_emitter = analysis.GetEmitterFusionKind() ==
8483
HloFusionAnalysis::EmitterFusionKind::kReduction;
@@ -219,9 +218,10 @@ std::optional<bool> FusionCanShareBufferHint(const HloInstruction* user,
219218
return found_path_to_output;
220219
}
221220

222-
std::optional<bool> CanShareBufferHint(const HloInstruction* user,
223-
const HloInstruction* operand,
224-
const ShapeIndex& user_index) {
221+
std::optional<bool> CanShareBufferHint(
222+
const HloInstruction* user, const HloInstruction* operand,
223+
const ShapeIndex& user_index,
224+
const se::DeviceDescription& device_description) {
225225
switch (user->opcode()) {
226226
case HloOpcode::kAllReduce:
227227
case HloOpcode::kCollectiveBroadcast:
@@ -243,7 +243,8 @@ std::optional<bool> CanShareBufferHint(const HloInstruction* user,
243243
}
244244
return false;
245245
case HloOpcode::kFusion:
246-
return FusionCanShareBufferHint(user, operand, user_index);
246+
return FusionCanShareBufferHint(user, operand, user_index,
247+
device_description);
247248
default:
248249
return std::nullopt;
249250
}

third_party/xla/xla/service/gpu/buffer_sharing.h

+10-7
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,19 @@ limitations under the License.
2020

2121
#include "xla/hlo/ir/hlo_instruction.h"
2222
#include "xla/shape_util.h"
23+
#include "xla/stream_executor/device_description.h"
2324

2425
namespace xla {
2526
namespace gpu {
26-
std::optional<bool> FusionCanShareBufferHint(const HloInstruction* user,
27-
const HloInstruction* operand,
28-
const ShapeIndex& user_index);
29-
30-
std::optional<bool> CanShareBufferHint(const HloInstruction* user,
31-
const HloInstruction* operand,
32-
const ShapeIndex& user_index);
27+
std::optional<bool> FusionCanShareBufferHint(
28+
const HloInstruction* user, const HloInstruction* operand,
29+
const ShapeIndex& user_index,
30+
const se::DeviceDescription& device_description);
31+
32+
std::optional<bool> CanShareBufferHint(
33+
const HloInstruction* user, const HloInstruction* operand,
34+
const ShapeIndex& user_index,
35+
const se::DeviceDescription& device_description);
3336
} // namespace gpu
3437
} // namespace xla
3538

third_party/xla/xla/service/gpu/fusion_pipeline.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ HloPassPipeline FusionPipeline(
8989
HloPassPipeline HorizontalFusionPipeline(
9090
const se::DeviceDescription& gpu_device_info) {
9191
HloPassFix<HloPassPipeline> horizontal_fusion("horizontal fusion");
92-
horizontal_fusion.AddPass<HorizontalLoopFusion>();
92+
horizontal_fusion.AddPass<HorizontalLoopFusion>(gpu_device_info);
9393
horizontal_fusion.AddPass<HorizontalInputFusion>(gpu_device_info);
9494
horizontal_fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
9595
/*only_fusion_computations=*/true);

third_party/xla/xla/service/gpu/fusions/legacy/BUILD

+2
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ cc_library(
167167
"//xla/service/llvm_ir:kernel_support_library",
168168
"//xla/service/llvm_ir:llvm_loop",
169169
"//xla/service/llvm_ir:llvm_util",
170+
"//xla/stream_executor:device_description",
170171
"@com_google_absl//absl/container:inlined_vector",
171172
"@com_google_absl//absl/log:check",
172173
"@com_google_absl//absl/strings",
@@ -322,6 +323,7 @@ cc_library(
322323
"//xla/service/llvm_ir:ir_array",
323324
"//xla/service/llvm_ir:llvm_util",
324325
"//xla/service/llvm_ir:loop_emitter",
326+
"//xla/stream_executor:device_description",
325327
"@com_google_absl//absl/container:flat_hash_map",
326328
"@com_google_absl//absl/container:inlined_vector",
327329
"@com_google_absl//absl/log:check",

0 commit comments

Comments
 (0)