Skip to content

Commit 3c1cff8

Browse files
naromero77amdAMD AMD
authored and
AMD AMD
committed
[release/2.5][ROCm][TunableOp] Improve identification of fastest solution (pytorch#144942) (#2018)
This PR addresses some stability issues with identifying the fastest solution on AMD GPUs, particularly the MI300. Changes include: - An improved timer, StreamTimerNoSync - More aggressive skipping of slow solutions - Additional statistics that can be used for diagnostics PYTORCH_TUNABLEOP_VERBOSE=3 Pull Request resolved: pytorch#144942 Approved by: https://github.com/jeffdaily (cherry picked from commit fd0cd6a)
1 parent 9d0a4a1 commit 3c1cff8

File tree

6 files changed

+172
-26
lines changed

6 files changed

+172
-26
lines changed

aten/src/ATen/cuda/tunable/GemmHipblaslt.h

-7
Original file line numberDiff line numberDiff line change
@@ -625,13 +625,6 @@ auto GetHipBlasLtTypeStringAndOps() {
625625
heuristic_result));
626626
TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle));
627627

628-
// Sort heuristic_result by algo index to make sure the order of returned algos is deterministic.
629-
std::sort(heuristic_result.begin(),
630-
heuristic_result.end(),
631-
[](hipblasLtMatmulHeuristicResult_t& a, hipblasLtMatmulHeuristicResult_t& b) {
632-
return hipblaslt_ext::getIndexFromAlgo(a.algo) < hipblaslt_ext::getIndexFromAlgo(b.algo);
633-
});
634-
635628
int returned_algo_count = heuristic_result.size();
636629
std::vector<std::pair<std::string, std::unique_ptr<Callable<ParamsT>>>> ret;
637630
for (int i = 0; i < returned_algo_count; i++) {

aten/src/ATen/cuda/tunable/GemmRocblas.h

-3
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,6 @@ auto GetRocBlasGemmTypeStringAndOps() {
192192
rocblas_gemm_flags_none,
193193
solutions.data(),
194194
&solution_size));
195-
// Sort the solutions in ascending order to make the solution vector deterministic across runs
196-
std::sort(solutions.begin(), solutions.end());
197-
198195
std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmParams<T>>>>> ret;
199196
for (size_t i = 0; i < solutions.size(); ++i) {
200197
auto callable = std::make_unique<RocblasGemmOp<T>>(solutions[i]);

aten/src/ATen/cuda/tunable/StreamTimer.cpp

+24-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ StreamTimer::StreamTimer() {
2424
StreamTimer::~StreamTimer() = default;
2525

2626
void StreamTimer::Start() {
27-
AT_CUDA_CHECK(cudaDeviceSynchronize());
27+
AT_CUDA_CHECK(cudaEventSynchronize(start_));
2828
AT_CUDA_CHECK(cudaEventRecord(start_, at::cuda::getCurrentCUDAStream()));
2929
}
3030

@@ -40,4 +40,27 @@ float StreamTimer::Duration() {
4040
return time;
4141
}
4242

43+
StreamTimerNoSync::StreamTimerNoSync() {
44+
AT_CUDA_CHECK(cudaEventCreate(&start_));
45+
AT_CUDA_CHECK(cudaEventCreate(&end_));
46+
}
47+
48+
StreamTimerNoSync::~StreamTimerNoSync() = default;
49+
50+
void StreamTimerNoSync::Start() {
51+
AT_CUDA_CHECK(cudaEventRecord(start_, at::cuda::getCurrentCUDAStream()));
52+
}
53+
54+
void StreamTimerNoSync::End() {
55+
AT_CUDA_CHECK(cudaEventRecord(end_, at::cuda::getCurrentCUDAStream()));
56+
}
57+
58+
float StreamTimerNoSync::Duration() {
59+
auto time = std::numeric_limits<float>::quiet_NaN();
60+
AT_CUDA_CHECK(cudaEventSynchronize(end_));
61+
// time is in ms with a resolution of 1 us
62+
AT_CUDA_CHECK(cudaEventElapsedTime(&time, start_, end_));
63+
return time;
64+
}
65+
4366
} // namespace at::cuda::tunable

aten/src/ATen/cuda/tunable/StreamTimer.h

+16
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,20 @@ class StreamTimer : public ITimer {
3131
cudaEvent_t end_{};
3232
};
3333

34+
class StreamTimerNoSync : public ITimer {
35+
public:
36+
StreamTimerNoSync();
37+
~StreamTimerNoSync() override;
38+
39+
void Start() override;
40+
41+
void End() override;
42+
43+
float Duration() override;
44+
45+
private:
46+
cudaEvent_t start_{};
47+
cudaEvent_t end_{};
48+
};
49+
3450
} // namespace at::cuda::tunable

aten/src/ATen/cuda/tunable/TunableGemm.h

+12-5
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#include <ATen/cuda/tunable/GemmHipblaslt.h>
1515
#include <ATen/cuda/tunable/GemmRocblas.h>
1616
#endif
17-
#include <ATen/cuda/tunable/StreamTimer.h>
1817
#include <ATen/cuda/tunable/TunableOp.h>
1918
#include <c10/cuda/CUDACachingAllocator.h>
2019
#include <c10/util/Float8_e4m3fn.h>
@@ -198,7 +197,7 @@ inline const char* TypeName(c10::complex<float> v) {
198197
}
199198

200199
template <typename T, BlasOp ALayout, BlasOp BLayout>
201-
class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
200+
class GemmTunableOp : public TunableOp<GemmParams<T>> {
202201
public:
203202
GemmTunableOp() {
204203
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
@@ -223,6 +222,8 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
223222
}
224223
}
225224
#endif
225+
226+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
226227
}
227228

228229
std::string Signature() override {
@@ -231,7 +232,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
231232
};
232233

233234
template <typename T, BlasOp ALayout, BlasOp BLayout>
234-
class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>, StreamTimer> {
235+
class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>> {
235236
public:
236237
GemmAndBiasTunableOp() {
237238
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
@@ -249,6 +250,8 @@ class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>, StreamTimer>
249250
}
250251
}
251252
#endif
253+
254+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
252255
}
253256

254257
std::string Signature() override {
@@ -257,7 +260,7 @@ class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>, StreamTimer>
257260
};
258261

259262
template <typename T, BlasOp ALayout, BlasOp BLayout>
260-
class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>, StreamTimer> {
263+
class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>> {
261264
public:
262265
GemmStridedBatchedTunableOp() {
263266
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
@@ -282,6 +285,8 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
282285
}
283286
}
284287
#endif
288+
289+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
285290
}
286291

287292
std::string Signature() override {
@@ -290,7 +295,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
290295
};
291296

292297
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
293-
class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer> {
298+
class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>> {
294299
public:
295300
ScaledGemmTunableOp() {
296301
this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
@@ -300,6 +305,8 @@ class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer>
300305
this->RegisterOp(std::move(name), std::move(op));
301306
}
302307
#endif
308+
309+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
303310
}
304311

305312
std::string Signature() override {

aten/src/ATen/cuda/tunable/TunableOp.h

+120-10
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#pragma once
1111

1212
#include <ATen/cuda/tunable/Tunable.h>
13+
#include <ATen/cuda/tunable/StreamTimer.h>
1314
#include <ATen/cuda/Sleep.h>
1415
#include <c10/cuda/CUDACachingAllocator.h>
1516

@@ -35,7 +36,57 @@ class Callable {
3536
}
3637
};
3738

38-
template <typename ParamsT, typename TimerT>
39+
namespace {
40+
41+
/** http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance */
42+
43+
class Stats {
44+
public:
45+
Stats() {
46+
_n = 0UL;
47+
_mean = 0.0;
48+
_M2 = 0.0;
49+
_sum = 0.0;
50+
_min = 0.0;
51+
_max = 0.0;
52+
}
53+
54+
void sample_value(const double x) {
55+
double delta = 0;
56+
_sum = _sum + x;
57+
if (0UL == _n) {
58+
_min = x;
59+
_max = x;
60+
}
61+
else {
62+
_min = _min < x ? _min : x;
63+
_max = _max > x ? _max : x;
64+
}
65+
_n = _n + 1UL;
66+
delta = x - _mean;
67+
_mean = _mean + delta/_n;
68+
_M2 = _M2 + delta * (x - _mean);
69+
}
70+
71+
double variance() const {
72+
return _M2/(_n-1);
73+
}
74+
75+
double stddev() const {
76+
return std::sqrt(variance());
77+
}
78+
79+
unsigned long _n;
80+
double _mean;
81+
double _M2;
82+
double _sum;
83+
double _min;
84+
double _max;
85+
};
86+
87+
} // anonymous namespace
88+
89+
template <typename ParamsT>
3990
class TunableOp {
4091
public:
4192
virtual ~TunableOp() = default;
@@ -100,10 +151,17 @@ class TunableOp {
100151
}
101152
}
102153

103-
static double Profile(Callable<ParamsT> *op, const std::vector<ParamsT*> &param, size_t num_iter, size_t &offset) {
154+
static double ProfileSimple(Callable<ParamsT> *op, const std::vector<ParamsT*> &param, size_t num_iter, size_t &offset) {
104155
TuningContext* ctx = getTuningContext();
105156
bool do_flush = ctx->IsICacheFlushEnabled();
106-
TimerT timer{};
157+
StreamTimerNoSync timer{};
158+
159+
// Small Mandatory Warmup
160+
// Reduces outliers
161+
for (size_t i = 0; i < 2; i++) {
162+
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
163+
}
164+
107165
timer.Start();
108166
for (size_t i = 0; i < num_iter; i++) {
109167
if (do_flush) {
@@ -115,6 +173,32 @@ class TunableOp {
115173
return timer.Duration() / num_iter;
116174
}
117175

176+
static Stats ProfileStats(Callable<ParamsT> *op, const std::vector<ParamsT*> &param, size_t num_iter, size_t &offset) {
177+
TuningContext* ctx = getTuningContext();
178+
bool do_flush = ctx->IsICacheFlushEnabled();
179+
std::vector<StreamTimerNoSync> timer(num_iter);
180+
181+
// Small Mandatory Warmup
182+
// Reduces outliers
183+
for (size_t i = 0; i < 2; i++) {
184+
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
185+
}
186+
187+
for (size_t i = 0; i < num_iter; i++) {
188+
timer[i].Start();
189+
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
190+
timer[i].End();
191+
if (do_flush) {
192+
at::cuda::flush_icache();
193+
}
194+
}
195+
Stats s;
196+
for (size_t i = 0; i < num_iter; i++) {
197+
s.sample_value(timer[i].Duration());
198+
}
199+
return s;
200+
}
201+
118202
protected:
119203
virtual ResultEntry FindFastest(const ParamsT* params) {
120204
TuningContext* ctx = getTuningContext();
@@ -184,14 +268,25 @@ class TunableOp {
184268
}
185269

186270
// collect a small profile
187-
constexpr const int approx_num_iter = 3;
188-
auto approx_duration = Profile(candidate, reusable_params, approx_num_iter, offset);
271+
int approx_num_iter = 3;
272+
auto s = ProfileStats(candidate, reusable_params, approx_num_iter, offset);
273+
double approx_duration = s._mean;
189274
// bail if too slow
190-
if (approx_duration > 2 * min_duration_ms) {
275+
if (approx_duration > 1.5 * min_duration_ms) {
191276
TUNABLE_LOG3("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
192277
continue;
193278
}
194279

280+
// 2nd phase skip, more aggressive
281+
approx_num_iter = 10;
282+
s = ProfileStats(candidate, reusable_params, approx_num_iter, offset);
283+
approx_duration = s._mean;
284+
// bail if too slow
285+
if (approx_duration > 1.15 * min_duration_ms) {
286+
TUNABLE_LOG3("├──2nd skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
287+
continue;
288+
}
289+
195290
// for warmup does user set max duration, max iters, or both?
196291
// warmup is allowed to be skipped by setting either iterations or duration to 0
197292
double max_warmup_duration = ctx->GetMaxWarmupDurationMs();
@@ -238,12 +333,27 @@ class TunableOp {
238333
"instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]);
239334
TUNABLE_LOG3("├──offset at ", offset);
240335
WarmUp(candidate, reusable_params, warmup_iter, offset);
241-
auto duration_ms = Profile(candidate, reusable_params, tuning_iter, offset);
242-
if (duration_ms < min_duration_ms) {
243-
TUNABLE_LOG3("├──found better instance id=", i, ". " , duration_ms, "ms. ", op_names_[i]);
244-
min_duration_ms = duration_ms;
336+
s = ProfileStats(candidate, reusable_params, tuning_iter, offset);
337+
auto s_stddev = s.stddev();
338+
// Assume normal distribution.
339+
// Solution with smallest mean + 2*sigma will be a better solution?
340+
// if ((s._mean + 2*s_stddev) < (min_duration_ms + 2*min_stddev_ms)) {
341+
if (s._mean < min_duration_ms) {
342+
TUNABLE_LOG3("├──found better instance id=", i, ". " , s._mean, "ms. ", op_names_[i],
343+
" min ", s._min,
344+
" max ", s._max,
345+
" mean ", s._mean,
346+
" std ", s_stddev);
347+
min_duration_ms = s._mean;
245348
id_name = op_names_[i];
246349
}
350+
else {
351+
TUNABLE_LOG3("├──found slower instance id=", i, ". " , s._mean, "ms. ", op_names_[i],
352+
" min ", s._min,
353+
" max ", s._max,
354+
" mean ", s._mean,
355+
" std ", s_stddev);
356+
}
247357
}
248358

249359
for (size_t i = 0; i < reusable_params.size(); i++) {

0 commit comments

Comments
 (0)