From 5e847170cf0d0ac457de655ee168f1677389b032 Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Wed, 27 Nov 2024 13:08:26 -0800 Subject: [PATCH] Use DeviceDescription instead of hard-coding warp size as 32 tensorflow@600513b [ROCm] Fix flaky gpu compiler test when building with rocm tensorflow@a35cf48 [XLA:GPU] Use DeviceDescription instead of hard-coding warp size as 32 xla@e849446 [ROCm] Pass correct warp size to Triton pipeline xla@3e7b0fe cherry-picked warp size passing to triton calls, and globally enabled warpsize=64 xla@750ad89 Fixes. --- third_party/xla/xla/service/gpu/BUILD | 12 +- .../service/gpu/autotuning/autotuner_util.h | 19 +-- .../gpu/autotuning/conv_algorithm_picker.cc | 3 +- .../gpu/autotuning/gemm_fusion_autotuner.cc | 39 ++--- .../autotuning/gemm_fusion_autotuner_test.cc | 8 +- .../xla/xla/service/gpu/buffer_sharing.cc | 19 +-- .../xla/xla/service/gpu/buffer_sharing.h | 17 +- .../xla/xla/service/gpu/fusion_pipeline.cc | 2 +- .../xla/xla/service/gpu/fusions/legacy/BUILD | 2 + .../service/gpu/fusions/legacy/reduction.cc | 58 ++++--- .../service/gpu/fusions/legacy/tiling_util.cc | 18 ++- .../service/gpu/fusions/legacy/tiling_util.h | 4 +- .../service/gpu/fusions/legacy/transpose.cc | 9 +- .../gpu/fusions/mlir/mlir_fusion_emitter.cc | 8 +- .../gpu/fusions/mlir/mlir_fusion_emitter.h | 3 +- .../xla/service/gpu/fusions/reduction_base.cc | 12 +- .../xla/service/gpu/fusions/reduction_base.h | 2 +- .../xla/service/gpu/fusions/reduction_mlir.cc | 61 +++++--- .../xla/service/gpu/fusions/reduction_mlir.h | 8 + .../xla/xla/service/gpu/fusions/tools/BUILD | 1 + .../gpu/fusions/tools/mlir_fusions_opt.cc | 4 +- .../transforms/lower_xla_gpu_to_scf.cc | 45 ++++-- .../service/gpu/fusions/transforms/passes.h | 5 +- .../service/gpu/fusions/transforms/passes.td | 6 + .../fusions/transforms/rewrite_reductions.cc | 40 +++-- .../xla/service/gpu/fusions/transpose_mlir.cc | 11 +- .../xla/service/gpu/fusions/transpose_mlir.h | 1 + .../xla/xla/service/gpu/fusions/triton.cc | 6 +- .../triton/compilation_pipeline_cuda.cc | 4 +- .../triton/compilation_pipeline_rocm.cc | 8 +- .../fusions/triton/triton_fusion_emitter.cc | 19 ++- .../fusions/triton/triton_fusion_emitter.h | 4 +- .../triton/triton_fusion_emitter_stub.cc | 4 +- .../xla/xla/service/gpu/gpu_compiler.cc | 59 ++++--- .../xla/xla/service/gpu/gpu_compiler.h | 17 +- .../service/gpu/gpu_copy_insertion_test.cc | 57 ++++++- .../xla/xla/service/gpu/gpu_fusible.cc | 148 +++++++++++------- third_party/xla/xla/service/gpu/gpu_fusible.h | 45 ++++-- .../xla/xla/service/gpu/gpu_fusible_test.cc | 87 ++++++++-- .../xla/service/gpu/gpu_offloading_test.cc | 4 +- .../xla/service/gpu/hlo_fusion_analysis.cc | 9 +- .../service/gpu/hlo_fusion_analysis_test.cc | 1 + .../xla/xla/service/gpu/ir_emission_utils.cc | 11 +- .../xla/xla/service/gpu/ir_emission_utils.h | 5 +- .../xla/service/gpu/ir_emitter_unnested.cc | 18 +-- .../service/gpu/model/coalescing_analysis.cc | 11 +- .../service/gpu/model/coalescing_analysis.h | 1 + .../gpu/model/coalescing_analysis_test.cc | 4 +- .../model/gpu_indexing_performance_model.cc | 16 +- .../model/gpu_indexing_performance_model.h | 3 +- .../gpu_indexing_performance_model_test.cc | 5 +- .../xla/xla/service/gpu/nvptx_compiler.cc | 8 +- .../xla/xla/service/gpu/nvptx_compiler.h | 3 +- .../prepare_hlo_for_ir_emitting_pipeline.cc | 9 +- .../prepare_hlo_for_ir_emitting_pipeline.h | 5 +- .../xla/xla/service/gpu/reduction_utils.cc | 47 +++--- .../xla/xla/service/gpu/reduction_utils.h | 18 ++- .../xla/xla/service/gpu/transforms/BUILD | 8 + .../xla/service/gpu/transforms/copy_fusion.cc | 5 +- .../xla/service/gpu/transforms/copy_fusion.h | 6 +- .../gpu/transforms/copy_fusion_test.cc | 11 ++ .../service/gpu/transforms/fusion_merger.cc | 21 ++- .../service/gpu/transforms/fusion_wrapper.cc | 4 +- .../service/gpu/transforms/fusion_wrapper.h | 6 + .../gpu/transforms/fusion_wrapper_test.cc | 33 +++- .../gpu/transforms/horizontal_input_fusion.cc | 21 ++- .../gpu/transforms/horizontal_loop_fusion.cc | 34 ++-- .../gpu/transforms/horizontal_loop_fusion.h | 8 +- .../transforms/horizontal_loop_fusion_test.cc | 42 +++-- .../gpu/transforms/instruction_fusion.cc | 11 +- .../gpu/transforms/layout_assignment.cc | 3 +- .../gpu/transforms/layout_assignment.h | 5 +- .../gpu/transforms/layout_assignment_test.cc | 53 ++++--- .../gpu/transforms/multi_output_fusion.cc | 24 +-- .../service/gpu/transforms/priority_fusion.cc | 1 + .../gpu/transforms/reduction_splitter.cc | 19 ++- .../gpu/transforms/reduction_splitter.h | 12 +- .../gpu/transforms/reduction_splitter_test.cc | 47 ++++-- .../gpu/transforms/softmax_rewriter_triton.cc | 4 +- .../transforms/stream_attribute_annotator.cc | 10 +- .../transforms/stream_attribute_annotator.h | 8 + .../stream_attribute_annotator_test.cc | 38 +++-- .../gpu/transforms/tree_reduction_rewriter.cc | 18 ++- .../gpu/transforms/tree_reduction_rewriter.h | 7 +- .../tree_reduction_rewriter_test.cc | 13 +- third_party/xla/xla/tools/hlo_opt/gpu_opt.cc | 3 +- 86 files changed, 1027 insertions(+), 501 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index a7acaf2c5146ab..8833eef19ec6e5 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -191,8 +191,10 @@ xla_cc_test( srcs = ["gpu_copy_insertion_test.cc"], deps = [ ":buffer_sharing", + ":gpu_device_info_for_tests", "//xla:test", "//xla:test_helpers", + "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", "//xla/service:copy_insertion", "//xla/tests:hlo_test_base", @@ -263,7 +265,9 @@ xla_cc_test( cc_library( name = "gpu_device_info_for_tests", - testonly = 1, + # This is *not* a test library because it is used in a cc_binary which is used for testing but + # test_only libraries are not allowed in cc_binaries. + testonly = 0, srcs = ["gpu_device_info_for_tests.cc"], hdrs = ["gpu_device_info_for_tests.h"], compatible_with = get_compatible_with_portable(), @@ -704,6 +708,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", "//xla/stream_executor:semantic_version", + "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", @@ -1345,6 +1350,7 @@ cc_library( "//xla/service/gpu/transforms:copy_fusion", "//xla/service/gpu/transforms:horizontal_loop_fusion", "//xla/service/gpu/transforms:sanitize_constant_names", + "//xla/stream_executor:device_description", ], ) @@ -2521,6 +2527,10 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/service:hlo_parser", "//xla/tests:hlo_test_base", + "//xla/service:hlo_runner", + "//xla/service:instruction_fusion", + "//xla/service:platform_util", + "//xla/stream_executor:device_description", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h index e70b252abb30a0..3de7738f416803 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h @@ -141,14 +141,8 @@ class AutotuneConfig { debug_options.xla_gpu_experimental_autotune_cache_mode()) {} std::string GetModelStr() const { - if (auto deviceless_config = std::get_if(&config_)) { - return AutotuneCacheKey::DeviceDescriptionToCacheKey( - deviceless_config->device_description); - } - - const auto& device_config = std::get(config_); return AutotuneCacheKey::DeviceDescriptionToCacheKey( - device_config.stream_exec->GetDeviceDescription()); + GetDeviceDescription()); } se::StreamExecutor* GetExecutor() const { @@ -175,11 +169,14 @@ class AutotuneConfig { } const se::GpuComputeCapability& GetGpuComputeCapability() const { - if (auto c = std::get_if(&config_)) { - return c->stream_exec->GetDeviceDescription().gpu_compute_capability(); + return GetDeviceDescription().gpu_compute_capability(); + } + + const se::DeviceDescription& GetDeviceDescription() const { + if (auto* device_config = std::get_if(&config_)) { + return device_config->stream_exec->GetDeviceDescription(); } - return std::get(config_) - .device_description.gpu_compute_capability(); + return std::get(config_).device_description; } bool IsDeviceless() const { diff --git a/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc index 90437a5633f509..e641c6f865bfc5 100644 --- a/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc @@ -459,8 +459,7 @@ GpuConvAlgorithmPicker::AutotuneRuntimeArguments::FromInstruction( // Get canonical HLO. std::string canonical_hlo( - AutotuneCacheKey(config.GetExecutor()->GetDeviceDescription(), *instr) - .GetHlo()); + AutotuneCacheKey(config.GetDeviceDescription(), *instr).GetHlo()); TF_ASSIGN_OR_RETURN(GpuConvConfig gpu_conv_config, GetGpuConvConfig(instr)); diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index 79524924584c97..8b0679613bb6e2 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -380,7 +380,7 @@ absl::StatusOr> TritonGemmAutotuneExtractor( // If the priority fusion pass above skipped some instructions, turn them // into fusions. - FusionWrapper fusion_wrapper; + FusionWrapper fusion_wrapper(gpu_device_info); TF_RETURN_IF_ERROR(fusion_wrapper.Run(new_module.get()).status()); } return new_module; @@ -528,7 +528,7 @@ absl::Status DumpAutotunedFusion(const AutotuneConfig& autotune_config, TritonGemmConfig::FromProto(result.triton())); } const se::DeviceDescription& device_desc = - autotune_config.GetExecutor()->GetDeviceDescription(); + autotune_config.GetDeviceDescription(); TF_ASSIGN_OR_RETURN( std::unique_ptr module, util.ExtractModule([&](const DebugOptions& debug_opts) { @@ -693,12 +693,12 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) { // a sufficient number of thread block programs to occupy all available cores. // Around 5 full waves completely avoid the need for split-K. // n_tiles = split_k * (M * N) / (block_m * block_n) - const int kCoreCount = - !config_.IsDeviceless() - ? config_.GetExecutor()->GetDeviceDescription().core_count() - : 100; // some sensible default + const int kCoreCount = config_.GetDeviceDescription().core_count(); + CHECK_GE(kCoreCount, 1); const int64_t kSufficientNumberOfTiles = kMaxWavesForSplitK * kCoreCount; const int64_t result_size = ShapeUtil::ElementsIn(dot.shape()); + const int64_t threads_per_warp = + config_.GetDeviceDescription().threads_per_warp(); // Triton configurations are adjusted and deduplicated. absl::flat_hash_set added; @@ -735,7 +735,7 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) { 2 * std::max(kMinTileSize, kLdmatrixGranularity / minBitWidth)); int meta_elements = config.block_m * config.block_k / 16; config.num_warps = - std::min(config.num_warps, meta_elements / WarpSize()); + std::min(config.num_warps, meta_elements / threads_per_warp); } if (added.insert(config).second) { @@ -783,13 +783,13 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util, -> absl::StatusOr { std::unique_ptr executable; if (std::holds_alternative(config)) { - TF_ASSIGN_OR_RETURN( - executable, compile_util.Compile([&](const DebugOptions& opts) { - return TritonGemmAutotuneExtractor( - std::get(config), - config_.GetExecutor()->GetDeviceDescription(), fusion, opts, - allow_filtering_kernels_spilling_registers); - })); + TF_ASSIGN_OR_RETURN(executable, + compile_util.Compile([&](const DebugOptions& opts) { + return TritonGemmAutotuneExtractor( + std::get(config), + config_.GetDeviceDescription(), fusion, opts, + allow_filtering_kernels_spilling_registers); + })); } else if (std::holds_alternative(config)) { executable = compile_util @@ -801,9 +801,9 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util, } else if (std::holds_alternative(config)) { TF_ASSIGN_OR_RETURN( executable, compile_util.Compile([&](const DebugOptions& opts) { - return CublasGemmAutotuneExtractor( - config_, config_.GetExecutor()->GetDeviceDescription(), - toolkit_version_, fusion, opts); + return CublasGemmAutotuneExtractor(config_, + config_.GetDeviceDescription(), + toolkit_version_, fusion, opts); })); } else { LOG(FATAL) << "Unsupported config type: " << config.index(); @@ -1005,6 +1005,9 @@ GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const { bool tune_ctas = debug_options_.xla_gpu_enable_triton_hopper() && cc.IsAtLeastHopper(); + const int64_t threads_per_warp = + config_.GetDeviceDescription().threads_per_warp(); + for (int num_stages : kNumStages) { // Volta doesn't support num_stages > 2. if (!cc.IsAtLeastAmpere() && num_stages > 2) { @@ -1017,7 +1020,7 @@ GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const { const int tile_rhs = tile_k * tile_n; for (int num_warps : kNumWarps) { // Each thread should read at least one input element. - if (num_warps * WarpSize() > std::min(tile_lhs, tile_rhs)) { + if (num_warps * threads_per_warp > std::min(tile_lhs, tile_rhs)) { break; } for (int split_k : kSplitK) { diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index d9bec3a09906a8..4f380785dcc5da 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -256,6 +256,8 @@ absl::StatusOr> GetPossibleMatmulAutotuneConfigs( auto ccc = deviceless_proto.mutable_cuda_compute_capability(); ccc->set_major(compute_capability.major); ccc->set_minor(compute_capability.minor); + deviceless_proto.set_core_count(100); + deviceless_proto.set_threads_per_warp(32); DevicelessConfig test_config{se::DeviceDescription{deviceless_proto}}; AutotuneConfig autotune_config{test_config, debug_options}; GemmFusionAutotunerImpl autotuner(autotune_config, toolkit_version, @@ -941,7 +943,11 @@ ENTRY wais { compute_capability, GetToolkitVersion(), debug_options)); for (const auto& config : configs) { int metadata_size = config.block_m * config.block_k / 16; - EXPECT_LE(config.num_warps * WarpSize(), metadata_size); + EXPECT_LE( + config.num_warps * + WarpSize( + backend().default_stream_executor()->GetDeviceDescription()), + metadata_size); EXPECT_GT(config.block_k, 16); // kMinTileSize } } diff --git a/third_party/xla/xla/service/gpu/buffer_sharing.cc b/third_party/xla/xla/service/gpu/buffer_sharing.cc index 0ffb8e3fe63de9..9c4610005e5fb9 100644 --- a/third_party/xla/xla/service/gpu/buffer_sharing.cc +++ b/third_party/xla/xla/service/gpu/buffer_sharing.cc @@ -40,9 +40,10 @@ limitations under the License. namespace xla { namespace gpu { -std::optional FusionCanShareBufferHint(const HloInstruction* user, - const HloInstruction* operand, - const ShapeIndex& user_index) { +std::optional FusionCanShareBufferHint( + const HloInstruction* user, const HloInstruction* operand, + const ShapeIndex& user_index, + const se::DeviceDescription& device_description) { const HloFusionInstruction* fusion = DynCast(user); if (fusion == nullptr) { return std::nullopt; @@ -77,8 +78,6 @@ std::optional FusionCanShareBufferHint(const HloInstruction* user, // Allow multiple output users, if they end in reductions. // This only works for the reduction emitter, as it calculates the reduction // first, i.e. before processing other outputs (that may overwrite the input). - stream_executor::GpuDeviceInfoProto device_info; - stream_executor::DeviceDescription device_description(device_info); auto analysis = HloFusionAnalysis::Create(*user, device_description); bool is_reduction_emitter = analysis.GetEmitterFusionKind() == HloFusionAnalysis::EmitterFusionKind::kReduction; @@ -219,9 +218,10 @@ std::optional FusionCanShareBufferHint(const HloInstruction* user, return found_path_to_output; } -std::optional CanShareBufferHint(const HloInstruction* user, - const HloInstruction* operand, - const ShapeIndex& user_index) { +std::optional CanShareBufferHint( + const HloInstruction* user, const HloInstruction* operand, + const ShapeIndex& user_index, + const se::DeviceDescription& device_description) { switch (user->opcode()) { case HloOpcode::kAllReduce: case HloOpcode::kCollectiveBroadcast: @@ -243,7 +243,8 @@ std::optional CanShareBufferHint(const HloInstruction* user, } return false; case HloOpcode::kFusion: - return FusionCanShareBufferHint(user, operand, user_index); + return FusionCanShareBufferHint(user, operand, user_index, + device_description); default: return std::nullopt; } diff --git a/third_party/xla/xla/service/gpu/buffer_sharing.h b/third_party/xla/xla/service/gpu/buffer_sharing.h index 7fdf4af78c11c7..a42f65fbdd9f24 100644 --- a/third_party/xla/xla/service/gpu/buffer_sharing.h +++ b/third_party/xla/xla/service/gpu/buffer_sharing.h @@ -20,16 +20,19 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { -std::optional FusionCanShareBufferHint(const HloInstruction* user, - const HloInstruction* operand, - const ShapeIndex& user_index); - -std::optional CanShareBufferHint(const HloInstruction* user, - const HloInstruction* operand, - const ShapeIndex& user_index); +std::optional FusionCanShareBufferHint( + const HloInstruction* user, const HloInstruction* operand, + const ShapeIndex& user_index, + const se::DeviceDescription& device_description); + +std::optional CanShareBufferHint( + const HloInstruction* user, const HloInstruction* operand, + const ShapeIndex& user_index, + const se::DeviceDescription& device_description); } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusion_pipeline.cc b/third_party/xla/xla/service/gpu/fusion_pipeline.cc index e27865c06c63d1..6100774e3766aa 100644 --- a/third_party/xla/xla/service/gpu/fusion_pipeline.cc +++ b/third_party/xla/xla/service/gpu/fusion_pipeline.cc @@ -89,7 +89,7 @@ HloPassPipeline FusionPipeline( HloPassPipeline HorizontalFusionPipeline( const se::DeviceDescription& gpu_device_info) { HloPassFix horizontal_fusion("horizontal fusion"); - horizontal_fusion.AddPass(); + horizontal_fusion.AddPass(gpu_device_info); horizontal_fusion.AddPass(gpu_device_info); horizontal_fusion.AddPass(/*is_layout_sensitive=*/true, /*only_fusion_computations=*/true); diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/BUILD b/third_party/xla/xla/service/gpu/fusions/legacy/BUILD index 98d8ade7c5e5c3..bf1c9abb27f706 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/legacy/BUILD @@ -167,6 +167,7 @@ cc_library( "//xla/service/llvm_ir:kernel_support_library", "//xla/service/llvm_ir:llvm_loop", "//xla/service/llvm_ir:llvm_util", + "//xla/stream_executor:device_description", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", @@ -322,6 +323,7 @@ cc_library( "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:loop_emitter", + "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc index e009ea18e0b48c..2694fdd3def406 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc @@ -182,6 +182,10 @@ class ReductionEmitter { return reduction_codegen_info_.GetTiling().GetShape()[2]; } + int64_t WarpSize() const { + return ::xla::gpu::WarpSize(analysis_.device_info()); + } + llvm::IRBuilder<>* builder_; GpuElementalIrEmitter elemental_emitter_; const HloFusionAnalysis& analysis_; @@ -311,19 +315,20 @@ ReductionGroupEmitter::ReductionGroupEmitter( if (reduction_info.IsRowReduction()) { // Multi-row reductions do not use shared memory. if (RowReductionGetRowsPerWarp( - reduction_emitter_.ReducedDimensionSize()) > 1) { + reduction_emitter_.ReducedDimensionSize(), + reduction_emitter_.WarpSize()) > 1) { return std::nullopt; } // Allocate one shared memory element per warp. auto block_size = tiling.GetThreadsPerBlock(); CHECK_EQ(block_size[ReductionDimensions::kRowMinorReducedDimension] % - WarpSize(), + reduction_emitter_.WarpSize(), 0); return llvm_ir::AllocateSharedMemoryTile( module, element_type, {block_size[ReductionDimensions::kRowKeptDimension], block_size[ReductionDimensions::kRowMinorReducedDimension] / - WarpSize()}, + reduction_emitter_.WarpSize()}, "shared_cache"); } const auto& num_threads = tiling.GetThreadsPerBlock(); @@ -528,7 +533,7 @@ void ReductionGroupEmitter::EmitFullWarpShuffleDownLoopForReduce( // We check this here as a mistake in the number of threads per // block is very hard to detect. CHECK_EQ(threads_per_block % 32, 0); - CHECK_EQ(WarpSize() % num_results_per_warp, 0); + CHECK_EQ(reduction_emitter_.WarpSize() % num_results_per_warp, 0); auto* builder = reduction_emitter_.builder_; for (int distance = 16 / num_results_per_warp; distance >= 1; distance /= 2) { @@ -682,14 +687,15 @@ void ReductionGroupEmitter::EmitReductionOutputForRowReduction( const auto& reduction_info = reduction_emitter_.reduction_codegen_info_; const Tiling& tiling = reduction_info.GetTiling(); - int num_rows_per_warp = - RowReductionGetRowsPerWarp(reduction_emitter_.ReducedDimensionSize()); + int num_rows_per_warp = RowReductionGetRowsPerWarp( + reduction_emitter_.ReducedDimensionSize(), reduction_emitter_.WarpSize()); EmitFullWarpShuffleDownLoopForReduce(reducer, absl::MakeSpan(current_outputs), tiling.GetNumThreadsPerBlock(), num_rows_per_warp); KernelSupportLibrary ksl(builder); - llvm::Value* warp_id = builder->CreateUDiv(thread_id_x, constant(WarpSize())); + llvm::Value* warp_id = + builder->CreateUDiv(thread_id_x, constant(reduction_emitter_.WarpSize())); auto emit_write_output = [&](llvm::Value* write_condition, const absl::Span values) { @@ -753,7 +759,7 @@ void ReductionGroupEmitter::EmitReductionOutputForRowReduction( thread_id_x, constant(tiling.GetThreadsPerBlock() [ReductionDimensions::kRowMinorReducedDimension] / - WarpSize())); + reduction_emitter_.WarpSize())); llvm::Value* selected_value = builder->CreateSelect( warp_exists, block_accum_addr, initial_value_addr); @@ -768,7 +774,8 @@ void ReductionGroupEmitter::EmitReductionOutputForRowReduction( // communication using shared memory and synchronization using barrier is // also unnecessary and should be removed. if (tiling.GetThreadsPerBlock() - [ReductionDimensions::kRowMinorReducedDimension] > WarpSize()) { + [ReductionDimensions::kRowMinorReducedDimension] > + reduction_emitter_.WarpSize()) { EmitFullWarpShuffleDownLoopForReduce( reducer, absl::MakeSpan(selected_values), tiling.GetNumThreadsPerBlock(), /*num_results_per_warp=*/1); @@ -913,7 +920,7 @@ absl::Status ReductionEmitter::EmitIRForReduction( for (const HloInstruction* hlo : instr_index_group) { auto& hero = FindNonTrivialHero(*hlo); - if (IsRealReductionHero(*hlo, hero)) { + if (IsRealReductionHero(*hlo, hero, analysis_.device_info())) { auto reduction = Cast(&hero); if (heroes_to_roots.find(reduction) == heroes_to_roots.end()) { heroes.push_back(reduction); @@ -956,7 +963,8 @@ absl::Status ReductionEmitter::EmitIRForReduction( }; EmitTile(builder_, reduction_codegen_info_.GetTiling(), thread_id_info, tile_dimensions, emit_element); - })); + }, + analysis_.device_info())); KernelSupportLibrary ksl(builder_); for (auto reduce : heroes) { @@ -1012,7 +1020,8 @@ absl::StatusOr ReductionEmitter::EmitInitializers() { for (int i = 0; i < fusion_roots.size(); ++i) { const HloInstruction* fusion_root = &fusion_roots[i].instruction(); - if (IsReductionFromOrToContiguousDimensions(*fusion_root)) { + if (IsReductionFromOrToContiguousDimensions(*fusion_root, + analysis_.device_info())) { TF_ASSIGN_OR_RETURN( result.thunks.emplace_back(), BuildFusedInitializerThunk(fusion_root, slices[i], i)); @@ -1102,7 +1111,8 @@ absl::Status ReductionFusion::EmitKernel(IrEmitterContext& ir_emitter_context, int ReductionInfo::GetRowsPerWarp() const { if (!is_row_reduction_) return 1; return RowReductionGetRowsPerWarp( - tiling_.GetShape()[ReductionDimensions::kRowMinorReducedDimension]); + tiling_.GetShape()[ReductionDimensions::kRowMinorReducedDimension], + WarpSize(analysis_.device_info())); } LaunchDimensions ReductionInfo::launch_dimensions() const { @@ -1124,12 +1134,14 @@ ReductionInfo ReductionInfo::Create(const HloFusionAnalysis& analysis) { << " " << shape[0] << " " << shape[1] << " " << shape[2]; Vector3 reduction_tiling = GetReductionTiling(reduction_dimensions); - int64_t num_threads_y = - reduction_dimensions.is_row_reduction ? 1 : WarpSize(); + int64_t num_threads_y = reduction_dimensions.is_row_reduction + ? 1 + : WarpSize(analysis.device_info()); int64_t rows_per_warp = reduction_dimensions.is_row_reduction ? RowReductionGetRowsPerWarp( - shape[ReductionDimensions::kRowMinorReducedDimension]) + shape[ReductionDimensions::kRowMinorReducedDimension], + WarpSize(analysis.device_info())) : 1; int64_t num_threads_x = [&] { if (reduction_dimensions.is_row_reduction) { @@ -1144,9 +1156,9 @@ ReductionInfo ReductionInfo::Create(const HloFusionAnalysis& analysis) { CeilOfRatio(shape[ReductionDimensions::kRowMinorReducedDimension], reduction_tiling [ReductionDimensions::kRowMinorReducedDimension]), - WarpSize())); + WarpSize(analysis.device_info()))); } - return WarpSize(); + return WarpSize(analysis.device_info()); }(); // If we're limited by the size of the x dimension, add additional parallelism @@ -1197,8 +1209,9 @@ ReductionInfo ReductionInfo::Create(const HloFusionAnalysis& analysis) { Tiling tiling(tiled_shape, tile_per_thread, num_threads, /*loops_to_unroll=*/{false, false, true, false}); - bool reduction_is_race_free = ReductionIsRaceFree( - hero_reduction->GetModule()->config(), reduction_dimensions); + bool reduction_is_race_free = + ReductionIsRaceFree(hero_reduction->GetModule()->config(), + reduction_dimensions, analysis.device_info()); return ReductionInfo(analysis, tiling, reduction_dimensions.is_row_reduction, reduction_is_race_free, GroupDisjointReductions(analysis, /*for_mlir=*/false), @@ -1250,7 +1263,8 @@ std::optional ReductionInfo::ComputeThreadIdToOutputIndexing( int rows_per_warp = GetRowsPerWarp(); if (rows_per_warp > 1) { linear_index.AddConstraint( - thread_ids[kRowMinorReduced] % (WarpSize() / rows_per_warp), + thread_ids[kRowMinorReduced] % + (WarpSize(analysis_.device_info()) / rows_per_warp), {0, 0}); } else { linear_index.AddConstraint(thread_ids[kRowMinorReduced], {0, 0}); @@ -1277,7 +1291,7 @@ std::optional ReductionInfo::ComputeThreadIdToOutputIndexing( projected_index.AddConstraint( mlir::getAffineDimExpr( KernelFusionInterface::kIndexingMapThreadIdxDims[0], ctx) % - WarpSize(), + WarpSize(analysis_.device_info()), {0, 0}); if (!is_row_reduction_) { projected_index.AddConstraint( diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc index a1a7acb58388a7..a04986e537f36f 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc @@ -167,9 +167,9 @@ llvm::Value* EmitThreadId(llvm::IRBuilder<>* builder, int64_t threads_per_block, // Emits the LLVM values for thread_id, block_id, coordinates of the current // tile and strides of the loops to iterate over the current tile. -absl::StatusOr EmitThreadIdInfo(llvm::IRBuilder<>* builder, - const Tiling& tiling, - llvm::Type* index_ty) { +absl::StatusOr EmitThreadIdInfo( + llvm::IRBuilder<>* builder, const Tiling& tiling, llvm::Type* index_ty, + const se::DeviceDescription& gpu_device_info) { auto constant = [&](uint64_t c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; @@ -202,8 +202,8 @@ absl::StatusOr EmitThreadIdInfo(llvm::IRBuilder<>* builder, } } - info.lane_id = - builder->CreateURem(info.thread_id, constant(WarpSize()), "lane_id"); + info.lane_id = builder->CreateURem( + info.thread_id, constant(WarpSize(gpu_device_info)), "lane_id"); return info; } @@ -218,15 +218,17 @@ AffineMap GetTilingAffineMap(llvm::ArrayRef exprs, absl::StatusOr EmitTilingKernel( llvm::IRBuilder<>* builder, const Tiling& tiling, llvm::Type* index_ty, - const TileGenerator& tile_element_generator) { + const TileGenerator& tile_element_generator, + const se::DeviceDescription& gpu_device_info) { absl::Span dims_in_elems = tiling.GetShape(); const auto& block_counts = tiling.GetBlockCounts(); auto constant = [&](uint64_t c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; - TF_ASSIGN_OR_RETURN(TilingThreadIdInfo thread_id_info, - EmitThreadIdInfo(builder, tiling, index_ty)); + TF_ASSIGN_OR_RETURN( + TilingThreadIdInfo thread_id_info, + EmitThreadIdInfo(builder, tiling, index_ty, gpu_device_info)); KernelSupportLibrary ksl(builder, llvm_ir::UnrollMode::kDefaultUnroll); diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.h b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.h index de367e36addb61..7cb9fd1ec3e5f1 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.h +++ b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.h @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/llvm_ir/ir_array.h" +#include "xla/stream_executor/device_description.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" @@ -178,7 +179,8 @@ void EmitTile(llvm::IRBuilder<>* builder, const Tiling& tiling, // scheme. absl::StatusOr EmitTilingKernel( llvm::IRBuilder<>* builder, const Tiling& tiling, llvm::Type* index_ty, - const TileGenerator& tile_element_generator); + const TileGenerator& tile_element_generator, + const se::DeviceDescription& gpu_device_info); // Creates an indexing map from thread and block IDs to elements of the tiled // shape. Uses the same convention as KernelFusionInterface: dimensions 0 to 2 diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc index f91a0a4b6b120f..e631583456ef98 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc @@ -62,7 +62,7 @@ namespace { Tiling ComputeTransposeTiling(const se::DeviceDescription& gpu_device_info, const TransposeDescription& tiled_transpose) { constexpr int kNumRows = 4; - static_assert(WarpSize() % kNumRows == 0); + assert(WarpSize(gpu_device_info) % kNumRows == 0); // 3D view over the output shape. absl::InlinedVector transposed_dims = tiled_transpose.dimensions; @@ -79,8 +79,8 @@ Tiling ComputeTransposeTiling(const se::DeviceDescription& gpu_device_info, // We tile along the minor dimensions pre- and post-transpose. absl::InlinedVector tile_sizes{1, 1, 1}; - tile_sizes[permutation[2]] = WarpSize() / kNumRows; - absl::InlinedVector num_threads{1, 1, WarpSize()}; + tile_sizes[permutation[2]] = WarpSize(gpu_device_info) / kNumRows; + absl::InlinedVector num_threads{1, 1, WarpSize(gpu_device_info)}; num_threads[permutation[2]] = kNumRows; auto capability = gpu_device_info.gpu_compute_capability(); @@ -298,7 +298,8 @@ absl::Status TransposeFusion::EmitKernel(IrEmitterContext& ir_emitter_context, llvm::Type* index_type = GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder); - return EmitTilingKernel(builder, tiling_, index_type, tile_generator) + return EmitTilingKernel(builder, tiling_, index_type, tile_generator, + analysis_.device_info()) .status(); } diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc index 448d3050c30bd4..a671901862ec1c 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc @@ -305,7 +305,7 @@ MlirFusionEmitterBase::CreateLLVMModule( mlir::PassManager pm(&mlir_context); AddXlaGpuOpsOptimizationPasses(pm); - AddLoopTransformationPasses(pm); + AddLoopTransformationPasses(pm, device); AddLoweringPasses(pm, device); auto pipeline_status = RunPassPipeline(module.get(), pm, trace.get()); if (trace) { @@ -539,8 +539,10 @@ void AddXlaGpuOpsOptimizationPasses(mlir::OpPassManager& pm) { pm.addPass(mlir::createCSEPass()); } -void AddLoopTransformationPasses(mlir::OpPassManager& pm) { - pm.addNestedPass(CreateLowerXlaGpuToScfPass()); +void AddLoopTransformationPasses(mlir::OpPassManager& pm, + const se::DeviceDescription& device) { + pm.addNestedPass( + CreateLowerXlaGpuToScfPass(device.threads_per_warp())); pm.addPass(mlir::createInlinerPass({}, [&](mlir::OpPassManager& pm) { // CSE after inlining because inlining can introduce duplicates. pm.addPass(mlir::createCSEPass()); diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h index 68ce87f4374aab..1a47306e1b1284 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h @@ -118,7 +118,8 @@ void AddXlaGpuOpsOptimizationPasses(mlir::OpPassManager& pm); // Adds passes that transform XLA_GPU and SCF loops, e.g. peel, pipeline, // vectorize. -void AddLoopTransformationPasses(mlir::OpPassManager& pm); +void AddLoopTransformationPasses(mlir::OpPassManager& pm, + const se::DeviceDescription& device); // Adds passes that lower transformed loops to LLVM. void AddLoweringPasses(mlir::OpPassManager& pm, diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_base.cc b/third_party/xla/xla/service/gpu/fusions/reduction_base.cc index b7f62e3c7d1d54..5e788f14a45bee 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_base.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_base.cc @@ -55,12 +55,12 @@ limitations under the License. namespace xla { namespace gpu { -int RowReductionGetRowsPerWarp(int reduced_dimension_size) { - if (WarpSize() % reduced_dimension_size != 0 || - reduced_dimension_size >= WarpSize()) { +int RowReductionGetRowsPerWarp(int reduced_dimension_size, int64_t warp_size) { + if (warp_size % reduced_dimension_size != 0 || + reduced_dimension_size >= warp_size) { return 1; } - return WarpSize() / reduced_dimension_size; + return warp_size / reduced_dimension_size; } int GetVectorSize(const HloFusionAnalysis& analysis, @@ -168,8 +168,8 @@ ReductionGroups GroupDisjointReductions(const HloFusionAnalysis& analysis, auto [it, inserted] = disjoint_sets.try_emplace(root, root); CHECK(inserted) << "Duplicate root " << root.ToString(); // Crash OK reachable_outputs[root].insert(root); - result.is_reduction_root.push_back( - IsRealReductionHero(root.instruction(), hero.instruction())); + result.is_reduction_root.push_back(IsRealReductionHero( + root.instruction(), hero.instruction(), analysis.device_info())); if (result.is_reduction_root.back()) { roots_with_reduction.insert(root); } else if (first_non_reduction_root != nullptr) { diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_base.h b/third_party/xla/xla/service/gpu/fusions/reduction_base.h index ad99e6f40140c9..782b1a1ad1e93a 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_base.h +++ b/third_party/xla/xla/service/gpu/fusions/reduction_base.h @@ -41,7 +41,7 @@ struct ReductionGroups { ReductionGroups GroupDisjointReductions(const HloFusionAnalysis& analysis, bool for_mlir); -int RowReductionGetRowsPerWarp(int reduced_dimension_size); +int RowReductionGetRowsPerWarp(int reduced_dimension_size, int64_t warp_size); int GetVectorSize(const HloFusionAnalysis& analysis, const ReductionDimensions& reduction_dimensions, diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc index b2db8fa5cd1730..22300d64a16e8e 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc @@ -117,10 +117,18 @@ struct MlirReductionFusion::EmitterState { PerThreadOutputs EmitPerThreadElements(int group_id, const HloValueMap& inits, const SmallVector& outputs); + mlir::ValueRange ReduceViaSharedMemory(int group_id, + const PerThreadOutputs& per_thread, + const HloValueMap& inits, + std::optional padding, + int max_dist); + mlir::ValueRange ReduceViaSharedMemory( int group_id, const PerThreadOutputs& per_thread, - const HloValueMap& inits, std::optional padding = std::nullopt, - int max_dist = WarpSize() / 2); + const HloValueMap& inits, std::optional padding = std::nullopt) { + return ReduceViaSharedMemory(group_id, per_thread, inits, padding, + owner.WarpSize() / 2); + } mlir::func::FuncOp GetReducer(const HloInstruction* hero) const { return call_target(hero->called_computations()[0]->root_instruction()); @@ -133,8 +141,12 @@ struct MlirReductionFusion::EmitterState { const HloValueMap& values, std::optional padding = std::nullopt); HloValueMap ShuffleReduce(absl::Span reductions, - const HloValueMap& per_thread_values, - int max_dist = WarpSize() / 2); + const HloValueMap& per_thread_values, int max_dist); + + HloValueMap ShuffleReduce(absl::Span reductions, + const HloValueMap& per_thread_values) { + return ShuffleReduce(reductions, per_thread_values, owner.WarpSize() / 2); + } SmallVector FusionParams() { return ValueRange(entry_function.getArguments().take_front( @@ -253,7 +265,7 @@ SmallVector MlirReductionFusion::EmitterState::WriteToSharedMemory( } if (padding) { shape.back() += *padding; - } else if ((shape.back() % WarpSize()) == 0) { + } else if ((shape.back() % owner.WarpSize()) == 0) { // Avoid bank conflicts. ++shape.back(); } @@ -325,7 +337,7 @@ mlir::ValueRange MlirReductionFusion::EmitterState::ReduceViaSharedMemory( // The constraints may have reduced the upper bound of the dimension. If // that's the case, we reset it to a multiple of the warp size. auto& bound = loop_indexing.GetMutableDimensionBound(0); - bound.upper = RoundUpTo(bound.upper + 1, WarpSize()) - 1; + bound.upper = RoundUpTo(bound.upper + 1, owner.WarpSize()) - 1; auto tiles = WriteToSharedMemory(reductions, per_thread.reduction_scalars, padding); @@ -365,7 +377,7 @@ MlirReductionFusion::MlirReductionFusion(const HloFusionAnalysis& analysis) VLOG(10) << reduction_dimensions_; CHECK(ReductionIsRaceFree(hero_reduction->GetModule()->config(), - reduction_dimensions_)) + reduction_dimensions_, analysis.device_info())) << "Non-race-free reductions should have been decomposed. Did " "tree_reduction_rewriter run?"; @@ -583,9 +595,9 @@ MlirColumnReductionFusion::MlirColumnReductionFusion( reduction_dimensions_.dimensions[1], reduction_dimensions_.dimensions[2]}; vector_size_ = GetVectorSizeForMlir( - analysis, /*minor_dim=*/input_shape_.back(), WarpSize()); - int64_t num_warps_per_column = WarpSize(); - num_threads_ = {num_warps_per_column, WarpSize()}; + analysis, /*minor_dim=*/input_shape_.back(), kTileSize); + int64_t num_warps_per_column = kTileSize; + num_threads_ = {num_warps_per_column, kTileSize}; int64_t num_col_elements_per_thread = CeilOfRatio(reduction_dimensions_ .dimensions[ReductionDimensions::kColReducedDimension], @@ -599,7 +611,7 @@ MlirColumnReductionFusion::MlirColumnReductionFusion( reduction_dimensions_ .dimensions[ReductionDimensions::kColMinorKeptDimension]; int64_t num_blocks_per_row = - CeilOfRatio(minor_kept_dim, WarpSize() * vector_size_); + CeilOfRatio(minor_kept_dim, kTileSize * vector_size_); num_blocks_ = {major_kept_dim, num_blocks_per_row}; } @@ -612,7 +624,7 @@ IndexingMap MlirColumnReductionFusion::ComputeReductionOutputIndexing( auto vector_index = getAffineSymbolExpr(0, ctx); SmallVector results{ block_id[0], - (block_id[1] * WarpSize() + thread_id[0]) * vector_size_ + vector_index}; + (block_id[1] * kTileSize + thread_id[0]) * vector_size_ + vector_index}; IndexingMap projected_index = GetIndexingMap(results, /*symbol_sizes=*/{vector_size_}); projected_index.AddConstraint(thread_id[1], {0, 0}); @@ -630,7 +642,7 @@ IndexingMap MlirColumnReductionFusion::ComputeReductionInputIndexing( SmallVector results{ block_id[0], thread_id[0] + element_index * num_threads_[1], - (block_id[1] * WarpSize() + thread_id[1]) * vector_size_ + vector_index}; + (block_id[1] * kTileSize + thread_id[1]) * vector_size_ + vector_index}; IndexingMap map = GetIndexingMap(results, tile_sizes_per_thread_); for (auto [result, dim_size] : llvm::zip(results, reduction_dimensions_.dimensions)) { @@ -678,20 +690,20 @@ MlirSmallColumnReductionFusion::MlirSmallColumnReductionFusion( // We emit a single loop over the dimensions 1 and 2, so we use their total // size when computing the vector size. vector_size_ = GetVectorSizeForMlir( - analysis, /*minor_dim=*/input_shape_[1] * input_shape_[2], WarpSize()); + analysis, /*minor_dim=*/input_shape_[1] * input_shape_[2], kTileSize); num_threads_ = {128}; shared_rows_ = vector_size_ * num_threads_[0] / input_shape_[kColMinorKept]; // If we have more than 32 shared rows, we'd have to go through shared // memory one extra time. We don't currently support that, and it's not been // tried, so we have to reduce the vector size/number of threads. - while (shared_rows_ > WarpSize() && vector_size_ > 1) { + while (shared_rows_ > kTileSize && vector_size_ > 1) { vector_size_ /= 2; shared_rows_ /= 2; } - if (shared_rows_ > WarpSize()) { - num_threads_[0] /= (shared_rows_ / WarpSize()); - shared_rows_ = WarpSize(); + if (shared_rows_ > kTileSize) { + num_threads_[0] /= (shared_rows_ / kTileSize); + shared_rows_ = kTileSize; } num_blocks_ = {input_shape_[kColMajorKept]}; @@ -776,15 +788,16 @@ std::unique_ptr CreateMlirReductionFusion( CHECK_NE(hero_reduction, nullptr); ReductionDimensions reduction_dimensions = GetReductionKindAndContiguousComponents(*hero_reduction); + const int64_t warp_size = analysis.device_info().threads_per_warp(); if (reduction_dimensions.is_row_reduction) { if (RowReductionGetRowsPerWarp( - reduction_dimensions.dimensions[kRowMinorReduced]) > 1) { + reduction_dimensions.dimensions[kRowMinorReduced], warp_size) > 1) { return std::make_unique(analysis); } return std::make_unique(analysis); } - if (WarpSize() % reduction_dimensions.dimensions[kColMinorKept] == 0) { + if (warp_size % reduction_dimensions.dimensions[kColMinorKept] == 0) { return std::make_unique(analysis); } return std::make_unique(analysis); @@ -795,7 +808,7 @@ MlirRowReductionFusion::MlirRowReductionFusion( : MlirReductionFusion(analysis) { CHECK(reduction_dimensions_.is_row_reduction); Vector3 shape = reduction_dimensions_.dimensions; - CHECK_EQ(RowReductionGetRowsPerWarp(shape[kRowMinorReduced]), 1); + CHECK_EQ(RowReductionGetRowsPerWarp(shape[kRowMinorReduced], WarpSize()), 1); constexpr int64_t kMinorReducedElementsPerThread = 16; int64_t num_threads_kept = 1; @@ -935,7 +948,8 @@ MlirMultiRowReductionFusion::MlirMultiRowReductionFusion( : MlirReductionFusion(analysis) { CHECK(reduction_dimensions_.is_row_reduction); Vector3 shape = reduction_dimensions_.dimensions; - int64_t rows_per_warp = RowReductionGetRowsPerWarp(shape[kRowMinorReduced]); + int64_t rows_per_warp = + RowReductionGetRowsPerWarp(shape[kRowMinorReduced], WarpSize()); input_shape_ = {shape[0], shape[1], shape[2]}; CHECK_GT(rows_per_warp, 1); @@ -1023,7 +1037,8 @@ IndexingMap MlirMultiRowReductionFusion::ComputeReductionOutputIndexing( int MlirMultiRowReductionFusion::GetRowsPerWarp() const { return RowReductionGetRowsPerWarp( - input_shape_[ReductionDimensions::kRowMinorReducedDimension]) * + input_shape_[ReductionDimensions::kRowMinorReducedDimension], + WarpSize()) * tile_sizes_per_thread_[1]; } diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h index 838729254070ac..8b6e8ecc967b55 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h @@ -123,6 +123,10 @@ class MlirReductionFusion : public MlirFusionEmitterBase { return IndexingMap::GetUndefined(); } + int64_t WarpSize() const { + return ::xla::gpu::WarpSize(analysis_.device_info()); + } + // The reduction heroes for each reduction group. std::vector> reduction_heroes_; // The roots that have reduction heroes for each reduction group. @@ -194,6 +198,8 @@ class MlirColumnReductionFusion : public MlirReductionFusion { IndexingMap GetSharedMemoryReductionReadMap( mlir::MLIRContext* ctx) const override; IndexingMap GetSharedMemoryWriteMap(mlir::MLIRContext* ctx) const override; + + const int64_t kTileSize = 32; }; // Special emitter for column reductions whose minor reduced dimension divides @@ -213,6 +219,8 @@ class MlirSmallColumnReductionFusion : public MlirReductionFusion { mlir::MLIRContext* ctx) const override; IndexingMap GetSharedMemoryWriteMap(mlir::MLIRContext* ctx) const override; + const int64_t kTileSize = 32; + int64_t shared_rows_; int64_t loop_size_; }; diff --git a/third_party/xla/xla/service/gpu/fusions/tools/BUILD b/third_party/xla/xla/service/gpu/fusions/tools/BUILD index 66079f5dcb02b3..9e5da86adad2a9 100644 --- a/third_party/xla/xla/service/gpu/fusions/tools/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/tools/BUILD @@ -11,6 +11,7 @@ xla_cc_binary( visibility = ["//xla/service/gpu/fusions:__subpackages__"], deps = [ "//xla/mlir_hlo", + "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", "//xla/service/gpu/fusions/transforms:passes", diff --git a/third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc b/third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc index 0db20fb3a3bbe0..43a1f708286456 100644 --- a/third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc +++ b/third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" #include "xla/service/gpu/fusions/transforms/passes.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" int main(int argc, char** argv) { mlir::DialectRegistry registry; @@ -76,7 +77,8 @@ int main(int argc, char** argv) { llvm::function_ref errorHandler) { if (!options.empty()) return mlir::failure(); - xla::gpu::AddLoopTransformationPasses(pm); + xla::gpu::AddLoopTransformationPasses( + pm, xla::gpu::TestGpuDeviceInfo::RTXA6000DeviceInfo()); return mlir::success(); }, [](llvm::function_ref) {}); diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc b/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc index be1686164d656f..3bde1caf1c5d98 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc @@ -44,6 +44,7 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" +#include "xla/service/gpu/fusions/transforms/passes.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/util.h" @@ -66,7 +67,9 @@ using mlir::ValueRange; using mlir::scf::IfOp; struct RewritePredicatedInsert : mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + RewritePredicatedInsert(mlir::MLIRContext* context, + const LowerXlaGpuToScfPassOptions& options) + : OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite( PredicatedInsertOp op, mlir::PatternRewriter& rewriter) const override { @@ -86,7 +89,9 @@ struct RewritePredicatedInsert : mlir::OpRewritePattern { }; struct RewritePredicatedExtract : mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + RewritePredicatedExtract(mlir::MLIRContext* context, + const LowerXlaGpuToScfPassOptions& options) + : OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite( PredicatedExtractOp op, mlir::PatternRewriter& rewriter) const override { @@ -106,15 +111,19 @@ struct RewritePredicatedExtract : mlir::OpRewritePattern { }; struct RewriteShuffleReduce : mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + const int64_t warp_size; + + RewriteShuffleReduce(mlir::MLIRContext* context, + const LowerXlaGpuToScfPassOptions& options) + : OpRewritePattern(context), warp_size(options.warp_size) {} mlir::LogicalResult matchAndRewrite( ShuffleReduceOp op, mlir::PatternRewriter& rewriter) const override { int max_distance = mlir::cast(op->getAttr("max_distance")).getInt(); // TODO(jreiffers): Do this in a verifier. - if (max_distance & (max_distance - 1) || max_distance >= WarpSize()) { - return op->emitOpError("max_distance must be a power of 2 < WarpSize()"); + if (max_distance & (max_distance - 1) || max_distance >= warp_size) { + return op->emitOpError("max_distance must be a power of 2 < warp_size_"); } ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -123,7 +132,7 @@ struct RewriteShuffleReduce : mlir::OpRewritePattern { namespace ml = mlir::LLVM; auto shuffle_32 = [&](Value v) { return b - .create(v, distance, WarpSize(), + .create(v, distance, warp_size, mlir::gpu::ShuffleMode::DOWN) .getShuffleResult(); }; @@ -259,7 +268,9 @@ mlir::VectorType getThreadLevelVectorType(IndexedVectorType indexed_vector) { } struct RewriteMaterialize : mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + RewriteMaterialize(mlir::MLIRContext* context, + const LowerXlaGpuToScfPassOptions& options) + : OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite( MaterializeOp op, mlir::PatternRewriter& rewriter) const override { @@ -316,7 +327,9 @@ struct RewriteMaterialize : mlir::OpRewritePattern { }; struct RewriteInsert : mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + RewriteInsert(mlir::MLIRContext* context, + const LowerXlaGpuToScfPassOptions& options) + : OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite( InsertOp op, mlir::PatternRewriter& rewriter) const override { @@ -368,16 +381,23 @@ struct RewriteInsert : mlir::OpRewritePattern { class LowerXlaGpuToScfPass : public impl::LowerXlaGpuToScfPassBase { public: + explicit LowerXlaGpuToScfPass(const LowerXlaGpuToScfPassOptions& options) + : options_(options) {} + void runOnOperation() override { auto* ctx = &getContext(); mlir::RewritePatternSet patterns(ctx); patterns.add(ctx); + RewriteShuffleReduce, RewriteMaterialize, RewriteInsert>( + ctx, options_); if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); } } + + private: + const LowerXlaGpuToScfPassOptions options_; }; class LowerXlaGpuLoopsToScfPass @@ -396,8 +416,11 @@ class LowerXlaGpuLoopsToScfPass } // namespace -std::unique_ptr<::mlir::Pass> CreateLowerXlaGpuToScfPass() { - return std::make_unique(); +std::unique_ptr<::mlir::Pass> CreateLowerXlaGpuToScfPass( + const int64_t warp_size) { + LowerXlaGpuToScfPassOptions options; + options.warp_size = warp_size; + return std::make_unique(options); } std::unique_ptr<::mlir::Pass> CreateLowerXlaGpuLoopsToScfPass() { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/passes.h b/third_party/xla/xla/service/gpu/fusions/transforms/passes.h index 470a333f70ccca..8c799a4b87d831 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/passes.h +++ b/third_party/xla/xla/service/gpu/fusions/transforms/passes.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_H_ #define XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_H_ +#include #include #include #include @@ -47,13 +48,13 @@ std::unique_ptr CreateFlattenTensorsPass(); std::unique_ptr CreateLowerTensorsPass( bool is_amd_gpu = false, const std::string& gpu_arch = "6.0"); std::unique_ptr CreateLowerToLLVMPass(bool use_rocdl); -std::unique_ptr CreateLowerXlaGpuToScfPass(); +std::unique_ptr CreateLowerXlaGpuToScfPass(int64_t warp_size = 32); std::unique_ptr CreateLowerXlaGpuLoopsToScfPass(); std::unique_ptr CreateMergePointersToSameSlicePass(); std::unique_ptr CreateOptimizeLoopsPass(); std::unique_ptr CreatePeelLoopsPass(); std::unique_ptr CreatePropagateSliceIndicesPass(); -std::unique_ptr CreateRewriteReductionsPass(); +std::unique_ptr CreateRewriteReductionsPass(int64_t warp_size = 32); std::unique_ptr CreateSimplifyAffinePass(); std::unique_ptr CreateSimplifyArithPass(); std::unique_ptr CreateUnswitchLoopsPass(); diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/passes.td b/third_party/xla/xla/service/gpu/fusions/transforms/passes.td index 52a0dacbc3db8f..61427ca8748bc2 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/passes.td +++ b/third_party/xla/xla/service/gpu/fusions/transforms/passes.td @@ -178,6 +178,9 @@ def LowerXlaGpuToScfPass : "mlir::tensor::TensorDialect", "xla::gpu::XlaGpuDialect", "mlir::vector::VectorDialect", ]; + let options = [ + Option<"warp_size", "warp_size", "int64_t", /*default=*/"32", "Warp size.">, + ]; let constructor = "CreateLowerXlaGpuToScfPass()"; } @@ -253,6 +256,9 @@ def RewriteReductionsPass : Pass< "xla::gpu::XlaGpuDialect", ]; + let options = [ + Option<"warp_size", "warp_size", "int64_t", /*default=*/"32", "Warp size.">, + ]; let constructor = "CreateRewriteReductionsPass()"; } diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/rewrite_reductions.cc b/third_party/xla/xla/service/gpu/fusions/transforms/rewrite_reductions.cc index 50969b8bd6bbd8..e32571972c2b58 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/rewrite_reductions.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/rewrite_reductions.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/transforms/passes.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" @@ -48,7 +49,13 @@ namespace { class RewriteReductionsPass : public impl::RewriteReductionsPassBase { public: + explicit RewriteReductionsPass(const RewriteReductionsPassOptions& options) + : options_(options) {} + void runOnOperation() override; + + private: + const RewriteReductionsPassOptions options_; }; mlir::ShapedType GetInputType(ReduceOp op) { @@ -125,7 +132,11 @@ llvm::SmallVector ReindexTensors( // This also pads the input if the number of threads does not divide the row // size. struct RewriteRowReduction : mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + const int64_t warp_size; + + RewriteRowReduction(mlir::MLIRContext* context, + const RewriteReductionsPassOptions& options) + : OpRewritePattern(context), warp_size(options.warp_size) {} mlir::LogicalResult matchAndRewrite( ReduceOp op, mlir::PatternRewriter& rewriter) const override { @@ -136,12 +147,12 @@ struct RewriteRowReduction : mlir::OpRewritePattern { return rewriter.notifyMatchFailure(op, "not a row reduction"); } - if (minor_reduction.size <= WarpSize()) { + if (minor_reduction.size <= warp_size) { return rewriter.notifyMatchFailure(op, "small minor dimension"); } int64_t num_threads = GetNumThreads(op); - assert(num_threads % WarpSize() == 0); + assert(num_threads % warp_size == 0); llvm::ArrayRef input_shape = GetInputType(op).getShape(); auto projected_input_shape = llvm::to_vector( @@ -165,8 +176,8 @@ struct RewriteRowReduction : mlir::OpRewritePattern { auto per_thread_reduction_input_shape = llvm::to_vector( input_shape.take_front(minor_reduction.first_dimension)); per_thread_reduction_input_shape.push_back(padded_size / num_threads); - per_thread_reduction_input_shape.push_back(num_threads / WarpSize()); - per_thread_reduction_input_shape.push_back(WarpSize()); + per_thread_reduction_input_shape.push_back(num_threads / warp_size); + per_thread_reduction_input_shape.push_back(warp_size); int per_thread_input_rank = per_thread_reduction_input_shape.size(); @@ -207,7 +218,11 @@ struct RewriteRowReduction : mlir::OpRewritePattern { // Rewrites column reductions to a reduce-transpose-reduce. struct RewriteColumnReduction : mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + const int64_t warp_size; + + RewriteColumnReduction(mlir::MLIRContext* context, + const RewriteReductionsPassOptions& options) + : OpRewritePattern(context), warp_size(options.warp_size) {} mlir::LogicalResult matchAndRewrite( ReduceOp op, mlir::PatternRewriter& rewriter) const override { @@ -256,7 +271,7 @@ struct RewriteColumnReduction : mlir::OpRewritePattern { // handle is the warp size. assert(num_threads > minor_reduction.stride); - int64_t c = std::min(WarpSize(), num_threads / minor_reduction.stride); + int64_t c = std::min(warp_size, num_threads / minor_reduction.stride); llvm::ArrayRef input_shape = GetInputType(op).getShape(); auto projected_input_shape = llvm::to_vector( @@ -328,7 +343,8 @@ struct RewriteColumnReduction : mlir::OpRewritePattern { void RewriteReductionsPass::runOnOperation() { mlir::RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); + patterns.add(&getContext(), + options_); if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); @@ -337,9 +353,11 @@ void RewriteReductionsPass::runOnOperation() { } // namespace -std::unique_ptr> -CreateRewriteReductionsPass() { - return std::make_unique(); +std::unique_ptr CreateRewriteReductionsPass( + const int64_t warp_size) { + RewriteReductionsPassOptions options; + options.warp_size = warp_size; + return std::make_unique(options); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc index fd18cef310a8fb..225109f5de9148 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc @@ -73,8 +73,8 @@ using mlir::func::FuncOp; using mlir::func::ReturnOp; using mlir_converter::ApplyIndexing; +constexpr int kTileSize = 32; constexpr int kNumRows = 4; -constexpr int kBaseBlockSize = WarpSize(); constexpr int kNumThreadsPerBlock = 128; constexpr int kMaxVectorizedBytes = 4; @@ -85,7 +85,8 @@ MlirTransposeFusion::MlirTransposeFusion(const HloFusionAnalysis& analysis) transpose_(analysis.tiled_transpose()), permutation_(transpose_.permutation), input_shape_( - Permute(transpose_.dimensions, InversePermutation(permutation_))) { + Permute(transpose_.dimensions, InversePermutation(permutation_))), + base_block_size_(kTileSize) { ConstHloInstructionSet transposes_to_tile; int index = 0; int64_t shmem_usage = 0; @@ -103,7 +104,7 @@ MlirTransposeFusion::MlirTransposeFusion(const HloFusionAnalysis& analysis) size *= input_shape_.back(); } max_element_bytes = std::max(max_element_bytes, size); - shmem_usage += kBaseBlockSize * (kBaseBlockSize + 1) * size; + shmem_usage += base_block_size_ * (base_block_size_ + 1) * size; shmem_transpose_root_indices_.push_back(index); } else { side_output_roots_.push_back(&root.instruction()); @@ -117,12 +118,12 @@ MlirTransposeFusion::MlirTransposeFusion(const HloFusionAnalysis& analysis) vector_size_ = vector_size; block_sizes_.assign(input_shape_.size(), 1); if (MostMinorDimensionUnchanged()) { - block_size_ = kBaseBlockSize; + block_size_ = base_block_size_; block_sizes_.back() = vector_size_; block_sizes_[block_sizes_.size() - 2] = block_size_; block_sizes_[permutation_[block_sizes_.size() - 2]] = block_size_; } else { - block_size_ = kBaseBlockSize * vector_size_; + block_size_ = base_block_size_ * vector_size_; block_sizes_.back() = block_size_; block_sizes_[permutation_.back()] = block_size_; } diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h index 9602242fe4745a..2cd089f674b69e 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h @@ -109,6 +109,7 @@ class MlirTransposeFusion : public MlirFusionEmitterBase { std::vector block_counts_; int vector_size_; int block_size_; + int64_t base_block_size_; std::vector shmem_transposes_; std::vector shmem_transpose_roots_; diff --git a/third_party/xla/xla/service/gpu/fusions/triton.cc b/third_party/xla/xla/service/gpu/fusions/triton.cc index 7d235c132989c4..f7c0c167799d3b 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton.cc @@ -174,7 +174,8 @@ absl::StatusOr TritonFusion::Emit( TF_ASSIGN_OR_RETURN( launch_dimensions, - GetMatMulLaunchDimensions(analysis, analysis_.fusion(), config)); + GetMatMulLaunchDimensions(analysis, analysis_.fusion(), config, + analysis_.device_info())); } llvm::Function* impl_fn = @@ -233,7 +234,8 @@ std::optional TritonFusion::launch_config() const { LaunchConfig launch_config; launch_config.launch_dimensions = LaunchDimensions{ static_cast(num_blocks), - static_cast(block_level_parameters.num_warps * WarpSize())}; + static_cast(block_level_parameters.num_warps * + WarpSize(analysis_.device_info()))}; launch_config.block_level_parameters = std::move(block_level_parameters); return launch_config; } diff --git a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc index 46a569d265bcdd..0c0e244b3e21d3 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc @@ -41,10 +41,10 @@ namespace gpu { namespace mt = ::mlir::triton; absl::Status CreateTritonPipeline( - mlir::OpPassManager& pm, const se::GpuComputeCapability& cc, + mlir::OpPassManager& pm, const se::DeviceDescription& device_info, const BlockLevelParameters& block_level_parameters, mt::nvidia_gpu::ClusterInfo& out_cluster_info) { - auto ccCuda = std::get(cc); + auto ccCuda = device_info.cuda_compute_capability(); const int ccAsInt = ccCuda.major * 10 + ccCuda.minor; const int threadsPerWarp = 32; diff --git a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc index a48e65ab3a6953..840329eb2cba3b 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc @@ -53,12 +53,12 @@ using ::mlir::Value; using mlir::ValueRange; absl::Status CreateTritonPipeline( - mlir::OpPassManager& pm, const se::GpuComputeCapability& cc, + mlir::OpPassManager& pm, const se::DeviceDescription& device_info, const BlockLevelParameters& block_level_parameters, mt::nvidia_gpu::ClusterInfo& out_cluster_info) { // TODO(ROCm): Check why some test fail when threadsPerWarp is set to 64. - const int threadsPerWarp = 32; - auto ccRocm = std::get(cc); + const int threadsPerWarp = device_info.threads_per_warp(); + auto ccRocm = device_info.rocm_compute_capability(); // Based on make_ttir() in // @triton//:third_party/amd/backend/compiler.py @@ -80,7 +80,7 @@ absl::Status CreateTritonPipeline( pm.addPass(mt::gpu::createTritonGPUCoalesce()); pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); pm.addPass(mt::gpu::createTritonGPUOptimizeThreadLocality()); - pm.addPass(mt::gpu::createTritonGPUAccelerateMatmul()); + pm.addPass(mlir::createTritonAMDGPUAccelerateMatmulPass()); pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); // TODO ROCm Check if we want to compare MI100 and greater pm.addPass(mlir::createTritonAMDGPUOptimizeEpiloguePass()); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index e9974d3ce1584f..6b297366cb9e76 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -1290,7 +1290,8 @@ struct MatMulDims { struct MatMulLaunchConfig { explicit MatMulLaunchConfig(const TritonGemmConfig& config, const HloDotInstruction& dot, - const MatMulDims& dims); + const MatMulDims& dims, + const se::DeviceDescription& device_info); int64_t grid_m; int64_t grid_n; @@ -1387,7 +1388,8 @@ struct MatMulLaunchConfig { MatMulLaunchConfig::MatMulLaunchConfig(const TritonGemmConfig& config, const HloDotInstruction& dot, - const MatMulDims& dims) + const MatMulDims& dims, + const se::DeviceDescription& device_info) : grid_m((dims.m + config.block_m - 1) / config.block_m), grid_n((dims.n + config.block_n - 1) / config.block_n) { int64_t batch_size = dims.lhs_noncontracting_split.value_or( @@ -1409,13 +1411,13 @@ MatMulLaunchConfig::MatMulLaunchConfig(const TritonGemmConfig& config, noncontracting_program_id_dim = mt::ProgramIDDim::Y; launch_dims = LaunchDimensions( se::BlockDim(batch_size, grid_m * grid_n, config.split_k), - se::ThreadDim(config.num_warps * WarpSize(), 1, 1)); + se::ThreadDim(config.num_warps * WarpSize(device_info), 1, 1)); } else { batch_program_id_dim = mt::ProgramIDDim::Y; noncontracting_program_id_dim = mt::ProgramIDDim::X; launch_dims = LaunchDimensions( se::BlockDim(grid_m * grid_n, batch_size, config.split_k), - se::ThreadDim(config.num_warps * WarpSize(), 1, 1)); + se::ThreadDim(config.num_warps * WarpSize(device_info), 1, 1)); } } @@ -1962,7 +1964,7 @@ class MatMulEmitterHelper { absl::StatusOr GetMatMulLaunchDimensions( const TritonFusionAnalysis& analysis, const HloFusionAdaptor& fusion, - const TritonGemmConfig& config) { + const TritonGemmConfig& config, const se::DeviceDescription& device_info) { auto dot = HloBfsFindIf(fusion.GetRoots(), fusion, [](auto node) { return node.opcode() == HloOpcode::kDot; }); @@ -1971,7 +1973,7 @@ absl::StatusOr GetMatMulLaunchDimensions( *static_cast(&dot->instruction()); TF_ASSIGN_OR_RETURN(MatMulDims dims, MatMulDims::Create(config, dot_instr, analysis)); - MatMulLaunchConfig launch_config(config, dot_instr, dims); + MatMulLaunchConfig launch_config(config, dot_instr, dims, device_info); return launch_config.launch_dims; } @@ -2516,7 +2518,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, TF_ASSIGN_OR_RETURN(const MatMulDims dims, MatMulDims::Create(config, *dot_instr, analysis)); - const MatMulLaunchConfig launch_config(config, *dot_instr, dims); + const MatMulLaunchConfig launch_config(config, *dot_instr, dims, device_info); VLOG(6) << analysis.ToString(); MatMulEmitterHelper emitter(libdevice_path, device_info, dot_instr, b, @@ -3194,7 +3196,8 @@ absl::StatusOr CompileTritonToLLVM( pm.addPass(CreateSimplifyAffinePass()); mlir::triton::nvidia_gpu::ClusterInfo cluster_info; - if (!CreateTritonPipeline(pm, cc, block_level_parameters, cluster_info) + if (!CreateTritonPipeline(pm, device_info, block_level_parameters, + cluster_info) .ok()) { return Internal("Failed to create Triton pipeline."); } diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h index 9c7cd49cd3d862..70d67a5fcfa0a4 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h @@ -77,7 +77,7 @@ absl::Status EmitGeneric(mlir::OpBuilder b, absl::string_view libdevice_path, // Compute the launch dimensions for the given Triton MatMul. absl::StatusOr GetMatMulLaunchDimensions( const TritonFusionAnalysis& analysis, const HloFusionAdaptor& fusion, - const TritonGemmConfig& config); + const TritonGemmConfig& config, const se::DeviceDescription& device_info); // Use tiling and execution parameters from 'config'. output_tile_sizes is // ignored. @@ -129,7 +129,7 @@ absl::StatusOr CompileTritonToLLVM( // parameter which would give a hint to Triton which cluster dims we prefer to // use, but that's not the case currently. absl::Status CreateTritonPipeline( - mlir::OpPassManager& pm, const se::GpuComputeCapability& cc, + mlir::OpPassManager& pm, const se::DeviceDescription& device_info, const BlockLevelParameters& block_level_parameters, mt::nvidia_gpu::ClusterInfo& out_cluster_info); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub.cc index 33c1e0666dd90d..a9e9350a13e6f1 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub.cc @@ -58,7 +58,7 @@ absl::Status EmitGeneric(mlir::OpBuilder b, absl::string_view libdevice_path, absl::StatusOr GetMatMulLaunchDimensions( const TritonFusionAnalysis& analysis, const HloFusionAdaptor& fusion, - const TritonGemmConfig& config) { + const TritonGemmConfig& config, const se::DeviceDescription& device_info) { return absl::UnimplementedError("not supported for this build configuration"); } @@ -108,7 +108,7 @@ absl::StatusOr CompileTritonToLLVM( } absl::Status CreateTritonPipeline( - mlir::OpPassManager& pm, const se::GpuComputeCapability& cc, + mlir::OpPassManager& pm, const se::DeviceDescription& device_info, const BlockLevelParameters& block_level_parameters, mt::nvidia_gpu::ClusterInfo& out_cluster_info) { return absl::UnimplementedError("not supported for this build configuration"); diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 0db21cdf7173d2..c16f1be692c0fb 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -974,9 +974,10 @@ absl::Status RunCollectiveOptimizationPasses( return collectives_pipeline.Run(hlo_module).status(); } -absl::Status RunLayoutAssignmentPasses(HloModule* hlo_module, - se::GpuComputeCapability gpu_version, - se::dnn::VersionInfo dnn_version) { +absl::Status RunLayoutAssignmentPasses( + HloModule* hlo_module, se::GpuComputeCapability gpu_version, + se::dnn::VersionInfo dnn_version, + const se::DeviceDescription& device_description) { // Run layout assignment in a separate pipeline from // "post-layout-assignment" because we want everything after layout // assignment to have a layout-sensitive invariant-checker, but @@ -990,7 +991,7 @@ absl::Status RunLayoutAssignmentPasses(HloModule* hlo_module, ChannelLayoutConstraints layout_constraints; pipeline.AddPass( hlo_module->mutable_entry_computation_layout(), gpu_version, dnn_version, - &layout_constraints); + device_description, &layout_constraints); // Run SubByteNormalization because GpuLayoutAssignment may modify a // Layout's element_size_in_bits field. pipeline.AddPass( @@ -1151,7 +1152,8 @@ absl::Status RunPostFusionCollectiveOptimizationPasses(HloModule* hlo_module) { absl::Status RunPostFusionSimplificationPasses( HloModule* hlo_module, const AlgebraicSimplifierOptions& layout_insensitive_algsimp_opts, - se::GpuComputeCapability gpu_version) { + se::GpuComputeCapability gpu_version, + const Compiler::TargetConfig& gpu_target_config) { HloPassPipeline pipeline("post-fusion-simplification-pipeline optimization"); AlgebraicSimplifierOptions options = layout_insensitive_algsimp_opts; options.set_is_layout_sensitive(true); @@ -1166,7 +1168,8 @@ absl::Status RunPostFusionSimplificationPasses( if (hlo_module->config() .debug_options() .xla_gpu_multi_streamed_windowed_einsum()) { - pipeline.AddPass(); + pipeline.AddPass( + gpu_target_config.device_description); pipeline.AddPass(); } @@ -1314,7 +1317,8 @@ absl::Status GpuCompiler::OptimizeHloModule( gpu_target_config.device_description.runtime_version())); TF_RETURN_IF_ERROR( - RunLayoutAssignmentPasses(hlo_module, gpu_version, dnn_version)); + RunLayoutAssignmentPasses(hlo_module, gpu_version, dnn_version, + gpu_target_config.device_description)); TF_RETURN_IF_ERROR(RunLayoutNormalizationPasses(hlo_module, gpu_version)); @@ -1338,7 +1342,8 @@ absl::Status GpuCompiler::OptimizeHloModule( })); TF_RETURN_IF_ERROR(RunPostFusionCollectiveOptimizationPasses(hlo_module)); TF_RETURN_IF_ERROR(RunPostFusionSimplificationPasses( - hlo_module, layout_insensitive_algsimp_opts, gpu_version)); + hlo_module, layout_insensitive_algsimp_opts, gpu_version, + gpu_target_config)); TF_RETURN_IF_ERROR(RunPostFusionVerificationPasses( hlo_module, stream_exec, options, gpu_target_config)); @@ -1361,8 +1366,11 @@ AlgebraicSimplifierOptions GpuCompiler::GetAlgebraicSimplifierOptions( // Modifies the given HLO module so that it will be accepted by IrEmitter. // Unlike optimization passes, the passes are necessary for correctness. -absl::Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) { - return PrepareHloModuleForIrEmittingPipeline(*hlo_module, GetCanShareBuffer()) +absl::Status GpuCompiler::PrepareHloModuleForIrEmitting( + HloModule* hlo_module, const se::DeviceDescription& device_description) { + return PrepareHloModuleForIrEmittingPipeline( + *hlo_module, GetCanShareBuffer(device_description), + device_description) .Run(hlo_module) .status(); } @@ -1451,7 +1459,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( pipeline.AddPass(); pipeline.AddPass([&](const HloInstruction* r) { - return IsReductionFromOrToContiguousDimensions(*r); + return IsReductionFromOrToContiguousDimensions( + *r, gpu_target_config.device_description); }); // Greedy pattern matching for custom kernel fusions. We run it before @@ -1537,8 +1546,11 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( // enabled, which handles such cases well. bool ignore_small_reduce_dims = !debug_options.xla_gpu_enable_priority_fusion(); - pipeline.AddPass>(ignore_small_reduce_dims); - pipeline.AddPass>(gpu_version); + pipeline.AddPass>( + gpu_target_config.device_description, + ignore_small_reduce_dims); + pipeline.AddPass>( + gpu_target_config.device_description); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } @@ -1695,7 +1707,8 @@ absl::StatusOr> GpuCompiler::RunHloPasses( is_deviceless ? nullptr : stream_exec, options, gpu_target_config)); - TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get())); + TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting( + module.get(), gpu_target_config.device_description)); uint64_t end_usecs = tsl::Env::Default()->NowMicros(); @@ -2160,7 +2173,7 @@ GpuCompiler::CompileToBackendResult( const se::DeviceDescription& gpu_device_info) { tsl::profiler::TraceMe traceme("GpuCompiler::CompileToBackendResult"); - TF_RETURN_IF_ERROR(RunPreSchedulingPasses(module, executor)); + TF_RETURN_IF_ERROR(RunPreSchedulingPasses(module, executor, gpu_device_info)); TF_ASSIGN_OR_RETURN( ScheduleMetadata schedule_metadata, ScheduleGpuModule(module, pointer_size_, gpu_device_info)); @@ -2190,7 +2203,8 @@ GpuCompiler::CompileToBackendResult( CompileModuleResults compile_module_results, CompileModuleToLlvmIr(module, llvm_context, target_triple_, data_layout_, platform->Name(), platform->id(), gpu_device_info, - GetCanShareBuffer(), BufferSizeBytesFunction(), + GetCanShareBuffer(gpu_device_info), + BufferSizeBytesFunction(), /*split_constants_module=*/use_cache)); if (user_pre_optimization_hook_) { @@ -2454,9 +2468,10 @@ absl::StatusOr> GpuCompiler::Export( } absl::Status GpuCompiler::RunPreSchedulingPasses( - HloModule* module, se::StreamExecutor* stream_exec) { + HloModule* module, se::StreamExecutor* stream_exec, + const se::DeviceDescription& gpu_device_info) { HloPassPipeline pipeline("pre-scheduling-passes"); - pipeline.AddPass(); + pipeline.AddPass(gpu_device_info); return pipeline.Run(module).status(); } @@ -2524,8 +2539,8 @@ HloRematerialization::Options CreateRematOpts( absl::Status GpuCompiler::RunPostSchedulingPipelines( HloModule* module, int64_t scheduler_mem_limit, const se::DeviceDescription& gpu_device_info) const { - TF_RETURN_IF_ERROR( - RunPostSchedulingCopyInsertion(module, GetCanShareBuffer())); + TF_RETURN_IF_ERROR(RunPostSchedulingCopyInsertion( + module, GetCanShareBuffer(gpu_device_info))); HloPassPipeline main_pipeline("post-scheduling-passes"); // Pipeline for async -> sync conversion on for non-overlapped async ops. @@ -2559,7 +2574,7 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines( main_pipeline.AddPass("remat-pipeline"); pipeline.AddPass(remat_opts, sizes); - pipeline.AddPass(); + pipeline.AddPass(gpu_device_info); pipeline.AddPass(); } @@ -2569,7 +2584,7 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines( { HloPassPipeline& pipeline = main_pipeline.AddPass("fusion-wrapper"); - pipeline.AddPass(); + pipeline.AddPass(gpu_device_info); } // Pipeline with passes which wrap a scheduled module into command buffers. diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.h b/third_party/xla/xla/service/gpu/gpu_compiler.h index b18b48abfcc4d9..fbdd183647c551 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.h +++ b/third_party/xla/xla/service/gpu/gpu_compiler.h @@ -114,8 +114,13 @@ class GpuCompiler : public LLVMCompiler { const Compiler::CompileOptions& options, const DebugOptions& debug_opts, se::StreamExecutor* executor); - virtual HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() const { - return &FusionCanShareBufferHint; + virtual HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer( + const se::DeviceDescription& device_description) const { + return [&](const HloInstruction* user, const HloInstruction* operand, + const ShapeIndex& user_index) { + return FusionCanShareBufferHint(user, operand, user_index, + device_description); + }; } virtual absl::StatusOr CanUseLinkModules( @@ -211,8 +216,9 @@ class GpuCompiler : public LLVMCompiler { absl::Status SerializeAutotuneResultsToFile( const DebugOptions& debug_options); - absl::Status RunPreSchedulingPasses(HloModule* module, - se::StreamExecutor* stream_exec); + absl::Status RunPreSchedulingPasses( + HloModule* module, se::StreamExecutor* stream_exec, + const se::DeviceDescription& gpu_device_info); absl::Status RunCollectiveScheduleLinearizerPasses( HloModule* hlo_module, se::StreamExecutor* stream_exec); @@ -237,7 +243,8 @@ class GpuCompiler : public LLVMCompiler { se::GpuComputeCapability gpu_version, bool relocatable, const HloModule* debug_module, const CompileOptions& options) = 0; - absl::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module); + absl::Status PrepareHloModuleForIrEmitting( + HloModule* hlo_module, const se::DeviceDescription& device_description); virtual absl::StatusOr> LinkModules( se::GpuComputeCapability gpu_compute_capability, diff --git a/third_party/xla/xla/service/gpu/gpu_copy_insertion_test.cc b/third_party/xla/xla/service/gpu/gpu_copy_insertion_test.cc index 874fa087af7ba2..ba7ce8841772ed 100644 --- a/third_party/xla/xla/service/gpu/gpu_copy_insertion_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_copy_insertion_test.cc @@ -19,12 +19,14 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/copy_insertion.h" #include "xla/service/gpu/buffer_sharing.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" @@ -64,7 +66,38 @@ void ExpectOptionalFalse(std::optional value) { EXPECT_FALSE(*value); } -using GpuCopyInsertionTest = HloTestBase; +class CanShareBufferWrapper { + public: + CanShareBufferWrapper() + : can_share_buffer_([&](const HloInstruction* fusion, + const HloInstruction* operand, + const ShapeIndex& user_index) { + return FusionCanShareBufferHint(fusion, operand, user_index, + device_description_); + }) {} + + HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() const { + return can_share_buffer_; + } + + private: + const se::DeviceDescription device_description_{ + xla::gpu::TestGpuDeviceInfo::CudaOrRocmDeviceInfo()}; + const HloDataflowAnalysis::CanShareBuffer can_share_buffer_; +}; + +class GpuCopyInsertionTest : public HloTestBase { + public: + using HloTestBase::HloTestBase; + + CopyInsertion CreateCopyInsertion() const { + return CopyInsertion(can_share_buffer_wrapper_.GetCanShareBuffer(), + /*use_region_based_live_range_analysis=*/0); + } + + private: + const CanShareBufferWrapper can_share_buffer_wrapper_; +}; // This is some kind of end-to-end test for FusionCanShareBufferHint. TEST_F(GpuCopyInsertionTest, DUSBitcastNoCopy) { @@ -116,8 +149,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); - CopyInsertion copy_insertion(FusionCanShareBufferHint, - /*use_region_based_live_range_analysis=*/0); + CopyInsertion copy_insertion = CreateCopyInsertion(); ASSERT_IS_OK(copy_insertion.Run(module.get(), {"foobar"}).status()); VLOG(2) << module->ToString(); // Copy insertion adds two copies inside the entry computation. @@ -127,7 +159,21 @@ ENTRY main { EXPECT_EQ(CountCopies(*module), 2); } -using FusionCanShareBufferHintTest = HloTestBase; +class FusionCanShareBufferHintTest : public HloTestBase { + public: + FusionCanShareBufferHintTest() + : can_share_buffer_(can_share_buffer_wrapper_.GetCanShareBuffer()) {} + + std::optional FusionCanShareBufferHint(const HloInstruction* fusion, + const HloInstruction* operand, + const ShapeIndex& user_index) { + return can_share_buffer_(fusion, operand, user_index); + } + + private: + const CanShareBufferWrapper can_share_buffer_wrapper_; + const HloDataflowAnalysis::CanShareBuffer can_share_buffer_; +}; TEST_F(FusionCanShareBufferHintTest, BufferCanBeSharedSameShape) { const char* const kModuleString = R"( @@ -990,8 +1036,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); - CopyInsertion copy_insertion(FusionCanShareBufferHint, - /*use_region_based_live_range_analysis=*/0); + CopyInsertion copy_insertion = CreateCopyInsertion(); ASSERT_IS_OK(copy_insertion.Run(module.get(), {"foobar"}).status()); VLOG(2) << module->ToString(); EXPECT_EQ(CountCopies(*module), 0); diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.cc b/third_party/xla/xla/service/gpu/gpu_fusible.cc index 94e67e43c1adb6..447930bc6d70ab 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible.cc +++ b/third_party/xla/xla/service/gpu/gpu_fusible.cc @@ -101,13 +101,14 @@ bool IfFusedReadsElementsMultipleTimes(const HloInstruction& instr) { return false; } -bool IsExpensiveToRepeat(const HloInstruction& instr) { +bool IsExpensiveToRepeat(const HloInstruction& instr, + const se::DeviceDescription& device_info) { CHECK_NE(instr.opcode(), HloOpcode::kFusion) << "`instr` has to be unfused."; // Reductions which use many input elements to calculate one output element // are both memory and computationally heavy. constexpr int kMaxInputsPerOutput = 10; if (instr.opcode() == HloOpcode::kReduce && - !IsReductionFromOrToContiguousDimensions(instr)) { + !IsReductionFromOrToContiguousDimensions(instr, device_info)) { int64_t reduction_ratio = ShapeUtil::ElementsIn(instr.operand(0)->shape()) / ShapeUtil::ElementsIn(instr.shape()); if (reduction_ratio > kMaxInputsPerOutput) return true; @@ -192,24 +193,27 @@ bool TransposesMinorDimension(const HloInstruction* instr) { } } -bool IsReduceInputFusion(const HloInstruction& instr) { +bool IsReduceInputFusion(const HloInstruction& instr, + const se::DeviceDescription& device_info) { return instr.opcode() == HloOpcode::kFusion && absl::c_any_of(GetFusionRoots(*instr.called_computations()[0]), - [](const HloInstruction* root) { - return IsRealReductionHero(*root, - FindNonTrivialHero(*root)); + [&](const HloInstruction* root) { + return IsRealReductionHero( + *root, FindNonTrivialHero(*root), device_info); }); } -bool IsInputFusibleReduction(const HloInstruction& instr) { - return IsReduceInputFusion(instr) || - IsReductionFromOrToContiguousDimensions(instr); +bool IsInputFusibleReduction(const HloInstruction& instr, + const se::DeviceDescription& device_info) { + return IsReduceInputFusion(instr, device_info) || + IsReductionFromOrToContiguousDimensions(instr, device_info); } -bool IsNestableVariadicReduction(const HloInstruction& instr) { +bool IsNestableVariadicReduction(const HloInstruction& instr, + const se::DeviceDescription& device_info) { return instr.shape().IsTuple() && ((instr.opcode() == HloOpcode::kReduce && - !IsReductionFromOrToContiguousDimensions(instr)) || + !IsReductionFromOrToContiguousDimensions(instr, device_info)) || (instr.opcode() == HloOpcode::kFusion && instr.fusion_kind() == HloInstruction::FusionKind::kLoop && instr.fused_expression_root()->opcode() == HloOpcode::kReduce)); @@ -226,14 +230,14 @@ bool IsInputFusibleTranspose(const HloInstruction& instr) { } const HloInstruction* GetRealHeroForMultiOutputFusion( - const HloInstruction& instr) { + const HloInstruction& instr, const se::DeviceDescription& device_info) { if (instr.opcode() != HloOpcode::kFusion) { return &instr; } auto fused_expression_root = instr.fused_expression_root(); if (!instr.IsMultiOutputFusion()) { const auto& hero = FindNonTrivialHero(*fused_expression_root); - if (IsRealReductionHero(*fused_expression_root, hero) || + if (IsRealReductionHero(*fused_expression_root, hero, device_info) || GetDescriptionForTiledTransposeEmitter(hero).has_value()) { return &hero; } @@ -245,7 +249,7 @@ const HloInstruction* GetRealHeroForMultiOutputFusion( // we find any, we can immediately return it. for (auto* inst : fused_expression_root->mutable_operands()) { const auto& hero = FindNonTrivialHero(*inst); - if (IsRealReductionHero(*inst, hero) || + if (IsRealReductionHero(*inst, hero, device_info) || GetDescriptionForTiledTransposeEmitter(hero).has_value()) { return &hero; } @@ -253,14 +257,15 @@ const HloInstruction* GetRealHeroForMultiOutputFusion( return fused_expression_root->operands()[0]; } -FusionDecision FusionHeroesAreCompatible(const HloInstruction* hero1, - const HloInstruction* hero2) { +FusionDecision FusionHeroesAreCompatible( + const HloInstruction* hero1, const HloInstruction* hero2, + const se::DeviceDescription& device_info) { auto hero1_is_unnested_reduce = - IsReductionFromOrToContiguousDimensions(*hero1); + IsReductionFromOrToContiguousDimensions(*hero1, device_info); auto tiled_transpose_hero1 = GetDescriptionForTiledTransposeEmitter(*hero1); bool hero1_is_unnested_transpose = tiled_transpose_hero1.has_value(); bool hero2_is_unnested_reduce = - IsReductionFromOrToContiguousDimensions(*hero2); + IsReductionFromOrToContiguousDimensions(*hero2, device_info); auto tiled_transpose_hero2 = GetDescriptionForTiledTransposeEmitter(*hero2); bool hero2_is_unnested_transpose = tiled_transpose_hero2.has_value(); @@ -318,7 +323,8 @@ FusionDecision FusionHeroesAreCompatible(const HloInstruction* hero1, } FusionDecision ShapesCompatibleForMultiOutputFusion( - const HloInstruction& instr1, const HloInstruction& instr2) { + const HloInstruction& instr1, const HloInstruction& instr2, + const se::DeviceDescription& device_info) { // Multi-output fusion kernels share a common parallel loop. The loop // dimensions are determined by instruction shapes. auto get_loop_shape = [&](const HloInstruction* element_instr) { @@ -328,7 +334,7 @@ FusionDecision ShapesCompatibleForMultiOutputFusion( const auto& hero = element_instr->parent()->IsFusionComputation() ? FindNonTrivialHero(*element_instr) : *element_instr; - if (IsReductionFromOrToContiguousDimensions(*element_instr) || + if (IsReductionFromOrToContiguousDimensions(*element_instr, device_info) || GetDescriptionForTiledTransposeEmitter(hero).has_value()) { return hero.operand(0)->shape(); } @@ -339,10 +345,13 @@ FusionDecision ShapesCompatibleForMultiOutputFusion( // root ops should have equal output shapes. An exception are // reduction-to-vector ops. Here the input shapes of the reduction (first // operand shape) and the reduction dimensions need to match. - const HloInstruction* hero1 = GetRealHeroForMultiOutputFusion(instr1); - const HloInstruction* hero2 = GetRealHeroForMultiOutputFusion(instr2); + const HloInstruction* hero1 = + GetRealHeroForMultiOutputFusion(instr1, device_info); + const HloInstruction* hero2 = + GetRealHeroForMultiOutputFusion(instr2, device_info); - if (auto compatible = FusionHeroesAreCompatible(hero1, hero2); !compatible) { + if (auto compatible = FusionHeroesAreCompatible(hero1, hero2, device_info); + !compatible) { return compatible; } @@ -371,11 +380,12 @@ bool IsInputFusibleScatter(const HloInstruction& instr) { return false; } -bool IsInputFusible(const HloInstruction& instr) { +bool IsInputFusible(const HloInstruction& instr, + const se::DeviceDescription& device_info) { // Input fusion only handles non-elemental reduction and scatter operations. return instr.IsFusible() && - (IsInputFusibleReduction(instr) || IsInputFusibleScatter(instr) || - IsInputFusibleTranspose(instr)); + (IsInputFusibleReduction(instr, device_info) || + IsInputFusibleScatter(instr) || IsInputFusibleTranspose(instr)); } // Returns true if `instr` can be fused as a producer or as a consumer into a @@ -414,7 +424,8 @@ bool IsUniversallyLoopFusible(const HloInstruction& instr) { } // Returns true if `instr` can be fused as a consumer into a kLoop fusion. -bool IsLoopFusibleAsConsumer(const HloInstruction& instr) { +bool IsLoopFusibleAsConsumer(const HloInstruction& instr, + const se::DeviceDescription& device_info) { // Instr should be fusible. if (!instr.IsFusible()) return false; @@ -429,7 +440,8 @@ bool IsLoopFusibleAsConsumer(const HloInstruction& instr) { // We may have input fusions which effectively have turned into loop // fusions. Those should still be considered as loop fusible consumers, // but they are not universally loop fusible. - if (!IsInputFusible(instr) && instr.opcode() == HloOpcode::kFusion && + if (!IsInputFusible(instr, device_info) && + instr.opcode() == HloOpcode::kFusion && instr.fusion_kind() == HloInstruction::FusionKind::kInput) { return true; } @@ -495,14 +507,15 @@ FusionDecision CanEmitInputFusedScatter(const HloInstruction& producer, return FusionDecision::Allow(); } -FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, - const HloInstruction& consumer) { +FusionDecision IsProducerConsumerFusible( + const HloInstruction& producer, const HloInstruction& consumer, + const se::DeviceDescription& device_info) { if (!IsLoopFusibleAsProducer(producer) && !IsInputFusibleTranspose(producer)) { return FusionDecision::Forbid("the producer is not loop-fusible"); } - if (IsInputFusibleReduction(producer)) { + if (IsInputFusibleReduction(producer, device_info)) { if (!producer.GetModule() ->config() .debug_options() @@ -516,7 +529,8 @@ FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, : producer; if (!ReductionIsRaceFree( reduce_hero.GetModule()->config(), - GetReductionKindAndContiguousComponents(reduce_hero))) { + GetReductionKindAndContiguousComponents(reduce_hero), + device_info)) { return FusionDecision::Forbid( "Reduction output fusion only works for race free reductions"); } @@ -537,7 +551,8 @@ FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, return can_fuse; } - if (!IsInputFusible(consumer) && !IsLoopFusibleAsConsumer(consumer)) { + if (!IsInputFusible(consumer, device_info) && + !IsLoopFusibleAsConsumer(consumer, device_info)) { return FusionDecision::Forbid( "the consumer is not input-fusible and not loop-fusible"); } @@ -566,7 +581,8 @@ FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, return InstructionFusion::ShouldFuseInPlaceOp(&producer, &consumer); } -FusionDecision IsProducerMultiOutputFusible(const HloInstruction& producer) { +FusionDecision IsProducerMultiOutputFusible( + const HloInstruction& producer, const se::DeviceDescription& device_info) { // Skip multiple output fusion. It's not yet supported. if (producer.IsMultiOutputFusion()) { return FusionDecision::Forbid("Producer is a multi-output fusion"); @@ -613,16 +629,17 @@ FusionDecision IsProducerMultiOutputFusible(const HloInstruction& producer) { // Returns an estimate of the shared memory usage for a given instruction in // bytes. -static int64_t SharedMemoryUsageNoCache(const HloInstruction& instr) { +static int64_t SharedMemoryUsageNoCache( + const HloInstruction& instr, const se::DeviceDescription& device_info) { if (instr.opcode() == HloOpcode::kFusion) { int64_t sum = 0; for (const HloInstruction* hlo : instr.fused_instructions_computation()->instructions()) { - sum += SharedMemoryUsageNoCache(*hlo); + sum += SharedMemoryUsageNoCache(*hlo, device_info); } return sum; } else if (instr.opcode() == HloOpcode::kReduce && - IsReductionFromOrToContiguousDimensions(instr)) { + IsReductionFromOrToContiguousDimensions(instr, device_info)) { ReductionDimensions reduction_info = GetReductionKindAndContiguousComponents(instr); int64_t primitive_size = ShapeUtil::ByteSizeOfPrimitiveType( @@ -665,16 +682,17 @@ int64_t FusionInfoCache::GetSharedMemoryUsage(const HloInstruction& instr) { // instructions, not instructions inside fusion nodes. Therefore we can only // cache top-level instructions; it would not be valid to pass the cache to // SharedMemoryUsageNoCache and use the cache *within* the fusion. - int64_t shared_memory_usage = SharedMemoryUsageNoCache(instr); + int64_t shared_memory_usage = SharedMemoryUsageNoCache(instr, device_info_); absl::MutexLock lock(&mutex_); shared_memory_usage_.emplace(&instr, shared_memory_usage); return shared_memory_usage; } -int64_t SharedMemoryUsage(const HloInstruction& instr, FusionInfoCache* cache) { +int64_t SharedMemoryUsage(const HloInstruction& instr, FusionInfoCache* cache, + const se::DeviceDescription& device_info) { if (!cache) { - return SharedMemoryUsageNoCache(instr); + return SharedMemoryUsageNoCache(instr, device_info); } return cache->GetSharedMemoryUsage(instr); } @@ -684,16 +702,17 @@ int64_t SharedMemoryUsage(const HloInstruction& instr, FusionInfoCache* cache) { constexpr int64_t kMaxUnnestedReductionOutputsPerFusion = 8; // Returns the number of unnested reductions in the instruction output. -static int64_t NumUnnestedReductionsNoCache(const HloInstruction& instr) { +static int64_t NumUnnestedReductionsNoCache( + const HloInstruction& instr, const se::DeviceDescription& device_info) { if (instr.opcode() == HloOpcode::kReduce && - IsReductionFromOrToContiguousDimensions(instr)) { + IsReductionFromOrToContiguousDimensions(instr, device_info)) { return 1; } if (instr.opcode() == HloOpcode::kFusion) { int64_t sum = 0; for (const HloInstruction* hlo : instr.fused_instructions_computation()->instructions()) { - sum += NumUnnestedReductionsNoCache(*hlo); + sum += NumUnnestedReductionsNoCache(*hlo, device_info); } return sum; } @@ -713,7 +732,8 @@ int64_t FusionInfoCache::GetNumUnnestedReductions(const HloInstruction& instr) { // instructions, not instructions inside fusion nodes. Therefore we can only // cache top-level instructions; it would not be valid to pass the cache to // NumUnnestedReductionsNoCache and use the cache *within* the fusion. - int64_t num_unnested_reductions = NumUnnestedReductionsNoCache(instr); + int64_t num_unnested_reductions = + NumUnnestedReductionsNoCache(instr, device_info_); absl::MutexLock lock(&mutex_); num_unnested_reductions_.emplace(&instr, num_unnested_reductions); @@ -721,9 +741,10 @@ int64_t FusionInfoCache::GetNumUnnestedReductions(const HloInstruction& instr) { } static int64_t NumUnnestedReductions(const HloInstruction& instr, - FusionInfoCache* cache) { + FusionInfoCache* cache, + const se::DeviceDescription& device_info) { if (!cache) { - return NumUnnestedReductionsNoCache(instr); + return NumUnnestedReductionsNoCache(instr, device_info); } return cache->GetNumUnnestedReductions(instr); @@ -757,15 +778,16 @@ FusionDecision FusionFitsInBudget(const HloInstruction& instr1, const se::DeviceDescription& device_info, bool is_consumer_producer_fusion, FusionInfoCache* cache /*=nullptr*/) { - if (SharedMemoryUsage(instr1, cache) + SharedMemoryUsage(instr2, cache) > + if (SharedMemoryUsage(instr1, cache, device_info) + + SharedMemoryUsage(instr2, cache, device_info) > device_info.shared_memory_per_block()) { return FusionDecision::Forbid( "shared memory usage would be over the budget of ") << device_info.shared_memory_per_block() << "B"; } - if (NumUnnestedReductions(instr1, cache) + - NumUnnestedReductions(instr2, cache) > + if (NumUnnestedReductions(instr1, cache, device_info) + + NumUnnestedReductions(instr2, cache, device_info) > kMaxUnnestedReductionOutputsPerFusion) { return FusionDecision::Forbid("over ") << kMaxUnnestedReductionOutputsPerFusion @@ -802,9 +824,10 @@ FusionDecision FusionFitsInBudget(const HloInstruction& instr1, MaxOperandsAndOutputsPerFusion()) { return FusionDecision::Allow(); } else { - VLOG(5) << "Operand count of " << "(" << instr1.ToString() - << " ) = " << instr1.operand_count() << " and ( " - << instr2.ToString() << " ) = " << instr2.operand_count() + VLOG(5) << "Operand count of " + << "(" << instr1.ToString() << " ) = " << instr1.operand_count() + << " and ( " << instr2.ToString() + << " ) = " << instr2.operand_count() << " and num_output_buffers = " << num_output_buffers << " is bigger than the bound of " << MaxOperandsAndOutputsPerFusion(); @@ -838,15 +861,16 @@ FusionDecision FusionFitsInBudget(const HloInstruction& instr1, } bool CreatesHeavyComputation(const HloInstruction& producer, - const HloInstruction& consumer) { + const HloInstruction& consumer, + const se::DeviceDescription& device_info) { // If producer's computation is not expensive to repeat even in the consumer // requests the same element multiple times there is nothing to do. auto producer_is_heavy = [&](const HloInstruction& instr) { if (producer.opcode() != HloOpcode::kFusion) { - return IsExpensiveToRepeat(producer); + return IsExpensiveToRepeat(producer, device_info); } for (const auto& instr : producer.fused_instructions()) { - if (IsExpensiveToRepeat(*instr)) { + if (IsExpensiveToRepeat(*instr, device_info)) { return true; } } @@ -901,21 +925,25 @@ bool CreatesHeavyComputation(const HloInstruction& producer, return false; } -bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr) { +bool IsFusibleAsMultiOutputFusionRoot( + const HloInstruction& instr, const se::DeviceDescription& device_info) { // We can fuse reduces and loop fusions. Elementwise instructions can be fused // with any other instruction. // Note that scatter cannot be the root of a multi-output fusion because // its emitter doesn't support it. return instr.IsFusible() && - (IsInputFusibleReduction(instr) || IsInputFusibleTranspose(instr) || + (IsInputFusibleReduction(instr, device_info) || + IsInputFusibleTranspose(instr) || instr.IsLoopFusion() || // TODO(b/130013493): Use IsLoopFusible here. instr.IsElementwise()); } -HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& producer, - const HloInstruction& consumer) { - return (IsInputFusible(consumer) || IsInputFusible(producer)) +HloInstruction::FusionKind ChooseFusionKind( + const HloInstruction& producer, const HloInstruction& consumer, + const se::DeviceDescription& device_info) { + return (IsInputFusible(consumer, device_info) || + IsInputFusible(producer, device_info)) ? HloInstruction::FusionKind::kInput : HloInstruction::FusionKind::kLoop; } diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.h b/third_party/xla/xla/service/gpu/gpu_fusible.h index 0dadbfa36f5476..e873504fce02fe 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible.h +++ b/third_party/xla/xla/service/gpu/gpu_fusible.h @@ -45,7 +45,8 @@ bool IfFusedReadsElementsMultipleTimes(const HloInstruction& instr); // Check if the operation is memory or computationally expensive // to repeat. -bool IsExpensiveToRepeat(const HloInstruction& instr); +bool IsExpensiveToRepeat(const HloInstruction& instr, + const se::DeviceDescription& device_info); // Fusion passes frequently do checks across all pairs of "interesting" nodes. // Computing e.g. FusionFitsInBudget(a, b) requires computing expensive @@ -55,6 +56,8 @@ bool IsExpensiveToRepeat(const HloInstruction& instr); // Invariant: After modifying or removing a fusion node, call Invalidate(node). class FusionInfoCache { public: + explicit FusionInfoCache(const se::DeviceDescription& device_info) + : device_info_(device_info) {} // Must be called after modifying or removing a fusion node (or other node // that's part of this cache). void Invalidate(const HloInstruction* instr) { @@ -69,6 +72,8 @@ class FusionInfoCache { int64_t GetNumUnnestedReductions(const HloInstruction& instr); private: + const se::DeviceDescription& device_info_; + absl::Mutex mutex_; absl::flat_hash_map shared_memory_usage_; @@ -110,15 +115,18 @@ bool TransposesMinorDimension(const HloInstruction* instr); // Whether `instr` is an input fusion rooted at a reduction-to-vector op or a // multi-output input fusion with at least one reduction-to-vector op root. -bool IsReduceInputFusion(const HloInstruction& instr); +bool IsReduceInputFusion(const HloInstruction& instr, + const se::DeviceDescription& device_info); // Whether `instr` is fusible as root of a reduce input fusions, i.e. `instr` // is either an unfused reduction-to-vector op or a reduce input fusion. -bool IsInputFusibleReduction(const HloInstruction& instr); +bool IsInputFusibleReduction(const HloInstruction& instr, + const se::DeviceDescription& device_info); // Whether `instr` is a nestable variadic reduction // or a loop fusion rooted with such. -bool IsNestableVariadicReduction(const HloInstruction& instr); +bool IsNestableVariadicReduction(const HloInstruction& instr, + const se::DeviceDescription& device_info); // Whether `instr` is fusible as root of a scatter input fusions, i.e. `instr` // is either an unfused scatter op or a scatter input fusion. @@ -139,18 +147,20 @@ FusionDecision FusionFitsInBudget(const HloInstruction& instr1, // producer has a complex computation per output and consumer calls this // computations multiple times. bool CreatesHeavyComputation(const HloInstruction& producer, - const HloInstruction& consumer); + const HloInstruction& consumer, + const se::DeviceDescription& device_info); // Returns the instruction that determines the emitter used for lowering, // sometimes referred to as "the real hero". const HloInstruction* GetRealHeroForMultiOutputFusion( - const HloInstruction& instr); + const HloInstruction& instr, const se::DeviceDescription& device_info); // Whether 'hero1' and 'hero2' are compatible if the two fusions containing // 'hero1' and 'hero2' are merged together. For example merging two fusions with // a reduction hero and a transpose here, respectively, does not work. -FusionDecision FusionHeroesAreCompatible(const HloInstruction* hero1, - const HloInstruction* hero2); +FusionDecision FusionHeroesAreCompatible( + const HloInstruction* hero1, const HloInstruction* hero2, + const se::DeviceDescription& device_info); // Whether instruction shapes are compatible for multi-output fusion, i.e. // whether the emitters support lowering the resulting fusion. @@ -160,7 +170,8 @@ FusionDecision FusionHeroesAreCompatible(const HloInstruction* hero1, // input fusions only. It is up to the caller to ensure the instructions // themselves are fusible! FusionDecision ShapesCompatibleForMultiOutputFusion( - const HloInstruction& instr1, const HloInstruction& instr2); + const HloInstruction& instr1, const HloInstruction& instr2, + const se::DeviceDescription& device_info); // Whether fusing producer into consumer creates a scatter fusion that cannot be // handled by the scatter emitter. @@ -171,20 +182,24 @@ FusionDecision CanEmitInputFusedScatter(const HloInstruction& producer, // i.e. whether the producer and consumer are loop/input fusible and // they are not library calls. // Used both by instruction fusion and fusion-fusion merging. -FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, - const HloInstruction& consumer); +FusionDecision IsProducerConsumerFusible( + const HloInstruction& producer, const HloInstruction& consumer, + const se::DeviceDescription& device_info); // Whether the producer is a valid candidate for a multi-output fusion. // That is, the root tuple of the multi-output fusion will contain the results // of both, the producer and consumer. -FusionDecision IsProducerMultiOutputFusible(const HloInstruction& producer); +FusionDecision IsProducerMultiOutputFusible( + const HloInstruction& producer, const se::DeviceDescription& device_info); // Whether `instr` is a candidate for sibling fusion or as a consumer in // a producer-consumer multi-output fusion. -bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr); +bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr, + const se::DeviceDescription& device_info); // Determines the fusion kind to be used when fusing into `consumer`. -HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& producer, - const HloInstruction& consumer); +HloInstruction::FusionKind ChooseFusionKind( + const HloInstruction& producer, const HloInstruction& consumer, + const se::DeviceDescription& device_info); // Returns whether `consumer` is the only non-root user of `instr`. bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr, diff --git a/third_party/xla/xla/service/gpu/gpu_fusible_test.cc b/third_party/xla/xla/service/gpu/gpu_fusible_test.cc index 735709cbd346f8..95da4bf3940de4 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_fusible_test.cc @@ -16,22 +16,80 @@ limitations under the License. #include "xla/service/gpu/gpu_fusible.h" #include -#include #include #include "absl/strings/str_cat.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_parser.h" +#include "xla/service/hlo_runner.h" +#include "xla/service/instruction_fusion.h" +#include "xla/service/platform_util.h" +#include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" namespace xla { namespace gpu { +namespace { using ::testing::ElementsAre; -using GpuFusibleTest = HloTestBase; +auto MakeDeviceDescription() { + stream_executor::DeviceDescription device_description{ + stream_executor::GpuDeviceInfoProto{}}; + device_description.set_threads_per_warp(32); + return device_description; +} + +class GpuFusibleTest : public HloTestBase { + public: + GpuFusibleTest() : device_description_(MakeDeviceDescription()) {} + + bool IsReduceInputFusion(const HloInstruction& instr) const { + return ::xla::gpu::IsReduceInputFusion(instr, device_description_); + } + + bool IsInputFusibleReduction(const HloInstruction& instr) const { + return ::xla::gpu::IsInputFusibleReduction(instr, device_description_); + } + + FusionDecision IsProducerConsumerFusible( + const HloInstruction& producer, const HloInstruction& consumer) const { + return ::xla::gpu::IsProducerConsumerFusible(producer, consumer, + device_description_); + } + + FusionDecision IsProducerMultiOutputFusible( + const HloInstruction& producer) const { + return ::xla::gpu::IsProducerMultiOutputFusible(producer, + device_description_); + } + + bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr) const { + return ::xla::gpu::IsFusibleAsMultiOutputFusionRoot(instr, + device_description_); + } + + FusionDecision FusionHeroesAreCompatible(const HloInstruction* hero1, + const HloInstruction* hero2) const { + return ::xla::gpu::FusionHeroesAreCompatible(hero1, hero2, + device_description_); + } + + FusionDecision ShapesCompatibleForMultiOutputFusion( + const HloInstruction& instr1, const HloInstruction& instr2) const { + return ::xla::gpu::ShapesCompatibleForMultiOutputFusion( + instr1, instr2, device_description_); + } + + const se::DeviceDescription& device_description() const { + return device_description_; + } + + private: + const se::DeviceDescription device_description_; +}; const char kModulePrefix[] = R"( HloModule test_module @@ -1383,7 +1441,8 @@ TEST_F(GpuFusibleTest, CreatesHeavyComputation_NonfusionInstr) { const HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* producer = root->operand(0); const HloInstruction* consumer = root->operand(1); - EXPECT_TRUE(CreatesHeavyComputation(*producer, *consumer)); + EXPECT_TRUE( + CreatesHeavyComputation(*producer, *consumer, device_description())); } TEST_F(GpuFusibleTest, DoesNotCreateHeavyComputation_NonfusionInstr) { @@ -1404,7 +1463,8 @@ TEST_F(GpuFusibleTest, DoesNotCreateHeavyComputation_NonfusionInstr) { const HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* producer = root->operand(0); const HloInstruction* consumer = root->operand(1); - EXPECT_FALSE(CreatesHeavyComputation(*producer, *consumer)); + EXPECT_FALSE( + CreatesHeavyComputation(*producer, *consumer, device_description())); } TEST_F(GpuFusibleTest, @@ -1427,7 +1487,8 @@ TEST_F(GpuFusibleTest, const HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* producer = root->operand(0); const HloInstruction* consumer = root->operand(1); - EXPECT_FALSE(CreatesHeavyComputation(*producer, *consumer)); + EXPECT_FALSE( + CreatesHeavyComputation(*producer, *consumer, device_description())); } TEST_F(GpuFusibleTest, CreatesHeavyComputation_ReduceWindowGather) { @@ -1448,9 +1509,10 @@ TEST_F(GpuFusibleTest, CreatesHeavyComputation_ReduceWindowGather) { EXPECT_EQ(gather->opcode(), HloOpcode::kGather); EXPECT_EQ(reduce_window->opcode(), HloOpcode::kReduceWindow); EXPECT_FALSE(IfFusedReadsElementsMultipleTimes(*reduce_window)); - EXPECT_TRUE(IsExpensiveToRepeat(*reduce_window)); + EXPECT_TRUE(IsExpensiveToRepeat(*reduce_window, device_description())); EXPECT_TRUE(IfFusedReadsElementsMultipleTimes(*gather)); - EXPECT_TRUE(CreatesHeavyComputation(*reduce_window, *gather)); + EXPECT_TRUE( + CreatesHeavyComputation(*reduce_window, *gather, device_description())); } TEST_F(GpuFusibleTest, CreatesHeavyComputation_FusionInstr) { @@ -1483,7 +1545,8 @@ TEST_F(GpuFusibleTest, CreatesHeavyComputation_FusionInstr) { const HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* producer = root->operand(0); const HloInstruction* consumer = root->operand(1); - EXPECT_TRUE(CreatesHeavyComputation(*producer, *consumer)); + EXPECT_TRUE( + CreatesHeavyComputation(*producer, *consumer, device_description())); } TEST_F(GpuFusibleTest, DoesNotCreateHeavyComputation_FusionInstr) { @@ -1516,7 +1579,8 @@ TEST_F(GpuFusibleTest, DoesNotCreateHeavyComputation_FusionInstr) { const HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* producer = root->operand(0); const HloInstruction* consumer = root->operand(1); - EXPECT_FALSE(CreatesHeavyComputation(*producer, *consumer)); + EXPECT_FALSE( + CreatesHeavyComputation(*producer, *consumer, device_description())); } TEST_F(GpuFusibleTest, ChooseFusionKind) { @@ -1532,7 +1596,7 @@ ENTRY computation { .value(); const HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* producer = root->operand(0); - EXPECT_EQ(ChooseFusionKind(*producer, *root), + EXPECT_EQ(ChooseFusionKind(*producer, *root, device_description()), HloInstruction::FusionKind::kInput); } @@ -1775,10 +1839,11 @@ TEST_F(GpuFusibleTest, GetSharedMemoryUsage) { .value(); auto& debug_options = module->mutable_config().mutable_debug_options(); debug_options.set_xla_gpu_mlir_emitter_level(3); - FusionInfoCache cache; + FusionInfoCache cache(device_description()); auto fusion = module->entry_computation()->root_instruction(); EXPECT_EQ(cache.GetSharedMemoryUsage(*fusion), 32 * 33 * 2 * 4); } +} // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_offloading_test.cc b/third_party/xla/xla/service/gpu/gpu_offloading_test.cc index 3aa9d79977bd81..61b2a996ea14a6 100644 --- a/third_party/xla/xla/service/gpu/gpu_offloading_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_offloading_test.cc @@ -218,7 +218,9 @@ TEST_F(GpuOffloadingTest, CopyIRCreationTest) { RunHloRematerialization( /*memory_limit_bytes=*/10 * 1024, module.get())); ASSERT_TRUE(changed); - StreamAttributeAnnotator attr_annotator; + stream_executor::StreamExecutor* executor = + backend().default_stream_executor(); + StreamAttributeAnnotator attr_annotator(executor->GetDeviceDescription()); TF_ASSERT_OK_AND_ASSIGN(bool changed_attr, attr_annotator.Run(module.get())); EXPECT_TRUE(changed_attr); // Verify that the stream attribute for a copy-start is annotated diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc index 94b4c55a802c67..a6e87bb662421a 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc @@ -256,7 +256,8 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() std::optional first_reduce_hero; for (auto [root, hero] : llvm::zip(fusion_roots_, fusion_heroes_)) { - if (IsRealReductionHero(root.instruction(), hero.instruction())) { + if (IsRealReductionHero(root.instruction(), hero.instruction(), + *device_info_)) { first_reduce_hero = hero; break; } @@ -268,7 +269,8 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() if (root == *first_reduce_hero) { continue; } - if (!IsRealReductionHero(root.instruction(), hero.instruction())) { + if (!IsRealReductionHero(root.instruction(), hero.instruction(), + *device_info_)) { // Needs to have a compatible shape to the reduce operand (compatible // meaning same number of elements). if (ShapeUtil::ElementsIn(root.shape()) != @@ -322,7 +324,8 @@ const HloInstruction* HloFusionAnalysis::FindHeroReduction() const { // have the same shape and layout as verified by // `IsFusedReductionOutputConsistent()`. for (auto [root, hero] : llvm::zip(roots, fusion_heroes_)) { - if (IsRealReductionHero(root.instruction(), hero.instruction())) { + if (IsRealReductionHero(root.instruction(), hero.instruction(), + *device_info_)) { return &hero.instruction(); } } diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc index 7328bc6dad0ec9..26b86b6c1f7d51 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc @@ -318,6 +318,7 @@ TEST_F(HloFusionAnalysisTest, InvalidDevice) { stream_executor::GpuDeviceInfoProto device_info_proto; stream_executor::DeviceDescription device_info(device_info_proto); + device_info.set_threads_per_warp(32); auto* root = module->entry_computation()->root_instruction(); auto analysis_fused = diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.cc b/third_party/xla/xla/service/gpu/ir_emission_utils.cc index 406fcd9534a9dc..8525b02c396b97 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.cc @@ -261,7 +261,8 @@ llvm::Value* EmitAMDGPUShflDownSwizzle(llvm::Value* value, llvm::Value* offset, // Helper function to emit call to NVPTX shfl_down intrinsic. llvm::Value* EmitNVPTXShflDown(llvm::Value* value, llvm::Value* offset, - llvm::IRBuilder<>* b) { + llvm::IRBuilder<>* b, + const se::DeviceDescription& gpu_device_info) { llvm::Module* module = b->GetInsertBlock()->getModule(); llvm::Intrinsic::ID llvm_intrinsic_id; CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32); @@ -272,8 +273,8 @@ llvm::Value* EmitNVPTXShflDown(llvm::Value* value, llvm::Value* offset, } llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(module, llvm_intrinsic_id, {}); - return b->CreateCall( - intrinsic, {b->getInt32(-1), value, offset, b->getInt32(WarpSize() - 1)}); + return b->CreateCall(intrinsic, {b->getInt32(-1), value, offset, + b->getInt32(WarpSize(gpu_device_info) - 1)}); } // Helper function to emit call to SPIR shfl_down intrinsic. @@ -309,7 +310,7 @@ llvm::Value* EmitFullWarpShuffleDown( // Special case for efficiency if (value->getType()->isFloatTy() && bit_width == 32) { if (target_triple.isNVPTX()) { - return EmitNVPTXShflDown(value, offset, builder); + return EmitNVPTXShflDown(value, offset, builder, gpu_device_info); } else if (target_triple.getArch() == llvm::Triple::amdgcn) { if (gpu_device_info.rocm_compute_capability().gfx9_mi100_or_later()) { return EmitAMDGPUShflDownSwizzle(value, offset, builder); @@ -334,7 +335,7 @@ llvm::Value* EmitFullWarpShuffleDown( llvm::Value* insert_val; if (target_triple.isNVPTX()) { insert_val = EmitNVPTXShflDown(builder->CreateExtractElement(x, i), - offset, builder); + offset, builder, gpu_device_info); } else if (target_triple.getArch() == llvm::Triple::amdgcn) { if (gpu_device_info.rocm_compute_capability().gfx9_mi100_or_later()) { insert_val = EmitAMDGPUShflDownSwizzle( diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.h b/third_party/xla/xla/service/gpu/ir_emission_utils.h index eef3943fcd8ee8..3e520d5030f7a3 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.h +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.h @@ -64,7 +64,10 @@ inline constexpr int64_t kMaxBytesInMostMinorDimension = 8; bool IsMatrixMultiplication(const HloInstruction& dot); bool IsMatrixVectorMultiplication(const HloInstruction& dot); -inline constexpr int64_t WarpSize() { return 32; } +inline constexpr int64_t WarpSize( + const se::DeviceDescription& gpu_device_info) { + return gpu_device_info.threads_per_warp(); +} // Fusions that implemented with pre-compiled device kernels have // FusionBackendConfig.kind requel to this string. diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 4cd504d8c78bcb..dae8adb567981d 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -731,9 +731,8 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunk( TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::Epilogue blas_lt_epilogue, gpublas_lt::AsBlasLtEpilogue(epilogue)); auto thunk = std::make_unique( - instr, std::move(gemm_config), - blas_lt_epilogue, algorithm, a, b, c, d, bias, aux, a_scale, b_scale, - c_scale, d_scale, d_amax, workspace_buffer); + instr, std::move(gemm_config), blas_lt_epilogue, algorithm, a, b, c, d, + bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax, workspace_buffer); AddThunkToThunkSequence(std::move(thunk)); return absl::OkStatus(); } @@ -822,9 +821,8 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8( TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::Epilogue blas_lt_epilogue, gpublas_lt::AsBlasLtEpilogue(epilogue)); auto thunk = std::make_unique( - instr, std::move(gemm_config), - blas_lt_epilogue, algorithm, a, b, c, d, bias, aux, a_scale, b_scale, - c_scale, d_scale, d_amax, workspace_buffer); + instr, std::move(gemm_config), blas_lt_epilogue, algorithm, a, b, c, d, + bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax, workspace_buffer); AddThunkToThunkSequence(std::move(thunk)); return absl::OkStatus(); } @@ -1432,9 +1430,11 @@ absl::Status IrEmitterUnnested::EmitTritonCustomCall( KernelArguments::Create(ir_emitter_context_->buffer_assignment(), instr, instr->operands(), /*dedup=*/false)); - auto launch_dimensions = - LaunchDimensions(se::BlockDim(call.grid_x, call.grid_y, call.grid_z), - se::ThreadDim(call.num_warps * 32)); + auto launch_dimensions = LaunchDimensions( + se::BlockDim(call.grid_x, call.grid_y, call.grid_z), + se::ThreadDim( + call.num_warps * + ir_emitter_context_->gpu_device_info().threads_per_warp())); std::string sanitized_kernel_name = GetSanitizedUniqueName(*ir_emitter_context_, kernel_name); diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc index 9612a5566a3c69..084216dbe9e1ce 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc @@ -54,6 +54,7 @@ namespace gpu { // producer and consumer are considered as one fusion, otherwise it's only the // producer. bool IsReadCoalescedHeuristic(HloFusionAnalysis::EmitterFusionKind fusion_kind, + const se::DeviceDescription& device_info, const HloInstruction* producer, const HloInstruction* consumer) { // Transposing minor dimension breaks coalescing. @@ -91,8 +92,8 @@ bool IsReadCoalescedHeuristic(HloFusionAnalysis::EmitterFusionKind fusion_kind, } // Fusing two row reductions breaks coalescing. if (fusion_kind == HloFusionAnalysis::EmitterFusionKind::kReduction && - IsInputFusibleReduction(*producer) && consumer && - IsInputFusibleReduction(*consumer)) { + IsInputFusibleReduction(*producer, device_info) && consumer && + IsInputFusibleReduction(*consumer, device_info)) { return false; } return true; @@ -586,7 +587,8 @@ CoalescingAnalysis::CoalescingAnalysis( } // If ComputeCoalescingForAllOperands fails, fallback to using the heuristic. is_coalesced_computed_by_heuristic_ = - IsReadCoalescedHeuristic(fusion_analysis.GetEmitterFusionKind(), instr); + IsReadCoalescedHeuristic(fusion_analysis.GetEmitterFusionKind(), + fusion_analysis.device_info(), instr); } CoalescingAnalysis::CoalescingAnalysis( @@ -604,7 +606,8 @@ CoalescingAnalysis::CoalescingAnalysis( } // If ComputeCoalescingForAllOperands fails, fallback to using the heuristic. is_coalesced_computed_by_heuristic_ = IsReadCoalescedHeuristic( - fusion_analysis.GetEmitterFusionKind(), producer, consumer); + fusion_analysis.GetEmitterFusionKind(), fusion_analysis.device_info(), + producer, consumer); } bool CoalescingAnalysis::ComputeCoalescingForAllOperands( diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis.h b/third_party/xla/xla/service/gpu/model/coalescing_analysis.h index 5e82b6455afcd7..e097b8cd46b4df 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis.h +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis.h @@ -70,6 +70,7 @@ class CoalescingAnalysis { // producer and consumer are considered as one fusion, otherwise it's only the // producer. bool IsReadCoalescedHeuristic(HloFusionAnalysis::EmitterFusionKind fusion_kind, + const se::DeviceDescription& device_info, const HloInstruction* producer, const HloInstruction* consumer = nullptr); diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc index 70cca31981174e..202f47cda9fe30 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc @@ -78,8 +78,8 @@ class CoalescingTest : public HloTestBase { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); HloInstruction* root = module->entry_computation()->root_instruction(); auto analysis = HloFusionAnalysis::Create(*root, device_info_); - return xla::gpu::IsReadCoalescedHeuristic(analysis.GetEmitterFusionKind(), - root->operand(0), root); + return xla::gpu::IsReadCoalescedHeuristic( + analysis.GetEmitterFusionKind(), device_info_, root->operand(0), root); } protected: diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc index 7798b80d17a681..61e879fa82c0ff 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -314,7 +314,7 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForInstruction( auto fusion_analysis = HloFusionAnalysis::Create(*producer, *device_info_); bool is_coalesced = IsReadCoalescedHeuristic( - fusion_analysis.GetEmitterFusionKind(), producer); + fusion_analysis.GetEmitterFusionKind(), *device_info_, producer); return EstimateRunTimeForFusion(fusion_analysis, is_coalesced); } @@ -324,8 +324,9 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForProducerConsumer( auto fusion_analysis = HloFusionAnalysis::Create(*producer, *consumer, *device_info_); - bool is_coalesced = IsReadCoalescedHeuristic( - fusion_analysis.GetEmitterFusionKind(), producer, consumer); + bool is_coalesced = + IsReadCoalescedHeuristic(fusion_analysis.GetEmitterFusionKind(), + *device_info_, producer, consumer); return EstimateRunTimeForFusion(fusion_analysis, is_coalesced); } @@ -504,13 +505,14 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTriton( /*static*/ LaunchDimensions GpuPerformanceModelWithIndexingAnalysis::GetLaunchDimensionsForTiledFusion( - const TiledHloComputation& tiled_hlo_computation) { + const TiledHloComputation& tiled_hlo_computation, + const se::DeviceDescription& device_info) { const auto* tiled_root = tiled_hlo_computation.GetRoot(); int64_t num_blocks = tiled_hlo_computation.num_output_tiles(); int64_t num_warps = GetNumWarps(GetPaddedTileSize(tiled_root->tile_sizes())); return {static_cast(num_blocks), - static_cast(num_warps * WarpSize())}; + static_cast(num_warps * WarpSize(device_info))}; } absl::StatusOr @@ -538,7 +540,7 @@ GpuPerformanceModelWithIndexingAnalysis::TryFindBestTilingForFusion( analysis.ComputeTiledHloInstructions(tiling)); LaunchDimensions launch_dimensions = - GetLaunchDimensionsForTiledFusion(tiled_hlo_computation); + GetLaunchDimensionsForTiledFusion(tiled_hlo_computation, *device_info_); TF_ASSIGN_OR_RETURN( EstimateRunTimeData estimate_run_time_data, @@ -552,7 +554,7 @@ GpuPerformanceModelWithIndexingAnalysis::TryFindBestTilingForFusion( block_level_parameters.output_tile_sizes = std::vector(tiling.begin(), tiling.end()); block_level_parameters.num_warps = - launch_dimensions.num_threads_per_block() / WarpSize(); + launch_dimensions.num_threads_per_block() / WarpSize(*device_info_); best_tiled_run_time_data = TiledRunTimeData{estimate_run_time_data, block_level_parameters}; diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h index 499b75ea61bffe..a8e7fed4314a30 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h @@ -73,7 +73,8 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { // Returns the launch dimensions for the given tiled HLO computation. static LaunchDimensions GetLaunchDimensionsForTiledFusion( - const TiledHloComputation& tiled_hlo_computation); + const TiledHloComputation& tiled_hlo_computation, + const se::DeviceDescription& device_info); EstimateRunTimeData EstimateRunTimeForFusion( const HloFusionAnalysis& fusion_analysis, bool is_coalesced = true); diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc index dddf7b1d428f9d..2e61dd28df014e 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/model/gpu_indexing_performance_model.h" +#include #include #include #include @@ -72,6 +73,8 @@ class GpuIndexingPerformanceModelTest : public HloTestBase { &device_info_, &fusion_analysis_cache_, ShapeSizeBytesFunction(), &mlir_context_}; + size_t WarpSize() const { return ::xla::gpu::WarpSize(device_info_); } + GpuIndexingPerformanceModelTest() : HloTestBase() {} }; @@ -613,7 +616,7 @@ ENTRY main { .ComputeTiledHloInstructions(/*tile_parameters=*/{9, 9, 9})); LaunchDimensions launch_dimensions = GpuPerformanceModelWithIndexingAnalysis:: - GetLaunchDimensionsForTiledFusion(tiled_hlo_computation); + GetLaunchDimensionsForTiledFusion(tiled_hlo_computation, device_info_); EXPECT_EQ(launch_dimensions.num_blocks(), 1); // Tile size is 9 * 9 * 9 = 729 that corresponds to 2 warps. But we estimate diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.cc b/third_party/xla/xla/service/gpu/nvptx_compiler.cc index e29420c8bfb8b3..3c619e7f03e6d8 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.cc @@ -569,8 +569,12 @@ NVPTXCompiler::NVPTXCompiler() : GpuCompiler(stream_executor::cuda::kCudaPlatformId, nvptx::TargetTriple(), nvptx::DataLayout()) {} -HloDataflowAnalysis::CanShareBuffer NVPTXCompiler::GetCanShareBuffer() const { - return &CanShareBufferHint; +HloDataflowAnalysis::CanShareBuffer NVPTXCompiler::GetCanShareBuffer( + const se::DeviceDescription& device_description) const { + return [&](const HloInstruction* user, const HloInstruction* operand, + const ShapeIndex& user_index) { + return CanShareBufferHint(user, operand, user_index, device_description); + }; } constexpr const uint8_t kPtxPrefix[] = {'P', 'T', 'X', ':', ' '}; diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.h b/third_party/xla/xla/service/gpu/nvptx_compiler.h index 78591bb2c42a7d..50807a39515535 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.h +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.h @@ -89,7 +89,8 @@ class NVPTXCompiler : public GpuCompiler { se::StreamExecutor* stream_exec, BinaryMap* dnn_compiled_graphs) override; - HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() const override; + HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer( + const se::DeviceDescription& device_description) const override; absl::StatusOr CompileTargetBinary( const HloModuleConfig& module_config, llvm::Module* llvm_module, diff --git a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc index 50c627981ce7c5..6591a369a52735 100644 --- a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc +++ b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc @@ -32,14 +32,15 @@ limitations under the License. #include "xla/service/hlo_verifier.h" #include "xla/service/layout_assignment.h" #include "xla/service/loop_schedule_linearizer.h" +#include "xla/stream_executor/device_description.h" #include "xla/xla.pb.h" namespace xla { namespace gpu { HloPassPipeline PrepareHloModuleForIrEmittingPipeline( - HloModule& hlo_module, - HloDataflowAnalysis::CanShareBuffer can_share_buffer) { + HloModule& hlo_module, HloDataflowAnalysis::CanShareBuffer can_share_buffer, + const se::DeviceDescription& device_description) { const DebugOptions& debug_options = hlo_module.config().debug_options(); // In some cases, we have to place the result of an instruction in a temporary @@ -83,8 +84,8 @@ HloPassPipeline PrepareHloModuleForIrEmittingPipeline( auto& sub_pipeline = pipeline.AddPass("horizontal-loop-fusion-for-copy"); // To fuse the copy. - sub_pipeline.AddPass(); - sub_pipeline.AddPass("copy_"); + sub_pipeline.AddPass(device_description); + sub_pipeline.AddPass(device_description, "copy_"); sub_pipeline.AddPass(); pipeline.AddPass(); return pipeline; diff --git a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h index 095907f39794ac..b90697b90fed62 100644 --- a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h +++ b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h @@ -19,6 +19,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/service/hlo_dataflow_analysis.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { @@ -27,8 +28,8 @@ namespace gpu { // pipeline. This pipeline must be run right before IR emission to ensure // correctness of the input module. HloPassPipeline PrepareHloModuleForIrEmittingPipeline( - HloModule& hlo_module, - HloDataflowAnalysis::CanShareBuffer can_share_buffer); + HloModule& hlo_module, HloDataflowAnalysis::CanShareBuffer can_share_buffer, + const se::DeviceDescription& device_description); } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/reduction_utils.cc b/third_party/xla/xla/service/gpu/reduction_utils.cc index 447c0427bbb07a..7abc6401deda87 100644 --- a/third_party/xla/xla/service/gpu/reduction_utils.cc +++ b/third_party/xla/xla/service/gpu/reduction_utils.cc @@ -34,6 +34,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/semantic_version.h" +#include "xla/stream_executor/device_description.h" #include "xla/util.h" #include "tsl/platform/logging.h" @@ -142,24 +143,27 @@ Vector3 GetReductionTiling(const ReductionDimensions& reduction_dimensions) { int64_t ReductionDimensionRaceFreeBound( const HloModuleConfig& hlo_module_config, - const ReductionDimensions& reduction_dimensions) { + const ReductionDimensions& reduction_dimensions, + const se::DeviceDescription& device_description) { Vector3 reduction_tiling = GetReductionTiling(reduction_dimensions); if (reduction_dimensions.is_row_reduction) { return MinThreadsXRowReduction(hlo_module_config) * reduction_tiling[2]; } - return WarpSize() * reduction_tiling[1]; + return WarpSize(device_description) * reduction_tiling[1]; } bool IsUnnestedReductionFasterThanElemental( - const ReductionDimensions& reduction_dimensions) { + const ReductionDimensions& reduction_dimensions, + const se::DeviceDescription& device_description) { + const int64_t warp_size = WarpSize(device_description); if (reduction_dimensions.is_row_reduction) { // For row reduction, the tile block is 1 x tile_size_x, and we are reducing // along tile_size_x which needs to be large enough to make the tiling // implementation efficient. // For very small reductions with a power-of-two size, we can fit multiple // reductions inside a single warp, which is more efficient than a loop. - return (reduction_dimensions.dimensions[2] >= WarpSize()) || - ((WarpSize() % reduction_dimensions.dimensions[2]) == 0); + return (reduction_dimensions.dimensions[2] >= warp_size) || + ((warp_size % reduction_dimensions.dimensions[2]) == 0); } // For column reduction, the tile block is tile_size_y x tile_size_x, and we @@ -170,15 +174,17 @@ bool IsUnnestedReductionFasterThanElemental( // Rule generated by sweeping the search space of small column reductions. bool prefer_elemental_emitter = - (major_size < WarpSize()) || - (major_size < 2 * WarpSize() && minor_size < WarpSize()) || - (major_size < 4 * WarpSize() && minor_size < 8) || - (major_size < 8 * WarpSize() && minor_size < 3); + (major_size < warp_size) || + (major_size < 2 * warp_size && minor_size < warp_size) || + (major_size < 4 * warp_size && minor_size < 8) || + (major_size < 8 * warp_size && minor_size < 3); return !prefer_elemental_emitter; } -bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) { +bool IsReductionFromOrToContiguousDimensions( + const HloInstruction& reduce, + const se::DeviceDescription& device_description) { if (reduce.opcode() != HloOpcode::kReduce) { return false; } @@ -201,15 +207,18 @@ bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) { LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(), dims_to_reduce)) && IsUnnestedReductionFasterThanElemental( - GetReductionKindAndContiguousComponents(reduce)); + GetReductionKindAndContiguousComponents(reduce), + device_description); } bool ReductionIsRaceFree(const HloModuleConfig& hlo_module_config, - const ReductionDimensions& reduction_dimensions) { + const ReductionDimensions& reduction_dimensions, + const se::DeviceDescription& device_description) { if (reduction_dimensions.is_row_reduction) { return reduction_dimensions.dimensions[2] <= ReductionDimensionRaceFreeBound(hlo_module_config, - reduction_dimensions) && + reduction_dimensions, + device_description) && reduction_dimensions.dimensions[0] <= BatchedReductionRaceFreeBound(); } @@ -217,7 +226,8 @@ bool ReductionIsRaceFree(const HloModuleConfig& hlo_module_config, // Column reduction. return reduction_dimensions.dimensions[1] <= ReductionDimensionRaceFreeBound(hlo_module_config, - reduction_dimensions); + reduction_dimensions, + device_description); } std::ostream& operator<<(std::ostream& os, @@ -275,14 +285,15 @@ ReductionDimensions GetReductionKindAndContiguousComponents( return {/*is_row_reduction=*/false, shape_partition}; } -bool IsRealReductionHero(const HloInstruction& root, - const HloInstruction& hero) { - if (!IsReductionFromOrToContiguousDimensions(hero)) { +bool IsRealReductionHero(const HloInstruction& root, const HloInstruction& hero, + const se::DeviceDescription& device_description) { + if (!IsReductionFromOrToContiguousDimensions(hero, device_description)) { return false; } return &root == &hero || ReductionIsRaceFree(hero.GetModule()->config(), - GetReductionKindAndContiguousComponents(hero)); + GetReductionKindAndContiguousComponents(hero), + device_description); } bool AreReductionsMultiOutputFusionCompatible( diff --git a/third_party/xla/xla/service/gpu/reduction_utils.h b/third_party/xla/xla/service/gpu/reduction_utils.h index 7e5e31bc464ce3..c7158856a32fe4 100644 --- a/third_party/xla/xla/service/gpu/reduction_utils.h +++ b/third_party/xla/xla/service/gpu/reduction_utils.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/hlo_module_config.h" +#include "xla/stream_executor/device_description.h" #include "xla/util.h" namespace xla { @@ -79,11 +80,14 @@ std::ostream& operator<<(std::ostream& os, // Returns true if using the reduction emitter is estimated to be faster than // using the elemental emitter. bool IsUnnestedReductionFasterThanElemental( - const ReductionDimensions& reduction_dimensions); + const ReductionDimensions& reduction_dimensions, + const se::DeviceDescription& device_description); // Returns true if either the dimensions being reduced or the dimensions being // kept are contiguous in the input of the reduce instruction. -bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce); +bool IsReductionFromOrToContiguousDimensions( + const HloInstruction& reduce, + const se::DeviceDescription& device_description); // Given the input shape and dimensions to reduce for a reduction, returns // ReductionDimensions. @@ -100,16 +104,18 @@ Vector3 GetReductionTiling(const ReductionDimensions& reduction_dimensions); // How big the reduction dimension can be to be race free. int64_t ReductionDimensionRaceFreeBound( const HloModuleConfig& hlo_module_config, - const ReductionDimensions& reduction_dimensions); + const ReductionDimensions& reduction_dimensions, + const se::DeviceDescription& device_description); // Returns whether the given reduction can be safely generated without atomics : // that is, at most one block will write to every output element. bool ReductionIsRaceFree(const HloModuleConfig& hlo_module_config, - const ReductionDimensions& reduction_dimensions); + const ReductionDimensions& reduction_dimensions, + const se::DeviceDescription& device_description); // Whether the instruction is a reduction hero for the given root. -bool IsRealReductionHero(const HloInstruction& root, - const HloInstruction& hero); +bool IsRealReductionHero(const HloInstruction& root, const HloInstruction& hero, + const se::DeviceDescription& device_description); // Whether `reduction_hero` is compatible with `first_reduce`. bool AreReductionsMultiOutputFusionCompatible( diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 31c99ac6984708..7aa7e5229686e6 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -641,6 +641,7 @@ cc_library( "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:reduction_utils", + "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", @@ -658,6 +659,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", + "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", @@ -1547,6 +1549,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service/gpu:gpu_fusible", + "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -1857,6 +1860,7 @@ cc_library( "//xla/service:hlo_creation_utils", "//xla/service:sub_byte_normalization", "//xla/service/gpu:gpu_fusible", + "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -2367,6 +2371,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service/gpu:reduction_utils", + "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -2388,6 +2393,7 @@ xla_cc_test( "//xla/service:hlo_parser", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", + "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", ], @@ -2702,6 +2708,7 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_fusible", "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", @@ -2720,6 +2727,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/service/gpu:backend_configs_cc", "//xla/tests:filecheck", + "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings:string_view", diff --git a/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc b/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc index eb43ca2364f0c8..9912316f5e945d 100644 --- a/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc @@ -75,7 +75,7 @@ absl::StatusOr CopyFusion::DoCopyFusion(HloComputation* computation) { continue; } HloInstruction* root = fused_computation->root_instruction(); - if (IsReductionFromOrToContiguousDimensions(*root) || + if (IsReductionFromOrToContiguousDimensions(*root, device_description_) || root->opcode() == HloOpcode::kScatter || (hlo->IsMultiOutputFusion() && absl::c_all_of(root->operands(), [](const HloInstruction* slice) { @@ -89,7 +89,8 @@ absl::StatusOr CopyFusion::DoCopyFusion(HloComputation* computation) { if (copy_user->opcode() == HloOpcode::kGetTupleElement && copy_user->user_count() == 1) { if (IsReductionFromOrToContiguousDimensions( - *(root->operand(copy_user->tuple_index())))) { + *(root->operand(copy_user->tuple_index())), + device_description_)) { other_users.push_back(user); continue; } diff --git a/third_party/xla/xla/service/gpu/transforms/copy_fusion.h b/third_party/xla/xla/service/gpu/transforms/copy_fusion.h index 8350935c8982d5..96ff095c81826d 100644 --- a/third_party/xla/xla/service/gpu/transforms/copy_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/copy_fusion.h @@ -22,6 +22,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { @@ -30,7 +31,8 @@ namespace gpu { // those copies to the fusion, replacing the copies with get_tuple_elements. class CopyFusion : public HloModulePass { public: - CopyFusion() = default; + explicit CopyFusion(const se::DeviceDescription& device_description) + : device_description_(device_description) {} absl::string_view name() const override { return "copy_fusion"; } @@ -41,6 +43,8 @@ class CopyFusion : public HloModulePass { private: absl::StatusOr DoCopyFusion(HloComputation* computation); + + const se::DeviceDescription& device_description_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/transforms/copy_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/copy_fusion_test.cc index 1bd2d11fe7ddc7..7d5edd9d06728a 100644 --- a/third_party/xla/xla/service/gpu/transforms/copy_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/copy_fusion_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" namespace xla { @@ -28,8 +29,18 @@ namespace gpu { namespace m = ::xla::match; +auto MakeDeviceDescriptor() { + stream_executor::DeviceDescription device_description{ + stream_executor::GpuDeviceInfoProto{}}; + device_description.set_threads_per_warp(32); + return device_description; +} + class CopyFusionTest : public HloTestBase { public: + CopyFusionTest() + : device_description_(MakeDeviceDescriptor()), cf_(device_description_) {} + const stream_executor::DeviceDescription device_description_; CopyFusion cf_; }; diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc b/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc index d132cf6f3ae682..7c1741c21d371a 100644 --- a/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc +++ b/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc @@ -58,7 +58,8 @@ class FusionInstructionMerger { dump_fusion_visualization_(computation->parent() ->config() .debug_options() - .xla_dump_fusion_visualization()) {} + .xla_dump_fusion_visualization()), + fusion_info_cache_(gpu_device_info) {} absl::Status Run(); @@ -113,7 +114,8 @@ absl::Status FusionInstructionMerger::FuseIntoAllUsers( HloInstruction* consumer = user; if (consumer->opcode() != HloOpcode::kFusion) { consumer = computation_->AddInstruction(HloInstruction::CreateFusion( - user->shape(), ChooseFusionKind(*producer, *user), user)); + user->shape(), ChooseFusionKind(*producer, *user, gpu_device_info_), + user)); TF_CHECK_OK(computation_->ReplaceInstruction(user, consumer)); } @@ -223,7 +225,8 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) { return FusionDecision::Forbid("not a loop fusion"); } - auto producer_hero = GetRealHeroForMultiOutputFusion(*producer); + auto producer_hero = + GetRealHeroForMultiOutputFusion(*producer, gpu_device_info_); bool has_reduction_user = false; for (const HloInstruction* user : producer->users()) { @@ -235,19 +238,21 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) { ++num_fail_merge_all_users_; return FusionDecision::Forbid("not fusing custom fusions"); } - auto consumer_hero = GetRealHeroForMultiOutputFusion(*user); - if (auto compatible = - FusionHeroesAreCompatible(producer_hero, consumer_hero); + auto consumer_hero = + GetRealHeroForMultiOutputFusion(*user, gpu_device_info_); + if (auto compatible = FusionHeroesAreCompatible( + producer_hero, consumer_hero, gpu_device_info_); !compatible) { return compatible; } - FusionDecision fusible = IsProducerConsumerFusible(*producer, *user); + FusionDecision fusible = + IsProducerConsumerFusible(*producer, *user, gpu_device_info_); if (!fusible) { ++num_fail_merge_all_users_; VLOG(9) << user->ToString(); return fusible; } - if (IsInputFusibleReduction(*user)) { + if (IsInputFusibleReduction(*user, gpu_device_info_)) { has_reduction_user = true; } } diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.cc b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.cc index 16957f80d370e0..71b09e084ec7f7 100644 --- a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.cc +++ b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.cc @@ -117,7 +117,9 @@ absl::StatusOr FusionWrapper::Run( auto* fusion_instruction = computation->AddInstruction(HloInstruction::CreateFusion( instruction->shape(), - ChooseFusionKind(*instruction, *instruction), instruction)); + ChooseFusionKind(*instruction, *instruction, + device_description_), + instruction)); const absl::string_view wrapped_opcode = HloOpcodeString(instruction->opcode()); module->SetAndUniquifyInstrName( diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.h b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.h index 1e9a085fbb0b26..804d0590ee33df 100644 --- a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.h +++ b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { @@ -28,12 +29,17 @@ namespace gpu { // have no LHLO equivalent in fusions containing just that instruction. class FusionWrapper : public HloModulePass { public: + explicit FusionWrapper(const se::DeviceDescription& device_description) + : device_description_(device_description) {} absl::string_view name() const override { return "fusion-wrapper"; } using HloPassInterface::Run; absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; + + private: + const se::DeviceDescription& device_description_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc index be5f0d7dfd49c6..984f9a246c323f 100644 --- a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/transforms/fusion_wrapper.h" +#include #include #include @@ -23,7 +24,25 @@ namespace xla { namespace gpu { namespace { -class FusionWrapperTest : public HloTestBase {}; +auto MakeDeviceDescription() { + stream_executor::DeviceDescription device_description{ + stream_executor::GpuDeviceInfoProto{}}; + device_description.set_threads_per_warp(32); + return device_description; +} + +class FusionWrapperTest : public HloTestBase { + public: + using HloTestBase::HloTestBase; + + const stream_executor::DeviceDescription& device_description() const { + return device_description_; + } + + private: + const stream_executor::DeviceDescription device_description_{ + MakeDeviceDescription()}; +}; TEST_F(FusionWrapperTest, ConvolutionWorks) { RunAndFilecheckHloRewrite(R"(HloModule TestModule @@ -33,7 +52,7 @@ ENTRY TestComputation { kernel = f32[20,1,2,1,4,15]{5,4,3,2,1,0} parameter(1) ROOT conv = f32[15,1,9,1,7,5]{5,4,3,2,1,0} convolution(input, kernel), dim_labels=0123bf_i0123o->f0123b, window={size=1x2x1x4} })", - FusionWrapper(), R"( + FusionWrapper(device_description()), R"( // CHECK: %wrapped_convolution_computation (param_0: f32[1,10,1,10,5,20], param_1: f32[20,1,2,1,4,15]) -> f32[15,1,9,1,7,5] { // CHECK: %param_0 = f32[1,10,1,10,5,20]{5,4,3,2,1,0} parameter(0) // CHECK: %param_1 = f32[20,1,2,1,4,15]{5,4,3,2,1,0} parameter(1) @@ -56,7 +75,7 @@ TEST_F(FusionWrapperTest, SimpleOp) { p1 = f16[30,41] parameter(1) ROOT result = f16[60, 41] concatenate(p0, p1), dimensions={0} })", - FusionWrapper(), R"( + FusionWrapper(device_description()), R"( // CHECK: %wrapped_concatenate_computation (param_0: f16[30,41], param_1: f16[30,41]) -> f16[60,41] { // CHECK: %param_0 = f16[30,41]{1,0} parameter(0) // CHECK: %param_1 = f16[30,41]{1,0} parameter(1) @@ -90,7 +109,7 @@ TEST_F(FusionWrapperTest, Scatter) { index_vector_dim=0, to_apply=update_s32 })", - FusionWrapper(), R"( + FusionWrapper(device_description()), R"( // CHECK: wrapped_scatter_computation // CHECK: %[[param_0:.*]] = s32[] parameter(0) // CHECK: %[[param_1:.*]] = s32[0]{0} parameter(1) @@ -119,7 +138,7 @@ TEST_F(FusionWrapperTest, ControlDependency) { constant_one = f32[] constant(1) ROOT add = f32[] add(param, constant_one), control-predecessors={fusion} })", - FusionWrapper(), R"( + FusionWrapper(device_description()), R"( // CHECK: ROOT %wrapped_add = f32[] fusion(%param.1, %constant_one), // CHECK-SAME: control-predecessors={%fusion})"); } @@ -146,7 +165,7 @@ TEST_F(FusionWrapperTest, While) { %tuple = (f32[5]{0}) tuple(f32[5]{0} %copy.3) ROOT %while.19 = (f32[5]{0}) while((f32[5]{0}) %tuple), condition=%cond, body=%body })", - FusionWrapper(), R"( + FusionWrapper(device_description()), R"( // CHECK: %wrapped_broadcast_computation {{.*}} { // CHECK: %param_0.1 = f32[] parameter(0) // CHECK: ROOT %broadcast.0 = f32[5]{0} broadcast(%param_0.1), dimensions={} @@ -200,7 +219,7 @@ TEST_F(FusionWrapperTest, WhileInFusion) { %parameter.1 = f32[5]{0} parameter(0) ROOT %fusion = (f32[5]{0}) fusion(f32[5]{0} %parameter.1), kind=kLoop, calls=%fusion })", - FusionWrapper(), + FusionWrapper(device_description()), // No change std::nullopt); } diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc index befe869ac072df..5523792c096b6d 100644 --- a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc @@ -42,9 +42,11 @@ namespace gpu { namespace { // Gets the representative input shape of the multi-output fusion. -Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) { +Shape GetInputShapeForMultiOutputFusion( + const HloInstruction& instr, const se::DeviceDescription& device_info) { // Get the HLO that determines the emitter used for lowering. - const HloInstruction* real_hero = GetRealHeroForMultiOutputFusion(instr); + const HloInstruction* real_hero = + GetRealHeroForMultiOutputFusion(instr, device_info); if (real_hero->operands().empty()) { // Simply return an empty shape if the representative node has no input // operands. @@ -87,7 +89,7 @@ bool CompareShapeDimsFromLeftToRight(const Shape& shape_a, } std::vector FindAndSortFusionCandidates( - HloInstruction* consumer) { + HloInstruction* consumer, const se::DeviceDescription& device_info) { absl::flat_hash_set fusion_instr_set; std::vector fusion_instrs; for (HloInstruction* opnd : consumer->operands()) { @@ -95,7 +97,7 @@ std::vector FindAndSortFusionCandidates( // Find out the input fusion instructions whose only consumer is `consumer`. // This guarantees that fusing these candidates will never create cycles, as // there is no back edge. - if (IsInputFusibleReduction(*predecessor) && + if (IsInputFusibleReduction(*predecessor, device_info) && IsConsumerTheOnlyNonRootUser(*predecessor, *consumer)) { if (fusion_instr_set.insert(predecessor).second) { fusion_instrs.push_back(predecessor); @@ -105,8 +107,10 @@ std::vector FindAndSortFusionCandidates( std::sort(fusion_instrs.begin(), fusion_instrs.end(), [&](const HloInstruction* a, const HloInstruction* b) { - Shape shape_a = GetInputShapeForMultiOutputFusion(*a); - Shape shape_b = GetInputShapeForMultiOutputFusion(*b); + Shape shape_a = + GetInputShapeForMultiOutputFusion(*a, device_info); + Shape shape_b = + GetInputShapeForMultiOutputFusion(*b, device_info); if (!ShapeUtil::EqualIgnoringElementType(shape_a, shape_b)) { // Sort shapes according to dimensions, so that the same input // shapes will be placed adjacent each other. @@ -128,7 +132,7 @@ absl::StatusOr HorizontalInputFusionImpl::Run() { std::vector def_to_use_order = computation_->MakeInstructionPostOrder(); for (HloInstruction* consumer : def_to_use_order) { - auto candidates = FindAndSortFusionCandidates(consumer); + auto candidates = FindAndSortFusionCandidates(consumer, device_info_); if (candidates.size() <= 1) { continue; } @@ -149,7 +153,8 @@ absl::StatusOr HorizontalInputFusionImpl::Run() { for (size_t j = 1; j < candidates.size(); ++j) { HloInstruction* fusion_anchor = candidates[fusion_anchor_id]; HloInstruction* fused = candidates[j]; - if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) && + if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused, + device_info_) && FusionFitsInBudget(*fusion_anchor, *fused, device_info_)) { VLOG(3) << "Fuse " << fused->ToString() << " into " << fusion_anchor->ToString(); diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc index 0a3d705103c416..1cef7a986804b9 100644 --- a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/service/sub_byte_normalization.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" @@ -70,9 +71,12 @@ PrimitiveType GetUniqueOutputTypeOfFusible(const HloInstruction& fusible) { class HorizontalLoopFusionImpl { public: - explicit HorizontalLoopFusionImpl(HloComputation* computation, - absl::string_view prefix) - : computation_(computation), prefix_(prefix) {} + explicit HorizontalLoopFusionImpl( + HloComputation* computation, + const se::DeviceDescription& device_description, absl::string_view prefix) + : computation_(computation), + device_description_(device_description), + prefix_(prefix) {} ~HorizontalLoopFusionImpl() = default; @@ -116,18 +120,20 @@ class HorizontalLoopFusionImpl { class FusionCandidates { public: explicit FusionCandidates(HloInstruction* consumer, - bool sliced_input_fusion) + bool sliced_input_fusion, + const se::DeviceDescription& device_description) : fusible_instrs_(), pos_(0), sliced_input_fusion_(sliced_input_fusion) { - Initialize(consumer); + Initialize(consumer, device_description); } // Gets a span of fusions to be fused. absl::Span GetNextSpanOfFusions(); private: - void Initialize(HloInstruction*); + void Initialize(HloInstruction* consumer, + const se::DeviceDescription& device_description); std::vector fusible_instrs_; // `pos_` points to the start position of the next span. @@ -138,17 +144,19 @@ class HorizontalLoopFusionImpl { }; HloComputation* computation_; + const se::DeviceDescription& device_description_; std::string prefix_; }; // HorizontalLoopFusionImpl -bool IsFusibleCandidate(const HloInstruction& instr) { +bool IsFusibleCandidate(const HloInstruction& instr, + const se::DeviceDescription& device_description) { // For now, we do not support fusing instruction with control flow. if (!instr.control_successors().empty() || !instr.control_predecessors().empty()) { return false; } - if (IsNestableVariadicReduction(instr)) { + if (IsNestableVariadicReduction(instr, device_description)) { return false; } @@ -266,7 +274,7 @@ bool AnyOpndIsParamSharedAmongFusions( } void HorizontalLoopFusionImpl::FusionCandidates::Initialize( - HloInstruction* consumer) { + HloInstruction* consumer, const se::DeviceDescription& device_description) { // First, find out all potential target candidates. We will filter out // unsupported/non-profitable cases below. absl::flat_hash_set fusible_candidates; @@ -275,7 +283,7 @@ void HorizontalLoopFusionImpl::FusionCandidates::Initialize( HloInstruction* predecessor = opnd->LatestNonGteAncestor(); // We support kLoop fusion and element-wise HLOs now. We may extend the // support list if needs arise. - if (IsFusibleCandidate(*predecessor)) { + if (IsFusibleCandidate(*predecessor, device_description)) { if (fusible_candidates.insert(predecessor).second) { // Add unseen fusion to ordered list. ordered_fusible_candidates.push_back(predecessor); @@ -423,7 +431,8 @@ absl::StatusOr HorizontalLoopFusionImpl::FuseConsumerOperands( HloInstruction* consumer, bool sliced_input_fusion, std::vector& to_fuse_candidates) { bool changed = false; - FusionCandidates loop_fusion_candidates(consumer, sliced_input_fusion); + FusionCandidates loop_fusion_candidates(consumer, sliced_input_fusion, + device_description_); while (true) { auto fusibles = loop_fusion_candidates.GetNextSpanOfFusions(); if (fusibles.empty()) { @@ -715,7 +724,8 @@ absl::StatusOr HorizontalLoopFusionImpl::Run() { absl::StatusOr HorizontalLoopFusion::RunOnComputation( HloComputation* computation) { - HorizontalLoopFusionImpl horizontal_fusion_impl(computation, prefix_); + HorizontalLoopFusionImpl horizontal_fusion_impl(computation, + device_description_, prefix_); return horizontal_fusion_impl.Run(); } diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.h b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.h index a602a516b724b6..d7add7aff840d7 100644 --- a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.h @@ -25,6 +25,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { @@ -124,8 +125,9 @@ namespace gpu { // Note, reshapes are added only if the tensors isn't already a vector. class HorizontalLoopFusion : public HloModulePass { public: - HorizontalLoopFusion() = default; - explicit HorizontalLoopFusion(absl::string_view prefix) : prefix_(prefix) {} + explicit HorizontalLoopFusion(const se::DeviceDescription& device_description, + absl::string_view prefix = "") + : device_description_(device_description), prefix_(prefix) {} absl::string_view name() const override { return "horizontal_loop_fusion"; } @@ -136,6 +138,8 @@ class HorizontalLoopFusion : public HloModulePass { private: absl::StatusOr RunOnComputation(HloComputation*); + + const se::DeviceDescription& device_description_; std::string prefix_; }; diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc index d3fb82e9d4b05f..6628825bb8e337 100644 --- a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc @@ -47,11 +47,19 @@ namespace { namespace m = ::xla::match; +auto MakeDeviceDescription() { + stream_executor::DeviceDescription device_description{ + stream_executor::GpuDeviceInfoProto{}}; + device_description.set_threads_per_warp(32); + return device_description; +} + class HorizontalLoopFusionTest : public HloTestBase { public: static bool IsFusion(const HloInstruction* instr) { return instr->opcode() == HloOpcode::kFusion; } + const se::DeviceDescription device_description_{MakeDeviceDescription()}; }; TEST_F(HorizontalLoopFusionTest, BasicTest) { @@ -85,7 +93,8 @@ TEST_F(HorizontalLoopFusionTest, BasicTest) { )") .value(); - EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE( + HorizontalLoopFusion{device_description_}.Run(module.get()).value()); TF_ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_FALSE(HloDCE().Run(module.get()).value()); @@ -136,7 +145,8 @@ TEST_F(HorizontalLoopFusionTest, NegativeTestForCycle) { )") .value(); - EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value()); + EXPECT_FALSE( + HorizontalLoopFusion{device_description_}.Run(module.get()).value()); } TEST_F(HorizontalLoopFusionTest, NegativeTestForIncompatibleTypes) { @@ -172,7 +182,8 @@ TEST_F(HorizontalLoopFusionTest, NegativeTestForIncompatibleTypes) { )") .value(); - EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value()); + EXPECT_FALSE( + HorizontalLoopFusion{device_description_}.Run(module.get()).value()); } TEST_F(HorizontalLoopFusionTest, FusingIntoKLoopAndKInputTogether) { @@ -259,7 +270,8 @@ TEST_F(HorizontalLoopFusionTest, FusingIntoKLoopAndKInputTogether) { )") .value(); - EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE( + HorizontalLoopFusion{device_description_}.Run(module.get()).value()); int input_fusion_count = 0; int loop_fusion_count = 0; @@ -308,7 +320,8 @@ TEST_F(HorizontalLoopFusionTest, HorizontalLoopFusionAfterVerticalFusion) { fusion.AddPass(/*may_duplicate=*/true, device_info); EXPECT_TRUE(fusion.Run(module.get()).value()); - EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE( + HorizontalLoopFusion{device_description_}.Run(module.get()).value()); TF_ASSERT_OK(verifier().Run(module.get()).status()); VLOG(2) << "Dump after horizontal fusion:"; @@ -415,7 +428,8 @@ TEST_F(HorizontalLoopFusionTest, FusingDifferentOutputs) { )") .value(); - EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE( + HorizontalLoopFusion{device_description_}.Run(module.get()).value()); TF_ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_FALSE(HloDCE().Run(module.get()).value()); @@ -545,7 +559,8 @@ TEST_F(HorizontalLoopFusionTest, DynamicUpdateSlice) { })") .value(); - EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE( + HorizontalLoopFusion{device_description_}.Run(module.get()).value()); TF_ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_FALSE(HloDCE().Run(module.get()).value()); @@ -586,7 +601,8 @@ TEST_F(HorizontalLoopFusionTest, NegativeTestForSharedParam) { )") .value(); - EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value()); + EXPECT_FALSE( + HorizontalLoopFusion{device_description_}.Run(module.get()).value()); } TEST_F(HorizontalLoopFusionTest, IterativeHorizontalFusion) { @@ -627,7 +643,7 @@ TEST_F(HorizontalLoopFusionTest, IterativeHorizontalFusion) { .value(); HloPassFix iterative_h_fusion("iterative_h_fusion"); - iterative_h_fusion.AddPass(); + iterative_h_fusion.AddPass(device_description_); iterative_h_fusion.AddPass(); EXPECT_TRUE(iterative_h_fusion.Run(module.get()).value()); @@ -699,7 +715,7 @@ TEST_F(HorizontalLoopFusionTest, TraversalOrder) { .value(); HloPassFix iterative_h_fusion("iterative_h_fusion"); - iterative_h_fusion.AddPass(); + iterative_h_fusion.AddPass(device_description_); EXPECT_TRUE(iterative_h_fusion.Run(module.get()).value()); // Verify that the total number of fusion instructions is 2 so that we @@ -773,7 +789,8 @@ ENTRY main { )"; auto module = ParseAndReturnUnverifiedModule(hlo_text).value(); - EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE( + HorizontalLoopFusion{device_description_}.Run(module.get()).value()); VLOG(2) << module->ToString(); @@ -843,7 +860,8 @@ TEST_F(HorizontalLoopFusionTest, DoNotMergeVariadicReductions) { })") .value(); - EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value()); + EXPECT_FALSE( + HorizontalLoopFusion{device_description_}.Run(module.get()).value()); } } // namespace diff --git a/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc index bfd8c5bbb6b0a9..c2649e7083cbfa 100644 --- a/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc @@ -50,7 +50,7 @@ class EmptyFusionQueue : public FusionQueue { DequeueNextInstructionAndOperandsToFuseInOrder() override { return {nullptr, {}}; } - void RemoveInstruction(HloInstruction* instruction) override {}; + void RemoveInstruction(HloInstruction* instruction) override{}; const std::vector* FusionConfiguration() override { return nullptr; }; }; @@ -108,15 +108,16 @@ FusionDecision GpuInstructionFusion::ShouldFuseInexpensiveChecks( // Do not fuse into fusions if the resulting kernel would suffer from // uncoalesced reads due to a transposed memory access pattern. - if (IsInputFusibleReduction(*consumer) && + if (IsInputFusibleReduction(*consumer, device_info_) && IsPhysicallyTransposing(*producer)) { return FusionDecision::Forbid( "fusing the producer would break read coalescing"); } - RETURN_IF_NOT_FUSIBLE(IsProducerConsumerFusible(*producer, *consumer)); + RETURN_IF_NOT_FUSIBLE( + IsProducerConsumerFusible(*producer, *consumer, device_info_)); - if (CreatesHeavyComputation(*producer, *consumer)) { + if (CreatesHeavyComputation(*producer, *consumer, device_info_)) { return FusionDecision::Forbid( "the fusion would create a heavy computation"); } @@ -160,7 +161,7 @@ FusionDecision GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( const HloInstruction* producer, const HloInstruction* consumer) { - return ChooseFusionKind(*producer, *consumer); + return ChooseFusionKind(*producer, *consumer, device_info_); } HloInstruction* GpuInstructionFusion::FuseInstruction( diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc index 9af3e7e04d4d47..8847322cfc26fa 100644 --- a/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc @@ -586,7 +586,8 @@ bool GpuLayoutAssignment::PropagateReductionLayoutToOperand( } int64_t kept_dimension_size = ShapeUtil::ElementsIn(user->shape()); return IsUnnestedReductionFasterThanElemental( - {/*is_row_reduction=*/true, {1, kept_dimension_size, reduction_size}}); + {/*is_row_reduction=*/true, {1, kept_dimension_size, reduction_size}}, + device_description_); } bool GpuLayoutAssignment::InstructionCanChangeLayoutInstance( diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment.h b/third_party/xla/xla/service/gpu/transforms/layout_assignment.h index efa58f3f8c3c72..55e6b8de457965 100644 --- a/third_party/xla/xla/service/gpu/transforms/layout_assignment.h +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment.h @@ -39,10 +39,12 @@ class GpuLayoutAssignment : public LayoutAssignment { ComputationLayout* entry_computation_layout, const se::GpuComputeCapability& gpu_version, const se::dnn::VersionInfo& dnn_version, + const se::DeviceDescription& device_description, ChannelLayoutConstraints* channel_constraints = nullptr) : LayoutAssignment(entry_computation_layout, channel_constraints), gpu_version_(gpu_version), - dnn_version_(dnn_version) {} + dnn_version_(dnn_version), + device_description_(device_description) {} ~GpuLayoutAssignment() override = default; protected: @@ -73,6 +75,7 @@ class GpuLayoutAssignment : public LayoutAssignment { const se::GpuComputeCapability gpu_version_; const se::dnn::VersionInfo dnn_version_; + const se::DeviceDescription& device_description_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc index 4dbd453e1d4850..9b299b541142c2 100644 --- a/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc @@ -110,7 +110,8 @@ TEST_F(LayoutAssignmentTest, Elementwise) { ShapeLayout(result_shape_with_layout); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); for (const HloInstruction* operand : add->operands()) { @@ -140,7 +141,8 @@ TEST_F(LayoutAssignmentTest, DotLayoutUnchangedIfValid) { module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT(module->entry_computation()->root_instruction(), GmockMatch(m::Dot(m::Op().WithShape(F32, {5, 2, 3}, {1, 2, 0}), @@ -166,7 +168,8 @@ TEST_F(LayoutAssignmentTest, DotLayoutSetToDefaultIfDefaultValid) { module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -193,7 +196,8 @@ TEST_F(LayoutAssignmentTest, DotOperandLayoutSetToBatchRowsColsOtherwise) { module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -219,7 +223,8 @@ TEST_F(LayoutAssignmentTest, DotOperandInconsistentDimLayouts) { module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT( @@ -247,7 +252,8 @@ TEST_F(LayoutAssignmentTest, TransposedDotLayout) { module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT( @@ -280,7 +286,8 @@ TEST_F(LayoutAssignmentTest, TransposedDotOfDotLayout) { module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); // The transpose layout is not supported by dot.2. Also, we need a copy @@ -316,7 +323,8 @@ TEST_F(LayoutAssignmentTest, DotLayoutS8) { module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -351,7 +359,8 @@ TEST_F(LayoutAssignmentTest, SortLayout) { module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); @@ -394,7 +403,8 @@ TEST_F(LayoutAssignmentTest, TopKLayout) { module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); @@ -421,7 +431,8 @@ TEST_F(LayoutAssignmentTest, FftLayout) { module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -449,7 +460,8 @@ ENTRY entry { m->entry_computation()->ComputeProgramShape()); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true)); @@ -481,7 +493,8 @@ ENTRY entry { m->entry_computation()->ComputeProgramShape()); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true)); @@ -509,7 +522,8 @@ ENTRY entry { m->entry_computation()->ComputeProgramShape()); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true)); @@ -617,7 +631,8 @@ ENTRY main { ComputationLayout computation_layout( m->entry_computation()->ComputeProgramShape()); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true)); auto reduce = m->entry_computation()->root_instruction(); @@ -647,7 +662,8 @@ ENTRY main { ComputationLayout computation_layout( m->entry_computation()->ComputeProgramShape()); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true)); auto reduce = m->entry_computation()->root_instruction(); @@ -683,7 +699,8 @@ ENTRY main { ComputationLayout computation_layout( m->entry_computation()->ComputeProgramShape()); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true)); auto reduce = m->entry_computation()->root_instruction(); @@ -744,7 +761,7 @@ ENTRY %main { RunAndFilecheckHloRewrite( hlo, GpuLayoutAssignment{&computation_layout, GetGpuComputeCapability(), - GetDnnVersion()}, + GetDnnVersion(), GetDeviceDescription()}, R"( // CHECK: (f32[100,100]{1,0}, u32[], token[]) recv // CHECK: (f32[100,100]{1,0}, token[]) recv-done diff --git a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc index 9ab9729b3b9202..684f3954daaef0 100644 --- a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc @@ -189,13 +189,13 @@ FusionDecision ProducerCandidateIsFusible( const HloDfsReachability& reachability, FusionInfoCache* fusion_info_cache, const se::DeviceDescription& device_info, GpuHloCostAnalysis* cost_analysis) { - if (!IsFusibleAsMultiOutputFusionRoot(consumer)) { + if (!IsFusibleAsMultiOutputFusionRoot(consumer, device_info)) { return FusionDecision::Forbid( "consumer not eligible as multi-output fusion root."); } RETURN_IF_NOT_FUSIBLE( - ShapesCompatibleForMultiOutputFusion(consumer, producer)); + ShapesCompatibleForMultiOutputFusion(consumer, producer, device_info)); RETURN_IF_NOT_FUSIBLE( OperandReachableFromProducer(producer, consumer, reachability)); @@ -233,7 +233,7 @@ std::vector GetProducerConsumerMultiOutputFusionCandidates( // If the producer is not a valid candidate for MOF, no need to check any of // its users. - if (!IsProducerMultiOutputFusible(*producer)) { + if (!IsProducerMultiOutputFusible(*producer, device_info)) { return fusion_candidates; } @@ -265,9 +265,11 @@ std::vector GetProducerConsumerMultiOutputFusionCandidates( return fusion_candidates; } -bool IsSiblingFusionCandidate(const HloInstruction* instr) { - if (instr->users().empty() || !IsFusibleAsMultiOutputFusionRoot(*instr) || - IsNestableVariadicReduction(*instr)) { +bool IsSiblingFusionCandidate(const HloInstruction* instr, + const se::DeviceDescription& device_info) { + if (instr->users().empty() || + !IsFusibleAsMultiOutputFusionRoot(*instr, device_info) || + IsNestableVariadicReduction(*instr, device_info)) { return false; } // Check if the users of multioutput fusion is not a get-tuple-element. @@ -292,7 +294,7 @@ FusionDecision CanFuseSiblings(const HloInstruction& sibling_consumer_1, } RETURN_IF_NOT_FUSIBLE(ShapesCompatibleForMultiOutputFusion( - sibling_consumer_1, sibling_consumer_2)); + sibling_consumer_1, sibling_consumer_2, device_info)); // Technically, this check is order-dependent (e.g. siblings A, B, C where // {A, B} and {B, C} overlap, but {A, C} do not. If the priority order is @@ -331,7 +333,9 @@ bool MultiOutputFusion::FuseSiblings(HloInstruction* parent, std::vector siblings; // Only consider siblings that are fusion candidates. absl::c_copy_if(parent->users(), std::back_inserter(siblings), - IsSiblingFusionCandidate); + [&](const HloInstruction* instr) { + return IsSiblingFusionCandidate(instr, device_info_); + }); // Sort the siblings such that multi-output fusion ops occur first, followed // by fusion ops, followed by unfused ops. absl::c_stable_sort(siblings, @@ -418,7 +422,7 @@ absl::StatusOr MultiOutputFusion::DoMultiOutputFusion() { std::vector defs_before_uses = computation_->MakeInstructionPostOrder(); - FusionInfoCache fusion_info_cache; + FusionInfoCache fusion_info_cache(device_info_); // Traverse the HLO in uses-before-defs order. for (auto it = defs_before_uses.rbegin(); it != defs_before_uses.rend(); ++it) { @@ -467,7 +471,7 @@ absl::StatusOr MultiOutputFusion::DoMultiOutputFusion() { } else { input_fusion = computation_->AddInstruction(HloInstruction::CreateFusion( consumer_for_fusion->shape(), - ChooseFusionKind(*producer, *consumer_for_fusion), + ChooseFusionKind(*producer, *consumer_for_fusion, device_info_), consumer_for_fusion)); VLOG(2) << "Fuse producer " << producer->name() << " and its consumer " << consumer_for_fusion->name() << " into " diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc index c0e86818d36fe8..d09798c0c78e37 100644 --- a/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc @@ -161,6 +161,7 @@ class PriorityFusionQueue { mlir_context_(mlir_context), fusion_analysis_cache_(fusion_analysis_cache), fusion_deduplication_cache_(fusion_deduplication_cache), + fusion_info_cache_(*device_info_), triton_softmax_priority_fusion_enabled_( triton_softmax_priority_fusion_enabled) { VLOG(2) << "Running full HLO cost analysis for " << computation_->name(); diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_splitter.cc b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.cc index dce9288888a8a5..bb61f8fd110f8e 100644 --- a/third_party/xla/xla/service/gpu/transforms/reduction_splitter.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.cc @@ -35,6 +35,7 @@ limitations under the License. #include "xla/service/gpu/reduction_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" #include "tsl/platform/statusor.h" namespace xla { @@ -42,14 +43,16 @@ namespace gpu { class ReductionSplitterVisitor : public DfsHloRewriteVisitor { public: - explicit ReductionSplitterVisitor(bool ignore_small_dims) - : ignore_small_dims_(ignore_small_dims) {} + ReductionSplitterVisitor(const se::DeviceDescription &device_description, + bool ignore_small_dims) + : device_description_(device_description), + ignore_small_dims_(ignore_small_dims) {} absl::Status HandleReduce(HloInstruction *reduce) override { VLOG(4) << "Input: " << reduce->ToString(); // Reductions with contiguous dimensions are lowered to efficient code. No // need to split such ops. - if (IsReductionFromOrToContiguousDimensions(*reduce)) { + if (IsReductionFromOrToContiguousDimensions(*reduce, device_description_)) { VLOG(4) << "Reduction with contiguous dimensions. Return."; return absl::OkStatus(); } @@ -124,15 +127,17 @@ class ReductionSplitterVisitor : public DfsHloRewriteVisitor { } private: - bool ignore_small_dims_; + const se::DeviceDescription &device_description_; + const bool ignore_small_dims_; }; absl::StatusOr ReductionSplitter::Run( HloModule *module, const absl::flat_hash_set &execution_threads) { - TF_ASSIGN_OR_RETURN(bool changed, - ReductionSplitterVisitor(ignore_small_dims_) - .RunOnModule(module, execution_threads)); + TF_ASSIGN_OR_RETURN( + bool changed, + ReductionSplitterVisitor(device_description_, ignore_small_dims_) + .RunOnModule(module, execution_threads)); return changed; } diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_splitter.h b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.h index 74a5a1d6f31a71..f5abe00c4014e7 100644 --- a/third_party/xla/xla/service/gpu/transforms/reduction_splitter.h +++ b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { @@ -36,12 +37,14 @@ namespace gpu { // fixpoint to split reduce ops along multiple dimensions. // // Precondition: ReductionDimensionGrouper has been run and adjacent reduce -// dimentsions have been grouped. Reduction layouts have been normalized. +// dimensions have been grouped. Reduction layouts have been normalized. class ReductionSplitter : public HloModulePass { public: - explicit ReductionSplitter(bool ignore_small_dims) - : ignore_small_dims_(ignore_small_dims) {} + ReductionSplitter(const se::DeviceDescription& device_description, + bool ignore_small_dims) + : device_description_(device_description), + ignore_small_dims_(ignore_small_dims) {} absl::string_view name() const override { return "reduction-splitter"; } using HloPassInterface::Run; @@ -50,7 +53,8 @@ class ReductionSplitter : public HloModulePass { const absl::flat_hash_set& execution_threads) override; private: - bool ignore_small_dims_; + const se::DeviceDescription& device_description_; + const bool ignore_small_dims_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc index 4b9f6fb130ed0f..a7f8214e6e7c28 100644 --- a/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" @@ -32,7 +33,26 @@ namespace { namespace m = ::xla::match; -class ReductionSplitterTest : public HloTestBase {}; +auto MakeDeviceDescription() { + stream_executor::DeviceDescription device_description{ + stream_executor::GpuDeviceInfoProto{}}; + device_description.set_threads_per_warp(32); + return device_description; +} + +class ReductionSplitterTest : public HloTestBase { + public: + using HloTestBase::HloTestBase; + + auto MakeReductionSplitter(bool ignore_small_dims) const { + return ReductionSplitter{device_description_, + /*ignore_small_dims=*/ignore_small_dims}; + } + + private: + const stream_executor::DeviceDescription device_description_{ + MakeDeviceDescription()}; +}; TEST_F(ReductionSplitterTest, SplitReductionAtDimensionTwo) { auto module = ParseAndReturnVerifiedModule(R"( @@ -54,8 +74,9 @@ TEST_F(ReductionSplitterTest, SplitReductionAtDimensionTwo) { } )") .value(); - ASSERT_TRUE( - ReductionSplitter(/*ignore_small_dims=*/true).Run(module.get()).value()); + ASSERT_TRUE(MakeReductionSplitter(/*ignore_small_dims=*/true) + .Run(module.get()) + .value()); SCOPED_TRACE(module->ToString()); const HloInstruction* root_reduction = module->entry_computation()->root_instruction(); @@ -86,8 +107,9 @@ TEST_F(ReductionSplitterTest, SplitReductionAtDimensionZero) { } )") .value(); - ASSERT_TRUE( - ReductionSplitter(/*ignore_small_dims=*/false).Run(module.get()).value()); + ASSERT_TRUE(MakeReductionSplitter(/*ignore_small_dims=*/false) + .Run(module.get()) + .value()); SCOPED_TRACE(module->ToString()); const HloInstruction* root_reduction = module->entry_computation()->root_instruction(); @@ -119,10 +141,12 @@ TEST_F(ReductionSplitterTest, DontSplitReductionWithSmallDimensions) { } )") .value(); - EXPECT_FALSE( - ReductionSplitter(/*ignore_small_dims=*/true).Run(module.get()).value()); - EXPECT_TRUE( - ReductionSplitter(/*ignore_small_dims=*/false).Run(module.get()).value()); + EXPECT_FALSE(MakeReductionSplitter(/*ignore_small_dims=*/true) + .Run(module.get()) + .value()); + EXPECT_TRUE(MakeReductionSplitter(/*ignore_small_dims=*/false) + .Run(module.get()) + .value()); } TEST_F(ReductionSplitterTest, DontSplitReductionsWithContiguousDimensions) { @@ -143,8 +167,9 @@ TEST_F(ReductionSplitterTest, DontSplitReductionsWithContiguousDimensions) { } )") .value(); - EXPECT_FALSE( - ReductionSplitter(/*ignore_small_dims=*/false).Run(module.get()).value()); + EXPECT_FALSE(MakeReductionSplitter(/*ignore_small_dims=*/false) + .Run(module.get()) + .value()); } } // namespace diff --git a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc index 4b2e12c1ce36b8..3c42961f95ce73 100644 --- a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc @@ -438,9 +438,9 @@ absl::Status RunFusionPipeline( // transform reductions. reduction_pipeline.AddPass(); reduction_pipeline.AddPass>( + device_info, /*ignore_small_reduce_dims=*/false); - reduction_pipeline.AddPass>( - device_info.gpu_compute_capability()); + reduction_pipeline.AddPass>(device_info); TF_RETURN_IF_ERROR(reduction_pipeline.Run(module).status()); diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc index 68805b1ddc3c0c..866b8a11fd7d6c 100644 --- a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc @@ -30,6 +30,7 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/runtime/thunk.h" +#include "xla/stream_executor/device_description.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" @@ -105,12 +106,14 @@ absl::StatusOr AnnotateStreamAttributesForCopyStart( absl::StatusOr WrapIntoFusionAndAnnotateStreamAttributes( HloInstruction* instruction, int64_t channel_id, - GpuBackendConfig& instr_gpu_config) { + GpuBackendConfig& instr_gpu_config, + const se::DeviceDescription& device_description) { auto* computation = instruction->parent(); auto* module = computation->parent(); auto* fusion_instruction = computation->AddInstruction(HloInstruction::CreateFusion( - instruction->shape(), ChooseFusionKind(*instruction, *instruction), + instruction->shape(), + ChooseFusionKind(*instruction, *instruction, device_description), instruction)); const absl::string_view wrapped_opcode = HloOpcodeString(instruction->opcode()); @@ -206,7 +209,8 @@ absl::StatusOr StreamAttributeAnnotator::Run( instr->opcode() == HloOpcode::kDynamicUpdateSlice)) { TF_ASSIGN_OR_RETURN(bool comp_result, WrapIntoFusionAndAnnotateStreamAttributes( - instr, channel_id, instr_gpu_config.value())); + instr, channel_id, instr_gpu_config.value(), + device_description_)); changed |= comp_result; continue; } diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.h b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.h index 84428f359491fc..74b2002670afca 100644 --- a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.h +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.h @@ -22,6 +22,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" namespace xla::gpu { @@ -45,6 +46,10 @@ namespace xla::gpu { class StreamAttributeAnnotator : public HloModulePass { public: + explicit StreamAttributeAnnotator( + const se::DeviceDescription& device_description) + : device_description_(device_description) {} + absl::string_view name() const override { return "stream-attribute-annotator"; } @@ -53,6 +58,9 @@ class StreamAttributeAnnotator : public HloModulePass { absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; + + private: + const se::DeviceDescription& device_description_; }; } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc index c7d2ca59cff0e9..386099337ff88a 100644 --- a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc @@ -27,13 +27,29 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/tests/filecheck.h" +#include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" namespace xla::gpu { namespace { -using StreamAttributeAnnotatorTest = HloTestBase; +auto MakeDeviceDescription() { + stream_executor::DeviceDescription device_description{ + stream_executor::GpuDeviceInfoProto{}}; + device_description.set_threads_per_warp(32); + return device_description; +} + +class StreamAttributeAnnotatorTest : public HloTestBase { + public: + const se::DeviceDescription& device_description() const { + return device_description_; + } + + private: + const se::DeviceDescription device_description_{MakeDeviceDescription()}; +}; TEST_F(StreamAttributeAnnotatorTest, AllUsersAreAnnotated) { constexpr absl::string_view kHloString = R"( @@ -53,7 +69,7 @@ TEST_F(StreamAttributeAnnotatorTest, AllUsersAreAnnotated) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - StreamAttributeAnnotator attr_annotator; + StreamAttributeAnnotator attr_annotator{device_description()}; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get())); EXPECT_TRUE(changed); @@ -85,7 +101,7 @@ TEST_F(StreamAttributeAnnotatorTest, MultipleStreamsAreCombined) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - StreamAttributeAnnotator attr_annotator; + StreamAttributeAnnotator attr_annotator{device_description()}; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get())); EXPECT_TRUE(changed); @@ -122,7 +138,7 @@ TEST_F(StreamAttributeAnnotatorTest, GTEUserIsAnnotated) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - StreamAttributeAnnotator attr_annotator; + StreamAttributeAnnotator attr_annotator{device_description()}; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get())); EXPECT_TRUE(changed); @@ -154,7 +170,7 @@ TEST_F(StreamAttributeAnnotatorTest, FusionIsAnnotated) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - StreamAttributeAnnotator attr_annotator; + StreamAttributeAnnotator attr_annotator{device_description()}; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get())); EXPECT_TRUE(changed); @@ -195,7 +211,7 @@ TEST_F(StreamAttributeAnnotatorTest, CopyStartIsAnnotated) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - StreamAttributeAnnotator attr_annotator; + StreamAttributeAnnotator attr_annotator{device_description()}; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get())); EXPECT_TRUE(changed); @@ -231,8 +247,9 @@ TEST_F(StreamAttributeAnnotatorTest, DynamicUpdateSliceWrappedAndAnnotated) { ParseAndReturnVerifiedModule(kHloString)); EXPECT_TRUE(module->has_schedule()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, - StreamAttributeAnnotator().Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + StreamAttributeAnnotator{device_description()}.Run(module.get())); EXPECT_TRUE(changed); // Check that the dynamic-update-slice instruction is wrapped in a fusion @@ -294,8 +311,9 @@ TEST_F(StreamAttributeAnnotatorTest, DynamicSliceWrappedAndAnnotated) { ParseAndReturnVerifiedModule(kHloString)); EXPECT_TRUE(module->has_schedule()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, - StreamAttributeAnnotator().Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + StreamAttributeAnnotator{device_description()}.Run(module.get())); EXPECT_TRUE(changed); // Check that the dynamic-slice instruction is wrapped in a fusion diff --git a/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.cc index fb023fc8cc693f..a76c3276c57f20 100644 --- a/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.cc @@ -71,8 +71,9 @@ bool IsMinMaxReduction(HloReduceInstruction *reduce) { class ReductionRewriterVisitor : public DfsHloRewriteVisitor { public: - explicit ReductionRewriterVisitor(se::GpuComputeCapability gpu_version) - : gpu_version_(gpu_version) {} + explicit ReductionRewriterVisitor( + const se::DeviceDescription &device_description) + : device_description_(device_description) {} absl::Status HandleReduce(HloInstruction *hlo) override { auto *reduce = Cast(hlo); @@ -84,7 +85,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { } ReductionDimensions reduction_dims = GetReductionKindAndContiguousComponents(*hlo); - if (ReductionIsRaceFree(config, reduction_dims)) { + if (ReductionIsRaceFree(config, reduction_dims, device_description_)) { VLOG(3) << "Base case: dimensions fit"; return absl::OkStatus(); } @@ -121,7 +122,8 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { << " since min/max operations are associative"; return false; } - if (!IsReductionFromOrToContiguousDimensions(*reduce)) { + if (!IsReductionFromOrToContiguousDimensions(*reduce, + device_description_)) { VLOG(3) << "Is not a reduction from or to contiguous dimensions"; return false; } @@ -136,7 +138,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { uint64_t n, int64_t race_free_bound, bool is_row_reduction) { - CHECK(k1 >= k2); + CHECK_GE(k1, k2); // Keep inner reduction as race free. if (k1 > race_free_bound) { return false; @@ -201,7 +203,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { uint64_t k2 = static_cast(std::floor(std::sqrt(reduced_dim_size))); int64_t race_free_bound = ReductionDimensionRaceFreeBound( - reduce->GetModule()->config(), reduction_dims); + reduce->GetModule()->config(), reduction_dims, device_description_); if (k2 > race_free_bound) { // This means we need more than one split. It is best to limit the n/k // dimension to the maximum size that doesn't require further splitting. @@ -371,7 +373,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { return ReplaceWithNewInstruction(hlo, std::move(out)); } - se::GpuComputeCapability gpu_version_; + const se::DeviceDescription &device_description_; }; absl::StatusOr TreeReductionRewriter::Run( @@ -379,7 +381,7 @@ absl::StatusOr TreeReductionRewriter::Run( const absl::flat_hash_set &execution_threads) { VLOG(5) << "Rewriter input: " << module->ToString(); TF_ASSIGN_OR_RETURN(bool changed, - ReductionRewriterVisitor(gpu_version_) + ReductionRewriterVisitor(device_description_) .RunOnModule(module, execution_threads)); VLOG(5) << "Rewriter output: " << module->ToString(); return changed; diff --git a/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.h b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.h index 4002ca94d585f2..864965836910db 100644 --- a/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.h @@ -75,8 +75,9 @@ namespace gpu { // class TreeReductionRewriter : public HloModulePass { public: - explicit TreeReductionRewriter(se::GpuComputeCapability gpu_version) - : gpu_version_(gpu_version) {} + explicit TreeReductionRewriter( + const se::DeviceDescription& device_description) + : device_description_(device_description) {} ~TreeReductionRewriter() override = default; absl::string_view name() const override { return "tree-reduction-rewriter"; } @@ -87,7 +88,7 @@ class TreeReductionRewriter : public HloModulePass { const absl::flat_hash_set& execution_threads) override; private: - se::GpuComputeCapability gpu_version_; + const se::DeviceDescription& device_description_; }; } // end namespace gpu diff --git a/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter_test.cc index 91f4481a202885..c5f72969346b86 100644 --- a/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter_test.cc @@ -31,16 +31,11 @@ class TreeReductionRewriterTest : public HloTestBase { void CheckTreeRewriter(absl::string_view hlo, std::optional expected) { #if GOOGLE_CUDA + stream_executor::DeviceDescription device_description{ + stream_executor::GpuDeviceInfoProto{}}; + device_description.set_threads_per_warp(32); RunAndFilecheckHloRewrite( - hlo, -#if TENSORFLOW_USE_ROCM - gpu::TreeReductionRewriter{se::RocmComputeCapability { - "908" - }}, -#else - gpu::TreeReductionRewriter{se::CudaComputeCapability{8, 1}}, -#endif - expected); + hlo, gpu::TreeReductionRewriter{device_description}, expected); #elif TENSORFLOW_USE_ROCM RunAndFilecheckHloRewrite( hlo, gpu::GpuTreeReductionRewriter{se::RocmComputeCapability{"908"}}, diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc b/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc index 095cd94dc60ac7..f79bcce74bb335 100644 --- a/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc +++ b/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc @@ -119,7 +119,8 @@ class GpuOptProvider : public OptProvider { xla::gpu::CompileModuleToLlvmIr( optimized_module, &llvm_context, gpu_compiler->GetTargetTriple(), gpu_compiler->GetDataLayout(), platform->Name(), platform->id(), - target_config.device_description, gpu_compiler->GetCanShareBuffer(), + target_config.device_description, + gpu_compiler->GetCanShareBuffer(target_config.device_description), gpu_compiler->BufferSizeBytesFunction())); return llvm_ir::DumpToString(results.llvm_module.get()); }