Skip to content

Commit 3afb38c

Browse files
authored
[CUDA] Add use_tf32 cuda provider option (for FP32 Conv) (#19426)
Follow up of #19357 to apply the use_tf32 option on fp32 cuDNN convolution. When use_tf32 = 0, we will disable TF32 in cuDNN convolution for FP32 inputs. https://docs.nvidia.com/deeplearning/cudnn/api/cudnn-graph-library.html#cudnnmathtype-t **CUDNN_FMA_MATH** - Restricted to only kernels that use FMA instructions. - On pre-NVIDIA A100 GPU devices, CUDNN_DEFAULT_MATH and CUDNN_FMA_MATH have the same behavior: Tensor Core kernels will not be selected. - With NVIDIA Ampere architecture and CUDA toolkit 11, CUDNN_DEFAULT_MATH permits TF32 Tensor Core operation and CUDNN_FMA_MATH does not. - The TF32 behavior for CUDNN_DEFAULT_MATH and the other Tensor Core math types can be explicitly disabled by the environment variable NVIDIA_TF32_OVERRIDE=0.
1 parent e5ce81a commit 3afb38c

File tree

7 files changed

+35
-12
lines changed

7 files changed

+35
-12
lines changed

onnxruntime/core/providers/cuda/nn/conv.cc

+14-3
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,8 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
326326

327327
ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
328328
gsl::narrow_cast<int>(conv_attrs_.group),
329-
CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType<CudaT>()));
329+
CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType<CudaT>(),
330+
UseTF32()));
330331

331332
if (context->InputCount() >= 3) {
332333
const Tensor* B = context->Input<Tensor>(2);
@@ -351,8 +352,13 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
351352

352353
if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) {
353354
// set math type to tensor core before algorithm search
354-
if constexpr (std::is_same<T, MLFloat16>::value)
355+
if constexpr (std::is_same<T, MLFloat16>::value) {
355356
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH));
357+
} else if constexpr (std::is_same<T, float>::value) {
358+
if (!UseTF32()) {
359+
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH));
360+
}
361+
}
356362

357363
cudnnConvolutionFwdAlgoPerf_t perf;
358364
int algo_count = 1;
@@ -399,6 +405,8 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
399405
CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory));
400406
if (std::is_same<T, MLFloat16>::value) {
401407
perf.mathType = CUDNN_TENSOR_OP_MATH;
408+
} else if (std::is_same<T, float>::value && !UseTF32()) {
409+
perf.mathType = CUDNN_FMA_MATH;
402410
} else {
403411
perf.mathType = CUDNN_DEFAULT_MATH;
404412
}
@@ -480,7 +488,8 @@ Status CudnnConvolutionDescriptor::Set(
480488
const gsl::span<const int64_t>& dilations,
481489
int groups,
482490
cudnnConvolutionMode_t mode,
483-
cudnnDataType_t data_type) {
491+
cudnnDataType_t data_type,
492+
bool use_tf32) {
484493
if (!desc_)
485494
CUDNN_RETURN_IF_ERROR(cudnnCreateConvolutionDescriptor(&desc_));
486495

@@ -513,6 +522,8 @@ Status CudnnConvolutionDescriptor::Set(
513522
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_DEFAULT_MATH));
514523
if (data_type == CUDNN_DATA_HALF) {
515524
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_TENSOR_OP_MATH));
525+
} else if (data_type == CUDNN_DATA_FLOAT && !use_tf32) {
526+
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_FMA_MATH));
516527
}
517528

518529
return Status::OK();

onnxruntime/core/providers/cuda/nn/conv.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ class CudnnConvolutionDescriptor final {
2929
const gsl::span<const int64_t>& dilations,
3030
int groups,
3131
cudnnConvolutionMode_t mode,
32-
cudnnDataType_t data_type);
32+
cudnnDataType_t data_type,
33+
bool use_tf32);
3334

3435
operator cudnnConvolutionDescriptor_t() const { return desc_; }
3536

onnxruntime/core/providers/cuda/nn/conv_transpose.cc

+8-2
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
167167
cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION;
168168
ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations,
169169
gsl::narrow_cast<int>(conv_transpose_attrs_.group), mode,
170-
CudnnTensor::GetDataType<CudaT>()));
170+
CudnnTensor::GetDataType<CudaT>(),
171+
UseTF32()));
171172

172173
if (has_bias) {
173174
const auto& b_shape = p.B->Shape();
@@ -187,8 +188,13 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
187188
GetScratchBuffer<void>(AlgoSearchWorkspaceSize, context->GetComputeStream());
188189

189190
// set math type to tensor core before algorithm search
190-
if constexpr (std::is_same<T, MLFloat16>::value)
191+
if constexpr (std::is_same<T, MLFloat16>::value) {
191192
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH));
193+
} else if constexpr (std::is_same<T, float>::value) {
194+
if (!UseTF32()) {
195+
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH));
196+
}
197+
}
192198

193199
cudnnConvolutionBwdDataAlgoPerf_t perf;
194200
int algo_count = 1;

orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ Status ConvGrad<T>::PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor&
114114
ORT_RETURN_IF_ERROR(args_.y_tensor.Set(dy_dims, args_.params.data_type));
115115
ORT_RETURN_IF_ERROR(args_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
116116
gsl::narrow_cast<int>(conv_attrs_.group), CUDNN_CROSS_CORRELATION,
117-
args_.params.data_type));
117+
args_.params.data_type,
118+
UseTF32()));
118119

119120
if (dB) {
120121
const TensorShape& db_shape = dB->Shape();

orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc

+4-2
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,13 @@ bool ConvParamsEqual::operator()(const ConvParams& a, const ConvParams& b) const
233233
}
234234

235235
template <typename T_Perf>
236-
Status AlgoIterator<T_Perf>::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results) {
236+
Status AlgoIterator<T_Perf>::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results, bool use_tf32) {
237237
perf_results.resize(1);
238238
perf_results[0].algo = AlgoSearch<T_Perf>::DEFAULT_ALGO;
239239
if (args.params.data_type == CUDNN_DATA_HALF) {
240240
perf_results[0].mathType = CUDNN_TENSOR_OP_MATH;
241+
} else if (args.params.data_type == CUDNN_DATA_FLOAT && !use_tf32) {
242+
perf_results[0].mathType = CUDNN_FMA_MATH;
241243
} else {
242244
perf_results[0].mathType = CUDNN_DEFAULT_MATH;
243245
}
@@ -256,7 +258,7 @@ Status AlgoIterator<T_Perf>::TryAll(const CUDAExecutionProvider* provider, const
256258

257259
std::vector<T_Perf> perf_results;
258260
ORT_RETURN_IF_ERROR(args_.params.algo_mode == OrtCudnnConvAlgoSearchDefault
259-
? OnlyDefaultAlgorithm(args_, perf_results)
261+
? OnlyDefaultAlgorithm(args_, perf_results, provider->UseTF32())
260262
: AlgoSearch<T_Perf>::FindAlgorithms(args_, provider, allocator, perf_results));
261263
for (auto& algo_perf : perf_results) {
262264
if (f(algo_perf) == Status::OK()) {

orttraining/orttraining/training_ops/cuda/nn/conv_shared.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class AlgoIterator {
7575
Status TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator,
7676
std::function<Status(const T_Perf& perf)> f);
7777

78-
static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results);
78+
static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results, bool use_tf32);
7979

8080
private:
8181
const ConvArgs& args_;

orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc

+4-2
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ Status ConvTransposeGrad<T>::PrepareConvForwardArgs(const Tensor& X, const Tenso
182182
ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type));
183183
ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
184184
gsl::narrow_cast<int>(conv_attrs_.group), CUDNN_CROSS_CORRELATION,
185-
args.params.data_type));
185+
args.params.data_type,
186+
UseTF32()));
186187
}
187188

188189
return Status::OK();
@@ -287,7 +288,8 @@ Status ConvTransposeGrad<T>::PrepareConvBackwardFilterArgs(const Tensor& X, cons
287288
ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type));
288289
ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
289290
gsl::narrow_cast<int>(conv_attrs_.group), CUDNN_CROSS_CORRELATION,
290-
args.params.data_type));
291+
args.params.data_type,
292+
UseTF32()));
291293

292294
if (dB) {
293295
const auto& b_shape = dB->Shape();

0 commit comments

Comments
 (0)