Skip to content

Commit 3a44a6f

Browse files
mfeliz-cruisenarendasan
authored andcommitted
feat: Add converter for aten::isfinite (#1841)
1 parent 5e39222 commit 3a44a6f

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

core/conversion/converters/impl/unary.cpp

+20
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,26 @@ auto logical_not_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
7979
return true;
8080
}});
8181

82+
auto isfinite_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
83+
{"aten::isfinite(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
84+
auto in = args[0].ITensorOrFreeze(ctx);
85+
// assuming x-x = 0 for all values other than nan/inf/-inf where x-x = nan
86+
// x==x for all non-nan values
87+
auto inf_test_layer = ctx->net->addElementWise(*in, *in, nvinfer1::ElementWiseOperation::kSUB);
88+
TORCHTRT_CHECK(inf_test_layer, "Unable to create sub layer from node: " << *n);
89+
inf_test_layer->setName((util::node_info(n) + "_inf_test").c_str());
90+
auto inf_test_tensor = inf_test_layer->getOutput(0);
91+
92+
auto nan_test_layer =
93+
ctx->net->addElementWise(*inf_test_tensor, *inf_test_tensor, nvinfer1::ElementWiseOperation::kEQUAL);
94+
TORCHTRT_CHECK(nan_test_layer, "Unable to create eq layer from node: " << *n);
95+
nan_test_layer->setName((util::node_info(n) + "_nan_test").c_str());
96+
97+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], nan_test_layer->getOutput(0));
98+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
99+
return true;
100+
}});
101+
82102
#define convert(unary, trt_type) \
83103
auto unary##_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( \
84104
{"aten::" #unary "(Tensor self) -> Tensor", \

tests/core/conversion/converters/test_unary.cpp

+21
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,27 @@ TEST(Converters, ATenLogicalNotBoolConvertsCorrectly) {
111111
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
112112
}
113113

114+
TEST(Converters, ATenFiniteConvertsCorrectly) {
115+
const auto graph = gen_test_graph("isfinite");
116+
auto g = std::make_shared<torch::jit::Graph>();
117+
torch::jit::parseIR(graph, g.get());
118+
auto in = at::tensor(
119+
{float(0),
120+
std::nanf(""),
121+
float(2),
122+
std::numeric_limits<float>::infinity(),
123+
float(4),
124+
-std::numeric_limits<float>::infinity()},
125+
{at::kCUDA});
126+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
127+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
128+
129+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
130+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
131+
132+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
133+
}
134+
114135
#define test_unary(unary, name) \
115136
TEST(Converters, ATen##name##ConvertsCorrectly) { \
116137
const auto graph = gen_test_graph(#unary); \

0 commit comments

Comments
 (0)