Skip to content

Commit acd66a2

Browse files
committed
[ROCm][TunableOp] Improve identification of fastest solution (pytorch#144942)
This PR addresses some stability issues with identifying the fastest solution on AMD GPUs, particularly the MI300. Changes include: - An improved timer, StreamTimerNoSync - More aggressive skipping of slow solutions - Additional statistics that can be used for diagnostics PYTORCH_TUNABLEOP_VERBOSE=3 Pull Request resolved: pytorch#144942 Approved by: https://github.com/jeffdaily (cherry picked from commit fd0cd6a)
1 parent 640334b commit acd66a2

File tree

6 files changed

+172
-26
lines changed

6 files changed

+172
-26
lines changed

Diff for: aten/src/ATen/cuda/tunable/GemmHipblaslt.h

-7
Original file line numberDiff line numberDiff line change
@@ -565,13 +565,6 @@ auto GetHipBlasLtTypeStringAndOps() {
565565
heuristic_result));
566566
TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle));
567567

568-
// Sort heuristic_result by algo index to make sure the order of returned algos is deterministic.
569-
std::sort(heuristic_result.begin(),
570-
heuristic_result.end(),
571-
[](hipblasLtMatmulHeuristicResult_t& a, hipblasLtMatmulHeuristicResult_t& b) {
572-
return hipblaslt_ext::getIndexFromAlgo(a.algo) < hipblaslt_ext::getIndexFromAlgo(b.algo);
573-
});
574-
575568
int returned_algo_count = heuristic_result.size();
576569
std::vector<std::pair<std::string, std::unique_ptr<Callable<ParamsT>>>> ret;
577570
for (int i = 0; i < returned_algo_count; i++) {

Diff for: aten/src/ATen/cuda/tunable/GemmRocblas.h

-3
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,6 @@ auto GetRocBlasGemmTypeStringAndOps() {
191191
rocblas_gemm_flags_none,
192192
solutions.data(),
193193
&solution_size));
194-
// Sort the solutions in ascending order to make the solution vector deterministic across runs
195-
std::sort(solutions.begin(), solutions.end());
196-
197194
std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmParams<T>>>>> ret;
198195
for (size_t i = 0; i < solutions.size(); ++i) {
199196
auto callable = std::make_unique<RocblasGemmOp<T>>(solutions[i]);

Diff for: aten/src/ATen/cuda/tunable/StreamTimer.cpp

+24-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ StreamTimer::~StreamTimer() {
2424
}
2525

2626
void StreamTimer::Start() {
27-
AT_CUDA_CHECK(cudaDeviceSynchronize());
27+
AT_CUDA_CHECK(cudaEventSynchronize(start_));
2828
AT_CUDA_CHECK(cudaEventRecord(start_, at::cuda::getCurrentCUDAStream()));
2929
}
3030

@@ -40,4 +40,27 @@ float StreamTimer::Duration() {
4040
return time;
4141
}
4242

43+
StreamTimerNoSync::StreamTimerNoSync() {
44+
AT_CUDA_CHECK(cudaEventCreate(&start_));
45+
AT_CUDA_CHECK(cudaEventCreate(&end_));
46+
}
47+
48+
StreamTimerNoSync::~StreamTimerNoSync() = default;
49+
50+
void StreamTimerNoSync::Start() {
51+
AT_CUDA_CHECK(cudaEventRecord(start_, at::cuda::getCurrentCUDAStream()));
52+
}
53+
54+
void StreamTimerNoSync::End() {
55+
AT_CUDA_CHECK(cudaEventRecord(end_, at::cuda::getCurrentCUDAStream()));
56+
}
57+
58+
float StreamTimerNoSync::Duration() {
59+
auto time = std::numeric_limits<float>::quiet_NaN();
60+
AT_CUDA_CHECK(cudaEventSynchronize(end_));
61+
// time is in ms with a resolution of 1 us
62+
AT_CUDA_CHECK(cudaEventElapsedTime(&time, start_, end_));
63+
return time;
64+
}
65+
4366
} // namespace at::cuda::tunable

Diff for: aten/src/ATen/cuda/tunable/StreamTimer.h

+16
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,20 @@ class StreamTimer : public ITimer {
3131
cudaEvent_t end_;
3232
};
3333

34+
class StreamTimerNoSync : public ITimer {
35+
public:
36+
StreamTimerNoSync();
37+
~StreamTimerNoSync() override;
38+
39+
void Start() override;
40+
41+
void End() override;
42+
43+
float Duration() override;
44+
45+
private:
46+
cudaEvent_t start_{};
47+
cudaEvent_t end_{};
48+
};
49+
3450
} // namespace at::cuda::tunable

Diff for: aten/src/ATen/cuda/tunable/TunableGemm.h

+12-5
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#include <ATen/cuda/tunable/GemmHipblaslt.h>
1515
#include <ATen/cuda/tunable/GemmRocblas.h>
1616
#endif
17-
#include <ATen/cuda/tunable/StreamTimer.h>
1817
#include <ATen/cuda/tunable/TunableOp.h>
1918
#include <c10/cuda/CUDACachingAllocator.h>
2019
#include <c10/util/Float8_e4m3fn.h>
@@ -190,7 +189,7 @@ inline std::string TypeName(c10::complex<float> v) {
190189
}
191190

192191
template <typename T, BlasOp ALayout, BlasOp BLayout>
193-
class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
192+
class GemmTunableOp : public TunableOp<GemmParams<T>> {
194193
public:
195194
GemmTunableOp() {
196195
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
@@ -215,6 +214,8 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
215214
}
216215
}
217216
#endif
217+
218+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
218219
}
219220

220221
std::string Signature() override {
@@ -223,7 +224,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
223224
};
224225

225226
template <typename T, BlasOp ALayout, BlasOp BLayout>
226-
class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>, StreamTimer> {
227+
class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>> {
227228
public:
228229
GemmAndBiasTunableOp() {
229230
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
@@ -241,6 +242,8 @@ class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>, StreamTimer>
241242
}
242243
}
243244
#endif
245+
246+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
244247
}
245248

246249
std::string Signature() override {
@@ -249,7 +252,7 @@ class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>, StreamTimer>
249252
};
250253

251254
template <typename T, BlasOp ALayout, BlasOp BLayout>
252-
class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>, StreamTimer> {
255+
class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>> {
253256
public:
254257
GemmStridedBatchedTunableOp() {
255258
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
@@ -274,6 +277,8 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
274277
}
275278
}
276279
#endif
280+
281+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
277282
}
278283

279284
std::string Signature() override {
@@ -282,7 +287,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
282287
};
283288

284289
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
285-
class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer> {
290+
class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>> {
286291
public:
287292
ScaledGemmTunableOp() {
288293
this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
@@ -292,6 +297,8 @@ class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer>
292297
this->RegisterOp(std::move(name), std::move(op));
293298
}
294299
#endif
300+
301+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
295302
}
296303

297304
std::string Signature() override {

Diff for: aten/src/ATen/cuda/tunable/TunableOp.h

+120-10
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#pragma once
1111

1212
#include <ATen/cuda/tunable/Tunable.h>
13+
#include <ATen/cuda/tunable/StreamTimer.h>
1314
#include <ATen/cuda/Sleep.h>
1415
#include <c10/cuda/CUDACachingAllocator.h>
1516

@@ -38,7 +39,57 @@ class Callable {
3839
}
3940
};
4041

41-
template <typename ParamsT, typename TimerT>
42+
namespace {
43+
44+
/** http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance */
45+
46+
class Stats {
47+
public:
48+
Stats() {
49+
_n = 0UL;
50+
_mean = 0.0;
51+
_M2 = 0.0;
52+
_sum = 0.0;
53+
_min = 0.0;
54+
_max = 0.0;
55+
}
56+
57+
void sample_value(const double x) {
58+
double delta = 0;
59+
_sum = _sum + x;
60+
if (0UL == _n) {
61+
_min = x;
62+
_max = x;
63+
}
64+
else {
65+
_min = _min < x ? _min : x;
66+
_max = _max > x ? _max : x;
67+
}
68+
_n = _n + 1UL;
69+
delta = x - _mean;
70+
_mean = _mean + delta/_n;
71+
_M2 = _M2 + delta * (x - _mean);
72+
}
73+
74+
double variance() const {
75+
return _M2/(_n-1);
76+
}
77+
78+
double stddev() const {
79+
return std::sqrt(variance());
80+
}
81+
82+
unsigned long _n;
83+
double _mean;
84+
double _M2;
85+
double _sum;
86+
double _min;
87+
double _max;
88+
};
89+
90+
} // anonymous namespace
91+
92+
template <typename ParamsT>
4293
class TunableOp {
4394
public:
4495
TunableOp() = default;
@@ -99,10 +150,17 @@ class TunableOp {
99150
}
100151
}
101152

102-
static double Profile(Callable<ParamsT> *op, const std::vector<ParamsT*> &param, size_t num_iter, size_t &offset) {
153+
static double ProfileSimple(Callable<ParamsT> *op, const std::vector<ParamsT*> &param, size_t num_iter, size_t &offset) {
103154
TuningContext* ctx = getTuningContext();
104155
bool do_flush = ctx->IsICacheFlushEnabled();
105-
TimerT timer{};
156+
StreamTimerNoSync timer{};
157+
158+
// Small Mandatory Warmup
159+
// Reduces outliers
160+
for (size_t i = 0; i < 2; i++) {
161+
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
162+
}
163+
106164
timer.Start();
107165
for (size_t i = 0; i < num_iter; i++) {
108166
if (do_flush) {
@@ -114,6 +172,32 @@ class TunableOp {
114172
return timer.Duration() / num_iter;
115173
}
116174

175+
static Stats ProfileStats(Callable<ParamsT> *op, const std::vector<ParamsT*> &param, size_t num_iter, size_t &offset) {
176+
TuningContext* ctx = getTuningContext();
177+
bool do_flush = ctx->IsICacheFlushEnabled();
178+
std::vector<StreamTimerNoSync> timer(num_iter);
179+
180+
// Small Mandatory Warmup
181+
// Reduces outliers
182+
for (size_t i = 0; i < 2; i++) {
183+
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
184+
}
185+
186+
for (size_t i = 0; i < num_iter; i++) {
187+
timer[i].Start();
188+
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
189+
timer[i].End();
190+
if (do_flush) {
191+
at::cuda::flush_icache();
192+
}
193+
}
194+
Stats s;
195+
for (size_t i = 0; i < num_iter; i++) {
196+
s.sample_value(timer[i].Duration());
197+
}
198+
return s;
199+
}
200+
117201
protected:
118202
virtual ResultEntry FindFastest(const ParamsT* params) {
119203
TuningContext* ctx = getTuningContext();
@@ -183,14 +267,25 @@ class TunableOp {
183267
}
184268

185269
// collect a small profile
186-
constexpr const int approx_num_iter = 3;
187-
auto approx_duration = Profile(candidate, reusable_params, approx_num_iter, offset);
270+
int approx_num_iter = 3;
271+
auto s = ProfileStats(candidate, reusable_params, approx_num_iter, offset);
272+
double approx_duration = s._mean;
188273
// bail if too slow
189-
if (approx_duration > 2 * min_duration_ms) {
274+
if (approx_duration > 1.5 * min_duration_ms) {
190275
TUNABLE_LOG3("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
191276
continue;
192277
}
193278

279+
// 2nd phase skip, more aggressive
280+
approx_num_iter = 10;
281+
s = ProfileStats(candidate, reusable_params, approx_num_iter, offset);
282+
approx_duration = s._mean;
283+
// bail if too slow
284+
if (approx_duration > 1.15 * min_duration_ms) {
285+
TUNABLE_LOG3("├──2nd skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
286+
continue;
287+
}
288+
194289
// for warmup does user set max duration, max iters, or both?
195290
// warmup is allowed to be skipped by setting either iterations or duration to 0
196291
double max_warmup_duration = ctx->GetMaxWarmupDurationMs();
@@ -237,12 +332,27 @@ class TunableOp {
237332
"instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]);
238333
TUNABLE_LOG3("├──offset at ", offset);
239334
WarmUp(candidate, reusable_params, warmup_iter, offset);
240-
auto duration_ms = Profile(candidate, reusable_params, tuning_iter, offset);
241-
if (duration_ms < min_duration_ms) {
242-
TUNABLE_LOG3("├──found better instance id=", i, ". " , duration_ms, "ms. ", op_names_[i]);
243-
min_duration_ms = duration_ms;
335+
s = ProfileStats(candidate, reusable_params, tuning_iter, offset);
336+
auto s_stddev = s.stddev();
337+
// Assume normal distribution.
338+
// Solution with smallest mean + 2*sigma will be a better solution?
339+
// if ((s._mean + 2*s_stddev) < (min_duration_ms + 2*min_stddev_ms)) {
340+
if (s._mean < min_duration_ms) {
341+
TUNABLE_LOG3("├──found better instance id=", i, ". " , s._mean, "ms. ", op_names_[i],
342+
" min ", s._min,
343+
" max ", s._max,
344+
" mean ", s._mean,
345+
" std ", s_stddev);
346+
min_duration_ms = s._mean;
244347
id_name = op_names_[i];
245348
}
349+
else {
350+
TUNABLE_LOG3("├──found slower instance id=", i, ". " , s._mean, "ms. ", op_names_[i],
351+
" min ", s._min,
352+
" max ", s._max,
353+
" mean ", s._mean,
354+
" std ", s_stddev);
355+
}
246356
}
247357

248358
for (size_t i = 0; i < reusable_params.size(); i++) {

0 commit comments

Comments
 (0)