Skip to content

Commit 88f0e81

Browse files
authored
Merge pull request #431 from NVIDIA/arbitrary_trt_engines
New API to register arbitrary TRT engines as TorchScript modules
2 parents bbf997e + cb7a547 commit 88f0e81

File tree

10 files changed

+189
-9
lines changed

10 files changed

+189
-9
lines changed

core/compiler.cpp

+15-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ c10::FunctionSchema GenerateGraphSchema(
4646
void AddEngineToGraph(
4747
torch::jit::script::Module mod,
4848
std::shared_ptr<torch::jit::Graph>& g,
49-
std::string& serialized_engine) {
49+
const std::string& serialized_engine) {
5050
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name(), serialized_engine);
5151
// Get required metadata about the engine out
5252
auto num_io = engine_ptr->num_io;
@@ -173,6 +173,20 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
173173
return new_mod;
174174
}
175175

176+
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine) {
177+
std::ostringstream engine_id;
178+
engine_id << reinterpret_cast<const int*>(&engine);
179+
torch::jit::script::Module new_mod("tensorrt_engine_mod_" + engine_id.str());
180+
auto new_g = std::make_shared<torch::jit::Graph>();
181+
AddEngineToGraph(new_mod, new_g, engine);
182+
auto new_method = new_mod._ivalue()->compilation_unit()->create_function("forward", new_g);
183+
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
184+
new_mod.type()->addMethod(new_method);
185+
new_method->setSchema(schema);
186+
187+
return new_mod;
188+
}
189+
176190
void set_device(const int gpu_id) {
177191
TRTORCH_ASSERT(cudaSetDevice(gpu_id) == cudaSuccess, "Unable to set CUDA device: " << gpu_id);
178192
}

core/compiler.h

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
1919

2020
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg);
2121

22+
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine);
23+
2224
void set_device(const int gpu_id);
2325

2426
} // namespace core

cpp/api/include/trtorch/trtorch.h

+15
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,21 @@ TRTORCH_API std::string ConvertGraphToTRTEngine(
480480
const torch::jit::Module& module,
481481
std::string method_name,
482482
CompileSpec info);
483+
484+
/**
485+
* @brief Take a previously created TensorRT engine and embed it in
486+
* in a TorchScript module
487+
*
488+
* @param engine: std::string - Pre-built serialized TensorRT engine
489+
*
490+
* Takes a pre-built serialized TensorRT engine and embeds it in a TorchScript
491+
* module. Registers execution of the engine as the forward method of the module
492+
* Forward is defined as: forward(Tensor[]) -> Tensor[]
493+
*
494+
* @return: A new module trageting a TensorRT engine
495+
*/
496+
TRTORCH_API torch::jit::Module EmbedEngineInNewModule(const std::string& engine);
497+
483498
/**
484499
* @brief Set gpu device id
485500
*

cpp/api/src/trtorch.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module
3131
return core::CompileGraph(module, to_internal_compile_spec(info));
3232
}
3333

34+
torch::jit::Module EmbedEngineInNewModule(const std::string& engine) {
35+
return core::EmbedEngineInNewModule(engine);
36+
}
37+
3438
std::string get_build_info() {
3539
auto info = core::util::get_build_info();
3640
return std::string("TRTorch Version: ") + TRTORCH_VERSION + '\n' + info;

py/trtorch/_compiler.py

+20
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,26 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
124124
return trtorch._C.convert_graph_to_trt_engine(module._c, method_name, _parse_compile_spec(compile_spec))
125125

126126

127+
def embed_engine_in_new_module(serialized_engine: bytes) -> torch.jit.ScriptModule:
128+
"""Takes a pre-built serialized TensorRT engine and embeds it within a TorchScript module
129+
130+
Takes a pre-built serialied TensorRT engine (as bytes) and embeds it within a TorchScript module.
131+
Registers the forward method to execute the TensorRT engine with the function signature:
132+
133+
forward(Tensor[]) -> Tensor[]
134+
135+
Module can be save with engine embedded with torch.jit.save and moved / loaded according to TRTorch portability rules
136+
137+
Args:
138+
serialized_engine (bytes): Serialized TensorRT engine from either TRTorch or TensorRT APIs
139+
140+
Returns:
141+
torch.jit.ScriptModule: New TorchScript module with engine embedded
142+
"""
143+
cpp_mod = trtorch._C.embed_engine_in_new_module(serialized_engine)
144+
return torch.jit._recursive.wrap_cpp_module(cpp_mod)
145+
146+
127147
def check_method_op_support(module: torch.jit.ScriptModule, method_name: str) -> bool:
128148
"""Checks to see if a method is fully supported by TRTorch
129149

py/trtorch/csrc/trtorch_py.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ bool CheckMethodOperatorSupport(const torch::jit::Module& module, const std::str
119119
return core::CheckMethodOperatorSupport(module, method_name);
120120
}
121121

122+
torch::jit::Module EmbedEngineInNewModule(const py::bytes& engine) {
123+
return core::EmbedEngineInNewModule(engine);
124+
}
125+
122126
std::string get_build_info() {
123127
auto info = core::util::get_build_info();
124128
return info;
@@ -270,6 +274,10 @@ PYBIND11_MODULE(_C, m) {
270274
"check_method_op_support",
271275
&trtorch::pyapi::CheckMethodOperatorSupport,
272276
"Takes a module and a method name and checks if the method graph contains purely convertable operators");
277+
m.def(
278+
"embed_engine_in_new_module",
279+
&trtorch::pyapi::EmbedEngineInNewModule,
280+
"Takes a serialized TensorRT engine and wraps it in the forward method of a new TorchScript module");
273281
m.def("get_build_info", &get_build_info, "Returns build info about the compiler as a string");
274282

275283
m.def("_get_logging_prefix", &logging::get_logging_prefix, "Get the current prefix for the logging output");

tests/modules/test_modules_as_engines.cpp

+28
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,34 @@ TEST_P(ModuleTests, ModuleAsEngineIsClose) {
1616
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-5));
1717
}
1818

19+
TEST_P(ModuleTests, ModuleToEngineToModuleIsClose) {
20+
std::vector<at::Tensor> inputs;
21+
std::vector<torch::jit::IValue> inputs_ivalues;
22+
for (auto in_shape : input_shapes) {
23+
inputs.push_back(at::randint(5, in_shape, {at::kCUDA}));
24+
inputs_ivalues.push_back(inputs[inputs.size() - 1].clone());
25+
}
26+
27+
torch::jit::IValue jit_results_ivalues = trtorch::tests::util::RunModuleForward(mod, inputs_ivalues);
28+
std::vector<at::Tensor> jit_results;
29+
jit_results.push_back(jit_results_ivalues.toTensor());
30+
31+
auto forward_graph = mod.get_method("forward");
32+
std::vector<c10::ArrayRef<int64_t>> input_ranges;
33+
for (auto in : inputs) {
34+
input_ranges.push_back(in.sizes());
35+
}
36+
37+
auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", input_ranges);
38+
auto trt_mod = trtorch::EmbedEngineInNewModule(engine);
39+
40+
torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(mod, inputs_ivalues);
41+
std::vector<at::Tensor> trt_results;
42+
trt_results.push_back(trt_results_ivalues.toTensor());
43+
44+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-5));
45+
}
46+
1947
INSTANTIATE_TEST_SUITE_P(
2048
ModuleAsEngineForwardIsCloseSuite,
2149
ModuleTests,

tests/py/BUILD

+14-5
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ py_test(
3030
srcs = [
3131
"test_ptq_dataloader_calibrator.py",
3232
"model_test_case.py"
33-
]
33+
],
3434
deps = [
3535
requirement("torchvision")
3636
]
@@ -43,7 +43,7 @@ py_test(
4343
srcs = [
4444
"test_ptq_trt_calibrator.py",
4545
"model_test_case.py"
46-
]
46+
],
4747
deps = [
4848
requirement("torchvision")
4949
]
@@ -56,8 +56,6 @@ py_test(
5656
"test_multi_gpu.py",
5757
"model_test_case.py"
5858
],
59-
"//conditions:default" : []
60-
}),
6159
deps = [
6260
requirement("torchvision")
6361
]
@@ -74,12 +72,23 @@ py_test(
7472
]
7573
)
7674

75+
py_test(
76+
name = "test_trt_intercompatability",
77+
srcs = [
78+
"test_trt_intercompatability.py",
79+
"model_test_case.py"
80+
],
81+
deps = [
82+
requirement("torchvision")
83+
]
84+
)
85+
7786
py_test(
7887
name = "test_ptq_to_backend",
7988
srcs = [
8089
"test_ptq_to_backend.py",
8190
"model_test_case.py"
82-
]
91+
],
8392
deps = [
8493
requirement("torchvision")
8594
]

tests/py/test_api.py

+28-3
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,30 @@ def test_compile_script(self):
4646
self.assertTrue(same < 2e-3)
4747

4848

49+
class TestPTtoTRTtoPT(ModelTestCase):
50+
51+
def setUp(self):
52+
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
53+
self.ts_model = torch.jit.script(self.model)
54+
55+
def test_pt_to_trt_to_pt(self):
56+
compile_spec = {
57+
"input_shapes": [self.input.shape],
58+
"device": {
59+
"device_type": trtorch.DeviceType.GPU,
60+
"gpu_id": 0,
61+
"dla_core": 0,
62+
"allow_gpu_fallback": False,
63+
"disable_tf32": False
64+
}
65+
}
66+
67+
trt_engine = trtorch.convert_method_to_trt_engine(self.ts_model, "forward", compile_spec)
68+
trt_mod = trtorch.embed_engine_in_new_module(trt_engine)
69+
same = (trt_mod(self.input) - self.ts_model(self.input)).abs().max()
70+
self.assertTrue(same < 2e-3)
71+
72+
4973
class TestCheckMethodOpSupport(unittest.TestCase):
5074

5175
def setUp(self):
@@ -59,13 +83,13 @@ def test_check_support(self):
5983
class TestLoggingAPIs(unittest.TestCase):
6084

6185
def test_logging_prefix(self):
62-
new_prefix = "TEST"
86+
new_prefix = "Python API Test: "
6387
trtorch.logging.set_logging_prefix(new_prefix)
6488
logging_prefix = trtorch.logging.get_logging_prefix()
6589
self.assertEqual(new_prefix, logging_prefix)
6690

6791
def test_reportable_log_level(self):
68-
new_level = trtorch.logging.Level.Warning
92+
new_level = trtorch.logging.Level.Error
6993
trtorch.logging.set_reportable_log_level(new_level)
7094
level = trtorch.logging.get_reportable_log_level()
7195
self.assertEqual(new_level, level)
@@ -78,10 +102,11 @@ def test_is_colored_output_on(self):
78102

79103
def test_suite():
80104
suite = unittest.TestSuite()
105+
suite.addTest(unittest.makeSuite(TestLoggingAPIs))
81106
suite.addTest(TestCompile.parametrize(TestCompile, model=models.resnet18(pretrained=True)))
82107
suite.addTest(TestCompile.parametrize(TestCompile, model=models.mobilenet_v2(pretrained=True)))
108+
suite.addTest(TestPTtoTRTtoPT.parametrize(TestPTtoTRTtoPT, model=models.mobilenet_v2(pretrained=True)))
83109
suite.addTest(unittest.makeSuite(TestCheckMethodOpSupport))
84-
suite.addTest(unittest.makeSuite(TestLoggingAPIs))
85110

86111
return suite
87112

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import unittest
2+
import trtorch
3+
import torch
4+
import torchvision.models as models
5+
import tensorrt as trt
6+
7+
from model_test_case import ModelTestCase
8+
9+
10+
class TestPyTorchToTRTEngine(ModelTestCase):
11+
12+
def setUp(self):
13+
self.input = torch.randn((1, 3, 224, 224)).to("cuda:0")
14+
self.ts_model = torch.jit.script(self.model)
15+
16+
def test_pt_to_trt(self):
17+
compile_spec = {
18+
"input_shapes": [self.input.shape],
19+
"device": {
20+
"device_type": trtorch.DeviceType.GPU,
21+
"gpu_id": 0,
22+
"dla_core": 0,
23+
"allow_gpu_fallback": False,
24+
"disable_tf32": False
25+
}
26+
}
27+
28+
trt_engine = trtorch.convert_method_to_trt_engine(self.ts_model, "forward", compile_spec)
29+
30+
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
31+
with trt.Runtime(TRT_LOGGER) as rt:
32+
engine = rt.deserialize_cuda_engine(trt_engine)
33+
with engine.create_execution_context() as ctx:
34+
out = torch.empty(size=tuple(engine.get_binding_shape(1))).to("cuda:0")
35+
bindings = [self.input.contiguous().data_ptr(), out.contiguous().data_ptr()]
36+
ctx.execute_async(batch_size=1,
37+
bindings=bindings,
38+
stream_handle=torch.cuda.current_stream(device='cuda:0').cuda_stream)
39+
same = (out - self.ts_model(self.input)).abs().max()
40+
self.assertTrue(same < 2e-3)
41+
42+
43+
def test_suite():
44+
suite = unittest.TestSuite()
45+
suite.addTest(TestPyTorchToTRTEngine.parametrize(TestPyTorchToTRTEngine, model=models.resnet18(pretrained=True)))
46+
47+
return suite
48+
49+
50+
suite = test_suite()
51+
52+
runner = unittest.TextTestRunner()
53+
result = runner.run(suite)
54+
55+
exit(int(not result.wasSuccessful()))

0 commit comments

Comments
 (0)