Skip to content

Commit 8e6136f

Browse files
gs-olivenarendasan
authored andcommitted
fix: Add lowering pass to remove output repacking in convert_method_to_trt_engine calls (#1945)
1 parent 840a23c commit 8e6136f

File tree

10 files changed

+145
-6
lines changed

10 files changed

+145
-6
lines changed

core/lowering/lowering.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
144144
passes::RemoveSingleUse0DTensors(g);
145145
passes::RemoveUnnecessaryCasts(g);
146146
passes::ReplaceAtenInt(g);
147+
if (lower_info.converting_to_trt_engine) {
148+
passes::RemoveCollectionCast(g);
149+
}
147150
passes::UnpackAndCastMaskedFill(g, lower_info.getGPUDeviceString());
148151
passes::UnpackAndCastNumToTensor(g, lower_info.getGPUDeviceString());
149152
passes::UnpackAndCastFull(g, lower_info.getGPUDeviceString());

core/lowering/lowering.h

+4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ struct LowerInfo {
1616
// Since these QDQ nodes will be identical as they share same input, one of them is eliminated due to CSE lowering
1717
// pass. Disable this in order to not disturb TensorRT's QAT optimizations.
1818
bool disable_cse = false;
19+
20+
// Whether the originating caller is `convert_method_to_trt_engine` (true) or `compile` (false)
21+
bool converting_to_trt_engine = false;
22+
1923
ir::Device target_device;
2024
std::vector<std::string> forced_fallback_modules;
2125
friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l);

core/lowering/passes/passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ void RemoveNOPs(std::shared_ptr<torch::jit::Graph> graph);
3333
void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g);
3434
void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph);
3535
void ReplaceAtenInt(std::shared_ptr<torch::jit::Graph>& g);
36+
void RemoveCollectionCast(std::shared_ptr<torch::jit::Graph>& g);
3637
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
3738
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
3839
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);

core/lowering/passes/remove_unnecessary_casts.cpp

+39
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,45 @@ void ReplaceAtenInt(std::shared_ptr<torch::jit::Graph>& g) {
356356
LOG_GRAPH("Post removing aten.Int.Tensor operations: " << *g);
357357
}
358358

359+
void RemoveCollectionCast(std::shared_ptr<torch::jit::Graph>& g) {
360+
// Removes unnecessary collection-casting of graph outputs
361+
// Only to be used if the overall output is intended to be a TRT Engine
362+
// Will cause errors if used directly as a TorchScript graph
363+
364+
// Validate the output is a single value with type Tuple or List
365+
if (!(g->outputs().size() == 1 &&
366+
(g->outputs()[0]->node()->kind() == torch::jit::prim::TupleConstruct ||
367+
g->outputs()[0]->node()->kind() == torch::jit::prim::ListConstruct))) {
368+
return;
369+
}
370+
371+
// Ensure all inputs to the Tuple/List Construct operator are regular Tensors
372+
// (nested structures cannot be preserved in TensorRT)
373+
auto all_tensors = true;
374+
auto collection_inputs = g->outputs()[0]->node()->inputs();
375+
376+
for (size_t i = 0; i < collection_inputs.size(); ++i) {
377+
all_tensors &= collection_inputs[i]->type()->isSubtypeOf(c10::TensorType::get());
378+
}
379+
380+
if (!all_tensors) {
381+
return;
382+
}
383+
384+
// For each input to the collection packing operator, add its value directly
385+
// as an output of the graph
386+
for (size_t i = 0; i < collection_inputs.size(); ++i) {
387+
g->registerOutput(collection_inputs[i]);
388+
}
389+
390+
// Remove the original output value of the graph (the collection object)
391+
g->eraseOutput(0);
392+
393+
// Clean up remnant collection node in graph
394+
torch::jit::EliminateDeadCode(g);
395+
LOG_GRAPH("Post removing collection casting operations: " << *g);
396+
}
397+
359398
} // namespace passes
360399
} // namespace lowering
361400
} // namespace core

cpp/src/compile_spec.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,11 @@ torchtrt::core::CompileSpec init_compile_spec(CompileSpec& external) {
7878
}
7979
}
8080

81-
torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
81+
torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external, bool converting_to_trt_engine) {
8282
torchtrt::core::CompileSpec internal = init_compile_spec(external);
8383

84+
internal.lower_info.converting_to_trt_engine = converting_to_trt_engine;
85+
8486
for (auto p : external.enabled_precisions) {
8587
internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
8688
}

cpp/src/torch_tensorrt.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace torch_tensorrt {
1010
torch_tensorrt::core::runtime::RTDevice to_internal_rt_device(Device device);
1111
namespace torchscript {
1212
// Defined in compile_spec.cpp
13-
torch_tensorrt::core::CompileSpec to_internal_compile_spec(CompileSpec external);
13+
torch_tensorrt::core::CompileSpec to_internal_compile_spec(CompileSpec external, bool converting_to_trt_engine = false);
1414

1515
bool check_method_operator_support(const torch::jit::script::Module& module, std::string method_name) {
1616
return torch_tensorrt::core::CheckMethodOperatorSupport(module, method_name);
@@ -23,7 +23,8 @@ std::string convert_method_to_trt_engine(
2323
LOG_DEBUG(get_build_info());
2424
// Want to export a much simpler (non TRT header dependent) API so doing the
2525
// type conversion here
26-
return torch_tensorrt::core::ConvertGraphToTRTEngine(module, method_name, to_internal_compile_spec(info));
26+
return torch_tensorrt::core::ConvertGraphToTRTEngine(
27+
module, method_name, to_internal_compile_spec(info, /*bool converting_to_trt_engine=*/true));
2728
}
2829

2930
torch::jit::script::Module compile(const torch::jit::script::Module& module, CompileSpec info) {

py/torch_tensorrt/csrc/tensorrt_classes.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,11 @@ core::CompileSpec init_compile_spec(CompileSpec external) {
326326
}
327327
}
328328

329-
core::CompileSpec CompileSpec::toInternalCompileSpec() {
329+
core::CompileSpec CompileSpec::toInternalCompileSpec(bool converting_to_trt_engine) {
330330
core::CompileSpec info = init_compile_spec(*this);
331331

332+
info.lower_info.converting_to_trt_engine = converting_to_trt_engine;
333+
332334
for (auto p : enabled_precisions) {
333335
info.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
334336
}

py/torch_tensorrt/csrc/tensorrt_classes.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ std::string to_str(EngineCapability value);
123123
nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value);
124124

125125
struct CompileSpec : torch::CustomClassHolder {
126-
core::CompileSpec toInternalCompileSpec();
126+
core::CompileSpec toInternalCompileSpec(bool converting_to_trt_engine = false);
127127
std::string stringify();
128128
void appendInput(const c10::intrusive_ptr<Input>& ir) {
129129
inputs.push_back(*ir);

py/torch_tensorrt/csrc/torch_tensorrt_py.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec& info
158158

159159
py::bytes ConvertGraphToTRTEngine(const torch::jit::Module& mod, const std::string& method_name, CompileSpec& info) {
160160
py::gil_scoped_acquire gil;
161-
auto trt_engine = core::ConvertGraphToTRTEngine(mod, method_name, info.toInternalCompileSpec());
161+
auto trt_engine = core::ConvertGraphToTRTEngine(
162+
mod, method_name, info.toInternalCompileSpec(/*bool converting_to_trt_engine=*/true));
162163
return py::bytes(trt_engine);
163164
}
164165

tests/core/lowering/test_remove_unnecessary_casts.cpp

+86
Original file line numberDiff line numberDiff line change
@@ -589,3 +589,89 @@ TEST(LoweringPasses, RemoveAtenIntConstTensorValuesAgree) {
589589
// Validate identical graphs after pooling constants and canonicalizing
590590
ASSERT_TRUE((tg->toString() == sg->toString()));
591591
}
592+
593+
TEST(LoweringPasses, RemoveCollectionCastTuple) {
594+
// Ensure the lowering pass transforms the first graph into the second
595+
std::string source_graph = R"IR(
596+
graph(%x.1 : Tensor):
597+
%3 : int = prim::Constant[value=1]()
598+
%2 : int = prim::Constant[value=2]()
599+
%a.1 : Tensor = aten::mul(%x.1, %2)
600+
%b.1 : Tensor = aten::add(%a.1, %2, %3)
601+
%c.1 : Tensor = aten::relu(%b.1)
602+
%d.1 : Tensor = aten::sqrt(%c.1)
603+
%8 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%c.1, %d.1, %b.1)
604+
return (%8))IR";
605+
606+
std::string target_graph = R"IR(
607+
graph(%x.1 : Tensor):
608+
%3 : int = prim::Constant[value=1]()
609+
%2 : int = prim::Constant[value=2]()
610+
%a.1 : Tensor = aten::mul(%x.1, %2)
611+
%b.1 : Tensor = aten::add(%a.1, %2, %3)
612+
%c.1 : Tensor = aten::relu(%b.1)
613+
%d.1 : Tensor = aten::sqrt(%c.1)
614+
return (%c.1, %d.1, %b.1))IR";
615+
616+
// Ensure the lowering pass transforms the first graph into the second
617+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
618+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
619+
auto sg = std::make_shared<torch::jit::Graph>();
620+
torch::jit::parseIR(source_graph, sg.get());
621+
622+
torch_tensorrt::core::lowering::passes::RemoveCollectionCast(sg);
623+
torch::jit::ConstantPooling(sg);
624+
sg = torch::jit::Canonicalize(sg, false);
625+
626+
auto tg = std::make_shared<torch::jit::Graph>();
627+
torch::jit::parseIR(target_graph, tg.get());
628+
629+
torch::jit::ConstantPooling(tg);
630+
tg = torch::jit::Canonicalize(tg, false);
631+
632+
// Validate identical graphs after pooling constants and canonicalizing
633+
ASSERT_TRUE((tg->toString() == sg->toString()));
634+
}
635+
636+
TEST(LoweringPasses, RemoveCollectionCastList) {
637+
// Ensure the lowering pass transforms the first graph into the second
638+
std::string source_graph = R"IR(
639+
graph(%x.1 : Tensor):
640+
%3 : int = prim::Constant[value=1]()
641+
%2 : int = prim::Constant[value=2]()
642+
%a.1 : Tensor = aten::mul(%x.1, %2)
643+
%b.1 : Tensor = aten::add(%a.1, %2, %3)
644+
%c.1 : Tensor = aten::relu(%b.1)
645+
%d.1 : Tensor = aten::sqrt(%c.1)
646+
%8 : (Tensor, Tensor, Tensor) = prim::ListConstruct(%b.1, %c.1, %d.1)
647+
return (%8))IR";
648+
649+
std::string target_graph = R"IR(
650+
graph(%x.1 : Tensor):
651+
%3 : int = prim::Constant[value=1]()
652+
%2 : int = prim::Constant[value=2]()
653+
%a.1 : Tensor = aten::mul(%x.1, %2)
654+
%b.1 : Tensor = aten::add(%a.1, %2, %3)
655+
%c.1 : Tensor = aten::relu(%b.1)
656+
%d.1 : Tensor = aten::sqrt(%c.1)
657+
return (%b.1, %c.1, %d.1))IR";
658+
659+
// Ensure the lowering pass transforms the first graph into the second
660+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
661+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
662+
auto sg = std::make_shared<torch::jit::Graph>();
663+
torch::jit::parseIR(source_graph, sg.get());
664+
665+
torch_tensorrt::core::lowering::passes::RemoveCollectionCast(sg);
666+
torch::jit::ConstantPooling(sg);
667+
sg = torch::jit::Canonicalize(sg, false);
668+
669+
auto tg = std::make_shared<torch::jit::Graph>();
670+
torch::jit::parseIR(target_graph, tg.get());
671+
672+
torch::jit::ConstantPooling(tg);
673+
tg = torch::jit::Canonicalize(tg, false);
674+
675+
// Validate identical graphs after pooling constants and canonicalizing
676+
ASSERT_TRUE((tg->toString() == sg->toString()));
677+
}

0 commit comments

Comments
 (0)