Skip to content

Commit 3ec836e

Browse files
committed
feat(//core): New API to register arbitrary TRT engines in TorchScript
Modules Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent bbf997e commit 3ec836e

File tree

5 files changed

+62
-0
lines changed

5 files changed

+62
-0
lines changed

Diff for: core/compiler.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,20 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
173173
return new_mod;
174174
}
175175

176+
torch::jit::script::Module EmbedEngineInNewModule(std::string& engine) {
177+
std::ostringstream engine_id;
178+
engine_id << reinterpret_cast<int*>(&engine);
179+
torch::jit::script::Module new_mod("tensorrt_engine_mod_" + engine_id.str());
180+
auto new_g = std::make_shared<torch::jit::Graph>();
181+
AddEngineToGraph(new_mod, new_g, engine);
182+
auto new_method = new_mod._ivalue()->compilation_unit()->create_function("forward", new_g);
183+
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
184+
new_mod.type()->addMethod(new_method);
185+
new_method->setSchema(schema);
186+
187+
return new_mod;
188+
}
189+
176190
void set_device(const int gpu_id) {
177191
TRTORCH_ASSERT(cudaSetDevice(gpu_id) == cudaSuccess, "Unable to set CUDA device: " << gpu_id);
178192
}

Diff for: core/compiler.h

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
1919

2020
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg);
2121

22+
torch::jit::script::Module EmbedEngineInNewModule(std::string& engine);
23+
2224
void set_device(const int gpu_id);
2325

2426
} // namespace core

Diff for: cpp/api/include/trtorch/trtorch.h

+14
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,20 @@ TRTORCH_API std::string ConvertGraphToTRTEngine(
480480
const torch::jit::Module& module,
481481
std::string method_name,
482482
CompileSpec info);
483+
484+
/**
485+
* @brief Take a previously created TensorRT engine and embed it in
486+
* in a TorchScript module
487+
*
488+
* @param engine: std::string - Precompiled serialized TensorRT engine
489+
*
490+
* Takes a prebuilt serialized TensorRT engine and embeds it in a TorchScript
491+
* graph. Registers the engine as the forward method of the module
492+
*
493+
* @return: A new module trageting a TensorRT engine
494+
*/
495+
TRTORCH_API torch::jit::Module EmbedEngineInNewModule(std::string& engine);
496+
483497
/**
484498
* @brief Set gpu device id
485499
*

Diff for: cpp/api/src/trtorch.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module
3131
return core::CompileGraph(module, to_internal_compile_spec(info));
3232
}
3333

34+
torch::jit::Module EmbedEngineInNewModule(std::string& engine) {
35+
return core::EmbedEngineInNewModule(engine);
36+
}
37+
3438
std::string get_build_info() {
3539
auto info = core::util::get_build_info();
3640
return std::string("TRTorch Version: ") + TRTORCH_VERSION + '\n' + info;

Diff for: tests/modules/test_modules_as_engines.cpp

+28
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,34 @@ TEST_P(ModuleTests, ModuleAsEngineIsClose) {
1616
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-5));
1717
}
1818

19+
TEST_P(ModuleTests, ModuleToModuleIsClose) {
20+
std::vector<at::Tensor> inputs;
21+
std::vector<torch::jit::IValue> inputs_ivalues;
22+
for (auto in_shape : input_shapes) {
23+
inputs.push_back(at::randint(5, in_shape, {at::kCUDA}));
24+
inputs_ivalues.push_back(inputs[inputs.size() - 1].clone());
25+
}
26+
27+
torch::jit::IValue jit_results_ivalues = trtorch::tests::util::RunModuleForward(mod, inputs_ivalues);
28+
std::vector<at::Tensor> jit_results;
29+
jit_results.push_back(jit_results_ivalues.toTensor());
30+
31+
auto forward_graph = mod.get_method("forward");
32+
std::vector<c10::ArrayRef<int64_t>> input_ranges;
33+
for (auto in : inputs) {
34+
input_ranges.push_back(in.sizes());
35+
}
36+
37+
auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", input_ranges);
38+
auto trt_mod = trtorch::EmbedEngineInNewModule(engine);
39+
40+
torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(mod, inputs_ivalues);
41+
std::vector<at::Tensor> trt_results;
42+
trt_results.push_back(trt_results_ivalues.toTensor());
43+
44+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-5));
45+
}
46+
1947
INSTANTIATE_TEST_SUITE_P(
2048
ModuleAsEngineForwardIsCloseSuite,
2149
ModuleTests,

0 commit comments

Comments
 (0)