@@ -79,6 +79,26 @@ auto logical_not_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
79
79
return true ;
80
80
}});
81
81
82
+ auto isfinite_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
83
+ {" aten::isfinite(Tensor self) -> Tensor" , [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
84
+ auto in = args[0 ].ITensorOrFreeze (ctx);
85
+ // assuming x-x = 0 for all values other than nan/inf/-inf where x-x = nan
86
+ // x==x for all non-nan values
87
+ auto inf_test_layer = ctx->net ->addElementWise (*in, *in, nvinfer1::ElementWiseOperation::kSUB );
88
+ TORCHTRT_CHECK (inf_test_layer, " Unable to create sub layer from node: " << *n);
89
+ inf_test_layer->setName ((util::node_info (n) + " _inf_test" ).c_str ());
90
+ auto inf_test_tensor = inf_test_layer->getOutput (0 );
91
+
92
+ auto nan_test_layer =
93
+ ctx->net ->addElementWise (*inf_test_tensor, *inf_test_tensor, nvinfer1::ElementWiseOperation::kEQUAL );
94
+ TORCHTRT_CHECK (nan_test_layer, " Unable to create eq layer from node: " << *n);
95
+ nan_test_layer->setName ((util::node_info (n) + " _nan_test" ).c_str ());
96
+
97
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], nan_test_layer->getOutput (0 ));
98
+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
99
+ return true ;
100
+ }});
101
+
82
102
#define convert (unary, trt_type ) \
83
103
auto unary##_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( \
84
104
{" aten::" #unary " (Tensor self) -> Tensor" , \
0 commit comments