From 87259f08593484a928370ec2c1bd21a9fb6ab1c5 Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Wed, 26 Oct 2022 16:38:48 -0700 Subject: [PATCH] Add converter support for aten::frobenius_norm --- core/conversion/converters/impl/normalize.cpp | 77 +++++++++++++++---- .../conversion/converters/test_normalize.cpp | 64 +++++++++++++++ 2 files changed, 124 insertions(+), 17 deletions(-) diff --git a/core/conversion/converters/impl/normalize.cpp b/core/conversion/converters/impl/normalize.cpp index e96c6361a5..35a8d70122 100644 --- a/core/conversion/converters/impl/normalize.cpp +++ b/core/conversion/converters/impl/normalize.cpp @@ -53,23 +53,66 @@ void create_plugin( LOG_DEBUG("Normalize layer output tensor shape: " << layer_output->getDimensions()); } -auto normalize_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( - {"aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); - auto in_shape = util::toVec(in->getDimensions()); - auto order = args[1].unwrapToScalar().to(); - auto axes_values = args[2].unwrapToIntList().vec(); - std::vector axes(axes_values.begin(), axes_values.end()); - auto keep_dims = (int32_t)args[3].unwrapToBool(); - LOG_DEBUG("Order of normalize_plugin: " << order); - LOG_DEBUG("Axis: " << axes); - LOG_DEBUG("keep_dims: " << keep_dims); - create_plugin(ctx, n, in, order, axes, keep_dims, "NormalizePluginTorchTRT"); - return true; - } - - }); +auto normalize_registrations TORCHTRT_UNUSED = + RegisterNodeConversionPatterns() + .pattern( + {"aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensorOrFreeze(ctx); + auto in_shape = util::toVec(in->getDimensions()); + auto order = args[1].unwrapToScalar().to(); + auto axes_values = args[2].unwrapToIntList().vec(); + std::vector axes(axes_values.begin(), axes_values.end()); + auto keep_dims = (int32_t)args[3].unwrapToBool(); + LOG_DEBUG("Order of normalize_plugin: " << order); + LOG_DEBUG("Axis: " << axes); + LOG_DEBUG("keep_dims: " << keep_dims); + create_plugin(ctx, n, in, order, axes, keep_dims, "NormalizePluginTorchTRT"); + return true; + } + + }) + .pattern( + {"aten::frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto self = args[0].ITensorOrFreeze(ctx); + auto axes_values = args[1].unwrapToIntList().vec(); + auto keep_dims = args[2].unwrapToBool(); + + int32_t axes_mask = 0; + auto self_nb_dims = self->getDimensions().nbDims; + for (size_t i = 0UL; i < axes_values.size(); ++i) { + auto axis = axes_values[i]; + if (axis < 0) { + axis += self_nb_dims; + } + TORCHTRT_CHECK( + axis < self_nb_dims, + "aten::frobenius_norm axis: " << i << " with value: " << axis << " exceeds input rank"); + axes_mask += 1 << axis; + } + + auto squared_layer = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kPROD, self, self, util::node_info(n) + "_squared"); + TORCHTRT_CHECK(squared_layer, "Unabled to create square layer from node: " << *n); + auto squared_output = squared_layer->getOutput(0); + + auto sum_layer = + ctx->net->addReduce(*squared_output, nvinfer1::ReduceOperation::kSUM, axes_mask, keep_dims); + TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n); + sum_layer->setName((util::node_info(n) + "_sum").c_str()); + auto sum_output = sum_layer->getOutput(0); + + auto sqrt_layer = ctx->net->addUnary(*sum_output, nvinfer1::UnaryOperation::kSQRT); + TORCHTRT_CHECK(sqrt_layer, "Unable to create sqrt layer from node: " << *n); + sqrt_layer->setName((util::node_info(n) + "_sqrt").c_str()); + auto sqrt_output = sqrt_layer->getOutput(0); + + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], sqrt_layer->getOutput(0)); + + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + return true; + }}); } // namespace } // namespace impl diff --git a/tests/core/conversion/converters/test_normalize.cpp b/tests/core/conversion/converters/test_normalize.cpp index 473749b3e7..f939cb8b68 100644 --- a/tests/core/conversion/converters/test_normalize.cpp +++ b/tests/core/conversion/converters/test_normalize.cpp @@ -75,3 +75,67 @@ ATEN_INTERPOLATE_TESTS( %5 : Tensor = aten::norm(%x.1, %3, %2, %4) return (%5))IR", std::vector({3, 4, 3})); + +TEST(Converters, ATenFrobeniusNorm) { + const auto graph = R"IR( + graph(%x : Tensor): + %0 : int = prim::Constant[value=0]() + %keep : bool = prim::Constant[value=0]() + %dims : int[] = prim::ListConstruct(%0) + %out : Tensor = aten::frobenius_norm(%x, %dims, %keep) + return (%out))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto x = at::randn({5, 5, 5}, {at::kCUDA}); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {x}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {x}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0])); +} + +TEST(Converters, ATenFrobeniusNormKeep) { + const auto graph = R"IR( + graph(%x : Tensor): + %1 : int = prim::Constant[value=-1]() + %keep : bool = prim::Constant[value=1]() + %dims : int[] = prim::ListConstruct(%1) + %out : Tensor = aten::frobenius_norm(%x, %dims, %keep) + return (%out))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto x = at::randn({5, 5, 5}, {at::kCUDA}); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {x}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {x}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0])); +} + +TEST(Converters, ATenFrobeniusNormMatrix) { + const auto graph = R"IR( + graph(%x : Tensor): + %0 : int = prim::Constant[value=0]() + %1 : int = prim::Constant[value=-1]() + %keep : bool = prim::Constant[value=0]() + %dims : int[] = prim::ListConstruct(%0, %1) + %out : Tensor = aten::frobenius_norm(%x, %dims, %keep) + return (%out))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto x = at::randn({3, 5, 7, 11, 13}, {at::kCUDA}); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {x}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {x}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0])); +} \ No newline at end of file