2
2
#include " core/util/prelude.h"
3
3
#include " torch/csrc/jit/api/module.h"
4
4
#include " core/util/prelude.h"
5
+ #include " core/lowering/passes/passes.h"
6
+
5
7
6
8
7
9
namespace trtorch {
@@ -20,9 +22,9 @@ torch::jit::Value* getOrAddInputForValue(torch::jit::Value* old_value, std::shar
20
22
}
21
23
auto new_value = graph->block ()->addInput ();
22
24
old_to_new[old_value] = new_value;
25
+ new_value->copyMetadata (old_value);
23
26
// mapping from new graph input Values to original graph values
24
27
old_to_new[new_value] = old_value;
25
- new_value->copyMetadata (old_value);
26
28
return new_value;
27
29
} else {
28
30
return old_to_new[old_value];
@@ -40,7 +42,6 @@ torch::jit::Node* cloneNode(torch::jit::Node* node, std::shared_ptr<torch::jit::
40
42
auto no = new_node->outputs ()[i];
41
43
old_to_new[oo] = no;
42
44
}
43
-
44
45
return new_node;
45
46
}
46
47
@@ -58,10 +59,13 @@ c10::FunctionSchema getFunctionSchema(std::string method_name, std::shared_ptr<t
58
59
return c10::FunctionSchema (method_name, method_name, args, returns);
59
60
}
60
61
61
- void registerSegmentInOutShape (SegmentedBlock &seg_block, std::unordered_map<torch::jit::Value*, nvinfer1::Dims > &input_shape_map ) {
62
+ void registerSegmentInOutIValues (SegmentedBlock &seg_block, std::unordered_map<torch::jit::Value*, torch::jit::IValue > &ivalues_maps ) {
62
63
// create a module to run the graph
63
64
auto g = seg_block.g ();
64
65
auto copy_g = g->copy ();
66
+ lowering::passes::RemoveInplaceAdd (copy_g);
67
+
68
+ // create tuple for multiple outputs
65
69
if (seg_block.raw_outputs ().size () > 1 ) {
66
70
auto new_output_node = copy_g->appendNode (copy_g->createTuple (copy_g->outputs ()));
67
71
for (int idx = copy_g->outputs ().size () - 1 ; idx >= 0 ; --idx) {
@@ -84,46 +88,60 @@ void registerSegmentInOutShape(SegmentedBlock &seg_block, std::unordered_map<tor
84
88
85
89
// set inputs ivalues
86
90
for (auto &input : seg_block.raw_inputs ()) {
87
- std::vector<int64_t > shape;
88
- nvinfer1::Dims cur_shape = input_shape_map[input];
89
- shape.insert (shape.begin (), std::begin (cur_shape.d ), std::begin (cur_shape.d ) + cur_shape.nbDims );
90
- auto in = at::randint (5 , shape, {at::kCUDA });
91
- jit_inputs_ivalues.push_back (in.clone ());
91
+ if (!ivalues_maps.count (input)) {
92
+ std::cerr << " could find graph input ivalues\n " ;
93
+ }
94
+ if (input->type ()->isSubtypeOf (torch::jit::TensorType::get ())) {
95
+ jit_inputs_ivalues.push_back (ivalues_maps[input].toTensor ());
96
+ } else if (input->type ()->isSubtypeOf (torch::jit::IntType::get ())) {
97
+ jit_inputs_ivalues.push_back (ivalues_maps[input].toInt ());
98
+ }
92
99
}
93
100
94
- std::vector<at::Tensor > jit_results;
101
+ std::vector<torch::jit::IValue > jit_results;
95
102
torch::jit::IValue jit_results_ivalues = cur_mod.forward (jit_inputs_ivalues);
96
- if (jit_results_ivalues.isTensor ()) {
97
- jit_results.push_back (jit_results_ivalues.toTensor ());
98
- } else {
103
+ if (jit_results_ivalues.isTuple ()) {
99
104
auto results = jit_results_ivalues.toTuple ()->elements ();
100
105
for (auto r : results) {
101
- jit_results.push_back (r. toTensor () );
106
+ jit_results.push_back (r);
102
107
}
108
+ } else {
109
+ jit_results.push_back (jit_results_ivalues);
103
110
}
104
111
105
112
size_t idx = 0 ;
106
113
for (auto &output : seg_block.raw_outputs ()) {
107
- input_shape_map [output] = util::toDims ( jit_results[idx++]. sizes ()) ;
114
+ ivalues_maps [output] = jit_results[idx++];
108
115
}
109
116
117
+ // set input shape for each segmented block so we wil use it in conversion process
110
118
std::vector<nvinfer1::Dims> input_shape;
111
119
for (auto &i : seg_block.raw_inputs ()) {
112
- input_shape.push_back (input_shape_map[i]);
120
+ if (ivalues_maps[i].isTensor ()) {
121
+ input_shape.push_back (util::toDims (ivalues_maps[i].toTensor ().sizes ()));
122
+ }
113
123
}
114
124
115
125
seg_block.register_inshape (input_shape);
116
126
}
117
127
118
- std::vector<nvinfer1::Dims> extractNvinfer1Dims (std::vector<conversion::InputRange>& input_ranges) {
119
- std::vector<nvinfer1::Dims> res;
128
+
129
+ std::vector<torch::jit::IValue> generateRandomInputs (std::vector<conversion::InputRange>& input_ranges) {
130
+ std::vector<torch::jit::IValue> random_inputs;
120
131
for (auto &input_range : input_ranges) {
121
- res.push_back (input_range.input_shape );
132
+ auto cur_shape = input_range.input_shape ;
133
+ std::vector<int64_t > shape;
134
+ shape.insert (shape.begin (), std::begin (cur_shape.d ), std::begin (cur_shape.d ) + cur_shape.nbDims );
135
+ auto in = at::randint (5 , shape, {at::kCUDA });
136
+ random_inputs.push_back (in.clone ());
137
+ printf (" is tensor: %d\n " , random_inputs.back ().isTensor ());
122
138
}
123
- return res ;
139
+ return random_inputs ;
124
140
}
125
141
142
+
126
143
void registerSegmentsInputsOutputs (std::vector<SegmentedBlock> &segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
144
+ // find the corresponding raw values in original global graph for this segmented block's inputs/outputs
127
145
std::set<torch::jit::Value*> input_values;
128
146
for (auto &seg_block : segmented_blocks) {
129
147
seg_block.registerInputs ();
@@ -176,6 +194,7 @@ std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g,
176
194
177
195
for (const auto n : nodes) {
178
196
if (n->kind () == torch::jit::prim::Constant) continue ;
197
+
179
198
std::string node_string (n->kind ().toQualString ());
180
199
181
200
if (conversion::OpSupported (n) && !forced_fallback_operators.count (node_string)) {
@@ -186,19 +205,21 @@ std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g,
186
205
}
187
206
}
188
207
merge_nodes (pytorch_nodes, tensorrt_nodes, segmented_blocks, min_block_size);
189
- if (!pytorch_nodes.empty ()) segmented_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
208
+ if (!pytorch_nodes.empty ()) {
209
+ segmented_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
210
+ }
190
211
191
212
registerSegmentsInputsOutputs (segmented_blocks, g);
192
213
193
- std::vector<nvinfer1::Dims> graph_inputs_shape = extractNvinfer1Dims (input_ranges);
194
- std::unordered_map<torch::jit::Value*, nvinfer1::Dims> input_shape_map;
214
+ std::unordered_map<torch::jit::Value*, torch::jit::IValue> ivalues_maps;
195
215
216
+ std::vector<torch::jit::IValue> random_inputs = generateRandomInputs (input_ranges);
196
217
for (size_t i = 0 ; i < g->inputs ().size (); ++i) {
197
- input_shape_map [g->inputs ()[i]] = graph_inputs_shape [i];
218
+ ivalues_maps [g->inputs ()[i]] = random_inputs [i];
198
219
}
199
220
200
221
for (auto &seg_block : segmented_blocks) {
201
- registerSegmentInOutShape (seg_block, input_shape_map );
222
+ registerSegmentInOutIValues (seg_block, ivalues_maps );
202
223
}
203
224
204
225
return segmented_blocks;
0 commit comments