Skip to content

Commit 35cf89d

Browse files
committed
Merge branch 'main' into bose_fx2trt_converters_slice_select
2 parents 8303cd5 + c2126b1 commit 35cf89d

File tree

139 files changed

+1330
-532
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

139 files changed

+1330
-532
lines changed

.circleci/config.yml

+146-8
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ commands:
263263
parameters:
264264
torch-build:
265265
type: string
266-
default: "2.0.0.dev20230219+cu117"
266+
default: "2.1.0.dev20230314+cu117"
267267
torch-build-index:
268268
type: string
269269
default: "https://download.pytorch.org/whl/nightly/cu117"
@@ -524,15 +524,47 @@ commands:
524524
- store_artifacts:
525525
path: /tmp/testlogs
526526

527-
test-fx_converters:
528-
description: "Test the fx converters"
527+
test-fx_converters_acc:
528+
description: "Test the fx acc converters"
529529
steps:
530530
- run:
531531
name: Run FX converter tests
532532
command: |
533533
cd py/torch_tensorrt/fx/test
534-
pushd converters/
535-
pytest --junitxml=/tmp/artifacts/test_results/fx/converters/test_results.xml
534+
pushd converters/acc_op/
535+
pytest --junitxml=/tmp/artifacts/test_results/fx/converters/acc_op/test_results.xml
536+
popd
537+
538+
- store_test_results:
539+
path: /tmp/artifacts
540+
- store_artifacts:
541+
path: /tmp/testlogs
542+
543+
test-fx_converters_aten:
544+
description: "Test the fx aten converters"
545+
steps:
546+
- run:
547+
name: Run FX converter tests
548+
command: |
549+
cd py/torch_tensorrt/fx/test
550+
pushd converters/aten_op/
551+
pytest --junitxml=/tmp/artifacts/test_results/fx/converters/aten_op/test_results.xml
552+
popd
553+
554+
- store_test_results:
555+
path: /tmp/artifacts
556+
- store_artifacts:
557+
path: /tmp/testlogs
558+
559+
test-fx_converters_vanilla:
560+
description: "Test the fx vanilla converters"
561+
steps:
562+
- run:
563+
name: Run FX converter tests
564+
command: |
565+
cd py/torch_tensorrt/fx/test
566+
pushd converters/vanilla/
567+
pytest --junitxml=/tmp/artifacts/test_results/fx/converters/vanilla/test_results.xml
536568
popd
537569
538570
- store_test_results:
@@ -587,7 +619,7 @@ commands:
587619
path: /tmp/testlogs
588620

589621
test-fx_tracer:
590-
description: "Test the fx tracer"
622+
description: "Test all fx tracers"
591623
steps:
592624
- run:
593625
name: Run FX tracer
@@ -602,6 +634,22 @@ commands:
602634
- store_artifacts:
603635
path: /tmp/testlogs
604636

637+
test-fx_tracer_acc:
638+
description: "Test the fx acc tracer only"
639+
steps:
640+
- run:
641+
name: Run FX tracer
642+
command: |
643+
cd py/torch_tensorrt/fx/test
644+
pushd tracer
645+
list_tracer=$(ls | grep test_acc)
646+
pytest $list_tracer --junitxml=/tmp/artifacts/test_results/fx/tracer/test_results.xml
647+
popd
648+
- store_test_results:
649+
path: /tmp/artifacts
650+
- store_artifacts:
651+
path: /tmp/testlogs
652+
605653
test-fx_quant:
606654
description: "Test the fx quant"
607655
steps:
@@ -625,7 +673,9 @@ commands:
625673
name: Run fx tests
626674
command: |
627675
mkdir -p /tmp/artifacts/test_results
628-
- test-fx_converters
676+
- test-fx_converters_acc
677+
- test-fx_converters_aten
678+
- test-fx_converters_vanilla
629679
- test-fx_passes
630680
- test-fx_tools
631681
- test-fx_trt_lower
@@ -637,6 +687,26 @@ commands:
637687
- store_artifacts:
638688
path: /tmp/testlogs
639689

690+
test-fx-no-aten:
691+
description: "Test the fx backend without aten operators"
692+
steps:
693+
- run:
694+
name: Run fx tests without aten ops
695+
command: |
696+
mkdir -p /tmp/artifacts/test_results
697+
- test-fx_converters_acc
698+
- test-fx_converters_vanilla
699+
- test-fx_passes
700+
- test-fx_tools
701+
- test-fx_trt_lower
702+
- test-fx_tracer_acc
703+
- test-fx_core
704+
- test-fx_quant
705+
- store_test_results:
706+
path: /tmp/artifacts
707+
- store_artifacts:
708+
path: /tmp/testlogs
709+
640710
# Define a job to be invoked later in a workflow.
641711
# See: https://circleci.com/docs/2.0/configuration-reference/#jobs
642712
jobs:
@@ -782,6 +852,37 @@ jobs:
782852
- dump-test-env
783853
- test-fx
784854

855+
test-py-fx-x86_64-linux-no-aten:
856+
parameters:
857+
torch-build:
858+
type: string
859+
torch-build-index:
860+
type: string
861+
trt-version-long:
862+
type: string
863+
machine:
864+
image: ubuntu-2004-cuda-11.4:202110-01
865+
resource_class: gpu.nvidia.large
866+
steps:
867+
- checkout
868+
- attach_workspace:
869+
at: /tmp/dist/
870+
- install-torch-from-index:
871+
torch-build: << parameters.torch-build >>
872+
torch-build-index: << parameters.torch-build-index >>
873+
- create-py-env:
874+
trt-version-long: << parameters.trt-version-long >>
875+
- install-cudnn
876+
# - run:
877+
# name: "Set LD_LIBRARY_PATH path to include the installed CUDNN"
878+
# command: export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH
879+
- run:
880+
name: "Install torch-tensorrt"
881+
command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl
882+
# We install torch after torch-trt because pip automatically enforces the version constraint otherwise
883+
- dump-test-env
884+
- test-fx-no-aten
885+
785886
package-x86_64-linux:
786887
parameters:
787888
enabled:
@@ -1070,10 +1171,16 @@ parameters:
10701171
# Nightly platform config
10711172
torch-build:
10721173
type: string
1073-
default: "2.0.0.dev20230219+cu117"
1174+
default: "2.1.0.dev20230314+cu117"
10741175
torch-build-index:
10751176
type: string
10761177
default: "https://download.pytorch.org/whl/nightly/cu117"
1178+
torch-build-legacy:
1179+
type: string
1180+
default: "1.13.1+cu117"
1181+
torch-build-index-legacy:
1182+
type: string
1183+
default: "https://download.pytorch.org/whl/cu117"
10771184
cudnn-version:
10781185
type: string
10791186
default: "8.5.0.96"
@@ -1127,6 +1234,7 @@ workflows:
11271234
- release/**/*
11281235
jobs:
11291236
- build-x86_64-linux:
1237+
name: build-x86_64-linux
11301238
torch-build: << pipeline.parameters.torch-build >>
11311239
torch-build-index: << pipeline.parameters.torch-build-index >>
11321240

@@ -1153,6 +1261,36 @@ workflows:
11531261
requires:
11541262
- build-x86_64-linux
11551263

1264+
- build-x86_64-linux:
1265+
name: build-x86_64-linux-legacy
1266+
torch-build: << pipeline.parameters.torch-build-legacy >>
1267+
torch-build-index: << pipeline.parameters.torch-build-index-legacy >>
1268+
1269+
- test-core-cpp-x86_64-linux:
1270+
name: test-core-cpp-x86_64-linux-legacy
1271+
torch-build: << pipeline.parameters.torch-build-legacy >>
1272+
torch-build-index: << pipeline.parameters.torch-build-index-legacy >>
1273+
trt-version-short: << pipeline.parameters.trt-version-short >>
1274+
trt-version-long: << pipeline.parameters.trt-version-long >>
1275+
cudnn-version: << pipeline.parameters.cudnn-version >>
1276+
requires:
1277+
- build-x86_64-linux-legacy
1278+
1279+
- test-py-ts-x86_64-linux:
1280+
name: test-py-ts-x86_64-linux-legacy
1281+
torch-build: << pipeline.parameters.torch-build-legacy >>
1282+
torch-build-index: << pipeline.parameters.torch-build-index-legacy >>
1283+
trt-version-long: << pipeline.parameters.trt-version-long >>
1284+
requires:
1285+
- build-x86_64-linux-legacy
1286+
1287+
- test-py-fx-x86_64-linux-no-aten:
1288+
torch-build: << pipeline.parameters.torch-build-legacy >>
1289+
torch-build-index: << pipeline.parameters.torch-build-index-legacy >>
1290+
trt-version-long: << pipeline.parameters.trt-version-long >>
1291+
requires:
1292+
- build-x86_64-linux-legacy
1293+
11561294
release:
11571295
when: << pipeline.parameters.enable-packaging >>
11581296
jobs:

core/compiler.cpp

+67-16
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ partitioning::GraphAndMapping BuildHybridGraph(
138138
torch::jit::Block* block,
139139
CompileSpec cfg,
140140
ir::StaticParams static_params,
141-
ir::CollectionTypeMap first_use_types) {
141+
ir::CollectionTypeMap first_use_types,
142+
bool expect_full_compilation = false) {
142143
auto convert_info = cfg.convert_info;
143144
auto partitioning_info = cfg.partitioning_info;
144145

@@ -149,17 +150,20 @@ partitioning::GraphAndMapping BuildHybridGraph(
149150
// TODO: Combine this within partition call
150151
partitioning::populateInputIValues(&partitioning_ctx);
151152

152-
partitioning::partition(&partitioning_ctx);
153+
partitioning::partition(&partitioning_ctx, expect_full_compilation);
153154

154155
for (auto& partitioned_block : partitioning_ctx.partitioned_blocks) {
155156
partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second;
157+
int num_torch_segments = 0;
158+
int num_trt_segments = 0;
156159

157160
for (auto& seg_block : segmented_blocks) {
158161
LOG_INFO("Block segment:" << seg_block);
159162
std::ostringstream trt_engine_id;
160163
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
161164

162165
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
166+
num_trt_segments++;
163167
auto inputs = seg_block.construct_inputs_spec();
164168
// update the input ranges for each segments
165169
convert_info.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
@@ -180,8 +184,32 @@ partitioning::GraphAndMapping BuildHybridGraph(
180184
true);
181185

182186
seg_block.update_graph(temp_g);
187+
} else {
188+
num_torch_segments++;
189+
190+
// If full compilation is expected, ensure that all operators in Torch blocks are
191+
// for collections processing
192+
if (expect_full_compilation) {
193+
for (auto torch_node : seg_block.block()->nodes()) {
194+
if (partitioning::CollectionNodeKinds.find(torch_node->kind()) == partitioning::CollectionNodeKinds.end()) {
195+
TORCHTRT_THROW_ERROR(
196+
"Full compilation specified but node "
197+
<< *torch_node
198+
<< " is set to run in PyTorch due to either lack of support in TensorRT or graph partitioning rules."
199+
<< " Try recompiling with require_full_compilation=False.");
200+
}
201+
}
202+
}
183203
}
184204
}
205+
206+
// If full compilation is expected, cannot have more than 2 Torch segments
207+
// (one for preprocessing inputs, one for post-processing outputs) and 1 TRT segment
208+
if (expect_full_compilation && !(num_torch_segments <= 2 && num_trt_segments == 1)) {
209+
TORCHTRT_THROW_ERROR(
210+
"Full compilation was requested but unable to convert all operations to TensorRT."
211+
<< " Try recompiling with require_full_compilation=False.");
212+
}
185213
}
186214

187215
return partitioning::stitch(&partitioning_ctx, block);
@@ -191,7 +219,8 @@ ir::TypeMap MapInputsAndDetermineDTypes(
191219
CompileSpec& cfg,
192220
std::shared_ptr<torch::jit::Graph>& g,
193221
ir::StaticParams& static_params,
194-
ir::CollectionTypeMap& first_use_type_map) {
222+
ir::CollectionTypeMap& first_use_type_map,
223+
bool requires_collection_handling = false) {
195224
cfg.convert_info.collection_input_spec_map =
196225
std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params));
197226
cfg.partitioning_info.collection_input_spec_map =
@@ -226,7 +255,7 @@ ir::TypeMap MapInputsAndDetermineDTypes(
226255
"Cannot infer input type from calcuations in graph for input "
227256
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
228257
spec[i].dtype = at::kFloat;
229-
} else if (spec[i].dtype_is_user_defined && cfg.partitioning_info.enabled) {
258+
} else if (spec[i].dtype_is_user_defined && (cfg.partitioning_info.enabled || requires_collection_handling)) {
230259
if (!est_type_opt[i]) {
231260
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
232261
std::stringstream ss;
@@ -297,6 +326,11 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
297326
return engine;
298327
}
299328

329+
bool userRequestedFallback(CompileSpec& cfg) {
330+
return cfg.lower_info.forced_fallback_modules.size() != 0 ||
331+
cfg.partitioning_info.forced_fallback_operators.size() != 0;
332+
}
333+
300334
torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) {
301335
torch::jit::Module new_mod(mod._ivalue()->name() + "_trt");
302336

@@ -315,8 +349,17 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
315349
// Infer the type of an input from the weights of the calculation
316350
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block());
317351

352+
// Determine if the block is convertible/has collection output, and based on the result,
353+
// whether full compilation can be expected
354+
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
355+
auto outputIsCollection = conversion::OutputIsCollection(g->block());
356+
auto requires_collection_handling = (isBlockConvertible && outputIsCollection);
357+
358+
// Determine whether user specifications necessitate partitioning
359+
auto isFallbackRequested = userRequestedFallback(cfg);
360+
318361
// Extract map of IValue to DType
319-
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
362+
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types, requires_collection_handling);
320363

321364
// Check whether any of the input types are Long
322365
bool user_requested_long = false;
@@ -330,20 +373,28 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
330373
user_requested_long &= (casts_inserted > 0);
331374
}
332375

333-
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
334-
auto outputIsCollection = conversion::OutputIsCollection(g->block());
335-
if (cfg.partitioning_info.enabled && !user_requested_long &&
336-
(cfg.lower_info.forced_fallback_modules.size() == 0 &&
337-
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) &&
338-
!outputIsCollection) {
376+
// Partitioning is required if:
377+
// 1. User requested some modules/operators fallback
378+
// 2. The block (graph) cannot be converted due to operator coverage
379+
// 3. The output of the graph is a collection
380+
// 4. The user requested a non-TRT data type input
381+
auto isPartitioningRequired =
382+
(isFallbackRequested || !isBlockConvertible || outputIsCollection || user_requested_long);
383+
384+
// The user did not require full compilation, but the model can be fully compiled
385+
if (cfg.partitioning_info.enabled && !isPartitioningRequired) {
339386
LOG_INFO("Skipping partitioning since model is fully supported");
340387
}
341388

342-
if (cfg.partitioning_info.enabled &&
343-
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
344-
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
345-
outputIsCollection || user_requested_long)) {
346-
auto graph_and_mapping = BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types);
389+
// The user did not require full compilation, and the model can be fully compiled
390+
// or, the user required full compilation but the I/O of the graph use collections
391+
if ((cfg.partitioning_info.enabled && isPartitioningRequired) || requires_collection_handling) {
392+
// If the model is fully-compilable and the user has specified full compilation, run partitioning
393+
// to generate collection-processing code in Torch
394+
auto expect_full_compilation = (requires_collection_handling && !cfg.partitioning_info.enabled);
395+
396+
auto graph_and_mapping =
397+
BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types, expect_full_compilation);
347398
new_g = graph_and_mapping.first;
348399
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
349400
for (size_t i = 0; i < new_g->inputs().size(); ++i) {

0 commit comments

Comments
 (0)