Skip to content

Loop Conversion #164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 156 additions & 1 deletion core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,148 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
}
}

void ConvertLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
auto block = n->blocks()[0];

// max_trip_count and start_cond already evaluated
auto max_trip_count = ctx->evaluated_value_map[n->input(0)];
auto start_cond = ctx->evaluated_value_map[n->input(1)];

ctx->evaluated_value_map[block->inputs()[0]] = torch::jit::IValue(0);
auto trip_count = ctx->evaluated_value_map[block->inputs()[0]];

// map node inputs [recurrent values] -> node outputs [recurrent values]
MapIValues(ctx, n->inputs(), n->outputs(), 2, 0);

LOG_DEBUG(ctx->logger, "(Loop Conversion) Evaluating loop " << *n);
LOG_DEBUG(ctx->logger, "(Loop Conversion) Max Trip Count: " << max_trip_count.toInt());
LOG_DEBUG(ctx->logger, "(Loop Conversion) Start Condition: " << start_cond.toBool());
LOG_DEBUG(ctx->logger, "(Loop Conversion) Current Trip Count: " << trip_count.toInt());

// map node outputs [recurrent values] -> block inputs [recurrent values]
MapIValues(ctx, n->outputs(), block->inputs(), 0, 1);

auto loop = ctx->net->addLoop();

// trip limit layer: max_trip_limit
auto count_weight = converters::Weights(ctx, (int32_t) max_trip_count.toInt());
auto for_const = ctx->net->addConstant(count_weight.shape, count_weight.data);
TRTORCH_CHECK(for_const, "Unable to create constant layer from node: " << *n);

auto count_limit = loop->addTripLimit(*for_const->getOutput(0), nvinfer1::TripLimit::kCOUNT);
TRTORCH_CHECK(count_limit, "Unable to create trip limit layer from node: " << *n);
count_limit->setName((n->input(0)->debugName() + " [Trip Limit Layer]").c_str());

// recurrence layer and trip limit layer: loop condition
auto cond_weight = converters::Weights(ctx, (int32_t) (start_cond.toBool() ? 1 : 0));
auto while_const = ctx->net->addIdentity(*ctx->net->addConstant(cond_weight.shape, cond_weight.data)->getOutput(0));
TRTORCH_CHECK(while_const, "Unable to create identity layer from node: " << *n);
while_const->setOutputType(0, nvinfer1::DataType::kBOOL);

auto recurrent_cond = loop->addRecurrence(*while_const->getOutput(0));
TRTORCH_CHECK(recurrent_cond, "Unable to create recurrence layer from node: " << *n);
recurrent_cond->setName((n->input(1)->debugName() + " [Recurrence Layer]").c_str());

auto cond_limit = loop->addTripLimit(*recurrent_cond->getOutput(0), nvinfer1::TripLimit::kWHILE);
TRTORCH_CHECK(cond_limit, "Unable to create trip limit layer from node: " << *n);
cond_limit->setName((n->input(1)->debugName() + " [Trip Limit Layer]").c_str());

// recurrence layer: trip_count
auto trip_weight = converters::Weights(ctx, (int32_t) trip_count.toInt());
auto trip_const = ctx->net->addConstant(trip_weight.shape, trip_weight.data);
TRTORCH_CHECK(trip_const, "Unable to create constant layer from node: " << *n);

auto recurrent_trip = loop->addRecurrence(*trip_const->getOutput(0));
TRTORCH_CHECK(recurrent_trip, "Unable to create recurrence layer from node: " << *n);
recurrent_trip->setName((block->inputs()[0]->debugName() + " [Recurrence Layer]").c_str());

// add recurrence layers to loop
std::vector<nvinfer1::IRecurrenceLayer*> recurrent_tensors;

// loop through all recurrent inputs
for (unsigned int i = 2; i < n->inputs().size(); i++) {
auto inp = n->inputs()[i];

if (inp->type()->isSubtypeOf(c10::TensorType::get())) {
auto recur = loop->addRecurrence(*ctx->value_tensor_map[inp]);
TRTORCH_CHECK(recur, "Unable to create recurrent layer from node: " << *n);
recur->setName((inp->debugName() + " [Recurrence Layer]").c_str());

recurrent_tensors.push_back(recur);
} else {
TRTORCH_THROW_ERROR("Only recurrent Tensors allowed as input to Loop");
}
}

// evaluate/convert all nodes inside block
for (auto bn : block->nodes()) {
if (bn->kind() == torch::jit::prim::Loop) {
bool returns_tensor = false;

// if even a single output of the loop returns a tensor, use ConvertLoopBlock
for (unsigned int i = 0; i < bn->outputs().size(); i++) {
if (bn->output(i)->type()->isSubtypeOf(c10::TensorType::get())) {
returns_tensor = true;
}
}

if (returns_tensor) {
ConvertLoopBlock(ctx, bn);
} else {
EvaluateLoopBlock(ctx, bn);
}
} else if (bn->kind() == torch::jit::prim::If) {
EvaluateConditionalBlock(ctx, bn, true);
} else if (evaluators::shouldEvalAtConversionTime(bn)) {
auto eval = EvaluateNode(ctx, bn);
ctx->AssociateValueAndIValue(bn->output(0), eval.value());
} else if (!isNodeConversionIgnored(bn)) {
AddLayer(ctx, bn);
}
}

// recurrent backedge input for loop condition and input for condition TripLimit (cond_limit)
auto iter_cond = ctx->evaluated_value_map[block->outputs()[0]];
auto iter_cond_weight = converters::Weights(ctx, (int32_t) (iter_cond.toBool() ? 1 : 0));
auto new_while_const = ctx->net->addIdentity(*ctx->net->addConstant(iter_cond_weight.shape, iter_cond_weight.data)->getOutput(0));
TRTORCH_CHECK(new_while_const, "Unable to create identity layer from node: " << *n);
new_while_const->setOutputType(0, nvinfer1::DataType::kBOOL);

recurrent_cond->setInput(1, *new_while_const->getOutput(0));
cond_limit->setInput(0, *recurrent_cond->getOutput(0));
ctx->AssociateValueAndTensor(block->outputs()[0], recurrent_cond->getOutput(0));

// recurrent backedge input for trip_count
auto one_weight = converters::Weights(ctx, (int32_t) 1);
auto one_const = ctx->net->addConstant(one_weight.shape, one_weight.data);
TRTORCH_CHECK(one_const, "Unable to create constant layer from node: " << *n);
auto add_layer = ctx->net->addElementWise(*recurrent_trip->getOutput(0), *one_const->getOutput(0), nvinfer1::ElementWiseOperation::kSUM);
TRTORCH_CHECK(add_layer, "Unable to create add layer from node: " << *n);

recurrent_trip->setInput(1, *add_layer->getOutput(0));
ctx->AssociateValueAndTensor(block->inputs()[0], recurrent_trip->getOutput(0));

// recurrent backedge input for each tensor in recurrent_tensor
for (unsigned int i = 1; i < block->outputs().size(); i++) {
auto out = block->outputs()[i];

if (out->type()->isSubtypeOf(c10::TensorType::get())) {
recurrent_tensors[i-1]->setInput(1, *ctx->value_tensor_map[out]);
} else {
TRTORCH_THROW_ERROR("Only recurrent Tensors allowed as output to block");
}
}

// map recurrent tensors --> n->outputs()
for (unsigned int i = 0; i < recurrent_tensors.size(); i++) {
auto out = loop->addLoopOutput(*recurrent_tensors[i]->getOutput(0), nvinfer1::LoopOutput::kLAST_VALUE);
TRTORCH_CHECK(out, "Unable to create loop output layer from node: " << *n);
ctx->AssociateValueAndTensor(n->outputs()[i], out->getOutput(0));
}

LOG_DEBUG(ctx->logger, "(Loop Conversion) Finished evaluating loop " << *n);
}

void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params) {
LOG_INFO(ctx->logger, "Converting Block");

Expand All @@ -310,7 +452,20 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, Conver
bool to_eval = evaluators::shouldEvalAtConversionTime(n);
bool ignored = isNodeConversionIgnored(n);
if (n->kind() == torch::jit::prim::Loop) {
EvaluateLoopBlock(ctx, n);
bool returns_tensor = false;

// if even a single output of the loop returns a tensor, use ConvertLoopBlock
for (unsigned int i = 0; i < n->outputs().size(); i++) {
if (n->output(i)->type()->isSubtypeOf(c10::TensorType::get())) {
returns_tensor = true;
}
}

if (returns_tensor) {
ConvertLoopBlock(ctx, n);
} else {
EvaluateLoopBlock(ctx, n);
}
} else if (n->kind() == torch::jit::prim::If) {
EvaluateConditionalBlock(ctx, n);
} else if (to_eval) {
Expand Down
2 changes: 1 addition & 1 deletion core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
passes::Conv2DToConvolution(g);
passes::FuseAddMMBranches(g);
torch::jit::EliminateCommonSubexpression(g);
torch::jit::UnrollLoops(g);
//torch::jit::UnrollLoops(g);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the unroll loops behavior still not work even with the non compliant converters?

torch::jit::EliminateCommonSubexpression(g);
passes::UnpackAddMM(g);
//passes::UnpackBatchNorm(g);
Expand Down
7 changes: 6 additions & 1 deletion tests/core/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ converter_test(
name = "test_stack"
)

converter_test(
name = "test_loop"
)

test_suite(
name = "test_converters",
tests = [
Expand All @@ -83,6 +87,7 @@ test_suite(
":test_unary",
":test_interpolate",
":test_select",
":test_stack"
":test_stack",
":test_loop"
]
)
65 changes: 65 additions & 0 deletions tests/core/converters/test_loop.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#include "gtest/gtest.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "tests/util/util.h"
#include "core/compiler.h"

TEST(Converters, ATenLoopConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor, %2 : Tensor, %3 : Tensor, %4 : Tensor, %5 : Tensor, %8 : Tensor):
%22 : int = prim::Constant[value=1]()
%10 : bool = prim::Constant[value=1]()
%6 : int = prim::Constant[value=0]()
%98 : Tensor = aten::tanh(%1)
%7 : int = aten::size(%0, %6)
%99 : Tensor, %95 : Tensor = prim::Loop(%7, %10, %98, %1)
block0(%90 : int, %96 : Tensor, %93 : Tensor):
%16 : Tensor = aten::select(%0, %6, %90)
%18 : Tensor = aten::matmul(%16, %2)
%21 : Tensor = aten::matmul(%93, %3)
%23 : Tensor = aten::add(%18, %21, %22)
%26 : Tensor = aten::add(%23, %4, %22)
%94 : Tensor = aten::tanh(%26)
%31 : Tensor = aten::matmul(%94, %5)
%34 : Tensor = aten::add(%31, %8, %22)
%97 : Tensor = aten::tanh(%34)
-> (%10, %97, %94)
return (%99))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, &*g);

auto x = at::randn({5, 5, 3}, {at::kCUDA});
auto h = at::randn({5, 5}, {at::kCUDA});
auto Wh = at::randn({3, 5}, {at::kCUDA});
auto Uh = at::randn({5, 5}, {at::kCUDA});
auto bh = at::randn({5, 5}, {at::kCUDA});
auto Wy = at::randn({5, 5}, {at::kCUDA});
auto by = at::randn({5, 5}, {at::kCUDA});

auto jit_x = at::clone(x);
auto jit_h = at::clone(h);
auto jit_Wh = at::clone(Wh);
auto jit_Uh = at::clone(Uh);
auto jit_bh = at::clone(bh);
auto jit_Wy = at::clone(Wy);
auto jit_by = at::clone(by);

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_x, jit_h, jit_Wh, jit_Uh, jit_bh, jit_Wy, jit_by});

auto trt_x = at::clone(x);
auto trt_h = at::clone(h);
auto trt_Wh = at::clone(Wh);
auto trt_Uh = at::clone(Uh);
auto trt_bh = at::clone(bh);
auto trt_Wy = at::clone(Wy);
auto trt_by = at::clone(by);

params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_x, trt_h, trt_Wh, trt_Uh, trt_bh, trt_Wy, trt_by});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}