@@ -190,6 +190,57 @@ void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) {
190
190
}
191
191
}
192
192
193
+ void MapIValues (ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> in_list, c10::ArrayRef<const torch::jit::Value*> out_list, int64_t in_offset, int64_t out_offset) {
194
+ std::vector<std::pair<const torch::jit::Value*, const torch::jit::Value*>> input_output_pairs;
195
+ std::transform (in_list.begin () + in_offset, in_list.end (), out_list.begin () + out_offset,
196
+ std::back_inserter (input_output_pairs),
197
+ [](auto in, auto out){
198
+ return std::make_pair (in, out);
199
+ });
200
+
201
+ for (auto p : input_output_pairs) {
202
+ auto input = ctx->evaluated_value_map [p.first ];
203
+ ctx->evaluated_value_map [p.second ] = torch::jit::IValue (input);
204
+ }
205
+ }
206
+
207
+ // TODO: With functionalization pass we may be able to make this into a regular evaluator later
208
+ void EvaluateLoopBlock (ConversionCtx* ctx, const torch::jit::Node* n) {
209
+ auto max_trip_count = ctx->evaluated_value_map [n->input (0 )];
210
+ auto start_cond = ctx->evaluated_value_map [n->input (1 )];
211
+ ctx->evaluated_value_map [n->blocks ()[0 ]->inputs ()[0 ]] = torch::jit::IValue (0 );
212
+ auto trip_count = ctx->evaluated_value_map [n->blocks ()[0 ]->inputs ()[0 ]];
213
+
214
+ MapIValues (ctx, n->inputs (), n->outputs (), 2 , 0 );
215
+
216
+ LOG_DEBUG (" (Loop Evaluation) Evaluating loop " << *n);
217
+ LOG_DEBUG (" (Loop Evaluation) Max Trip Count: " << max_trip_count.toInt ());
218
+ LOG_DEBUG (" (Loop Evaluation) Start Condition: " << start_cond.toBool ());
219
+ LOG_DEBUG (" (Loop Evaluation) Current Trip Count: " << trip_count.toInt ());
220
+
221
+ while (start_cond.toBool () && trip_count.toInt () < max_trip_count.toInt ()) {
222
+ MapIValues (ctx, n->outputs (), n->blocks ()[0 ]->inputs (), 0 , 1 );
223
+ for (auto bn : n->blocks ()[0 ]->nodes ()) {
224
+ auto eval = EvaluateNode (ctx, bn);
225
+ if (eval) {
226
+ if (!eval.value ().isTensor ()) {
227
+ LOG_DEBUG (ctx->logger , " (Loop Evaluation) Found the value to be: " << eval.value ());
228
+ } else {
229
+ LOG_DEBUG (ctx->logger , " (Loop Evaluation) Found the value to be a tensor (shape " << eval.value ().toTensor ().sizes () << ' )' );
230
+ }
231
+ ctx->AssociateValueAndIValue (bn->output (0 ), eval.value ());
232
+ }
233
+ }
234
+
235
+ MapIValues (ctx, n->blocks ()[0 ]->outputs (), n->outputs (), 1 , 0 );
236
+ start_cond = ctx->evaluated_value_map [n->blocks ()[0 ]->outputs ()[0 ]];
237
+ auto new_trip_count = torch::jit::IValue (trip_count.toInt () + 1 );
238
+ trip_count.swap (new_trip_count);
239
+ LOG_DEBUG (" (Loop Evaluation) Condition: " << start_cond.toBool ());
240
+ LOG_DEBUG (" (Loop Evaluation) Current Trip Count: " << trip_count.toInt ());
241
+ }
242
+ }
243
+
193
244
void ConvertBlockToNetDef (ConversionCtx* ctx, const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params) {
194
245
LOG_INFO (ctx->logger , " Converting Block" );
195
246
@@ -202,7 +253,19 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, Conver
202
253
for (const auto n : nodes) {
203
254
bool to_eval = evaluators::shouldEvalAtConversionTime (n);
204
255
bool blacklisted = isNodeConversionBlacklisted (n);
205
- if (!to_eval && !blacklisted) {
256
+ if (n->kind () == torch::jit::prim::Loop) {
257
+ EvaluateLoopBlock (ctx, n);
258
+ } else if (to_eval) {
259
+ auto eval = EvaluateNode (ctx, n);
260
+ if (eval) {
261
+ if (!eval.value ().isTensor ()) {
262
+ LOG_DEBUG (ctx->logger , " Found the value to be: " << eval.value ());
263
+ } else {
264
+ LOG_DEBUG (ctx->logger , " Found the value to be a tensor (shape " << eval.value ().toTensor ().sizes () << ' )' );
265
+ }
266
+ ctx->AssociateValueAndIValue (n->output (0 ), eval.value ());
267
+ }
268
+ } else if (!blacklisted) {
206
269
// Should error out if something fails
207
270
AddLayer (ctx, n);
208
271
} else {
@@ -237,22 +300,29 @@ std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo buil
237
300
return engine;
238
301
}
239
302
240
- bool VerifyConverterSupportForBlock (const torch::jit::Block* b) {
241
- bool supported = true ;
303
+ std::set<std::string> GetUnsupportedOpsInBlock (const torch::jit::Block* b ) {
242
304
std::set<std::string> unsupported_ops;
243
305
for (const auto n : b->nodes ()) {
244
- if (!OpSupported (n)) {
306
+ if (!OpSupported (n) && n-> kind () != torch::jit::prim::Loop ) {
245
307
auto schema = n->maybeSchema ();
246
308
TRTORCH_CHECK (schema, " Unable to get schema for Node " << util::node_info (n) \
247
309
<< " (conversion.VerifyCoverterSupportForBlock" );
248
310
std::stringstream ss;
249
311
ss << *schema;
250
312
unsupported_ops.insert (ss.str ());
251
- supported = false ;
313
+ }
314
+ for (const auto sub_b : n->blocks ()) {
315
+ auto sub_b_unsupported_ops = GetUnsupportedOpsInBlock (sub_b);
316
+ unsupported_ops.insert (sub_b_unsupported_ops.begin (), sub_b_unsupported_ops.end ());
252
317
}
253
318
}
319
+ return unsupported_ops;
320
+ }
321
+
322
+ bool VerifyConverterSupportForBlock (const torch::jit::Block* b) {
323
+ auto unsupported_ops = GetUnsupportedOpsInBlock (b);
254
324
255
- if (!supported ) {
325
+ if (unsupported_ops. size () != 0 ) {
256
326
std::stringstream unsupported_msg;
257
327
unsupported_msg << " Method requested cannot be compiled by TRTorch.\n Unsupported operators listed below:" << std::endl;
258
328
for (auto s : unsupported_ops) {
@@ -261,8 +331,10 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
261
331
unsupported_msg << " You can either implement converters for these ops in your application or request implementation" << std::endl;
262
332
unsupported_msg << " https://www.github.com/nvidia/TRTorch/issues" << std::endl;
263
333
LOG_ERROR (unsupported_msg.str ());
334
+ return false ;
335
+ } else {
336
+ return true ;
264
337
}
265
- return supported;
266
338
}
267
339
268
340
} // namespace conversion
0 commit comments