@@ -291,6 +291,148 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
291
291
}
292
292
}
293
293
294
+ void ConvertLoopBlock (ConversionCtx* ctx, const torch::jit::Node* n) {
295
+ auto block = n->blocks ()[0 ];
296
+
297
+ // max_trip_count and start_cond already evaluated
298
+ auto max_trip_count = ctx->evaluated_value_map [n->input (0 )];
299
+ auto start_cond = ctx->evaluated_value_map [n->input (1 )];
300
+
301
+ ctx->evaluated_value_map [block->inputs ()[0 ]] = torch::jit::IValue (0 );
302
+ auto trip_count = ctx->evaluated_value_map [block->inputs ()[0 ]];
303
+
304
+ // map node inputs [recurrent values] -> node outputs [recurrent values]
305
+ MapIValues (ctx, n->inputs (), n->outputs (), 2 , 0 );
306
+
307
+ LOG_DEBUG (ctx->logger , " (Loop Conversion) Evaluating loop " << *n);
308
+ LOG_DEBUG (ctx->logger , " (Loop Conversion) Max Trip Count: " << max_trip_count.toInt ());
309
+ LOG_DEBUG (ctx->logger , " (Loop Conversion) Start Condition: " << start_cond.toBool ());
310
+ LOG_DEBUG (ctx->logger , " (Loop Conversion) Current Trip Count: " << trip_count.toInt ());
311
+
312
+ // map node outputs [recurrent values] -> block inputs [recurrent values]
313
+ MapIValues (ctx, n->outputs (), block->inputs (), 0 , 1 );
314
+
315
+ auto loop = ctx->net ->addLoop ();
316
+
317
+ // trip limit layer: max_trip_limit
318
+ auto count_weight = converters::Weights (ctx, (int32_t ) max_trip_count.toInt ());
319
+ auto for_const = ctx->net ->addConstant (count_weight.shape , count_weight.data );
320
+ TRTORCH_CHECK (for_const, " Unable to create constant layer from node: " << *n);
321
+
322
+ auto count_limit = loop->addTripLimit (*for_const->getOutput (0 ), nvinfer1::TripLimit::kCOUNT );
323
+ TRTORCH_CHECK (count_limit, " Unable to create trip limit layer from node: " << *n);
324
+ count_limit->setName ((n->input (0 )->debugName () + " [Trip Limit Layer]" ).c_str ());
325
+
326
+ // recurrence layer and trip limit layer: loop condition
327
+ auto cond_weight = converters::Weights (ctx, (int32_t ) (start_cond.toBool () ? 1 : 0 ));
328
+ auto while_const = ctx->net ->addIdentity (*ctx->net ->addConstant (cond_weight.shape , cond_weight.data )->getOutput (0 ));
329
+ TRTORCH_CHECK (while_const, " Unable to create identity layer from node: " << *n);
330
+ while_const->setOutputType (0 , nvinfer1::DataType::kBOOL );
331
+
332
+ auto recurrent_cond = loop->addRecurrence (*while_const->getOutput (0 ));
333
+ TRTORCH_CHECK (recurrent_cond, " Unable to create recurrence layer from node: " << *n);
334
+ recurrent_cond->setName ((n->input (1 )->debugName () + " [Recurrence Layer]" ).c_str ());
335
+
336
+ auto cond_limit = loop->addTripLimit (*recurrent_cond->getOutput (0 ), nvinfer1::TripLimit::kWHILE );
337
+ TRTORCH_CHECK (cond_limit, " Unable to create trip limit layer from node: " << *n);
338
+ cond_limit->setName ((n->input (1 )->debugName () + " [Trip Limit Layer]" ).c_str ());
339
+
340
+ // recurrence layer: trip_count
341
+ auto trip_weight = converters::Weights (ctx, (int32_t ) trip_count.toInt ());
342
+ auto trip_const = ctx->net ->addConstant (trip_weight.shape , trip_weight.data );
343
+ TRTORCH_CHECK (trip_const, " Unable to create constant layer from node: " << *n);
344
+
345
+ auto recurrent_trip = loop->addRecurrence (*trip_const->getOutput (0 ));
346
+ TRTORCH_CHECK (recurrent_trip, " Unable to create recurrence layer from node: " << *n);
347
+ recurrent_trip->setName ((block->inputs ()[0 ]->debugName () + " [Recurrence Layer]" ).c_str ());
348
+
349
+ // add recurrence layers to loop
350
+ std::vector<nvinfer1::IRecurrenceLayer*> recurrent_tensors;
351
+
352
+ // loop through all recurrent inputs
353
+ for (unsigned int i = 2 ; i < n->inputs ().size (); i++) {
354
+ auto inp = n->inputs ()[i];
355
+
356
+ if (inp->type ()->isSubtypeOf (c10::TensorType::get ())) {
357
+ auto recur = loop->addRecurrence (*ctx->value_tensor_map [inp]);
358
+ TRTORCH_CHECK (recur, " Unable to create recurrent layer from node: " << *n);
359
+ recur->setName ((inp->debugName () + " [Recurrence Layer]" ).c_str ());
360
+
361
+ recurrent_tensors.push_back (recur);
362
+ } else {
363
+ TRTORCH_THROW_ERROR (" Only recurrent Tensors allowed as input to Loop" );
364
+ }
365
+ }
366
+
367
+ // evaluate/convert all nodes inside block
368
+ for (auto bn : block->nodes ()) {
369
+ if (bn->kind () == torch::jit::prim::Loop) {
370
+ bool returns_tensor = false ;
371
+
372
+ // if even a single output of the loop returns a tensor, use ConvertLoopBlock
373
+ for (unsigned int i = 0 ; i < bn->outputs ().size (); i++) {
374
+ if (bn->output (i)->type ()->isSubtypeOf (c10::TensorType::get ())) {
375
+ returns_tensor = true ;
376
+ }
377
+ }
378
+
379
+ if (returns_tensor) {
380
+ ConvertLoopBlock (ctx, bn);
381
+ } else {
382
+ EvaluateLoopBlock (ctx, bn);
383
+ }
384
+ } else if (bn->kind () == torch::jit::prim::If) {
385
+ EvaluateConditionalBlock (ctx, bn, true );
386
+ } else if (evaluators::shouldEvalAtConversionTime (bn)) {
387
+ auto eval = EvaluateNode (ctx, bn);
388
+ ctx->AssociateValueAndIValue (bn->output (0 ), eval.value ());
389
+ } else if (!isNodeConversionIgnored (bn)) {
390
+ AddLayer (ctx, bn);
391
+ }
392
+ }
393
+
394
+ // recurrent backedge input for loop condition and input for condition TripLimit (cond_limit)
395
+ auto iter_cond = ctx->evaluated_value_map [block->outputs ()[0 ]];
396
+ auto iter_cond_weight = converters::Weights (ctx, (int32_t ) (iter_cond.toBool () ? 1 : 0 ));
397
+ auto new_while_const = ctx->net ->addIdentity (*ctx->net ->addConstant (iter_cond_weight.shape , iter_cond_weight.data )->getOutput (0 ));
398
+ TRTORCH_CHECK (new_while_const, " Unable to create identity layer from node: " << *n);
399
+ new_while_const->setOutputType (0 , nvinfer1::DataType::kBOOL );
400
+
401
+ recurrent_cond->setInput (1 , *new_while_const->getOutput (0 ));
402
+ cond_limit->setInput (0 , *recurrent_cond->getOutput (0 ));
403
+ ctx->AssociateValueAndTensor (block->outputs ()[0 ], recurrent_cond->getOutput (0 ));
404
+
405
+ // recurrent backedge input for trip_count
406
+ auto one_weight = converters::Weights (ctx, (int32_t ) 1 );
407
+ auto one_const = ctx->net ->addConstant (one_weight.shape , one_weight.data );
408
+ TRTORCH_CHECK (one_const, " Unable to create constant layer from node: " << *n);
409
+ auto add_layer = ctx->net ->addElementWise (*recurrent_trip->getOutput (0 ), *one_const->getOutput (0 ), nvinfer1::ElementWiseOperation::kSUM );
410
+ TRTORCH_CHECK (add_layer, " Unable to create add layer from node: " << *n);
411
+
412
+ recurrent_trip->setInput (1 , *add_layer->getOutput (0 ));
413
+ ctx->AssociateValueAndTensor (block->inputs ()[0 ], recurrent_trip->getOutput (0 ));
414
+
415
+ // recurrent backedge input for each tensor in recurrent_tensor
416
+ for (unsigned int i = 1 ; i < block->outputs ().size (); i++) {
417
+ auto out = block->outputs ()[i];
418
+
419
+ if (out->type ()->isSubtypeOf (c10::TensorType::get ())) {
420
+ recurrent_tensors[i-1 ]->setInput (1 , *ctx->value_tensor_map [out]);
421
+ } else {
422
+ TRTORCH_THROW_ERROR (" Only recurrent Tensors allowed as output to block" );
423
+ }
424
+ }
425
+
426
+ // map recurrent tensors --> n->outputs()
427
+ for (unsigned int i = 0 ; i < recurrent_tensors.size (); i++) {
428
+ auto out = loop->addLoopOutput (*recurrent_tensors[i]->getOutput (0 ), nvinfer1::LoopOutput::kLAST_VALUE );
429
+ TRTORCH_CHECK (out, " Unable to create loop output layer from node: " << *n);
430
+ ctx->AssociateValueAndTensor (n->outputs ()[i], out->getOutput (0 ));
431
+ }
432
+
433
+ LOG_DEBUG (ctx->logger , " (Loop Conversion) Finished evaluating loop " << *n);
434
+ }
435
+
294
436
void ConvertBlockToNetDef (ConversionCtx* ctx, const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params) {
295
437
LOG_INFO (ctx->logger , " Converting Block" );
296
438
@@ -304,7 +446,20 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, Conver
304
446
bool to_eval = evaluators::shouldEvalAtConversionTime (n);
305
447
bool ignored = isNodeConversionIgnored (n);
306
448
if (n->kind () == torch::jit::prim::Loop) {
307
- EvaluateLoopBlock (ctx, n);
449
+ bool returns_tensor = false ;
450
+
451
+ // if even a single output of the loop returns a tensor, use ConvertLoopBlock
452
+ for (unsigned int i = 0 ; i < n->outputs ().size (); i++) {
453
+ if (n->output (i)->type ()->isSubtypeOf (c10::TensorType::get ())) {
454
+ returns_tensor = true ;
455
+ }
456
+ }
457
+
458
+ if (returns_tensor) {
459
+ ConvertLoopBlock (ctx, n);
460
+ } else {
461
+ EvaluateLoopBlock (ctx, n);
462
+ }
308
463
} else if (n->kind () == torch::jit::prim::If) {
309
464
EvaluateConditionalBlock (ctx, n);
310
465
} else if (to_eval) {
0 commit comments