Skip to content

[feat]Add converter support for aten::frobenius_norm #1422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 60 additions & 17 deletions core/conversion/converters/impl/normalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,66 @@ void create_plugin(
LOG_DEBUG("Normalize layer output tensor shape: " << layer_output->getDimensions());
}

auto normalize_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
{"aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto in_shape = util::toVec(in->getDimensions());
auto order = args[1].unwrapToScalar().to<int32_t>();
auto axes_values = args[2].unwrapToIntList().vec();
std::vector<int32_t> axes(axes_values.begin(), axes_values.end());
auto keep_dims = (int32_t)args[3].unwrapToBool();
LOG_DEBUG("Order of normalize_plugin: " << order);
LOG_DEBUG("Axis: " << axes);
LOG_DEBUG("keep_dims: " << keep_dims);
create_plugin(ctx, n, in, order, axes, keep_dims, "NormalizePluginTorchTRT");
return true;
}

});
auto normalize_registrations TORCHTRT_UNUSED =
RegisterNodeConversionPatterns()
.pattern(
{"aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensorOrFreeze(ctx);
auto in_shape = util::toVec(in->getDimensions());
auto order = args[1].unwrapToScalar().to<int32_t>();
auto axes_values = args[2].unwrapToIntList().vec();
std::vector<int32_t> axes(axes_values.begin(), axes_values.end());
auto keep_dims = (int32_t)args[3].unwrapToBool();
LOG_DEBUG("Order of normalize_plugin: " << order);
LOG_DEBUG("Axis: " << axes);
LOG_DEBUG("keep_dims: " << keep_dims);
create_plugin(ctx, n, in, order, axes, keep_dims, "NormalizePluginTorchTRT");
return true;
}

})
.pattern(
{"aten::frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
auto axes_values = args[1].unwrapToIntList().vec();
auto keep_dims = args[2].unwrapToBool();

int32_t axes_mask = 0;
auto self_nb_dims = self->getDimensions().nbDims;
for (size_t i = 0UL; i < axes_values.size(); ++i) {
auto axis = axes_values[i];
if (axis < 0) {
axis += self_nb_dims;
}
TORCHTRT_CHECK(
axis < self_nb_dims,
"aten::frobenius_norm axis: " << i << " with value: " << axis << " exceeds input rank");
axes_mask += 1 << axis;
}

auto squared_layer = add_elementwise(
ctx, nvinfer1::ElementWiseOperation::kPROD, self, self, util::node_info(n) + "_squared");
TORCHTRT_CHECK(squared_layer, "Unabled to create square layer from node: " << *n);
auto squared_output = squared_layer->getOutput(0);

auto sum_layer =
ctx->net->addReduce(*squared_output, nvinfer1::ReduceOperation::kSUM, axes_mask, keep_dims);
TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n);
sum_layer->setName((util::node_info(n) + "_sum").c_str());
auto sum_output = sum_layer->getOutput(0);

auto sqrt_layer = ctx->net->addUnary(*sum_output, nvinfer1::UnaryOperation::kSQRT);
TORCHTRT_CHECK(sqrt_layer, "Unable to create sqrt layer from node: " << *n);
sqrt_layer->setName((util::node_info(n) + "_sqrt").c_str());
auto sqrt_output = sqrt_layer->getOutput(0);

auto out = ctx->AssociateValueAndTensor(n->outputs()[0], sqrt_layer->getOutput(0));

LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}});

} // namespace
} // namespace impl
Expand Down
64 changes: 64 additions & 0 deletions tests/core/conversion/converters/test_normalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,67 @@ ATEN_INTERPOLATE_TESTS(
%5 : Tensor = aten::norm(%x.1, %3, %2, %4)
return (%5))IR",
std::vector<int64_t>({3, 4, 3}));

TEST(Converters, ATenFrobeniusNorm) {
const auto graph = R"IR(
graph(%x : Tensor):
%0 : int = prim::Constant[value=0]()
%keep : bool = prim::Constant[value=0]()
%dims : int[] = prim::ListConstruct(%0)
%out : Tensor = aten::frobenius_norm(%x, %dims, %keep)
return (%out))IR";
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto x = at::randn({5, 5, 5}, {at::kCUDA});
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {x});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {x});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0]));
}

TEST(Converters, ATenFrobeniusNormKeep) {
const auto graph = R"IR(
graph(%x : Tensor):
%1 : int = prim::Constant[value=-1]()
%keep : bool = prim::Constant[value=1]()
%dims : int[] = prim::ListConstruct(%1)
%out : Tensor = aten::frobenius_norm(%x, %dims, %keep)
return (%out))IR";
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto x = at::randn({5, 5, 5}, {at::kCUDA});
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {x});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {x});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0]));
}

TEST(Converters, ATenFrobeniusNormMatrix) {
const auto graph = R"IR(
graph(%x : Tensor):
%0 : int = prim::Constant[value=0]()
%1 : int = prim::Constant[value=-1]()
%keep : bool = prim::Constant[value=0]()
%dims : int[] = prim::ListConstruct(%0, %1)
%out : Tensor = aten::frobenius_norm(%x, %dims, %keep)
return (%out))IR";
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto x = at::randn({3, 5, 7, 11, 13}, {at::kCUDA});
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {x});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {x});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0]));
}