Skip to content

Commit 2879114

Browse files
committed
feat(): Loop conversion
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent f8bea3b commit 2879114

File tree

4 files changed

+228
-3
lines changed

4 files changed

+228
-3
lines changed

core/conversion/conversion.cpp

+156-1
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,148 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
291291
}
292292
}
293293

294+
void ConvertLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
295+
auto block = n->blocks()[0];
296+
297+
// max_trip_count and start_cond already evaluated
298+
auto max_trip_count = ctx->evaluated_value_map[n->input(0)];
299+
auto start_cond = ctx->evaluated_value_map[n->input(1)];
300+
301+
ctx->evaluated_value_map[block->inputs()[0]] = torch::jit::IValue(0);
302+
auto trip_count = ctx->evaluated_value_map[block->inputs()[0]];
303+
304+
// map node inputs [recurrent values] -> node outputs [recurrent values]
305+
MapIValues(ctx, n->inputs(), n->outputs(), 2, 0);
306+
307+
LOG_DEBUG(ctx->logger, "(Loop Conversion) Evaluating loop " << *n);
308+
LOG_DEBUG(ctx->logger, "(Loop Conversion) Max Trip Count: " << max_trip_count.toInt());
309+
LOG_DEBUG(ctx->logger, "(Loop Conversion) Start Condition: " << start_cond.toBool());
310+
LOG_DEBUG(ctx->logger, "(Loop Conversion) Current Trip Count: " << trip_count.toInt());
311+
312+
// map node outputs [recurrent values] -> block inputs [recurrent values]
313+
MapIValues(ctx, n->outputs(), block->inputs(), 0, 1);
314+
315+
auto loop = ctx->net->addLoop();
316+
317+
// trip limit layer: max_trip_limit
318+
auto count_weight = converters::Weights(ctx, (int32_t) max_trip_count.toInt());
319+
auto for_const = ctx->net->addConstant(count_weight.shape, count_weight.data);
320+
TRTORCH_CHECK(for_const, "Unable to create constant layer from node: " << *n);
321+
322+
auto count_limit = loop->addTripLimit(*for_const->getOutput(0), nvinfer1::TripLimit::kCOUNT);
323+
TRTORCH_CHECK(count_limit, "Unable to create trip limit layer from node: " << *n);
324+
count_limit->setName((n->input(0)->debugName() + " [Trip Limit Layer]").c_str());
325+
326+
// recurrence layer and trip limit layer: loop condition
327+
auto cond_weight = converters::Weights(ctx, (int32_t) (start_cond.toBool() ? 1 : 0));
328+
auto while_const = ctx->net->addIdentity(*ctx->net->addConstant(cond_weight.shape, cond_weight.data)->getOutput(0));
329+
TRTORCH_CHECK(while_const, "Unable to create identity layer from node: " << *n);
330+
while_const->setOutputType(0, nvinfer1::DataType::kBOOL);
331+
332+
auto recurrent_cond = loop->addRecurrence(*while_const->getOutput(0));
333+
TRTORCH_CHECK(recurrent_cond, "Unable to create recurrence layer from node: " << *n);
334+
recurrent_cond->setName((n->input(1)->debugName() + " [Recurrence Layer]").c_str());
335+
336+
auto cond_limit = loop->addTripLimit(*recurrent_cond->getOutput(0), nvinfer1::TripLimit::kWHILE);
337+
TRTORCH_CHECK(cond_limit, "Unable to create trip limit layer from node: " << *n);
338+
cond_limit->setName((n->input(1)->debugName() + " [Trip Limit Layer]").c_str());
339+
340+
// recurrence layer: trip_count
341+
auto trip_weight = converters::Weights(ctx, (int32_t) trip_count.toInt());
342+
auto trip_const = ctx->net->addConstant(trip_weight.shape, trip_weight.data);
343+
TRTORCH_CHECK(trip_const, "Unable to create constant layer from node: " << *n);
344+
345+
auto recurrent_trip = loop->addRecurrence(*trip_const->getOutput(0));
346+
TRTORCH_CHECK(recurrent_trip, "Unable to create recurrence layer from node: " << *n);
347+
recurrent_trip->setName((block->inputs()[0]->debugName() + " [Recurrence Layer]").c_str());
348+
349+
// add recurrence layers to loop
350+
std::vector<nvinfer1::IRecurrenceLayer*> recurrent_tensors;
351+
352+
// loop through all recurrent inputs
353+
for (unsigned int i = 2; i < n->inputs().size(); i++) {
354+
auto inp = n->inputs()[i];
355+
356+
if (inp->type()->isSubtypeOf(c10::TensorType::get())) {
357+
auto recur = loop->addRecurrence(*ctx->value_tensor_map[inp]);
358+
TRTORCH_CHECK(recur, "Unable to create recurrent layer from node: " << *n);
359+
recur->setName((inp->debugName() + " [Recurrence Layer]").c_str());
360+
361+
recurrent_tensors.push_back(recur);
362+
} else {
363+
TRTORCH_THROW_ERROR("Only recurrent Tensors allowed as input to Loop");
364+
}
365+
}
366+
367+
// evaluate/convert all nodes inside block
368+
for (auto bn : block->nodes()) {
369+
if (bn->kind() == torch::jit::prim::Loop) {
370+
bool returns_tensor = false;
371+
372+
// if even a single output of the loop returns a tensor, use ConvertLoopBlock
373+
for (unsigned int i = 0; i < bn->outputs().size(); i++) {
374+
if (bn->output(i)->type()->isSubtypeOf(c10::TensorType::get())) {
375+
returns_tensor = true;
376+
}
377+
}
378+
379+
if (returns_tensor) {
380+
ConvertLoopBlock(ctx, bn);
381+
} else {
382+
EvaluateLoopBlock(ctx, bn);
383+
}
384+
} else if (bn->kind() == torch::jit::prim::If) {
385+
EvaluateConditionalBlock(ctx, bn, true);
386+
} else if (evaluators::shouldEvalAtConversionTime(bn)) {
387+
auto eval = EvaluateNode(ctx, bn);
388+
ctx->AssociateValueAndIValue(bn->output(0), eval.value());
389+
} else if (!isNodeConversionIgnored(bn)) {
390+
AddLayer(ctx, bn);
391+
}
392+
}
393+
394+
// recurrent backedge input for loop condition and input for condition TripLimit (cond_limit)
395+
auto iter_cond = ctx->evaluated_value_map[block->outputs()[0]];
396+
auto iter_cond_weight = converters::Weights(ctx, (int32_t) (iter_cond.toBool() ? 1 : 0));
397+
auto new_while_const = ctx->net->addIdentity(*ctx->net->addConstant(iter_cond_weight.shape, iter_cond_weight.data)->getOutput(0));
398+
TRTORCH_CHECK(new_while_const, "Unable to create identity layer from node: " << *n);
399+
new_while_const->setOutputType(0, nvinfer1::DataType::kBOOL);
400+
401+
recurrent_cond->setInput(1, *new_while_const->getOutput(0));
402+
cond_limit->setInput(0, *recurrent_cond->getOutput(0));
403+
ctx->AssociateValueAndTensor(block->outputs()[0], recurrent_cond->getOutput(0));
404+
405+
// recurrent backedge input for trip_count
406+
auto one_weight = converters::Weights(ctx, (int32_t) 1);
407+
auto one_const = ctx->net->addConstant(one_weight.shape, one_weight.data);
408+
TRTORCH_CHECK(one_const, "Unable to create constant layer from node: " << *n);
409+
auto add_layer = ctx->net->addElementWise(*recurrent_trip->getOutput(0), *one_const->getOutput(0), nvinfer1::ElementWiseOperation::kSUM);
410+
TRTORCH_CHECK(add_layer, "Unable to create add layer from node: " << *n);
411+
412+
recurrent_trip->setInput(1, *add_layer->getOutput(0));
413+
ctx->AssociateValueAndTensor(block->inputs()[0], recurrent_trip->getOutput(0));
414+
415+
// recurrent backedge input for each tensor in recurrent_tensor
416+
for (unsigned int i = 1; i < block->outputs().size(); i++) {
417+
auto out = block->outputs()[i];
418+
419+
if (out->type()->isSubtypeOf(c10::TensorType::get())) {
420+
recurrent_tensors[i-1]->setInput(1, *ctx->value_tensor_map[out]);
421+
} else {
422+
TRTORCH_THROW_ERROR("Only recurrent Tensors allowed as output to block");
423+
}
424+
}
425+
426+
// map recurrent tensors --> n->outputs()
427+
for (unsigned int i = 0; i < recurrent_tensors.size(); i++) {
428+
auto out = loop->addLoopOutput(*recurrent_tensors[i]->getOutput(0), nvinfer1::LoopOutput::kLAST_VALUE);
429+
TRTORCH_CHECK(out, "Unable to create loop output layer from node: " << *n);
430+
ctx->AssociateValueAndTensor(n->outputs()[i], out->getOutput(0));
431+
}
432+
433+
LOG_DEBUG(ctx->logger, "(Loop Conversion) Finished evaluating loop " << *n);
434+
}
435+
294436
void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params) {
295437
LOG_INFO(ctx->logger, "Converting Block");
296438

@@ -304,7 +446,20 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, Conver
304446
bool to_eval = evaluators::shouldEvalAtConversionTime(n);
305447
bool ignored = isNodeConversionIgnored(n);
306448
if (n->kind() == torch::jit::prim::Loop) {
307-
EvaluateLoopBlock(ctx, n);
449+
bool returns_tensor = false;
450+
451+
// if even a single output of the loop returns a tensor, use ConvertLoopBlock
452+
for (unsigned int i = 0; i < n->outputs().size(); i++) {
453+
if (n->output(i)->type()->isSubtypeOf(c10::TensorType::get())) {
454+
returns_tensor = true;
455+
}
456+
}
457+
458+
if (returns_tensor) {
459+
ConvertLoopBlock(ctx, n);
460+
} else {
461+
EvaluateLoopBlock(ctx, n);
462+
}
308463
} else if (n->kind() == torch::jit::prim::If) {
309464
EvaluateConditionalBlock(ctx, n);
310465
} else if (to_eval) {

core/lowering/lowering.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
3535
passes::Conv2DToConvolution(g);
3636
passes::FuseAddMMBranches(g);
3737
torch::jit::EliminateCommonSubexpression(g);
38-
torch::jit::UnrollLoops(g);
38+
//torch::jit::UnrollLoops(g);
3939
torch::jit::EliminateCommonSubexpression(g);
4040
passes::UnpackAddMM(g);
4141
//passes::UnpackBatchNorm(g);

tests/core/converters/BUILD

+6-1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ converter_test(
6767
name = "test_stack"
6868
)
6969

70+
converter_test(
71+
name = "test_loop"
72+
)
73+
7074
test_suite(
7175
name = "test_converters",
7276
tests = [
@@ -83,6 +87,7 @@ test_suite(
8387
":test_unary",
8488
":test_interpolate",
8589
":test_select",
86-
":test_stack"
90+
":test_stack",
91+
":test_loop"
8792
]
8893
)

tests/core/converters/test_loop.cpp

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#include "gtest/gtest.h"
2+
#include "torch/csrc/jit/ir/irparser.h"
3+
#include "tests/util/util.h"
4+
#include "core/compiler.h"
5+
6+
TEST(Converters, ATenLoopConvertsCorrectly) {
7+
const auto graph = R"IR(
8+
graph(%0 : Tensor, %1 : Tensor, %2 : Tensor, %3 : Tensor, %4 : Tensor, %5 : Tensor, %8 : Tensor):
9+
%22 : int = prim::Constant[value=1]()
10+
%10 : bool = prim::Constant[value=1]()
11+
%6 : int = prim::Constant[value=0]()
12+
%98 : Tensor = aten::tanh(%1)
13+
%7 : int = aten::size(%0, %6)
14+
%99 : Tensor, %95 : Tensor = prim::Loop(%7, %10, %98, %1)
15+
block0(%90 : int, %96 : Tensor, %93 : Tensor):
16+
%16 : Tensor = aten::select(%0, %6, %90)
17+
%18 : Tensor = aten::matmul(%16, %2)
18+
%21 : Tensor = aten::matmul(%93, %3)
19+
%23 : Tensor = aten::add(%18, %21, %22)
20+
%26 : Tensor = aten::add(%23, %4, %22)
21+
%94 : Tensor = aten::tanh(%26)
22+
%31 : Tensor = aten::matmul(%94, %5)
23+
%34 : Tensor = aten::add(%31, %8, %22)
24+
%97 : Tensor = aten::tanh(%34)
25+
-> (%10, %97, %94)
26+
return (%99))IR";
27+
28+
auto g = std::make_shared<torch::jit::Graph>();
29+
30+
torch::jit::parseIR(graph, &*g);
31+
32+
auto x = at::randn({5, 5, 3}, {at::kCUDA});
33+
auto h = at::randn({5, 5}, {at::kCUDA});
34+
auto Wh = at::randn({3, 5}, {at::kCUDA});
35+
auto Uh = at::randn({5, 5}, {at::kCUDA});
36+
auto bh = at::randn({5, 5}, {at::kCUDA});
37+
auto Wy = at::randn({5, 5}, {at::kCUDA});
38+
auto by = at::randn({5, 5}, {at::kCUDA});
39+
40+
auto jit_x = at::clone(x);
41+
auto jit_h = at::clone(h);
42+
auto jit_Wh = at::clone(Wh);
43+
auto jit_Uh = at::clone(Uh);
44+
auto jit_bh = at::clone(bh);
45+
auto jit_Wy = at::clone(Wy);
46+
auto jit_by = at::clone(by);
47+
48+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
49+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_x, jit_h, jit_Wh, jit_Uh, jit_bh, jit_Wy, jit_by});
50+
51+
auto trt_x = at::clone(x);
52+
auto trt_h = at::clone(h);
53+
auto trt_Wh = at::clone(Wh);
54+
auto trt_Uh = at::clone(Uh);
55+
auto trt_bh = at::clone(bh);
56+
auto trt_Wy = at::clone(Wy);
57+
auto trt_by = at::clone(by);
58+
59+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
60+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_x, trt_h, trt_Wh, trt_Uh, trt_bh, trt_Wy, trt_by});
61+
62+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
63+
64+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
65+
}

0 commit comments

Comments
 (0)