Skip to content

[feat] add support for aten::reciprocal(int) #1308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion core/conversion/converters/impl/unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@ auto abs_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern
}
}});

auto reciprocal_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
{"aten::reciprocal(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensorOrFreeze(ctx);
if (in->getType() == nvinfer1::DataType::kINT32) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a large change but would it make sense to just add this to the macro for other unary ops? @peri044 thoughts on what the repercussions would be?

Copy link
Contributor Author

@mfeliz-cruise mfeliz-cruise Aug 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case the behavior matches pytorch. For other ops (ex. abs implemented element-wise above) this behavior would be incorrect. I have not checked any of the other ops.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two ops which have some restrictions according to the doc but other unary ops must have floating point inputs. https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_network_definition.html#a77831224c9a72ad02587a56ded93c672

Generally the input must have a floating-point type (or kINT8 as a quantized float), except for the following operations:
kSIGN accepts a floating-point or Int32 tensor.
kNOT requires a Bool tensor.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add the above restrictions in the code in addition to what Michael added, to cover the cases in the doc completely.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was mostly talking from the pytorch perspective since seems like theres at least a few ops where int inputs are valid

// pytorch implicitly casts to float for aten::reciprocal(int)
in = castITensor(ctx, in, nvinfer1::DataType::kFLOAT);
}
auto unary_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kRECIP);
TORCHTRT_CHECK(unary_layer, "Unable to create recip layer from node: " << *n);
unary_layer->setName(util::node_info(n).c_str());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], unary_layer->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
return true;
}});

#define convert(unary, trt_type) \
auto unary##_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( \
{"aten::" #unary "(Tensor self) -> Tensor", \
Expand All @@ -74,7 +89,6 @@ convert(sinh, kSINH);
convert(tan, kTAN);
convert(atan, kATAN);
convert(floor, kFLOOR);
convert(reciprocal, kRECIP);
convert(log, kLOG);
convert(ceil, kCEIL);
convert(sqrt, kSQRT);
Expand Down
16 changes: 16 additions & 0 deletions tests/core/conversion/converters/test_unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,22 @@ TEST(Converters, ATenAbsIntConvertsCorrectly) {
ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_results[0], trt_results[0]));
}

TEST(Converters, ATenReciprocalIntConvertsCorrectly) {
const auto graph = gen_test_graph("reciprocal");
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::tensor({-1, 1, -2, 2, -3, 3}, {at::kCUDA}).to(torch::kInt32);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});

in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_results[0], trt_results[0]));
}

#define test_unary(unary, name) \
TEST(Converters, ATen##name##ConvertsCorrectly) { \
const auto graph = gen_test_graph(#unary); \
Expand Down