Skip to content

Commit 91a8d43

Browse files
committed
fix: Add lowering pass to remove output repacking
- Automatically remove output repacking for `convert_method_to_trt_engine` calls, to improve parity between models which can be converted directly to TRT engines, and models which can be fully compiled - Add new internal `CompileSpec` argument for lowering which indicates whether the lowering passes originate from a `convert_method_to_trt_engine` call or a regular `compile` call, which affects whether the lowering pass is applied - Regular TorchScript graphs cannot have this pass applied, as it can otherwise break the output graph. Newer versions of Torch disallow graph outputs with 0 or 2+ arguments which are not packed in a struct - Current lowering pass detects outputs which are flat Lists or Tuples of Tensors and returns the outputs as-is (direct from the TRT Engine), so the entire model can be converted to a single TRT engine
1 parent c60070b commit 91a8d43

File tree

10 files changed

+146
-6
lines changed

10 files changed

+146
-6
lines changed

core/lowering/lowering.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
143143
passes::SiluToSigmoidMultipication(g);
144144
passes::RemoveSingleUse0DTensors(g);
145145
passes::RemoveUnnecessaryCasts(g);
146+
if (lower_info.converting_to_trt_engine) {
147+
passes::RemoveCollectionCast(g);
148+
}
146149
passes::UnpackAndCastMaskedFill(g, lower_info.getGPUDeviceString());
147150
passes::UnpackAndCastNumToTensor(g, lower_info.getGPUDeviceString());
148151
passes::UnpackAndCastFull(g, lower_info.getGPUDeviceString());

core/lowering/lowering.h

+4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ struct LowerInfo {
1616
// Since these QDQ nodes will be identical as they share same input, one of them is eliminated due to CSE lowering
1717
// pass. Disable this in order to not disturb TensorRT's QAT optimizations.
1818
bool disable_cse = false;
19+
20+
// Whether the originating caller is `convert_method_to_trt_engine` (true) or `compile` (false)
21+
bool converting_to_trt_engine = false;
22+
1923
ir::Device target_device;
2024
std::vector<std::string> forced_fallback_modules;
2125
friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l);

core/lowering/passes/passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
3232
void RemoveNOPs(std::shared_ptr<torch::jit::Graph> graph);
3333
void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g);
3434
void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph);
35+
void RemoveCollectionCast(std::shared_ptr<torch::jit::Graph>& g);
3536
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
3637
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
3738
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);

core/lowering/passes/remove_unnecessary_casts.cpp

+40
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "torch/csrc/jit/ir/constants.h"
2+
#include "torch/csrc/jit/passes/dead_code_elimination.h"
23
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
34

45
#include "core/util/prelude.h"
@@ -211,6 +212,45 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
211212
LOG_GRAPH("Post removing single use 0-dim Tensor operations: " << *g);
212213
}
213214

215+
void RemoveCollectionCast(std::shared_ptr<torch::jit::Graph>& g) {
216+
// Removes unnecessary collection-casting of graph outputs
217+
// Only to be used if the overall output is intended to be a TRT Engine
218+
// Will cause errors if used directly as a TorchScript graph
219+
220+
// Validate the output is a single value with type Tuple or List
221+
if (!(g->outputs().size() == 1 &&
222+
(g->outputs()[0]->node()->kind() == torch::jit::prim::TupleConstruct ||
223+
g->outputs()[0]->node()->kind() == torch::jit::prim::ListConstruct))) {
224+
return;
225+
}
226+
227+
// Ensure all inputs to the Tuple/List Construct operator are regular Tensors
228+
// (nested structures cannot be preserved in TensorRT)
229+
auto all_tensors = true;
230+
auto collection_inputs = g->outputs()[0]->node()->inputs();
231+
232+
for (size_t i = 0; i < collection_inputs.size(); ++i) {
233+
all_tensors &= collection_inputs[i]->type()->isSubtypeOf(c10::TensorType::get());
234+
}
235+
236+
if (!all_tensors) {
237+
return;
238+
}
239+
240+
// For each input to the collection packing operator, add its value directly
241+
// as an output of the graph
242+
for (size_t i = 0; i < collection_inputs.size(); ++i) {
243+
g->registerOutput(collection_inputs[i]);
244+
}
245+
246+
// Remove the original output value of the graph (the collection object)
247+
g->eraseOutput(0);
248+
249+
// Clean up remnant collection node in graph
250+
torch::jit::EliminateDeadCode(g);
251+
LOG_GRAPH("Post removing collection casting operations: " << *g);
252+
}
253+
214254
} // namespace passes
215255
} // namespace lowering
216256
} // namespace core

cpp/src/compile_spec.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,11 @@ torchtrt::core::CompileSpec init_compile_spec(CompileSpec& external) {
7878
}
7979
}
8080

81-
torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
81+
torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external, bool converting_to_trt_engine) {
8282
torchtrt::core::CompileSpec internal = init_compile_spec(external);
8383

84+
internal.lower_info.converting_to_trt_engine = converting_to_trt_engine;
85+
8486
for (auto p : external.enabled_precisions) {
8587
internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
8688
}

cpp/src/torch_tensorrt.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace torch_tensorrt {
1010
torch_tensorrt::core::runtime::RTDevice to_internal_rt_device(Device device);
1111
namespace torchscript {
1212
// Defined in compile_spec.cpp
13-
torch_tensorrt::core::CompileSpec to_internal_compile_spec(CompileSpec external);
13+
torch_tensorrt::core::CompileSpec to_internal_compile_spec(CompileSpec external, bool converting_to_trt_engine = false);
1414

1515
bool check_method_operator_support(const torch::jit::script::Module& module, std::string method_name) {
1616
return torch_tensorrt::core::CheckMethodOperatorSupport(module, method_name);
@@ -23,7 +23,8 @@ std::string convert_method_to_trt_engine(
2323
LOG_DEBUG(get_build_info());
2424
// Want to export a much simpler (non TRT header dependent) API so doing the
2525
// type conversion here
26-
return torch_tensorrt::core::ConvertGraphToTRTEngine(module, method_name, to_internal_compile_spec(info));
26+
return torch_tensorrt::core::ConvertGraphToTRTEngine(
27+
module, method_name, to_internal_compile_spec(info, /*bool converting_to_trt_engine=*/true));
2728
}
2829

2930
torch::jit::script::Module compile(const torch::jit::script::Module& module, CompileSpec info) {

py/torch_tensorrt/csrc/tensorrt_classes.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,11 @@ core::CompileSpec init_compile_spec(CompileSpec external) {
326326
}
327327
}
328328

329-
core::CompileSpec CompileSpec::toInternalCompileSpec() {
329+
core::CompileSpec CompileSpec::toInternalCompileSpec(bool converting_to_trt_engine) {
330330
core::CompileSpec info = init_compile_spec(*this);
331331

332+
info.lower_info.converting_to_trt_engine = converting_to_trt_engine;
333+
332334
for (auto p : enabled_precisions) {
333335
info.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
334336
}

py/torch_tensorrt/csrc/tensorrt_classes.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ std::string to_str(EngineCapability value);
123123
nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value);
124124

125125
struct CompileSpec : torch::CustomClassHolder {
126-
core::CompileSpec toInternalCompileSpec();
126+
core::CompileSpec toInternalCompileSpec(bool converting_to_trt_engine = false);
127127
std::string stringify();
128128
void appendInput(const c10::intrusive_ptr<Input>& ir) {
129129
inputs.push_back(*ir);

py/torch_tensorrt/csrc/torch_tensorrt_py.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec& info
158158

159159
py::bytes ConvertGraphToTRTEngine(const torch::jit::Module& mod, const std::string& method_name, CompileSpec& info) {
160160
py::gil_scoped_acquire gil;
161-
auto trt_engine = core::ConvertGraphToTRTEngine(mod, method_name, info.toInternalCompileSpec());
161+
auto trt_engine = core::ConvertGraphToTRTEngine(
162+
mod, method_name, info.toInternalCompileSpec(/*bool converting_to_trt_engine=*/true));
162163
return py::bytes(trt_engine);
163164
}
164165

tests/core/lowering/test_remove_unnecessary_casts.cpp

+86
Original file line numberDiff line numberDiff line change
@@ -437,3 +437,89 @@ TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivFloatValuesAgree) {
437437
ASSERT_TRUE(
438438
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
439439
}
440+
441+
TEST(LoweringPasses, RemoveCollectionCastTuple) {
442+
// Ensure the lowering pass transforms the first graph into the second
443+
std::string source_graph = R"IR(
444+
graph(%x.1 : Tensor):
445+
%3 : int = prim::Constant[value=1]()
446+
%2 : int = prim::Constant[value=2]()
447+
%a.1 : Tensor = aten::mul(%x.1, %2)
448+
%b.1 : Tensor = aten::add(%a.1, %2, %3)
449+
%c.1 : Tensor = aten::relu(%b.1)
450+
%d.1 : Tensor = aten::sqrt(%c.1)
451+
%8 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%c.1, %d.1, %b.1)
452+
return (%8))IR";
453+
454+
std::string target_graph = R"IR(
455+
graph(%x.1 : Tensor):
456+
%3 : int = prim::Constant[value=1]()
457+
%2 : int = prim::Constant[value=2]()
458+
%a.1 : Tensor = aten::mul(%x.1, %2)
459+
%b.1 : Tensor = aten::add(%a.1, %2, %3)
460+
%c.1 : Tensor = aten::relu(%b.1)
461+
%d.1 : Tensor = aten::sqrt(%c.1)
462+
return (%c.1, %d.1, %b.1))IR";
463+
464+
// Ensure the lowering pass transforms the first graph into the second
465+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
466+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
467+
auto sg = std::make_shared<torch::jit::Graph>();
468+
torch::jit::parseIR(source_graph, sg.get());
469+
470+
torch_tensorrt::core::lowering::passes::RemoveCollectionCast(sg);
471+
torch::jit::ConstantPooling(sg);
472+
sg = torch::jit::Canonicalize(sg, false);
473+
474+
auto tg = std::make_shared<torch::jit::Graph>();
475+
torch::jit::parseIR(target_graph, tg.get());
476+
477+
torch::jit::ConstantPooling(tg);
478+
tg = torch::jit::Canonicalize(tg, false);
479+
480+
// Validate identical graphs after pooling constants and canonicalizing
481+
ASSERT_TRUE((tg->toString() == sg->toString()));
482+
}
483+
484+
TEST(LoweringPasses, RemoveCollectionCastList) {
485+
// Ensure the lowering pass transforms the first graph into the second
486+
std::string source_graph = R"IR(
487+
graph(%x.1 : Tensor):
488+
%3 : int = prim::Constant[value=1]()
489+
%2 : int = prim::Constant[value=2]()
490+
%a.1 : Tensor = aten::mul(%x.1, %2)
491+
%b.1 : Tensor = aten::add(%a.1, %2, %3)
492+
%c.1 : Tensor = aten::relu(%b.1)
493+
%d.1 : Tensor = aten::sqrt(%c.1)
494+
%8 : (Tensor, Tensor, Tensor) = prim::ListConstruct(%b.1, %c.1, %d.1)
495+
return (%8))IR";
496+
497+
std::string target_graph = R"IR(
498+
graph(%x.1 : Tensor):
499+
%3 : int = prim::Constant[value=1]()
500+
%2 : int = prim::Constant[value=2]()
501+
%a.1 : Tensor = aten::mul(%x.1, %2)
502+
%b.1 : Tensor = aten::add(%a.1, %2, %3)
503+
%c.1 : Tensor = aten::relu(%b.1)
504+
%d.1 : Tensor = aten::sqrt(%c.1)
505+
return (%b.1, %c.1, %d.1))IR";
506+
507+
// Ensure the lowering pass transforms the first graph into the second
508+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
509+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
510+
auto sg = std::make_shared<torch::jit::Graph>();
511+
torch::jit::parseIR(source_graph, sg.get());
512+
513+
torch_tensorrt::core::lowering::passes::RemoveCollectionCast(sg);
514+
torch::jit::ConstantPooling(sg);
515+
sg = torch::jit::Canonicalize(sg, false);
516+
517+
auto tg = std::make_shared<torch::jit::Graph>();
518+
torch::jit::parseIR(target_graph, tg.get());
519+
520+
torch::jit::ConstantPooling(tg);
521+
tg = torch::jit::Canonicalize(tg, false);
522+
523+
// Validate identical graphs after pooling constants and canonicalizing
524+
ASSERT_TRUE((tg->toString() == sg->toString()));
525+
}

0 commit comments

Comments
 (0)