Skip to content

Commit 00984de

Browse files
tianleiwuankitm3k
authored andcommitted
[CUDA] Build nhwc ops by default (microsoft#22648)
### Description * Build cuda nhwc ops by default. * Deprecate `--enable_cuda_nhwc_ops` in build.py and add `--disable_cuda_nhwc_ops` option Note that it requires cuDNN 9.x. If you build with cuDNN 8, NHWC ops will be disabled automatically. ### Motivation and Context In general, NHWC is faster than NCHW for convolution in Nvidia GPUs with Tensor Cores, and this could improve performance for vision models. This is the first step to prefer NHWC for CUDA in 1.21 release. Next step is to do some tests on popular vision models. If it help in most models and devices, set `prefer_nhwc=1` as default cuda provider option.
1 parent 3059be4 commit 00984de

File tree

8 files changed

+85
-31
lines changed

8 files changed

+85
-31
lines changed

Diff for: cmake/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ option(onnxruntime_USE_CUDA "Build with CUDA support" OFF)
8686
# use. If you hit any problem with that, please do not report it to GTest. Turn OFF the following build option instead.
8787
cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS" OFF)
8888

89-
option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" OFF)
89+
cmake_dependent_option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" ON "onnxruntime_USE_CUDA" OFF)
9090
option(onnxruntime_CUDA_MINIMAL "Build CUDA without any operations apart from memcpy ops. Usefuel for a very minial TRT build" OFF)
9191
option(onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO "When building with CUDA support, generate device code line number information." OFF)
9292
option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF)

Diff for: dockerfiles/Dockerfile.cuda

-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ RUN cd /code \
5656
--build_shared_lib --skip_tests \
5757
--config Release --build_wheel --update --build --parallel \
5858
--cmake_generator Ninja \
59-
--enable_cuda_nhwc_ops \
6059
--cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) "CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}" onnxruntime_BUILD_UNIT_TESTS=OFF
6160

6261
# Start second stage to copy the build artifacts

Diff for: docs/OperatorKernels.md

+29
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,35 @@ Do not modify directly.*
925925
|WhisperBeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *in* decoder_input_ids:**I**<br> *in* logits_processor:**I**<br> *in* cross_qk_layer_head:**I**<br> *in* extra_decoding_ids:**I**<br> *in* temperature:**T**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**<br> *out* cross_qk:**V**<br> *out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)|
926926
| |
927927
| |
928+
|**Operator Domain:** *com.ms.internal.nhwc*||||
929+
|AveragePool|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(float), tensor(float16)|
930+
|||10|**T** = tensor(float), tensor(float16)|
931+
|||[7, 9]|**T** = tensor(float), tensor(float16)|
932+
|BatchNormalization|*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* input_mean:**U**<br> *in* input_var:**U**<br> *out* Y:**T**<br> *out* running_mean:**U**<br> *out* running_var:**U**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* mean:**T**<br> *in* var:**T**<br> *out* Y:**T**<br> *out* mean:**T**<br> *out* var:**T**<br> *out* saved_mean:**T**<br> *out* saved_var:**T**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T1**<br> *in* B:**T1**<br> *in* input_mean:**T2**<br> *in* input_var:**T2**<br> *out* Y:**T**<br> *out* running_mean:**T2**<br> *out* running_var:**T2**|15+|**T** = tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(double), tensor(float), tensor(float16)|
933+
|||14|**T** = tensor(double), tensor(float), tensor(float16)<br/> **U** = tensor(double), tensor(float), tensor(float16)|
934+
|||[9, 13]|**T** = tensor(double), tensor(float), tensor(float16)|
935+
|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
936+
|Conv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|11+|**T** = tensor(float), tensor(float16)|
937+
|||[1, 10]|**T** = tensor(float), tensor(float16)|
938+
|ConvTranspose|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|11+|**T** = tensor(float), tensor(float16)|
939+
|||[1, 10]|**T** = tensor(float), tensor(float16)|
940+
|DepthToSpace|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
941+
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
942+
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
943+
|GlobalAveragePool|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
944+
|GlobalMaxPool|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
945+
|GridSample|*in* X:**T1**<br> *in* grid:**T2**<br> *out* Y:**T1**|16+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
946+
|LRN|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
947+
|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
948+
|MaxPool|*in* X:**T**<br> *out* Y:**T**<br><br>or<br><br>*in* X:**T**<br> *out* Y:**T**<br> *out* Indices:**I**|12+|**I** = tensor(int64)<br/> **T** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)|
949+
|||11|**I** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
950+
|||10|**I** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
951+
|||[8, 9]|**I** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
952+
|||[1, 7]|**T** = tensor(float), tensor(float16)|
953+
|SpaceToDepth|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
954+
|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
955+
| |
956+
| |
928957

929958

930959
<a name="dmlexecutionprovider"/>

Diff for: onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh

-1
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ build_onnxruntime_gpu_for_profiling() {
191191
--build_wheel --skip_tests \
192192
--cmake_generator Ninja \
193193
--compile_no_warning_as_error \
194-
--enable_cuda_nhwc_ops \
195194
--cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=$CUDA_ARCH \
196195
--cmake_extra_defines onnxruntime_ENABLE_NVTX_PROFILE=ON \
197196
--enable_cuda_line_info

Diff for: onnxruntime/test/providers/cpu/nn/conv_op_test.cc

+17-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
3-
3+
#include "core/graph/constants.h"
44
#include "gtest/gtest.h"
55
#include "test/providers/provider_test_utils.h"
6+
67
using namespace std;
78
namespace onnxruntime {
89
namespace test {
@@ -28,7 +29,8 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes,
2829
optional<float> epsilon = optional<float>(),
2930
OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess,
3031
const std::string& err_str = "",
31-
int opset = 7) {
32+
int opset = 7,
33+
bool exclude_cuda_nhwc = false) {
3234
OpTester test("Conv", opset);
3335
test.AddAttribute("group", attributes.group);
3436
test.AddAttribute("kernel_shape", attributes.kernel_shape);
@@ -65,6 +67,12 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes,
6567
// Disable TensorRT because weight as input is not supported
6668
excluded_providers.insert(kTensorrtExecutionProvider);
6769

70+
if (exclude_cuda_nhwc) {
71+
#ifdef ENABLE_CUDA_NHWC_OPS
72+
excluded_providers.insert(kCudaNHWCExecutionProvider);
73+
#endif
74+
}
75+
6876
// QNN SDK 2.10.0 has a bug that breaks support for dynamic bias inputs.
6977
excluded_providers.insert(kQnnExecutionProvider);
7078

@@ -197,10 +205,15 @@ TEST(ConvTest, Conv1D_Bias) {
197205
// as TF32 has a 10 bit mantissa.
198206
float epsilon = 1.1e-5f;
199207

200-
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, false, epsilon);
208+
// This case is not supported by cuDNN frontend, and the fallback (legacy code) requires weight to 4D tensor for NHWC.
209+
constexpr bool exclude_cuda_nhwc = true;
210+
211+
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, false, epsilon,
212+
OpTester::ExpectResult::kExpectSuccess, "", 10, exclude_cuda_nhwc);
201213

202214
// CoreML EP requires weight to be an initializer
203-
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true, epsilon);
215+
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true, epsilon,
216+
OpTester::ExpectResult::kExpectSuccess, "", 10, exclude_cuda_nhwc);
204217
}
205218

206219
// Conv47

Diff for: tools/ci_build/build.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import shutil
1414
import subprocess
1515
import sys
16+
import warnings
1617
from pathlib import Path
1718

1819

@@ -253,7 +254,12 @@ def convert_arg_line_to_args(self, arg_line):
253254
"--cudnn_home is not specified.",
254255
)
255256
parser.add_argument("--enable_cuda_line_info", action="store_true", help="Enable CUDA line info.")
256-
parser.add_argument("--enable_cuda_nhwc_ops", action="store_true", help="Enable CUDA NHWC ops in build.")
257+
258+
parser.add_argument(
259+
"--enable_cuda_nhwc_ops", action="store_true", help="Deprecated; default to enable CUDA NHWC ops in build."
260+
)
261+
262+
parser.add_argument("--disable_cuda_nhwc_ops", action="store_true", help="Disable CUDA NHWC ops in build.")
257263

258264
# Python bindings
259265
parser.add_argument("--enable_pybind", action="store_true", help="Enable Python Bindings.")
@@ -793,6 +799,11 @@ def convert_arg_line_to_args(self, arg_line):
793799
if args.cmake_generator is None and is_windows():
794800
args.cmake_generator = "Ninja" if args.build_wasm else "Visual Studio 17 2022"
795801

802+
if args.enable_cuda_nhwc_ops:
803+
warnings.warn(
804+
"The argument '--enable_cuda_nhwc_ops' is deprecated and is default to True. ", DeprecationWarning
805+
)
806+
796807
return args
797808

798809

@@ -1074,7 +1085,7 @@ def generate_build_tree(
10741085
"-Donnxruntime_USE_MPI=" + ("ON" if args.use_mpi else "OFF"),
10751086
"-Donnxruntime_ENABLE_MEMORY_PROFILE=" + ("ON" if args.enable_memory_profile else "OFF"),
10761087
"-Donnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO=" + ("ON" if args.enable_cuda_line_info else "OFF"),
1077-
"-Donnxruntime_USE_CUDA_NHWC_OPS=" + ("ON" if args.enable_cuda_nhwc_ops else "OFF"),
1088+
"-Donnxruntime_USE_CUDA_NHWC_OPS=" + ("ON" if args.use_cuda and not args.disable_cuda_nhwc_ops else "OFF"),
10781089
"-Donnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB=" + ("ON" if args.build_wasm_static_lib else "OFF"),
10791090
"-Donnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING="
10801091
+ ("OFF" if args.disable_wasm_exception_catching else "ON"),

Diff for: tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ stages:
123123
--parallel \
124124
--build_wheel \
125125
--enable_onnx_tests --use_cuda --cuda_version=11.8 --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8 \
126-
--enable_cuda_profiling --enable_cuda_nhwc_ops \
126+
--enable_cuda_profiling \
127127
--enable_pybind --build_java \
128128
--use_cache \
129129
--cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=75;86' ; \

Diff for: tools/ci_build/github/linux/build_cuda_ci.sh

+24-21
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,31 @@ set -ex
33
#Every cuda container has this $CUDA_VERSION env var set.
44
SHORT_CUDA_VERSION=$(echo $CUDA_VERSION | sed 's/\([[:digit:]]\+\.[[:digit:]]\+\)\.[[:digit:]]\+/\1/')
55

6-
BUILD_ARGS=('--config' 'Release' '--update' '--build'
7-
'--skip_submodule_sync'
8-
'--build_shared_lib'
9-
'--parallel' '--use_binskim_compliant_compile_flags'
10-
'--build_wheel'
11-
'--enable_onnx_tests'
12-
'--use_cuda'
13-
"--cuda_version=$SHORT_CUDA_VERSION"
14-
"--cuda_home=/usr/local/cuda-$SHORT_CUDA_VERSION"
15-
"--cudnn_home=/usr/local/cuda-$SHORT_CUDA_VERSION"
16-
"--enable_cuda_profiling"
17-
"--enable_cuda_nhwc_ops"
18-
"--enable_pybind"
19-
"--build_java"
20-
"--cmake_extra_defines"
21-
"CMAKE_CUDA_ARCHITECTURES=75"
22-
"onnxruntime_BUILD_UNIT_TESTS=ON"
23-
"onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON")
6+
BUILD_ARGS=('--config'
7+
'Release'
8+
'--update'
9+
'--build'
10+
'--skip_submodule_sync'
11+
'--build_shared_lib'
12+
'--parallel'
13+
'--use_binskim_compliant_compile_flags'
14+
'--build_wheel'
15+
'--enable_onnx_tests'
16+
'--use_cuda'
17+
"--cuda_version=$SHORT_CUDA_VERSION"
18+
"--cuda_home=/usr/local/cuda-$SHORT_CUDA_VERSION"
19+
"--cudnn_home=/usr/local/cuda-$SHORT_CUDA_VERSION"
20+
"--enable_cuda_profiling"
21+
"--enable_pybind"
22+
"--build_java"
23+
"--cmake_extra_defines"
24+
"CMAKE_CUDA_ARCHITECTURES=75"
25+
"onnxruntime_BUILD_UNIT_TESTS=ON"
26+
"onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON")
2427
if [ -x "$(command -v ninja)" ]; then
2528
BUILD_ARGS+=('--cmake_generator' 'Ninja')
2629
fi
27-
30+
2831
if [ -d /build ]; then
2932
BUILD_ARGS+=('--build_dir' '/build')
3033
else
@@ -40,7 +43,7 @@ if [ -f /opt/python/cp312-cp312/bin/python3 ]; then
4043
else
4144
python3 tools/ci_build/build.py "${BUILD_ARGS[@]}"
4245
fi
43-
if [ -x "$(command -v ccache)" ]; then
44-
ccache -sv
46+
if [ -x "$(command -v ccache)" ]; then
47+
ccache -sv
4548
ccache -z
4649
fi

0 commit comments

Comments
 (0)