From 37cbbf900ea0165600cbe83983505c7a24d33413 Mon Sep 17 00:00:00 2001 From: Abhiram Iyer Date: Fri, 31 Jul 2020 20:32:01 -0700 Subject: [PATCH] feat(): Loop conversion Signed-off-by: Abhiram Iyer Signed-off-by: Abhiram Iyer --- core/conversion/conversion.cpp | 157 +++++++++++++++++++++++++++- core/lowering/lowering.cpp | 2 +- tests/core/converters/BUILD | 7 +- tests/core/converters/test_loop.cpp | 65 ++++++++++++ 4 files changed, 228 insertions(+), 3 deletions(-) create mode 100755 tests/core/converters/test_loop.cpp diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index 3ee5640c64..441a30d04f 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -297,6 +297,148 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) { } } +void ConvertLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) { + auto block = n->blocks()[0]; + + // max_trip_count and start_cond already evaluated + auto max_trip_count = ctx->evaluated_value_map[n->input(0)]; + auto start_cond = ctx->evaluated_value_map[n->input(1)]; + + ctx->evaluated_value_map[block->inputs()[0]] = torch::jit::IValue(0); + auto trip_count = ctx->evaluated_value_map[block->inputs()[0]]; + + // map node inputs [recurrent values] -> node outputs [recurrent values] + MapIValues(ctx, n->inputs(), n->outputs(), 2, 0); + + LOG_DEBUG(ctx->logger, "(Loop Conversion) Evaluating loop " << *n); + LOG_DEBUG(ctx->logger, "(Loop Conversion) Max Trip Count: " << max_trip_count.toInt()); + LOG_DEBUG(ctx->logger, "(Loop Conversion) Start Condition: " << start_cond.toBool()); + LOG_DEBUG(ctx->logger, "(Loop Conversion) Current Trip Count: " << trip_count.toInt()); + + // map node outputs [recurrent values] -> block inputs [recurrent values] + MapIValues(ctx, n->outputs(), block->inputs(), 0, 1); + + auto loop = ctx->net->addLoop(); + + // trip limit layer: max_trip_limit + auto count_weight = converters::Weights(ctx, (int32_t) max_trip_count.toInt()); + auto for_const = ctx->net->addConstant(count_weight.shape, count_weight.data); + TRTORCH_CHECK(for_const, "Unable to create constant layer from node: " << *n); + + auto count_limit = loop->addTripLimit(*for_const->getOutput(0), nvinfer1::TripLimit::kCOUNT); + TRTORCH_CHECK(count_limit, "Unable to create trip limit layer from node: " << *n); + count_limit->setName((n->input(0)->debugName() + " [Trip Limit Layer]").c_str()); + + // recurrence layer and trip limit layer: loop condition + auto cond_weight = converters::Weights(ctx, (int32_t) (start_cond.toBool() ? 1 : 0)); + auto while_const = ctx->net->addIdentity(*ctx->net->addConstant(cond_weight.shape, cond_weight.data)->getOutput(0)); + TRTORCH_CHECK(while_const, "Unable to create identity layer from node: " << *n); + while_const->setOutputType(0, nvinfer1::DataType::kBOOL); + + auto recurrent_cond = loop->addRecurrence(*while_const->getOutput(0)); + TRTORCH_CHECK(recurrent_cond, "Unable to create recurrence layer from node: " << *n); + recurrent_cond->setName((n->input(1)->debugName() + " [Recurrence Layer]").c_str()); + + auto cond_limit = loop->addTripLimit(*recurrent_cond->getOutput(0), nvinfer1::TripLimit::kWHILE); + TRTORCH_CHECK(cond_limit, "Unable to create trip limit layer from node: " << *n); + cond_limit->setName((n->input(1)->debugName() + " [Trip Limit Layer]").c_str()); + + // recurrence layer: trip_count + auto trip_weight = converters::Weights(ctx, (int32_t) trip_count.toInt()); + auto trip_const = ctx->net->addConstant(trip_weight.shape, trip_weight.data); + TRTORCH_CHECK(trip_const, "Unable to create constant layer from node: " << *n); + + auto recurrent_trip = loop->addRecurrence(*trip_const->getOutput(0)); + TRTORCH_CHECK(recurrent_trip, "Unable to create recurrence layer from node: " << *n); + recurrent_trip->setName((block->inputs()[0]->debugName() + " [Recurrence Layer]").c_str()); + + // add recurrence layers to loop + std::vector recurrent_tensors; + + // loop through all recurrent inputs + for (unsigned int i = 2; i < n->inputs().size(); i++) { + auto inp = n->inputs()[i]; + + if (inp->type()->isSubtypeOf(c10::TensorType::get())) { + auto recur = loop->addRecurrence(*ctx->value_tensor_map[inp]); + TRTORCH_CHECK(recur, "Unable to create recurrent layer from node: " << *n); + recur->setName((inp->debugName() + " [Recurrence Layer]").c_str()); + + recurrent_tensors.push_back(recur); + } else { + TRTORCH_THROW_ERROR("Only recurrent Tensors allowed as input to Loop"); + } + } + + // evaluate/convert all nodes inside block + for (auto bn : block->nodes()) { + if (bn->kind() == torch::jit::prim::Loop) { + bool returns_tensor = false; + + // if even a single output of the loop returns a tensor, use ConvertLoopBlock + for (unsigned int i = 0; i < bn->outputs().size(); i++) { + if (bn->output(i)->type()->isSubtypeOf(c10::TensorType::get())) { + returns_tensor = true; + } + } + + if (returns_tensor) { + ConvertLoopBlock(ctx, bn); + } else { + EvaluateLoopBlock(ctx, bn); + } + } else if (bn->kind() == torch::jit::prim::If) { + EvaluateConditionalBlock(ctx, bn, true); + } else if (evaluators::shouldEvalAtConversionTime(bn)) { + auto eval = EvaluateNode(ctx, bn); + ctx->AssociateValueAndIValue(bn->output(0), eval.value()); + } else if (!isNodeConversionIgnored(bn)) { + AddLayer(ctx, bn); + } + } + + // recurrent backedge input for loop condition and input for condition TripLimit (cond_limit) + auto iter_cond = ctx->evaluated_value_map[block->outputs()[0]]; + auto iter_cond_weight = converters::Weights(ctx, (int32_t) (iter_cond.toBool() ? 1 : 0)); + auto new_while_const = ctx->net->addIdentity(*ctx->net->addConstant(iter_cond_weight.shape, iter_cond_weight.data)->getOutput(0)); + TRTORCH_CHECK(new_while_const, "Unable to create identity layer from node: " << *n); + new_while_const->setOutputType(0, nvinfer1::DataType::kBOOL); + + recurrent_cond->setInput(1, *new_while_const->getOutput(0)); + cond_limit->setInput(0, *recurrent_cond->getOutput(0)); + ctx->AssociateValueAndTensor(block->outputs()[0], recurrent_cond->getOutput(0)); + + // recurrent backedge input for trip_count + auto one_weight = converters::Weights(ctx, (int32_t) 1); + auto one_const = ctx->net->addConstant(one_weight.shape, one_weight.data); + TRTORCH_CHECK(one_const, "Unable to create constant layer from node: " << *n); + auto add_layer = ctx->net->addElementWise(*recurrent_trip->getOutput(0), *one_const->getOutput(0), nvinfer1::ElementWiseOperation::kSUM); + TRTORCH_CHECK(add_layer, "Unable to create add layer from node: " << *n); + + recurrent_trip->setInput(1, *add_layer->getOutput(0)); + ctx->AssociateValueAndTensor(block->inputs()[0], recurrent_trip->getOutput(0)); + + // recurrent backedge input for each tensor in recurrent_tensor + for (unsigned int i = 1; i < block->outputs().size(); i++) { + auto out = block->outputs()[i]; + + if (out->type()->isSubtypeOf(c10::TensorType::get())) { + recurrent_tensors[i-1]->setInput(1, *ctx->value_tensor_map[out]); + } else { + TRTORCH_THROW_ERROR("Only recurrent Tensors allowed as output to block"); + } + } + + // map recurrent tensors --> n->outputs() + for (unsigned int i = 0; i < recurrent_tensors.size(); i++) { + auto out = loop->addLoopOutput(*recurrent_tensors[i]->getOutput(0), nvinfer1::LoopOutput::kLAST_VALUE); + TRTORCH_CHECK(out, "Unable to create loop output layer from node: " << *n); + ctx->AssociateValueAndTensor(n->outputs()[i], out->getOutput(0)); + } + + LOG_DEBUG(ctx->logger, "(Loop Conversion) Finished evaluating loop " << *n); +} + void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params) { LOG_INFO(ctx->logger, "Converting Block"); @@ -310,7 +452,20 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, Conver bool to_eval = evaluators::shouldEvalAtConversionTime(n); bool ignored = isNodeConversionIgnored(n); if (n->kind() == torch::jit::prim::Loop) { - EvaluateLoopBlock(ctx, n); + bool returns_tensor = false; + + // if even a single output of the loop returns a tensor, use ConvertLoopBlock + for (unsigned int i = 0; i < n->outputs().size(); i++) { + if (n->output(i)->type()->isSubtypeOf(c10::TensorType::get())) { + returns_tensor = true; + } + } + + if (returns_tensor) { + ConvertLoopBlock(ctx, n); + } else { + EvaluateLoopBlock(ctx, n); + } } else if (n->kind() == torch::jit::prim::If) { EvaluateConditionalBlock(ctx, n); } else if (to_eval) { diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index eea21a265b..4b47ef398e 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -34,7 +34,7 @@ void LowerGraph(std::shared_ptr& g) { passes::Conv2DToConvolution(g); passes::FuseAddMMBranches(g); torch::jit::EliminateCommonSubexpression(g); - torch::jit::UnrollLoops(g); + //torch::jit::UnrollLoops(g); torch::jit::EliminateCommonSubexpression(g); passes::UnpackAddMM(g); //passes::UnpackBatchNorm(g); diff --git a/tests/core/converters/BUILD b/tests/core/converters/BUILD index 3e0ed92dc9..129bed3880 100644 --- a/tests/core/converters/BUILD +++ b/tests/core/converters/BUILD @@ -67,6 +67,10 @@ converter_test( name = "test_stack" ) +converter_test( + name = "test_loop" +) + test_suite( name = "test_converters", tests = [ @@ -83,6 +87,7 @@ test_suite( ":test_unary", ":test_interpolate", ":test_select", - ":test_stack" + ":test_stack", + ":test_loop" ] ) diff --git a/tests/core/converters/test_loop.cpp b/tests/core/converters/test_loop.cpp new file mode 100755 index 0000000000..13f3587fe1 --- /dev/null +++ b/tests/core/converters/test_loop.cpp @@ -0,0 +1,65 @@ +#include "gtest/gtest.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "tests/util/util.h" +#include "core/compiler.h" + +TEST(Converters, ATenLoopConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor, %2 : Tensor, %3 : Tensor, %4 : Tensor, %5 : Tensor, %8 : Tensor): + %22 : int = prim::Constant[value=1]() + %10 : bool = prim::Constant[value=1]() + %6 : int = prim::Constant[value=0]() + %98 : Tensor = aten::tanh(%1) + %7 : int = aten::size(%0, %6) + %99 : Tensor, %95 : Tensor = prim::Loop(%7, %10, %98, %1) + block0(%90 : int, %96 : Tensor, %93 : Tensor): + %16 : Tensor = aten::select(%0, %6, %90) + %18 : Tensor = aten::matmul(%16, %2) + %21 : Tensor = aten::matmul(%93, %3) + %23 : Tensor = aten::add(%18, %21, %22) + %26 : Tensor = aten::add(%23, %4, %22) + %94 : Tensor = aten::tanh(%26) + %31 : Tensor = aten::matmul(%94, %5) + %34 : Tensor = aten::add(%31, %8, %22) + %97 : Tensor = aten::tanh(%34) + -> (%10, %97, %94) + return (%99))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, &*g); + + auto x = at::randn({5, 5, 3}, {at::kCUDA}); + auto h = at::randn({5, 5}, {at::kCUDA}); + auto Wh = at::randn({3, 5}, {at::kCUDA}); + auto Uh = at::randn({5, 5}, {at::kCUDA}); + auto bh = at::randn({5, 5}, {at::kCUDA}); + auto Wy = at::randn({5, 5}, {at::kCUDA}); + auto by = at::randn({5, 5}, {at::kCUDA}); + + auto jit_x = at::clone(x); + auto jit_h = at::clone(h); + auto jit_Wh = at::clone(Wh); + auto jit_Uh = at::clone(Uh); + auto jit_bh = at::clone(bh); + auto jit_Wy = at::clone(Wy); + auto jit_by = at::clone(by); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_x, jit_h, jit_Wh, jit_Uh, jit_bh, jit_Wy, jit_by}); + + auto trt_x = at::clone(x); + auto trt_h = at::clone(h); + auto trt_Wh = at::clone(Wh); + auto trt_Uh = at::clone(Uh); + auto trt_bh = at::clone(bh); + auto trt_Wy = at::clone(Wy); + auto trt_by = at::clone(by); + + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_x, trt_h, trt_Wh, trt_Uh, trt_bh, trt_Wy, trt_by}); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} \ No newline at end of file