@@ -24,40 +24,29 @@ c10::IValue preprocess(const torch::jit::Module& mod, const c10::Dict<c10::IValu
24
24
25
25
c10::impl::GenericDict TensorRTBackend::compile (c10::IValue mod_val, c10::impl::GenericDict method_compile_spec) {
26
26
auto mod = mod_val.toModule ();
27
- mod = core::lowering::LowerModule (mod);
28
-
29
27
auto spec = c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
30
- core::lowering::LowerInfo lower_info;
31
- for (auto it = spec.begin (), end = spec.end (); it != end; ++it) {
32
- const auto & method_name = it->key ();
33
- auto method = mod.get_method (method_name);
34
- auto graph = method.graph ();
35
- core::lowering::LowerGraph (graph, lower_info);
36
- }
37
28
38
29
auto handles = c10::impl::GenericDict (
39
30
c10::StringType::get (), c10::getCustomClassType<c10::intrusive_ptr<core::runtime::TRTEngine>>());
40
31
41
32
for (auto it = spec.begin (), end = spec.end (); it != end; ++it) {
33
+ auto mod_ = mod.clone ();
42
34
const auto & method_name = it->key ();
43
- auto method = mod.get_method (method_name);
44
- auto g = method.graph ();
45
-
46
35
auto raw_spec = it->value ().toCustomClass <trtorch::pyapi::CompileSpec>();
47
36
LOG_DEBUG (raw_spec->stringify ());
48
37
auto cfg = raw_spec->toInternalCompileSpec ();
49
- auto convert_cfg = std::move (cfg.convert_info );
50
- auto graph_and_ivalues = torch::jit::LowerGraph (*g, mod._ivalue ());
38
+ auto graph_and_ivals = Lower (mod_, method_name, cfg.lower_info );
51
39
52
- g = graph_and_ivalues .first ;
53
- auto params = graph_and_ivalues .second ;
40
+ auto g = graph_and_ivals .first ;
41
+ auto params = graph_and_ivals .second ;
54
42
auto named_params = core::conversion::get_named_params (g->inputs (), params);
55
43
44
+ auto convert_cfg = std::move (cfg.convert_info );
56
45
auto device_spec = convert_cfg.engine_settings .device ;
57
46
auto device = core::runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
58
47
auto serialized_engine = core::conversion::ConvertBlockToEngine (g->block (), convert_cfg, named_params);
59
48
auto engine_handle = c10::make_intrusive<core::runtime::TRTEngine>(it->key (), serialized_engine, device);
60
- handles.insert (method. name () , at::IValue (engine_handle));
49
+ handles.insert (method_name , at::IValue (engine_handle));
61
50
}
62
51
63
52
return c10::impl::toGenericDict (handles);
0 commit comments