7
7
#include " torch/csrc/jit/passes/lower_graph.h"
8
8
#include " torch/csrc/jit/passes/lower_tuples.h"
9
9
#include " torch/csrc/jit/passes/peephole.h"
10
- #include " torch/csrc/jit/passes/quantization.h"
11
10
12
11
#include " core/util/prelude.h"
13
12
#include " core/lowering/lowering.h"
@@ -50,8 +49,7 @@ torch::jit::Module LowerModule(const torch::jit::script::Module& mod) {
50
49
return mod_;
51
50
}
52
51
53
- std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<at::Tensor>> Lower (const torch::jit::script::Module& mod,
54
- std::string method_name) {
52
+ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<at::Tensor>> Lower (const torch::jit::script::Module& mod, std::string method_name) {
55
53
auto lowered_mod = LowerModule (mod);
56
54
auto g = lowered_mod.get_method (method_name).graph ();
57
55
LOG_GRAPH (*g);
@@ -62,9 +60,14 @@ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<at::Tensor>> Lower(con
62
60
lowering::LowerGraph (g);
63
61
// =[torch::jit::FoldConvBatchNorm2d(lowered_mod);
64
62
LOG_GRAPH (" LibTorch Lowering" );
65
- auto graph_and_parameters = torch::jit::LowerGraph (*g, lowered_mod._ivalue ());
63
+ auto graph_and_ivalues = torch::jit::LowerGraph (*g, lowered_mod._ivalue ());
66
64
// Is this necessary?
67
65
lowering::LowerBlock (g->block ());
66
+ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<at::Tensor>> graph_and_parameters;
67
+ for (auto i : graph_and_ivalues.second ) {
68
+ graph_and_parameters.second .push_back (i.toTensor ());
69
+ }
70
+ graph_and_parameters.first = graph_and_ivalues.first ;
68
71
return graph_and_parameters;
69
72
}
70
73
0 commit comments