Skip to content

Commit 20d6f58

Browse files
authored
Merge pull request #2302 from pytorch/2.1-staging
Cherry-pick changes from main into release/2.1
2 parents 94d1bdd + adf4e32 commit 20d6f58

File tree

319 files changed

+13939
-2628
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

319 files changed

+13939
-2628
lines changed

.circleci/config.yml

+17-2
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,7 @@ commands:
802802
- store_artifacts:
803803
path: /tmp/testlogs
804804

805-
test-dynamo-models_torch_export:
805+
test-dynamo-models_export:
806806
description: "Test the Dynamo models via torch_export path"
807807
steps:
808808
- run:
@@ -818,6 +818,20 @@ commands:
818818
- store_artifacts:
819819
path: /tmp/testlogs
820820

821+
test-dynamo-export_serde:
822+
description: "Test the export serialize/deserialize functionality for Dynamo models"
823+
steps:
824+
- run:
825+
name: Run Dynamo models and test export serde with TRT compiled modules
826+
command: |
827+
cd tests/py/dynamo/models
828+
pytest test_export_serde.py --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml --ir dynamo
829+
830+
- store_test_results:
831+
path: /tmp/artifacts
832+
- store_artifacts:
833+
path: /tmp/testlogs
834+
821835
test-dynamo-converters:
822836
description: "Test the Dynamo aten converters"
823837
steps:
@@ -1122,7 +1136,8 @@ jobs:
11221136
- test-dynamo-backend
11231137
- test-dynamo-shared_utilities
11241138
- test-dynamo-models_torch_compile
1125-
- test-dynamo-models_torch_export
1139+
- test-dynamo-models_export
1140+
- test-dynamo-export_serde
11261141

11271142
package-x86_64-linux:
11281143
parameters:

.github/workflows/build-test.yml

+36-33
Original file line numberDiff line numberDiff line change
@@ -54,39 +54,40 @@ jobs:
5454
AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }}
5555
AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }}
5656

57-
# tests-py-torchscript-fe:
58-
# name: Test torchscript frontend [Python]
59-
# needs: [generate-matrix, build]
60-
# strategy:
61-
# fail-fast: false
62-
# matrix:
63-
# include:
64-
# - repository: pytorch/tensorrt
65-
# package-name: torch_tensorrt
66-
# pre-script: packaging/pre_build_script.sh
67-
# uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
68-
# with:
69-
# job-name: tests-py-torchscript-fe
70-
# repository: "pytorch/tensorrt"
71-
# ref: ""
72-
# test-infra-repository: pytorch/test-infra
73-
# test-infra-ref: main
74-
# build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
75-
# pre-script: ${{ matrix.pre-script }}
76-
# script: |
77-
# export USE_HOST_DEPS=1
78-
# pushd .
79-
# cd tests/modules
80-
# ${CONDA_RUN} python -m pip install -r requirements.txt
81-
# ${CONDA_RUN} python hub.py
82-
# popd
83-
# pushd .
84-
# cd tests/py/ts
85-
# ${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest --use-deprecated=legacy-resolver
86-
# ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_api_test_results.xml api/
87-
# ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_models_test_results.xml models/
88-
# ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_integrations_test_results.xml integrations/
89-
# popd
57+
tests-py-torchscript-fe:
58+
name: Test torchscript frontend [Python]
59+
needs: [generate-matrix, build]
60+
strategy:
61+
fail-fast: false
62+
matrix:
63+
include:
64+
- repository: pytorch/tensorrt
65+
package-name: torch_tensorrt
66+
pre-script: packaging/pre_build_script.sh
67+
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
68+
with:
69+
job-name: tests-py-torchscript-fe
70+
repository: "pytorch/tensorrt"
71+
ref: ""
72+
test-infra-repository: pytorch/test-infra
73+
test-infra-ref: main
74+
build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
75+
pre-script: ${{ matrix.pre-script }}
76+
script: |
77+
export USE_HOST_DEPS=1
78+
export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH
79+
pushd .
80+
cd tests/modules
81+
${CONDA_RUN} python -m pip install --pre -r requirements.txt --use-deprecated=legacy-resolver
82+
${CONDA_RUN} python hub.py
83+
popd
84+
pushd .
85+
cd tests/py/ts
86+
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest --use-deprecated=legacy-resolver
87+
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_api_test_results.xml api/
88+
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_models_test_results.xml models/
89+
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_integrations_test_results.xml integrations/
90+
popd
9091
9192
tests-py-dynamo-converters:
9293
name: Test dynamo converters [Python]
@@ -140,6 +141,8 @@ jobs:
140141
cd tests/py/dynamo
141142
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest --use-deprecated=legacy-resolver
142143
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
144+
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
145+
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
143146
popd
144147
145148
tests-py-torch-compile-be:

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ repos:
4040
rev: 'v1.4.1'
4141
hooks:
4242
- id: mypy
43-
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^tools|^docs|noxfile.py|setup.py|versions.py"
43+
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py"
4444
- repo: https://github.com/astral-sh/ruff-pre-commit
4545
# Ruff version.
4646
rev: v0.0.278

core/conversion/converters/impl/shuffle.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
2020
auto in_shape = util::toVec(in->getDimensions());
2121
std::vector<int64_t> out_shape;
2222
if (ctx->input_is_dynamic) {
23-
end_dim = (end_dim == -1) ? in_shape.size() - 1 : end_dim;
23+
if (start_dim < 0) {
24+
start_dim = start_dim + in_shape.size();
25+
}
26+
if (end_dim < 0) {
27+
end_dim = end_dim + in_shape.size();
28+
}
2429
int nbDynamicFlattenedDims = 0;
2530
int nbDynamicUnflattenedDims = 0;
2631
for (int i = 0; i < (int)in_shape.size(); i++) {

core/runtime/execute_engine.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ bool is_switch_required(const RTDevice& curr_device, const RTDevice& engine_devi
4343
return false;
4444
}
4545

46-
RTDevice select_rt_device(const RTDevice& engine_device) {
47-
auto new_target_device_opt = get_most_compatible_device(engine_device);
46+
RTDevice select_rt_device(const RTDevice& engine_device, const RTDevice& curr_device) {
47+
auto new_target_device_opt = get_most_compatible_device(engine_device, curr_device);
4848

4949
// REVIEW: THIS DOES NOT LIST DLA PROBABLY, WHICH WE SHOULD
5050
// TODO: I think this logic could be way simpler at execution time since if the tensors arent on the right
@@ -89,7 +89,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
8989

9090
if (is_switch_required(curr_device, compiled_engine->device_info)) {
9191
// Scan through available CUDA devices and set the CUDA device context correctly
92-
RTDevice device = select_rt_device(compiled_engine->device_info);
92+
RTDevice device = select_rt_device(compiled_engine->device_info, curr_device);
9393
set_rt_device(device);
9494

9595
// Target device is new device

core/runtime/runtime.cpp

+22-5
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,16 @@ namespace torch_tensorrt {
77
namespace core {
88
namespace runtime {
99

10-
c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device) {
10+
c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device, const RTDevice& curr_device) {
1111
LOG_DEBUG("Target Device: " << target_device);
1212
auto device_options = find_compatible_devices(target_device);
13+
RTDevice current_device;
14+
if (current_device.id == -1) {
15+
current_device = get_current_device();
16+
} else {
17+
current_device = curr_device;
18+
}
19+
1320
if (device_options.size() == 0) {
1421
return {};
1522
} else if (device_options.size() == 1) {
@@ -21,10 +28,20 @@ c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device
2128
dev_list << "[" << std::endl;
2229
for (auto device : device_options) {
2330
dev_list << " " << device << ',' << std::endl;
24-
if (device.device_name == target_device.device_name && best_match.device_name != target_device.device_name) {
25-
best_match = device;
26-
} else if (device.device_name == target_device.device_name && best_match.device_name == target_device.device_name) {
27-
if (device.id == target_device.id && best_match.id != target_device.id) {
31+
if (device.device_name == target_device.device_name) {
32+
// First priority is selecting a candidate which agrees with the current device ID
33+
// If such a device is found, we can select it and break out of the loop
34+
if (device.id == current_device.id && best_match.id != current_device.id) {
35+
best_match = device;
36+
break;
37+
}
38+
// Second priority is selecting a candidate which agrees with the target device ID
39+
// At deserialization time, the current device and target device may not agree
40+
else if (device.id == target_device.id && best_match.id != target_device.id) {
41+
best_match = device;
42+
}
43+
// If no such GPU ID is found, select the first available candidate GPU
44+
else if (best_match.device_name != target_device.device_name) {
2845
best_match = device;
2946
}
3047
}

core/runtime/runtime.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ typedef enum {
2626
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
2727
} SerializedInfoIndex;
2828

29-
c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device);
29+
c10::optional<RTDevice> get_most_compatible_device(
30+
const RTDevice& target_device,
31+
const RTDevice& curr_device = RTDevice());
3032
std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device);
3133

3234
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine);

core/util/trt_util.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ nvinfer1::Dims squeezeDims(const nvinfer1::Dims& d, int pos, bool use_zeros, boo
216216
// Replace all instances of -1, indicating dynamic dimension
217217
// with 0, indicating copy the dimension from another tensor
218218
// (Generally used for reshape operations)
219-
if (use_zeros && d.d[i] == -1) {
219+
if (use_zeros && d.d[i] == -1 && i < pos) {
220220
dims.d[j] = 0;
221221
// If zeros already exist in the dimensions (empty tensor),
222222
// Replace all instances of 0, indicating empty dimension

cpp/include/torch_tensorrt/torch_tensorrt.h

+2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ class DataType {
6060
enum Value : int8_t {
6161
/// INT64
6262
kLong,
63+
/// FP64
64+
kDouble,
6365
/// FP32
6466
kFloat,
6567
/// FP16

cpp/src/types.cpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ at::ScalarType toAtenDataType(DataType value) {
9797
return at::kInt;
9898
case DataType::kLong:
9999
return at::kLong;
100+
case DataType::kDouble:
101+
return at::kDouble;
100102
case DataType::kBool:
101103
return at::kBool;
102104
case DataType::kFloat:
@@ -119,7 +121,8 @@ nvinfer1::TensorFormat toTRTTensorFormat(TensorFormat value) {
119121

120122
DataType::DataType(c10::ScalarType t) {
121123
TORCHTRT_CHECK(
122-
t == at::kHalf || t == at::kFloat || t == at::kChar || t == at::kLong || t == at::kInt || t == at::kBool,
124+
t == at::kHalf || t == at::kFloat || t == at::kChar || t == at::kLong || t == at::kDouble || t == at::kInt ||
125+
t == at::kBool,
123126
"Data type is unsupported (" << t << ")");
124127
switch (t) {
125128
case at::kHalf:
@@ -134,6 +137,9 @@ DataType::DataType(c10::ScalarType t) {
134137
case at::kLong:
135138
value = DataType::kLong;
136139
break;
140+
case at::kDouble:
141+
value = DataType::kDouble;
142+
break;
137143
case at::kBool:
138144
value = DataType::kBool;
139145
break;

docker/WORKSPACE.ngc

+19-19
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,28 @@ http_archive(
99
sha256 = "778197e26c5fbeb07ac2a2c5ae405b30f6cb7ad1f5510ea6fdac03bded96cc6f",
1010
)
1111

12-
load("@rules_python//python:pip.bzl", "pip_install")
12+
load("@rules_python//python:repositories.bzl", "py_repositories")
13+
14+
py_repositories()
1315

1416
http_archive(
1517
name = "rules_pkg",
18+
sha256 = "8f9ee2dc10c1ae514ee599a8b42ed99fa262b757058f65ad3c384289ff70c4b8",
1619
urls = [
17-
"https://mirror.bazel.build/github.com/bazelbuild/rules_pkg/releases/download/0.4.0/rules_pkg-0.4.0.tar.gz",
18-
"https://github.com/bazelbuild/rules_pkg/releases/download/0.4.0/rules_pkg-0.4.0.tar.gz",
20+
"https://mirror.bazel.build/github.com/bazelbuild/rules_pkg/releases/download/0.9.1/rules_pkg-0.9.1.tar.gz",
21+
"https://github.com/bazelbuild/rules_pkg/releases/download/0.9.1/rules_pkg-0.9.1.tar.gz",
1922
],
20-
sha256 = "038f1caa773a7e35b3663865ffb003169c6a71dc995e39bf4815792f385d837d",
2123
)
24+
2225
load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies")
26+
2327
rules_pkg_dependencies()
2428

25-
git_repository(
29+
http_archive(
2630
name = "googletest",
27-
remote = "https://github.com/google/googletest",
28-
commit = "703bd9caab50b139428cea1aaff9974ebee5742e",
29-
shallow_since = "1570114335 -0400"
31+
sha256 = "755f9a39bc7205f5a0c428e920ddad092c33c8a1b46997def3f1d4a82aded6e1",
32+
strip_prefix = "googletest-5ab508a01f9eb089207ee87fd547d290da39d015",
33+
urls = ["https://github.com/google/googletest/archive/5ab508a01f9eb089207ee87fd547d290da39d015.zip"],
3034
)
3135

3236
# External dependency for torch_tensorrt if you already have precompiled binaries.
@@ -80,17 +84,13 @@ new_local_repository(
8084
#########################################################################
8185
# Testing Dependencies (optional - comment out on aarch64)
8286
#########################################################################
83-
pip_install(
84-
name = "torch_tensorrt_py_deps",
85-
requirements = "//py:requirements.txt",
86-
)
87+
load("@rules_python//python:pip.bzl", "pip_parse")
8788

88-
pip_install(
89-
name = "py_test_deps",
90-
requirements = "//tests/py:requirements.txt",
89+
pip_parse(
90+
name = "devtools_deps",
91+
requirements_lock = "//:requirements-dev.txt",
9192
)
9293

93-
pip_install(
94-
name = "pylinter_deps",
95-
requirements = "//tools/linter:requirements.txt",
96-
)
94+
load("@devtools_deps//:requirements.bzl", "install_deps")
95+
96+
install_deps()

docs/_cpp_api/classtorch__tensorrt_1_1DataType.html

+11-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Class DataType &mdash; Torch-TensorRT v2.0.0.dev0+1fec519 documentation</title>
13+
<title>Class DataType &mdash; Torch-TensorRT v2.2.0.dev0+50ab2c1 documentation</title>
1414

1515

1616

@@ -225,7 +225,7 @@
225225

226226

227227
<div class="version">
228-
v2.0.0.dev0+1fec519
228+
v2.2.0.dev0+50ab2c1
229229
</div>
230230

231231

@@ -269,6 +269,8 @@
269269
<li class="toctree-l1"><a class="reference internal" href="../user_guide/getting_started_with_fx_path.html">Torch-TensorRT (FX Frontend) User Guide</a></li>
270270
<li class="toctree-l1"><a class="reference internal" href="../user_guide/ptq.html">Post Training Quantization (PTQ)</a></li>
271271
<li class="toctree-l1"><a class="reference internal" href="../user_guide/runtime.html">Deploying Torch-TensorRT Programs</a></li>
272+
<li class="toctree-l1"><a class="reference internal" href="../user_guide/saving_models.html">Saving models compiled with Torch-TensorRT</a></li>
273+
<li class="toctree-l1"><a class="reference internal" href="../user_guide/dynamic_shapes.html">Dynamic shapes with Torch-TensorRT</a></li>
272274
<li class="toctree-l1"><a class="reference internal" href="../user_guide/use_from_pytorch.html">Using Torch-TensorRT Directly From PyTorch</a></li>
273275
<li class="toctree-l1"><a class="reference internal" href="../user_guide/using_dla.html">DLA</a></li>
274276
</ul>
@@ -304,6 +306,7 @@
304306
<ul>
305307
<li class="toctree-l1"><a class="reference internal" href="../contributors/system_overview.html">System Overview</a></li>
306308
<li class="toctree-l1"><a class="reference internal" href="../contributors/writing_converters.html">Writing Converters</a></li>
309+
<li class="toctree-l1"><a class="reference internal" href="../contributors/writing_dynamo_aten_lowering_passes.html">Writing Dynamo ATen Lowering Passes</a></li>
307310
<li class="toctree-l1"><a class="reference internal" href="../contributors/useful_links.html">Useful Links for Torch-TensorRT Development</a></li>
308311
</ul>
309312
<p class="caption" role="heading"><span class="caption-text">Indices</span></p>
@@ -414,6 +417,12 @@ <h2>Class Documentation<a class="headerlink" href="#class-documentation" title="
414417
<dd><p>INT64. </p>
415418
</dd></dl>
416419

420+
<dl class="cpp enumerator">
421+
<dt class="sig sig-object cpp" id="_CPPv4N14torch_tensorrt8DataType5Value7kDoubleE">
422+
<span class="target" id="classtorch__tensorrt_1_1DataType_1a6335c0e206340d85a1382a5df17bf684aacf5b40b44995643185a977d2d1ce1bf"></span><span class="k"><span class="pre">enumerator</span></span><span class="w"> </span><span class="sig-name descname"><span class="n"><span class="pre">kDouble</span></span></span><a class="headerlink" href="#_CPPv4N14torch_tensorrt8DataType5Value7kDoubleE" title="Permalink to this definition"></a><br /></dt>
423+
<dd><p>FP64. </p>
424+
</dd></dl>
425+
417426
<dl class="cpp enumerator">
418427
<dt class="sig sig-object cpp" id="_CPPv4N14torch_tensorrt8DataType5Value6kFloatE">
419428
<span class="target" id="classtorch__tensorrt_1_1DataType_1a6335c0e206340d85a1382a5df17bf684a45ceda04c1ab50695a4a6aeaeae99817"></span><span class="k"><span class="pre">enumerator</span></span><span class="w"> </span><span class="sig-name descname"><span class="n"><span class="pre">kFloat</span></span></span><a class="headerlink" href="#_CPPv4N14torch_tensorrt8DataType5Value6kFloatE" title="Permalink to this definition"></a><br /></dt>

0 commit comments

Comments
 (0)