@@ -380,7 +380,7 @@ absl::StatusOr<std::unique_ptr<HloModule>> TritonGemmAutotuneExtractor(
380
380
381
381
// If the priority fusion pass above skipped some instructions, turn them
382
382
// into fusions.
383
- FusionWrapper fusion_wrapper;
383
+ FusionWrapper fusion_wrapper (gpu_device_info) ;
384
384
TF_RETURN_IF_ERROR (fusion_wrapper.Run (new_module.get ()).status ());
385
385
}
386
386
return new_module;
@@ -528,7 +528,7 @@ absl::Status DumpAutotunedFusion(const AutotuneConfig& autotune_config,
528
528
TritonGemmConfig::FromProto (result.triton ()));
529
529
}
530
530
const se::DeviceDescription& device_desc =
531
- autotune_config.GetExecutor ()-> GetDeviceDescription ();
531
+ autotune_config.GetDeviceDescription ();
532
532
TF_ASSIGN_OR_RETURN (
533
533
std::unique_ptr<HloModule> module,
534
534
util.ExtractModule ([&](const DebugOptions& debug_opts) {
@@ -693,12 +693,12 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) {
693
693
// a sufficient number of thread block programs to occupy all available cores.
694
694
// Around 5 full waves completely avoid the need for split-K.
695
695
// n_tiles = split_k * (M * N) / (block_m * block_n)
696
- const int kCoreCount =
697
- !config_.IsDeviceless ()
698
- ? config_.GetExecutor ()->GetDeviceDescription ().core_count ()
699
- : 100 ; // some sensible default
696
+ const int kCoreCount = config_.GetDeviceDescription ().core_count ();
697
+ CHECK_GE (kCoreCount , 1 );
700
698
const int64_t kSufficientNumberOfTiles = kMaxWavesForSplitK * kCoreCount ;
701
699
const int64_t result_size = ShapeUtil::ElementsIn (dot.shape ());
700
+ const int64_t threads_per_warp =
701
+ config_.GetDeviceDescription ().threads_per_warp ();
702
702
703
703
// Triton configurations are adjusted and deduplicated.
704
704
absl::flat_hash_set<TritonGemmConfig> added;
@@ -735,7 +735,7 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) {
735
735
2 * std::max (kMinTileSize , kLdmatrixGranularity / minBitWidth));
736
736
int meta_elements = config.block_m * config.block_k / 16 ;
737
737
config.num_warps =
738
- std::min<int >(config.num_warps , meta_elements / WarpSize () );
738
+ std::min<int >(config.num_warps , meta_elements / threads_per_warp );
739
739
}
740
740
741
741
if (added.insert (config).second ) {
@@ -783,13 +783,13 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util,
783
783
-> absl::StatusOr<bool > {
784
784
std::unique_ptr<Executable> executable;
785
785
if (std::holds_alternative<TritonGemmConfig>(config)) {
786
- TF_ASSIGN_OR_RETURN (
787
- executable, compile_util.Compile ([&](const DebugOptions& opts) {
788
- return TritonGemmAutotuneExtractor (
789
- std::get<TritonGemmConfig>(config),
790
- config_.GetExecutor ()-> GetDeviceDescription (), fusion, opts,
791
- allow_filtering_kernels_spilling_registers);
792
- }));
786
+ TF_ASSIGN_OR_RETURN (executable,
787
+ compile_util.Compile ([&](const DebugOptions& opts) {
788
+ return TritonGemmAutotuneExtractor (
789
+ std::get<TritonGemmConfig>(config),
790
+ config_.GetDeviceDescription (), fusion, opts,
791
+ allow_filtering_kernels_spilling_registers);
792
+ }));
793
793
} else if (std::holds_alternative<CuDnnConfig>(config)) {
794
794
executable =
795
795
compile_util
@@ -801,9 +801,9 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util,
801
801
} else if (std::holds_alternative<CuBlasConfig>(config)) {
802
802
TF_ASSIGN_OR_RETURN (
803
803
executable, compile_util.Compile ([&](const DebugOptions& opts) {
804
- return CublasGemmAutotuneExtractor (
805
- config_, config_.GetExecutor ()-> GetDeviceDescription (),
806
- toolkit_version_, fusion, opts);
804
+ return CublasGemmAutotuneExtractor (config_,
805
+ config_.GetDeviceDescription (),
806
+ toolkit_version_, fusion, opts);
807
807
}));
808
808
} else {
809
809
LOG (FATAL) << " Unsupported config type: " << config.index ();
@@ -1005,6 +1005,9 @@ GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const {
1005
1005
bool tune_ctas =
1006
1006
debug_options_.xla_gpu_enable_triton_hopper () && cc.IsAtLeastHopper ();
1007
1007
1008
+ const int64_t threads_per_warp =
1009
+ config_.GetDeviceDescription ().threads_per_warp ();
1010
+
1008
1011
for (int num_stages : kNumStages ) {
1009
1012
// Volta doesn't support num_stages > 2.
1010
1013
if (!cc.IsAtLeastAmpere () && num_stages > 2 ) {
@@ -1017,7 +1020,7 @@ GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const {
1017
1020
const int tile_rhs = tile_k * tile_n;
1018
1021
for (int num_warps : kNumWarps ) {
1019
1022
// Each thread should read at least one input element.
1020
- if (num_warps * WarpSize () > std::min (tile_lhs, tile_rhs)) {
1023
+ if (num_warps * threads_per_warp > std::min (tile_lhs, tile_rhs)) {
1021
1024
break ;
1022
1025
}
1023
1026
for (int split_k : kSplitK ) {
0 commit comments