diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 4184b2f6be..cb1fd97327 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -144,6 +144,9 @@ void LowerGraph(std::shared_ptr& g, std::vector forced_fallback_modules; friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l); diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index a4493379ad..77ff842198 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -33,6 +33,7 @@ void RemoveNOPs(std::shared_ptr graph); void RemoveSingleUse0DTensors(std::shared_ptr& g); void RemoveUnnecessaryCasts(std::shared_ptr& graph); void ReplaceAtenInt(std::shared_ptr& g); +void RemoveCollectionCast(std::shared_ptr& g); void UnpackAddMM(std::shared_ptr& graph); void UnpackBatchNorm(std::shared_ptr& graph); void UnpackLogSoftmax(std::shared_ptr& graph); diff --git a/core/lowering/passes/remove_unnecessary_casts.cpp b/core/lowering/passes/remove_unnecessary_casts.cpp index 672c30409d..9f16b07741 100644 --- a/core/lowering/passes/remove_unnecessary_casts.cpp +++ b/core/lowering/passes/remove_unnecessary_casts.cpp @@ -356,6 +356,45 @@ void ReplaceAtenInt(std::shared_ptr& g) { LOG_GRAPH("Post removing aten.Int.Tensor operations: " << *g); } +void RemoveCollectionCast(std::shared_ptr& g) { + // Removes unnecessary collection-casting of graph outputs + // Only to be used if the overall output is intended to be a TRT Engine + // Will cause errors if used directly as a TorchScript graph + + // Validate the output is a single value with type Tuple or List + if (!(g->outputs().size() == 1 && + (g->outputs()[0]->node()->kind() == torch::jit::prim::TupleConstruct || + g->outputs()[0]->node()->kind() == torch::jit::prim::ListConstruct))) { + return; + } + + // Ensure all inputs to the Tuple/List Construct operator are regular Tensors + // (nested structures cannot be preserved in TensorRT) + auto all_tensors = true; + auto collection_inputs = g->outputs()[0]->node()->inputs(); + + for (size_t i = 0; i < collection_inputs.size(); ++i) { + all_tensors &= collection_inputs[i]->type()->isSubtypeOf(c10::TensorType::get()); + } + + if (!all_tensors) { + return; + } + + // For each input to the collection packing operator, add its value directly + // as an output of the graph + for (size_t i = 0; i < collection_inputs.size(); ++i) { + g->registerOutput(collection_inputs[i]); + } + + // Remove the original output value of the graph (the collection object) + g->eraseOutput(0); + + // Clean up remnant collection node in graph + torch::jit::EliminateDeadCode(g); + LOG_GRAPH("Post removing collection casting operations: " << *g); +} + } // namespace passes } // namespace lowering } // namespace core diff --git a/cpp/src/compile_spec.cpp b/cpp/src/compile_spec.cpp index 41dae65114..68a25b3912 100644 --- a/cpp/src/compile_spec.cpp +++ b/cpp/src/compile_spec.cpp @@ -78,9 +78,11 @@ torchtrt::core::CompileSpec init_compile_spec(CompileSpec& external) { } } -torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) { +torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external, bool converting_to_trt_engine) { torchtrt::core::CompileSpec internal = init_compile_spec(external); + internal.lower_info.converting_to_trt_engine = converting_to_trt_engine; + for (auto p : external.enabled_precisions) { internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p)); } diff --git a/cpp/src/torch_tensorrt.cpp b/cpp/src/torch_tensorrt.cpp index 8c54cf3a29..3f241f205e 100644 --- a/cpp/src/torch_tensorrt.cpp +++ b/cpp/src/torch_tensorrt.cpp @@ -10,7 +10,7 @@ namespace torch_tensorrt { torch_tensorrt::core::runtime::RTDevice to_internal_rt_device(Device device); namespace torchscript { // Defined in compile_spec.cpp -torch_tensorrt::core::CompileSpec to_internal_compile_spec(CompileSpec external); +torch_tensorrt::core::CompileSpec to_internal_compile_spec(CompileSpec external, bool converting_to_trt_engine = false); bool check_method_operator_support(const torch::jit::script::Module& module, std::string method_name) { return torch_tensorrt::core::CheckMethodOperatorSupport(module, method_name); @@ -23,7 +23,8 @@ std::string convert_method_to_trt_engine( LOG_DEBUG(get_build_info()); // Want to export a much simpler (non TRT header dependent) API so doing the // type conversion here - return torch_tensorrt::core::ConvertGraphToTRTEngine(module, method_name, to_internal_compile_spec(info)); + return torch_tensorrt::core::ConvertGraphToTRTEngine( + module, method_name, to_internal_compile_spec(info, /*bool converting_to_trt_engine=*/true)); } torch::jit::script::Module compile(const torch::jit::script::Module& module, CompileSpec info) { diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.cpp b/py/torch_tensorrt/csrc/tensorrt_classes.cpp index 9488b963cf..ac2dffb4b8 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.cpp +++ b/py/torch_tensorrt/csrc/tensorrt_classes.cpp @@ -326,9 +326,11 @@ core::CompileSpec init_compile_spec(CompileSpec external) { } } -core::CompileSpec CompileSpec::toInternalCompileSpec() { +core::CompileSpec CompileSpec::toInternalCompileSpec(bool converting_to_trt_engine) { core::CompileSpec info = init_compile_spec(*this); + info.lower_info.converting_to_trt_engine = converting_to_trt_engine; + for (auto p : enabled_precisions) { info.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p)); } diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.h b/py/torch_tensorrt/csrc/tensorrt_classes.h index b570e456e9..28321b0571 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.h +++ b/py/torch_tensorrt/csrc/tensorrt_classes.h @@ -123,7 +123,7 @@ std::string to_str(EngineCapability value); nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value); struct CompileSpec : torch::CustomClassHolder { - core::CompileSpec toInternalCompileSpec(); + core::CompileSpec toInternalCompileSpec(bool converting_to_trt_engine = false); std::string stringify(); void appendInput(const c10::intrusive_ptr& ir) { inputs.push_back(*ir); diff --git a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp index f39888eb0f..b3880b335a 100644 --- a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp +++ b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp @@ -158,7 +158,8 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec& info py::bytes ConvertGraphToTRTEngine(const torch::jit::Module& mod, const std::string& method_name, CompileSpec& info) { py::gil_scoped_acquire gil; - auto trt_engine = core::ConvertGraphToTRTEngine(mod, method_name, info.toInternalCompileSpec()); + auto trt_engine = core::ConvertGraphToTRTEngine( + mod, method_name, info.toInternalCompileSpec(/*bool converting_to_trt_engine=*/true)); return py::bytes(trt_engine); } diff --git a/tests/core/lowering/test_remove_unnecessary_casts.cpp b/tests/core/lowering/test_remove_unnecessary_casts.cpp index 488d7988ea..47a48ba82f 100644 --- a/tests/core/lowering/test_remove_unnecessary_casts.cpp +++ b/tests/core/lowering/test_remove_unnecessary_casts.cpp @@ -589,3 +589,89 @@ TEST(LoweringPasses, RemoveAtenIntConstTensorValuesAgree) { // Validate identical graphs after pooling constants and canonicalizing ASSERT_TRUE((tg->toString() == sg->toString())); } + +TEST(LoweringPasses, RemoveCollectionCastTuple) { + // Ensure the lowering pass transforms the first graph into the second + std::string source_graph = R"IR( + graph(%x.1 : Tensor): + %3 : int = prim::Constant[value=1]() + %2 : int = prim::Constant[value=2]() + %a.1 : Tensor = aten::mul(%x.1, %2) + %b.1 : Tensor = aten::add(%a.1, %2, %3) + %c.1 : Tensor = aten::relu(%b.1) + %d.1 : Tensor = aten::sqrt(%c.1) + %8 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%c.1, %d.1, %b.1) + return (%8))IR"; + + std::string target_graph = R"IR( + graph(%x.1 : Tensor): + %3 : int = prim::Constant[value=1]() + %2 : int = prim::Constant[value=2]() + %a.1 : Tensor = aten::mul(%x.1, %2) + %b.1 : Tensor = aten::add(%a.1, %2, %3) + %c.1 : Tensor = aten::relu(%b.1) + %d.1 : Tensor = aten::sqrt(%c.1) + return (%c.1, %d.1, %b.1))IR"; + + // Ensure the lowering pass transforms the first graph into the second + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, sg.get()); + + torch_tensorrt::core::lowering::passes::RemoveCollectionCast(sg); + torch::jit::ConstantPooling(sg); + sg = torch::jit::Canonicalize(sg, false); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + torch::jit::ConstantPooling(tg); + tg = torch::jit::Canonicalize(tg, false); + + // Validate identical graphs after pooling constants and canonicalizing + ASSERT_TRUE((tg->toString() == sg->toString())); +} + +TEST(LoweringPasses, RemoveCollectionCastList) { + // Ensure the lowering pass transforms the first graph into the second + std::string source_graph = R"IR( + graph(%x.1 : Tensor): + %3 : int = prim::Constant[value=1]() + %2 : int = prim::Constant[value=2]() + %a.1 : Tensor = aten::mul(%x.1, %2) + %b.1 : Tensor = aten::add(%a.1, %2, %3) + %c.1 : Tensor = aten::relu(%b.1) + %d.1 : Tensor = aten::sqrt(%c.1) + %8 : (Tensor, Tensor, Tensor) = prim::ListConstruct(%b.1, %c.1, %d.1) + return (%8))IR"; + + std::string target_graph = R"IR( + graph(%x.1 : Tensor): + %3 : int = prim::Constant[value=1]() + %2 : int = prim::Constant[value=2]() + %a.1 : Tensor = aten::mul(%x.1, %2) + %b.1 : Tensor = aten::add(%a.1, %2, %3) + %c.1 : Tensor = aten::relu(%b.1) + %d.1 : Tensor = aten::sqrt(%c.1) + return (%b.1, %c.1, %d.1))IR"; + + // Ensure the lowering pass transforms the first graph into the second + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, sg.get()); + + torch_tensorrt::core::lowering::passes::RemoveCollectionCast(sg); + torch::jit::ConstantPooling(sg); + sg = torch::jit::Canonicalize(sg, false); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + torch::jit::ConstantPooling(tg); + tg = torch::jit::Canonicalize(tg, false); + + // Validate identical graphs after pooling constants and canonicalizing + ASSERT_TRUE((tg->toString() == sg->toString())); +}