diff --git a/core/conversion/converters/impl/normalize.cpp b/core/conversion/converters/impl/normalize.cpp index 35a8d70122..4569c31110 100644 --- a/core/conversion/converters/impl/normalize.cpp +++ b/core/conversion/converters/impl/normalize.cpp @@ -53,6 +53,47 @@ void create_plugin( LOG_DEBUG("Normalize layer output tensor shape: " << layer_output->getDimensions()); } +int32_t axes_mask_from_axes_values( + const torch::jit::Node* n, + int32_t nb_dims, + const std::vector& axes_values) { + int32_t axes_mask = 0; + for (size_t i = 0UL; i < axes_values.size(); ++i) { + auto axis = axes_values[i]; + if (axis < 0) { + axis += nb_dims; + } + TORCHTRT_CHECK( + axis < nb_dims, util::node_info(n) << " axis " << i << " with value: " << axis << " exceeds input rank"); + axes_mask += 1 << axis; + } + return axes_mask; +} + +nvinfer1::ITensor* frobenius_norm( + ConversionCtx* ctx, + const torch::jit::Node* n, + nvinfer1::ITensor* self, + int32_t axes_mask, + bool keep_dims) { + auto squared_layer = + add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, self, util::node_info(n) + "_squared"); + TORCHTRT_CHECK(squared_layer, "Unabled to create square layer from node: " << *n); + auto squared_output = squared_layer->getOutput(0); + + auto sum_layer = ctx->net->addReduce(*squared_output, nvinfer1::ReduceOperation::kSUM, axes_mask, keep_dims); + TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n); + sum_layer->setName((util::node_info(n) + "_sum").c_str()); + auto sum_output = sum_layer->getOutput(0); + LOG_DEBUG("SUM SHAPE: " << sum_output->getDimensions()); + + auto sqrt_layer = ctx->net->addUnary(*sum_output, nvinfer1::UnaryOperation::kSQRT); + TORCHTRT_CHECK(sqrt_layer, "Unable to create sqrt layer from node: " << *n); + sqrt_layer->setName((util::node_info(n) + "_sqrt").c_str()); + auto sqrt_output = sqrt_layer->getOutput(0); + return sqrt_output; +} + auto normalize_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() .pattern( @@ -79,37 +120,48 @@ auto normalize_registrations TORCHTRT_UNUSED = auto axes_values = args[1].unwrapToIntList().vec(); auto keep_dims = args[2].unwrapToBool(); - int32_t axes_mask = 0; - auto self_nb_dims = self->getDimensions().nbDims; - for (size_t i = 0UL; i < axes_values.size(); ++i) { - auto axis = axes_values[i]; - if (axis < 0) { - axis += self_nb_dims; - } - TORCHTRT_CHECK( - axis < self_nb_dims, - "aten::frobenius_norm axis: " << i << " with value: " << axis << " exceeds input rank"); - axes_mask += 1 << axis; - } + auto axes_mask = axes_mask_from_axes_values(n, self->getDimensions().nbDims, axes_values); - auto squared_layer = add_elementwise( - ctx, nvinfer1::ElementWiseOperation::kPROD, self, self, util::node_info(n) + "_squared"); - TORCHTRT_CHECK(squared_layer, "Unabled to create square layer from node: " << *n); - auto squared_output = squared_layer->getOutput(0); - - auto sum_layer = - ctx->net->addReduce(*squared_output, nvinfer1::ReduceOperation::kSUM, axes_mask, keep_dims); - TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n); - sum_layer->setName((util::node_info(n) + "_sum").c_str()); - auto sum_output = sum_layer->getOutput(0); - - auto sqrt_layer = ctx->net->addUnary(*sum_output, nvinfer1::UnaryOperation::kSQRT); - TORCHTRT_CHECK(sqrt_layer, "Unable to create sqrt layer from node: " << *n); - sqrt_layer->setName((util::node_info(n) + "_sqrt").c_str()); - auto sqrt_output = sqrt_layer->getOutput(0); + auto norm = frobenius_norm(ctx, n, self, axes_mask, keep_dims); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], norm); + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + return true; + }}) + .pattern( + {"aten::linalg_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, int? dtype=None) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + // https://pytorch.org/docs/stable/generated/torch.linalg.norm.html + auto self = args[0].ITensorOrFreeze(ctx); + TORCHTRT_CHECK( + args[1].IValue()->isNone(), + "aten::linalg_norm converter does not yet support non-None 'ord' arguments. Add aten::linalg_norm to torch_executed_ops to force it to fallback."); + auto keep_dims = args[3].unwrapToBool(); + auto self_nb_dims = self->getDimensions().nbDims; - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], sqrt_layer->getOutput(0)); + if (!args.back().IValue()->isNone()) { + // If specified, the input tensor is cast to dtype before performing the operation, and the returned + // tensor’s type will be dtype + auto dtype = args.back().unwrapToScalar().to(); + auto trt_dtype = util::ScalarTypeToTRTDataType(static_cast(dtype)); + self = castITensor(ctx, self, trt_dtype); + } + int32_t axes_mask = 0; + if (args[2].IValue()->isNone()) { + // If dim= None and ord= None, self will be flattened to 1D and the 2-norm of the resulting vector will + // be computed. + axes_mask = 1; + keep_dims = true; // the single output dim is always preserved + auto flatten_layer = ctx->net->addShuffle(*self); + TORCHTRT_CHECK(flatten_layer, "Unable to create shuffle layer from node: " << *n); + flatten_layer->setReshapeDimensions(util::toDims(std::vector({-1}))); + flatten_layer->setName((util::node_info(n) + "_flatten").c_str()); + self = flatten_layer->getOutput(0); + } else { + axes_mask = axes_mask_from_axes_values(n, self_nb_dims, args[2].unwrapToIntList().vec()); + } + auto norm = frobenius_norm(ctx, n, self, axes_mask, keep_dims); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], norm); LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; }}); diff --git a/tests/core/conversion/converters/test_normalize.cpp b/tests/core/conversion/converters/test_normalize.cpp index f939cb8b68..b665c0b6fb 100644 --- a/tests/core/conversion/converters/test_normalize.cpp +++ b/tests/core/conversion/converters/test_normalize.cpp @@ -138,4 +138,69 @@ TEST(Converters, ATenFrobeniusNormMatrix) { auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {x}); ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0])); -} \ No newline at end of file +} + +TEST(Converters, ATenLinAlgNorm_None) { + const auto graph = R"IR( + graph(%x : Tensor): + %none : NoneType = prim::Constant() + %keep : bool = prim::Constant[value=0]() + %out : Tensor = aten::linalg_norm(%x, %none, %none, %keep, %none) + return (%out))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + auto x = at::randn({5, 5, 5}, {at::kCUDA}); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {x}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {x}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0])); +} + +TEST(Converters, ATenLinAlgNorm_1D) { + const auto graph = R"IR( + graph(%x : Tensor): + %1 : int = prim::Constant[value=1]() + %none : NoneType = prim::Constant() + %keep : bool = prim::Constant[value=0]() + %dims : int[] = prim::ListConstruct(%1) + %out : Tensor = aten::linalg_norm(%x, %none, %dims, %keep, %none) + return (%out))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto x = at::randn({5, 5, 5}, {at::kCUDA}); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {x}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {x}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0])); +} + +TEST(Converters, ATenLinAlgNorm_2D) { + const auto graph = R"IR( + graph(%x : Tensor): + %0 : int = prim::Constant[value=0]() + %1 : int = prim::Constant[value=-1]() + %none : NoneType = prim::Constant() + %keep : bool = prim::Constant[value=1]() + %dims : int[] = prim::ListConstruct(%0, %1) + %float : int = prim::Constant[value=6]() + %out : Tensor = aten::linalg_norm(%x, %none, %dims, %keep, %float) + return (%out))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto x = at::randn({5, 5, 5}, {at::kCUDA}).to(at::kHalf); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {x}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {x}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0])); +}