Skip to content

Commit 4e15605

Browse files
committed
fix(to_backend): Clean up to_backend implementation
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent a5bc3b0 commit 4e15605

File tree

1 file changed

+6
-17
lines changed

1 file changed

+6
-17
lines changed

Diff for: py/trtorch/csrc/tensorrt_backend.cpp

+6-17
Original file line numberDiff line numberDiff line change
@@ -24,40 +24,29 @@ c10::IValue preprocess(const torch::jit::Module& mod, const c10::Dict<c10::IValu
2424

2525
c10::impl::GenericDict TensorRTBackend::compile(c10::IValue mod_val, c10::impl::GenericDict method_compile_spec) {
2626
auto mod = mod_val.toModule();
27-
mod = core::lowering::LowerModule(mod);
28-
2927
auto spec = c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
30-
core::lowering::LowerInfo lower_info;
31-
for (auto it = spec.begin(), end = spec.end(); it != end; ++it) {
32-
const auto& method_name = it->key();
33-
auto method = mod.get_method(method_name);
34-
auto graph = method.graph();
35-
core::lowering::LowerGraph(graph, lower_info);
36-
}
3728

3829
auto handles = c10::impl::GenericDict(
3930
c10::StringType::get(), c10::getCustomClassType<c10::intrusive_ptr<core::runtime::TRTEngine>>());
4031

4132
for (auto it = spec.begin(), end = spec.end(); it != end; ++it) {
33+
auto mod_ = mod.clone();
4234
const auto& method_name = it->key();
43-
auto method = mod.get_method(method_name);
44-
auto g = method.graph();
45-
4635
auto raw_spec = it->value().toCustomClass<trtorch::pyapi::CompileSpec>();
4736
LOG_DEBUG(raw_spec->stringify());
4837
auto cfg = raw_spec->toInternalCompileSpec();
49-
auto convert_cfg = std::move(cfg.convert_info);
50-
auto graph_and_ivalues = torch::jit::LowerGraph(*g, mod._ivalue());
38+
auto graph_and_ivals = Lower(mod_, method_name, cfg.lower_info);
5139

52-
g = graph_and_ivalues.first;
53-
auto params = graph_and_ivalues.second;
40+
auto g = graph_and_ivals.first;
41+
auto params = graph_and_ivals.second;
5442
auto named_params = core::conversion::get_named_params(g->inputs(), params);
5543

44+
auto convert_cfg = std::move(cfg.convert_info);
5645
auto device_spec = convert_cfg.engine_settings.device;
5746
auto device = core::runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
5847
auto serialized_engine = core::conversion::ConvertBlockToEngine(g->block(), convert_cfg, named_params);
5948
auto engine_handle = c10::make_intrusive<core::runtime::TRTEngine>(it->key(), serialized_engine, device);
60-
handles.insert(method.name(), at::IValue(engine_handle));
49+
handles.insert(method_name, at::IValue(engine_handle));
6150
}
6251

6352
return c10::impl::toGenericDict(handles);

0 commit comments

Comments
 (0)